mirror of
https://github.com/simon987/od-database.git
synced 2025-12-13 14:59:02 +00:00
Crawl tasks are now fetched by the crawlers instead of pushed by the server
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from logging import FileHandler
|
||||
import sys
|
||||
from logging import FileHandler, StreamHandler
|
||||
|
||||
logger = logging.getLogger("default")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@@ -8,3 +9,4 @@ formatter = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s')
|
||||
file_handler = FileHandler("crawl_server.log")
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(StreamHandler(sys.stdout))
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from crawl_server.database import Task
|
||||
from crawl_server.reddit_bot import RedditBot
|
||||
import praw
|
||||
|
||||
|
||||
class PostCrawlCallback:
|
||||
|
||||
def __init__(self, task: Task):
|
||||
self.task = task
|
||||
|
||||
def run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PostCrawlCallbackFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_callback(task: Task):
|
||||
|
||||
if task.callback_type == "reddit_post":
|
||||
return RedditPostCallback(task)
|
||||
|
||||
elif task.callback_type == "reddit_comment":
|
||||
return RedditCommentCallback(task)
|
||||
|
||||
elif task.callback_type == "discord":
|
||||
return DiscordCallback(task)
|
||||
|
||||
|
||||
class RedditCallback(PostCrawlCallback):
|
||||
|
||||
def __init__(self, task: Task):
|
||||
super().__init__(task)
|
||||
|
||||
reddit = praw.Reddit('opendirectories-bot',
|
||||
user_agent='github.com/simon987/od-database (by /u/Hexahedr_n)')
|
||||
self.reddit_bot = RedditBot("crawled.txt", reddit)
|
||||
|
||||
def run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedditPostCallback(RedditCallback):
|
||||
|
||||
def run(self):
|
||||
print("Reddit post callback for task " + str(self.task))
|
||||
pass
|
||||
|
||||
|
||||
class RedditCommentCallback(RedditCallback):
|
||||
|
||||
def run(self):
|
||||
print("Reddit comment callback for task " + str(self.task))
|
||||
pass
|
||||
|
||||
|
||||
class DiscordCallback(PostCrawlCallback):
|
||||
|
||||
def run(self):
|
||||
print("Discord callback for task " + str(self.task))
|
||||
pass
|
||||
@@ -1,145 +0,0 @@
|
||||
from crawl_server import logger
|
||||
import os
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
|
||||
class TaskResult:
|
||||
|
||||
def __init__(self, status_code=None, file_count=0, start_time=0,
|
||||
end_time=0, website_id=0, indexed_time=0, server_name=""):
|
||||
self.status_code = status_code
|
||||
self.file_count = file_count
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.website_id = website_id
|
||||
self.indexed_time = indexed_time
|
||||
self.server_name = server_name
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"status_code": self.status_code,
|
||||
"file_count": self.file_count,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"website_id": self.website_id,
|
||||
"indexed_time": self.indexed_time
|
||||
}
|
||||
|
||||
|
||||
class Task:
|
||||
|
||||
def __init__(self, website_id: int, url: str, priority: int = 1,
|
||||
callback_type: str = None, callback_args: str = None):
|
||||
self.website_id = website_id
|
||||
self.url = url
|
||||
self.priority = priority
|
||||
self.callback_type = callback_type
|
||||
self.callback_args = json.loads(callback_args) if callback_args else {}
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"website_id": self.website_id,
|
||||
"url": self.url,
|
||||
"priority": self.priority,
|
||||
"callback_type": self.callback_type,
|
||||
"callback_args": json.dumps(self.callback_args)
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_json())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class TaskManagerDatabase:
|
||||
|
||||
def __init__(self, db_path):
|
||||
self.db_path = db_path
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
self.init_database()
|
||||
logger.info("Initialised database")
|
||||
|
||||
def init_database(self):
|
||||
|
||||
with open("task_db_init.sql", "r") as f:
|
||||
init_script = f.read()
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.executescript(init_script)
|
||||
conn.commit()
|
||||
|
||||
def pop_task(self):
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT id, website_id, url, priority, callback_type, callback_args"
|
||||
" FROM Queue ORDER BY priority DESC, Queue.id ASC LIMIT 1")
|
||||
task = cursor.fetchone()
|
||||
|
||||
if task:
|
||||
cursor.execute("DELETE FROM Queue WHERE id=?", (task[0],))
|
||||
conn.commit()
|
||||
return Task(task[1], task[2], task[3], task[4], task[5])
|
||||
else:
|
||||
return None
|
||||
|
||||
def pop_all_tasks(self):
|
||||
|
||||
tasks = self.get_tasks()
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM Queue")
|
||||
return tasks
|
||||
|
||||
def put_task(self, task: Task):
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("INSERT INTO Queue (website_id, url, priority, callback_type, callback_args) "
|
||||
"VALUES (?,?,?,?,?)",
|
||||
(task.website_id, task.url, task.priority,
|
||||
task.callback_type, json.dumps(task.callback_args)))
|
||||
conn.commit()
|
||||
|
||||
def get_tasks(self):
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT website_id, url, priority, callback_type, callback_args FROM Queue")
|
||||
tasks = cursor.fetchall()
|
||||
|
||||
return [Task(t[0], t[1], t[2], t[3], t[4]) for t in tasks]
|
||||
|
||||
def log_result(self, result: TaskResult):
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("INSERT INTO TaskResult (website_id, status_code, file_count, start_time, end_time) "
|
||||
"VALUES (?,?,?,?,?)", (result.website_id, result.status_code, result.file_count,
|
||||
result.start_time, result.end_time))
|
||||
conn.commit()
|
||||
|
||||
def get_non_indexed_results(self):
|
||||
"""Get a list of new TaskResults since the last call of this method"""
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT status_code, file_count, start_time, end_time, website_id"
|
||||
" FROM TaskResult WHERE indexed_time IS NULL")
|
||||
db_result = cursor.fetchall()
|
||||
|
||||
cursor.execute("UPDATE TaskResult SET indexed_time=CURRENT_TIMESTAMP WHERE indexed_time IS NULL")
|
||||
conn.commit()
|
||||
|
||||
return [TaskResult(r[0], r[1], r[2], r[3], r[4]) for r in db_result]
|
||||
|
||||
8
crawl_server/run.py
Normal file
8
crawl_server/run.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from crawl_server.task_manager import TaskManager
|
||||
import time
|
||||
import config
|
||||
|
||||
tm = TaskManager(config.CRAWL_SERVER_PROCESSES)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
@@ -1,104 +0,0 @@
|
||||
from flask import Flask, request, abort, Response, send_file
|
||||
from flask_httpauth import HTTPTokenAuth
|
||||
import json
|
||||
from crawl_server import logger
|
||||
from crawl_server.task_manager import TaskManager, Task
|
||||
import os
|
||||
import config
|
||||
app = Flask(__name__)
|
||||
auth = HTTPTokenAuth(scheme="Token")
|
||||
|
||||
token = config.CRAWL_SERVER_TOKEN
|
||||
|
||||
tm = TaskManager("tm_db.sqlite3", config.CRAWL_SERVER_PROCESSES)
|
||||
|
||||
|
||||
@auth.verify_token
|
||||
def verify_token(provided_token):
|
||||
return token == provided_token
|
||||
|
||||
|
||||
@app.route("/task/")
|
||||
@auth.login_required
|
||||
def get_tasks():
|
||||
json_str = json.dumps([task.to_json() for task in tm.get_tasks()])
|
||||
return Response(json_str, mimetype="application/json")
|
||||
|
||||
|
||||
@app.route("/task/put", methods=["POST"])
|
||||
@auth.login_required
|
||||
def task_put():
|
||||
|
||||
if request.json:
|
||||
try:
|
||||
website_id = request.json["website_id"]
|
||||
url = request.json["url"]
|
||||
priority = request.json["priority"]
|
||||
callback_type = request.json["callback_type"]
|
||||
callback_args = request.json["callback_args"]
|
||||
except KeyError as e:
|
||||
logger.error("Invalid task put request from " + request.remote_addr + " missing key: " + str(e))
|
||||
return abort(400)
|
||||
|
||||
task = Task(website_id, url, priority, callback_type, callback_args)
|
||||
tm.put_task(task)
|
||||
logger.info("Submitted new task to queue: " + str(task.to_json()))
|
||||
return '{"ok": "true"}'
|
||||
|
||||
return abort(400)
|
||||
|
||||
|
||||
@app.route("/task/completed", methods=["GET"])
|
||||
@auth.login_required
|
||||
def get_completed_tasks():
|
||||
json_str = json.dumps([result.to_json() for result in tm.get_non_indexed_results()])
|
||||
logger.debug("Webserver has requested list of newly completed tasks from " + request.remote_addr)
|
||||
return Response(json_str, mimetype="application/json")
|
||||
|
||||
|
||||
@app.route("/task/current", methods=["GET"])
|
||||
@auth.login_required
|
||||
def get_current_tasks():
|
||||
|
||||
current_tasks = tm.get_current_tasks()
|
||||
logger.debug("Webserver has requested list of current tasks from " + request.remote_addr)
|
||||
return json.dumps([t.to_json() for t in current_tasks])
|
||||
|
||||
|
||||
@app.route("/file_list/<int:website_id>/")
|
||||
@auth.login_required
|
||||
def get_file_list(website_id):
|
||||
|
||||
file_name = "./crawled/" + str(website_id) + ".json"
|
||||
if os.path.exists(file_name):
|
||||
logger.info("Webserver requested file list of website with id" + str(website_id))
|
||||
return send_file(file_name)
|
||||
else:
|
||||
logger.error("Webserver requested file list of non-existent or empty website with id: " + str(website_id))
|
||||
return abort(404)
|
||||
|
||||
|
||||
@app.route("/file_list/<int:website_id>/free")
|
||||
@auth.login_required
|
||||
def free_file_list(website_id):
|
||||
file_name = "./crawled/" + str(website_id) + ".json"
|
||||
if os.path.exists(file_name):
|
||||
os.remove(file_name)
|
||||
logger.debug("Webserver indicated that the files for the website with id " +
|
||||
str(website_id) + " are safe to delete")
|
||||
return '{"ok": "true"}'
|
||||
else:
|
||||
return abort(404)
|
||||
|
||||
|
||||
@app.route("/task/pop_all")
|
||||
@auth.login_required
|
||||
def pop_queued_tasks():
|
||||
|
||||
json_str = json.dumps([task.to_json() for task in tm.pop_tasks()])
|
||||
logger.info("Webserver poped all queued tasks")
|
||||
return Response(json_str, mimetype="application/json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=config.CRAWL_SERVER_PORT, host="0.0.0.0", ssl_context="adhoc")
|
||||
@@ -1,19 +0,0 @@
|
||||
|
||||
CREATE TABLE Queue (
|
||||
id INTEGER PRIMARY KEY,
|
||||
website_id INTEGER,
|
||||
url TEXT,
|
||||
priority INTEGER,
|
||||
callback_type TEXT,
|
||||
callback_args TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE TaskResult (
|
||||
id INTEGER PRIMARY KEY,
|
||||
website_id INT,
|
||||
status_code TEXT,
|
||||
file_count INT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
indexed_time TIMESTAMP DEFAULT NULL
|
||||
);
|
||||
@@ -1,6 +1,8 @@
|
||||
from crawl_server import logger
|
||||
from tasks import TaskResult, Task
|
||||
import config
|
||||
from crawl_server.database import TaskManagerDatabase, Task, TaskResult
|
||||
import requests
|
||||
import json
|
||||
from multiprocessing import Manager, Pool
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from datetime import datetime
|
||||
@@ -9,9 +11,7 @@ from crawl_server.crawler import RemoteDirectoryCrawler
|
||||
|
||||
class TaskManager:
|
||||
|
||||
def __init__(self, db_path, max_processes=2):
|
||||
self.db_path = db_path
|
||||
self.db = TaskManagerDatabase(db_path)
|
||||
def __init__(self, max_processes=2):
|
||||
self.pool = Pool(maxtasksperchild=1, processes=max_processes)
|
||||
self.max_processes = max_processes
|
||||
manager = Manager()
|
||||
@@ -21,41 +21,68 @@ class TaskManager:
|
||||
scheduler.add_job(self.execute_queued_task, "interval", seconds=1)
|
||||
scheduler.start()
|
||||
|
||||
def put_task(self, task: Task):
|
||||
self.db.put_task(task)
|
||||
def fetch_task(self):
|
||||
try:
|
||||
payload = {
|
||||
"token": config.API_TOKEN
|
||||
}
|
||||
r = requests.post(config.SERVER_URL + "/task/get", data=payload)
|
||||
|
||||
def get_tasks(self):
|
||||
return self.db.get_tasks()
|
||||
if r.status_code == 200:
|
||||
text = r.text
|
||||
logger.info("Fetched task from server : " + text)
|
||||
task_json = json.loads(text)
|
||||
return Task(task_json["website_id"], task_json["url"])
|
||||
|
||||
def pop_tasks(self):
|
||||
return self.db.pop_all_tasks()
|
||||
return None
|
||||
|
||||
def get_current_tasks(self):
|
||||
return self.current_tasks
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_non_indexed_results(self):
|
||||
return self.db.get_non_indexed_results()
|
||||
@staticmethod
|
||||
def push_result(task_result: TaskResult):
|
||||
|
||||
try:
|
||||
|
||||
payload = {
|
||||
"token": config.API_TOKEN,
|
||||
"result": json.dumps(task_result.to_json())
|
||||
}
|
||||
|
||||
files = {
|
||||
# "file_list": open("./crawled/" + str(task_result.website_id) + ".json")
|
||||
"file_list": open("./local.json")
|
||||
}
|
||||
|
||||
r = requests.post(config.SERVER_URL + "/task/complete", data=payload, files=files)
|
||||
|
||||
logger.info("RESPONSE: " + r.text)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def execute_queued_task(self):
|
||||
|
||||
if len(self.current_tasks) <= self.max_processes:
|
||||
task = self.db.pop_task()
|
||||
|
||||
task = self.fetch_task()
|
||||
|
||||
if task:
|
||||
logger.info("Submitted " + task.url + " to process pool")
|
||||
self.current_tasks.append(task)
|
||||
|
||||
self.pool.apply_async(
|
||||
TaskManager.run_task,
|
||||
args=(task, self.db_path, self.current_tasks),
|
||||
args=(task, self.current_tasks),
|
||||
callback=TaskManager.task_complete,
|
||||
error_callback=TaskManager.task_error
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def run_task(task, db_path, current_tasks):
|
||||
def run_task(task, current_tasks):
|
||||
|
||||
result = TaskResult()
|
||||
result.start_time = datetime.utcnow()
|
||||
result.start_time = datetime.utcnow().timestamp()
|
||||
result.website_id = task.website_id
|
||||
|
||||
logger.info("Starting task " + task.url)
|
||||
@@ -67,15 +94,10 @@ class TaskManager:
|
||||
result.file_count = crawl_result.file_count
|
||||
result.status_code = crawl_result.status_code
|
||||
|
||||
result.end_time = datetime.utcnow()
|
||||
result.end_time = datetime.utcnow().timestamp()
|
||||
logger.info("End task " + task.url)
|
||||
|
||||
# TODO: Figure out the callbacks
|
||||
# callback = PostCrawlCallbackFactory.get_callback(task)
|
||||
# if callback:
|
||||
# callback.run()
|
||||
|
||||
return result, db_path, current_tasks
|
||||
return result, current_tasks
|
||||
|
||||
@staticmethod
|
||||
def task_error(result):
|
||||
@@ -85,14 +107,13 @@ class TaskManager:
|
||||
@staticmethod
|
||||
def task_complete(result):
|
||||
|
||||
task_result, db_path, current_tasks = result
|
||||
task_result, current_tasks = result
|
||||
|
||||
logger.info("Task completed, logger result to database")
|
||||
logger.info("Task completed, sending result to server")
|
||||
logger.info("Status code: " + task_result.status_code)
|
||||
logger.info("File count: " + str(task_result.file_count))
|
||||
|
||||
db = TaskManagerDatabase(db_path)
|
||||
db.log_result(task_result)
|
||||
TaskManager.push_result(task_result)
|
||||
|
||||
for i, task in enumerate(current_tasks):
|
||||
if task.website_id == task_result.website_id:
|
||||
|
||||
Reference in New Issue
Block a user