mirror of
https://github.com/simon987/od-database.git
synced 2025-12-13 23:09:01 +00:00
Tasks can now be queued from the web interface. Tasks are dispatched to the crawl server(s)
This commit is contained in:
126
task.py
126
task.py
@@ -1,81 +1,91 @@
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
import os
|
||||
from database import Website
|
||||
from multiprocessing import Value, Process
|
||||
from database import Database
|
||||
from crawl_server.database import Task, TaskResult
|
||||
import requests
|
||||
import json
|
||||
from reddit_bot import RedditBot
|
||||
import praw
|
||||
|
||||
|
||||
class TaskManager:
|
||||
class CrawlServer:
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
|
||||
def queue_task(self, task: Task) -> bool:
|
||||
|
||||
print("Sending task to crawl server " + self.url)
|
||||
payload = json.dumps(task.to_json())
|
||||
r = requests.post(self.url + "/task/put", headers=CrawlServer.headers, data=payload)
|
||||
print(r)
|
||||
return r.status_code == 200
|
||||
|
||||
def get_completed_tasks(self) -> list:
|
||||
|
||||
r = requests.get(self.url + "/task/completed")
|
||||
return []
|
||||
|
||||
def get_queued_tasks(self) -> list:
|
||||
|
||||
r = requests.get(self.url + "/task/")
|
||||
return [
|
||||
Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"])
|
||||
for t in json.loads(r.text)
|
||||
]
|
||||
|
||||
def get_current_tasks(self):
|
||||
|
||||
r = requests.get(self.url + "/task/current")
|
||||
return [
|
||||
Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"])
|
||||
for t in json.loads(r.text)
|
||||
]
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
|
||||
def __init__(self):
|
||||
self.busy = Value("i", 0)
|
||||
self.current_website = None
|
||||
self.current_task = None
|
||||
|
||||
reddit = praw.Reddit('opendirectories-bot',
|
||||
user_agent='github.com/simon987/od-database v1.0 (by /u/Hexahedr_n)')
|
||||
self.reddit_bot = RedditBot("crawled.txt", reddit)
|
||||
|
||||
self.db = Database("db.sqlite3")
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(self.check_new_task, "interval", seconds=1)
|
||||
scheduler.add_job(self.check_completed_tasks, "interval", seconds=1)
|
||||
scheduler.start()
|
||||
|
||||
def check_new_task(self):
|
||||
if self.current_task is None:
|
||||
task = self.db.dequeue()
|
||||
# TODO load from config
|
||||
self.crawl_servers = [
|
||||
CrawlServer("http://localhost:5001"),
|
||||
]
|
||||
|
||||
if task:
|
||||
website_id, post_id, comment_id = task
|
||||
website = self.db.get_website_by_id(website_id)
|
||||
self.current_task = Process(target=self.execute_task,
|
||||
args=(website, self.busy, post_id, comment_id))
|
||||
self.current_website = website
|
||||
self.current_task.start()
|
||||
def check_completed_tasks(self):
|
||||
return self._get_available_crawl_server().get_completed_tasks()
|
||||
|
||||
elif self.busy.value == 0:
|
||||
self.current_task.terminate()
|
||||
self.current_task = None
|
||||
self.current_website = None
|
||||
def dispatch_task(self, task: Task):
|
||||
self._get_available_crawl_server().queue_task(task)
|
||||
|
||||
def execute_task(self, website: Website, busy: Value, post_id: str, comment_id: str):
|
||||
busy.value = 1
|
||||
if os.path.exists("data.json"):
|
||||
os.remove("data.json")
|
||||
print("Started crawling task")
|
||||
process = CrawlerProcess(get_project_settings())
|
||||
process.crawl("od_links", base_url=website.url)
|
||||
process.start()
|
||||
print("Done crawling")
|
||||
def _get_available_crawl_server(self) -> CrawlServer:
|
||||
# TODO: Load balancing & health check for crawl servers
|
||||
return self.crawl_servers[0]
|
||||
|
||||
self.db.import_json("data.json", website)
|
||||
os.remove("data.json")
|
||||
print("Imported in SQLite3")
|
||||
def get_queued_tasks(self) -> list:
|
||||
|
||||
# TODO: Extract 'callbacks' for posts and comments in a function
|
||||
if post_id:
|
||||
# Reply to post
|
||||
stats = self.db.get_website_stats(website.id)
|
||||
comment = self.reddit_bot.get_comment({"": stats}, website.id)
|
||||
print(comment)
|
||||
if "total_size" in stats and stats["total_size"] > 10000000:
|
||||
post = self.reddit_bot.reddit.submission(post_id)
|
||||
self.reddit_bot.reply(post, comment)
|
||||
pass
|
||||
else:
|
||||
self.reddit_bot.log_crawl(post_id)
|
||||
queued_tasks = []
|
||||
|
||||
elif comment_id:
|
||||
# Reply to comment
|
||||
stats = self.db.get_website_stats(website.id)
|
||||
comment = self.reddit_bot.get_comment({"There you go!": stats}, website.id)
|
||||
print(comment)
|
||||
reddit_comment = self.reddit_bot.reddit.comment(comment_id)
|
||||
self.reddit_bot.reply(reddit_comment, comment)
|
||||
for server in self.crawl_servers:
|
||||
queued_tasks.extend(server.get_queued_tasks())
|
||||
|
||||
busy.value = 0
|
||||
print("Done crawling task")
|
||||
return queued_tasks
|
||||
|
||||
def get_current_tasks(self) -> list:
|
||||
|
||||
current_tasks = []
|
||||
for server in self.crawl_servers:
|
||||
current_tasks.extend(server.get_current_tasks())
|
||||
|
||||
return current_tasks
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user