Tasks can now be queued from the web interface. Tasks are dispatched to the crawl server(s)

This commit is contained in:
Simon 2018-06-12 13:44:03 -04:00
parent 6d48f1f780
commit d61fd75890
14 changed files with 169 additions and 409 deletions

38
app.py
View File

@ -1,16 +1,13 @@
from flask import Flask, render_template, redirect, request, flash, abort, Response, send_from_directory, session from flask import Flask, render_template, redirect, request, flash, abort, Response, send_from_directory, session
import os import os
import json
import time import time
import ssl import ssl
from database import Database, Website, InvalidQueryException from database import Database, Website, InvalidQueryException
from flask_recaptcha import ReCaptcha from flask_recaptcha import ReCaptcha
import od_util import od_util
import sqlite3
import config import config
from flask_caching import Cache from flask_caching import Cache
from task import TaskManager from task import TaskDispatcher, Task
app = Flask(__name__) app = Flask(__name__)
recaptcha = ReCaptcha(app=app, recaptcha = ReCaptcha(app=app,
@ -23,7 +20,7 @@ app.jinja_env.globals.update(truncate_path=od_util.truncate_path)
app.jinja_env.globals.update(get_color=od_util.get_color) app.jinja_env.globals.update(get_color=od_util.get_color)
app.jinja_env.globals.update(get_mime=od_util.get_mime) app.jinja_env.globals.update(get_mime=od_util.get_mime)
tm = TaskManager() taskDispatcher = TaskDispatcher()
@app.template_filter("datetime_format") @app.template_filter("datetime_format")
@ -68,8 +65,9 @@ def website_json_chart(website_id):
website = db.get_website_by_id(website_id) website = db.get_website_by_id(website_id)
print("FIXME: website_json_chart")
if website: if website:
stats = Response(json.dumps(db.get_website_stats(website_id)), mimetype="application/json") stats = {}
return stats return stats
else: else:
abort(404) abort(404)
@ -81,7 +79,9 @@ def website_links(website_id):
website = db.get_website_by_id(website_id) website = db.get_website_by_id(website_id)
if website: if website:
return Response("\n".join(db.get_website_links(website_id)), mimetype="text/plain") print("FIXME: website_links")
links = []
return Response("\n".join(links), mimetype="text/plain")
else: else:
abort(404) abort(404)
@ -107,7 +107,9 @@ def search():
if q: if q:
try: try:
hits = db.search(q, per_page, page, sort_order) # hits = sea.search(q, per_page, page, sort_order)
print("FIXME: Search")
hits = []
except InvalidQueryException as e: except InvalidQueryException as e:
flash("<strong>Invalid query:</strong> " + str(e), "warning") flash("<strong>Invalid query:</strong> " + str(e), "warning")
return redirect("/search") return redirect("/search")
@ -127,21 +129,16 @@ def contribute():
@app.route("/") @app.route("/")
def home(): def home():
if tm.busy.value == 1: # TODO get stats
current_website = tm.current_website.url stats = {}
else: current_website = "TODO"
current_website = None
try:
stats = db.get_stats()
except sqlite3.OperationalError:
stats = None
return render_template("home.html", stats=stats, current_website=current_website) return render_template("home.html", stats=stats, current_website=current_website)
@app.route("/submit") @app.route("/submit")
def submit(): def submit():
return render_template("submit.html", queue=db.queue(), recaptcha=recaptcha) queued_websites = taskDispatcher.get_queued_tasks()
return render_template("submit.html", queue=queued_websites, recaptcha=recaptcha)
def try_enqueue(url): def try_enqueue(url):
@ -172,7 +169,9 @@ def try_enqueue(url):
"this is an error, please <a href='/contribute'>contact me</a>.", "danger" "this is an error, please <a href='/contribute'>contact me</a>.", "danger"
web_id = db.insert_website(Website(url, str(request.remote_addr), str(request.user_agent))) web_id = db.insert_website(Website(url, str(request.remote_addr), str(request.user_agent)))
db.enqueue(web_id)
task = Task(web_id, url, priority=1)
taskDispatcher.dispatch_task(task)
return "The website has been added to the queue", "success" return "The website has been added to the queue", "success"
@ -219,7 +218,6 @@ def enqueue_bulk():
return redirect("/submit") return redirect("/submit")
@app.route("/admin") @app.route("/admin")
def admin_login_form(): def admin_login_form():
if "username" in session: if "username" in session:

View File

@ -47,8 +47,8 @@ class RemoteDirectory:
class RemoteDirectoryFactory: class RemoteDirectoryFactory:
from crawler.ftp import FtpDirectory from crawl_server.remote_ftp import FtpDirectory
from crawler.http import HttpDirectory from crawl_server.remote_http import HttpDirectory
DIR_ENGINES = (FtpDirectory, HttpDirectory) DIR_ENGINES = (FtpDirectory, HttpDirectory)
@staticmethod @staticmethod

View File

@ -5,12 +5,21 @@ import sqlite3
class TaskResult: class TaskResult:
def __init__(self): def __init__(self, status_code=None, file_count=0, start_time=0, end_time=0, website_id=0):
self.status_code: str = None self.status_code = status_code
self.file_count = 0 self.file_count = file_count
self.start_time = None self.start_time = start_time
self.end_time = None self.end_time = end_time
self.website_id = None self.website_id = website_id
def to_json(self):
return {
"status_code": self.status_code,
"file_count": self.file_count,
"start_time": self.start_time,
"end_time": self.end_time,
"website_id": self.website_id
}
class Task: class Task:
@ -24,13 +33,16 @@ class Task:
self.callback_args = json.loads(callback_args) if callback_args else {} self.callback_args = json.loads(callback_args) if callback_args else {}
def to_json(self): def to_json(self):
return ({ return {
"website_id": self.website_id, "website_id": self.website_id,
"url": self.url, "url": self.url,
"priority": self.priority, "priority": self.priority,
"callback_type": self.callback_type, "callback_type": self.callback_type,
"callback_args": json.dumps(self.callback_args) "callback_args": json.dumps(self.callback_args)
}) }
def __repr__(self):
return json.dumps(self.to_json())
class TaskManagerDatabase: class TaskManagerDatabase:
@ -96,3 +108,17 @@ class TaskManagerDatabase:
"VALUES (?,?,?,?,?)", (result.website_id, result.status_code, result.file_count, "VALUES (?,?,?,?,?)", (result.website_id, result.status_code, result.file_count,
result.start_time, result.end_time)) result.start_time, result.end_time))
conn.commit() conn.commit()
def get_non_indexed_results(self):
"""Get a list of new TaskResults since the last call of this method"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT status_code, file_count, start_time, end_time, website_id"
" FROM TaskResult WHERE indexed_time != NULL")
db_result = cursor.fetchall()
cursor.execute("UPDATE TaskResult SET indexed_time = CURRENT_TIMESTAMP")
return [TaskResult(r[0], r[1], r[2], r[3], r[4]) for r in db_result]

View File

@ -8,7 +8,7 @@ import ftputil.error
from ftputil.session import session_factory from ftputil.session import session_factory
import random import random
import timeout_decorator import timeout_decorator
from crawler.crawler import RemoteDirectory, File, TooManyConnectionsError from crawl_server.crawler import RemoteDirectory, File, TooManyConnectionsError
class FtpDirectory(RemoteDirectory): class FtpDirectory(RemoteDirectory):

View File

@ -3,7 +3,7 @@ from urllib.parse import urljoin, unquote
import os import os
from lxml import etree from lxml import etree
from itertools import repeat from itertools import repeat
from crawler.crawler import RemoteDirectory, File from crawl_server.crawler import RemoteDirectory, File
import requests import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool

View File

@ -1,16 +1,11 @@
from flask import Flask, request, abort, Response from flask import Flask, request, abort, Response
import json import json
from crawl_server.task_manager import TaskManager, Task from crawl_server.task_manager import TaskManager, Task, TaskResult
app = Flask(__name__) app = Flask(__name__)
tm = TaskManager("tm_db.sqlite3") tm = TaskManager("tm_db.sqlite3")
@app.route("/")
def hello():
return "Hello World!"
@app.route("/task/") @app.route("/task/")
def get_tasks(): def get_tasks():
json_str = json.dumps([task.to_json() for task in tm.get_tasks()]) json_str = json.dumps([task.to_json() for task in tm.get_tasks()])
@ -37,5 +32,18 @@ def task_put():
return abort(400) return abort(400)
@app.route("/task/completed", methods=["GET"])
def get_completed_tasks():
json_str = json.dumps([result.to_json() for result in tm.get_non_indexed_results()])
return json_str
@app.route("/task/current", methods=["GET"])
def get_current_tasks():
current_tasks = tm.get_current_tasks()
return current_tasks
if __name__ == "__main__": if __name__ == "__main__":
app.run() app.run(port=5001)

View File

@ -2,7 +2,7 @@ from crawl_server.database import TaskManagerDatabase, Task, TaskResult
from multiprocessing import Pool from multiprocessing import Pool
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime from datetime import datetime
from crawler.crawler import RemoteDirectoryCrawler from crawl_server.crawler import RemoteDirectoryCrawler
class TaskManager: class TaskManager:
@ -12,8 +12,10 @@ class TaskManager:
self.db = TaskManagerDatabase(db_path) self.db = TaskManagerDatabase(db_path)
self.pool = Pool(processes=max_processes) self.pool = Pool(processes=max_processes)
self.current_tasks = []
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
scheduler.add_job(self.execute_queued_task, "interval", seconds=1) scheduler.add_job(self.execute_queued_task, "interval", seconds=5)
scheduler.start() scheduler.start()
def put_task(self, task: Task): def put_task(self, task: Task):
@ -22,11 +24,21 @@ class TaskManager:
def get_tasks(self): def get_tasks(self):
return self.db.get_tasks() return self.db.get_tasks()
def get_current_tasks(self):
return self.current_tasks
def get_non_indexed_results(self):
return self.db.get_non_indexed_results()
def execute_queued_task(self): def execute_queued_task(self):
task = self.db.pop_task() task = self.db.pop_task()
if task: if task:
self.current_tasks.append(task)
print("pooled " + task.url) print("pooled " + task.url)
self.pool.apply_async( self.pool.apply_async(
TaskManager.run_task, TaskManager.run_task,
args=(task, self.db_path), args=(task, self.db_path),
@ -68,8 +80,9 @@ class TaskManager:
@staticmethod @staticmethod
def task_error(err): def task_error(err):
print("ERROR") print("FIXME: Task failed (This should not happen)")
print(err) print(err)
raise err

View File

View File

@ -20,16 +20,6 @@ class Website:
self.id = website_id self.id = website_id
class File:
def __init__(self, website_id: int, path: str, mime: str, name: str, size: int):
self.mime = mime
self.size = size
self.name = name
self.path = path
self.website_id = website_id
class ApiToken: class ApiToken:
def __init__(self, token, description): def __init__(self, token, description):
@ -39,13 +29,6 @@ class ApiToken:
class Database: class Database:
SORT_ORDERS = {
"score": "ORDER BY rank",
"size_asc": "ORDER BY size ASC",
"size_dsc": "ORDER BY size DESC",
"none": ""
}
def __init__(self, db_path): def __init__(self, db_path):
self.db_path = db_path self.db_path = db_path
@ -75,60 +58,6 @@ class Database:
return website_id return website_id
def insert_files(self, files: list):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Insert Paths first
website_paths = dict()
for file in files:
if file.path not in website_paths:
cursor.execute("INSERT INTO WebsitePath (website_id, path) VALUES (?,?)",
(file.website_id, file.path))
cursor.execute("SELECT LAST_INSERT_ROWID()")
website_paths[file.path] = cursor.fetchone()[0]
# Then FileTypes
mimetypes = dict()
cursor.execute("SELECT * FROM FileType")
db_mimetypes = cursor.fetchall()
for db_mimetype in db_mimetypes:
mimetypes[db_mimetype[1]] = db_mimetype[0]
for file in files:
if file.mime not in mimetypes:
cursor.execute("INSERT INTO FileType (mime) VALUES (?)", (file.mime, ))
cursor.execute("SELECT LAST_INSERT_ROWID()")
mimetypes[file.mime] = cursor.fetchone()[0]
conn.commit()
# Then insert files
cursor.executemany("INSERT INTO File (path_id, name, size, mime_id) VALUES (?,?,?,?)",
[(website_paths[x.path], x.name, x.size, mimetypes[x.mime]) for x in files])
# Update date
if len(files) > 0:
cursor.execute("UPDATE Website SET last_modified=CURRENT_TIMESTAMP WHERE id = ?",
(files[0].website_id, ))
conn.commit()
def import_json(self, json_file, website: Website):
if not self.get_website_by_url(website.url):
website_id = self.insert_website(website)
else:
website_id = website.id
with open(json_file, "r") as f:
try:
self.insert_files([File(website_id, x["path"], os.path.splitext(x["name"])[1].lower(), x["name"], x["size"])
for x in json.load(f)])
except Exception as e:
print(e)
print("Couldn't read json file!")
pass
def get_website_by_url(self, url): def get_website_by_url(self, url):
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@ -158,152 +87,6 @@ class Database:
else: else:
return None return None
def enqueue(self, website_id, reddit_post_id=None, reddit_comment_id=None, priority=1):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
if reddit_post_id:
cursor.execute("INSERT OR IGNORE INTO Queue (website_id, reddit_post_id, priority) VALUES (?,?,?)",
(website_id, reddit_post_id, priority))
else:
cursor.execute("INSERT OR IGNORE INTO Queue (website_id, reddit_comment_id, priority) VALUES (?,?,?)",
(website_id, reddit_comment_id, priority))
conn.commit()
def dequeue(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT website_id, reddit_post_id, reddit_comment_id"
" FROM Queue ORDER BY priority DESC, Queue.id ASC LIMIT 1")
website = cursor.fetchone()
if website:
cursor.execute("DELETE FROM Queue WHERE website_id=?", (website[0],))
return website[0], website[1], website[2]
else:
return None
def queue(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT url, logged_ip, logged_useragent, last_modified "
"FROM Queue INNER JOIN Website ON website_id=Website.id "
"ORDER BY Queue.priority DESC, Queue.id ASC")
return [Website(x[0], x[1], x[2], x[3]) for x in cursor.fetchall()]
def get_stats(self):
stats = {}
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*), SUM(size) FROM File")
db_files = cursor.fetchone()
stats["file_count"] = db_files[0]
stats["file_size"] = db_files[1]
cursor.execute("SELECT COUNT(DISTINCT website_id), COUNT(id) FROM WebsitePath")
db_websites = cursor.fetchone()
stats["website_count"] = db_websites[0]
stats["website_paths"] = db_websites[1]
return stats
def search(self, q, limit: int = 50, offset: int = 0, sort_order="score"):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
try:
order_by = Database.SORT_ORDERS.get(sort_order, "")
cursor.execute("SELECT size, Website.url, WebsitePath.path, File.name, Website.id FROM File_index "
"INNER JOIN File ON File.id = File_index.rowid "
"INNER JOIN WebsitePath ON File.path_id = WebsitePath.id "
"INNER JOIN Website ON website_id = Website.id "
"WHERE File_index MATCH ? " +
order_by + " LIMIT ? OFFSET ?",
(q, limit, offset * limit))
except sqlite3.OperationalError as e:
raise InvalidQueryException(str(e))
return cursor.fetchall()
def get_website_stats(self, website_id):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT SUM(File.size), COUNT(*) FROM File "
"INNER JOIN WebsitePath Path on File.path_id = Path.id "
"WHERE Path.website_id = ?", (website_id, ))
file_sum, file_count = cursor.fetchone()
cursor.execute("SELECT SUM(File.size) as total_size, COUNT(File.id), FileType.mime FROM File "
"INNER JOIN FileType ON FileType.id = File.mime_id "
"INNER JOIN WebsitePath Path on File.path_id = Path.id "
"WHERE Path.website_id = ? "
"GROUP BY FileType.id ORDER BY total_size DESC", (website_id, ))
db_mime_stats = cursor.fetchall()
cursor.execute("SELECT Website.url, Website.last_modified FROM Website WHERE id = ?", (website_id, ))
website_url, website_date = cursor.fetchone()
return {
"total_size": file_sum if file_sum else 0,
"total_count": file_count if file_count else 0,
"base_url": website_url,
"report_time": website_date,
"mime_stats": db_mime_stats
}
def get_subdir_stats(self, website_id: int, path: str):
"""Get stats of a sub directory. path must not start with / and must end with /"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT SUM(File.size), COUNT(*) FROM File "
"INNER JOIN WebsitePath Path on File.path_id = Path.id "
"WHERE Path.website_id = ? AND Path.path LIKE ?", (website_id, path + "%"))
file_sum, file_count = cursor.fetchone()
cursor.execute("SELECT SUM(File.size) as total_size, COUNT(File.id), FileType.mime FROM File "
"INNER JOIN FileType ON FileType.id = File.mime_id "
"INNER JOIN WebsitePath Path on File.path_id = Path.id "
"WHERE Path.website_id = ? AND Path.path LIKE ? "
"GROUP BY FileType.id ORDER BY total_size DESC", (website_id, path + "%"))
db_mime_stats = cursor.fetchall()
cursor.execute("SELECT Website.url, Website.last_modified FROM Website WHERE id = ?", (website_id, ))
website_url, website_date = cursor.fetchone()
return {
"total_size": file_sum if file_sum else 0,
"total_count": file_count if file_count else 0,
"base_url": website_url,
"report_time": website_date,
"mime_stats": db_mime_stats
}
def get_website_links(self, website_id):
"""Get all download links of a website"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
website = self.get_website_by_id(website_id)
cursor.execute("SELECT File.name, WebsitePath.path FROM File "
"INNER JOIN WebsitePath on File.path_id = WebsitePath.id "
"WHERE WebsitePath.website_id = ?", (website.id, ))
return [website.url + x[1] + ("/" if len(x[1]) > 0 else "") + x[0] for x in cursor.fetchall()]
def get_websites(self, per_page, page: int): def get_websites(self, per_page, page: int):
"""Get all websites""" """Get all websites"""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
@ -325,29 +108,13 @@ class Database:
def website_has_been_scanned(self, url): def website_has_been_scanned(self, url):
"""Check if a website has at least 1 file""" """Check if a website has at least 1 file"""
# TODO: Check with SearchEngine
with sqlite3.connect(self.db_path) as conn: print("FIXME: website_has_been_scanned")
cursor = conn.cursor()
website_id = self.website_exists(url)
if website_id:
cursor.execute("SELECT COUNT(Path.id) FROM Website "
"INNER JOIN WebsitePath Path on Website.id = Path.website_id "
"WHERE Website.id = ?", (website_id, ))
return cursor.fetchone()[0] > 0
return None
def clear_website(self, website_id): def clear_website(self, website_id):
"""Remove all files from a website and update its last_updated date""" """Remove all files from a website and update its last_updated date"""
with sqlite3.connect(self.db_path) as conn: # TODO: Check with SearchEngine
cursor = conn.cursor() print("FIXME: clear_website")
cursor.execute("DELETE FROM File WHERE File.path_id IN (SELECT WebsitePath.id "
"FROM WebsitePath WHERE WebsitePath.website_id=?)", (website_id, ))
cursor.execute("DELETE FROM WebsitePath WHERE website_id=?", (website_id, ))
cursor.execute("UPDATE Website SET last_modified=CURRENT_TIMESTAMP WHERE id=?", (website_id, ))
conn.commit()
def get_websites_older(self, delta: datetime.timedelta): def get_websites_older(self, delta: datetime.timedelta):
"""Get websites last updated before a given date""" """Get websites last updated before a given date"""
@ -358,17 +125,6 @@ class Database:
cursor.execute("SELECT Website.id FROM Website WHERE last_modified < ?", (date, )) cursor.execute("SELECT Website.id FROM Website WHERE last_modified < ?", (date, ))
return [x[0] for x in cursor.fetchall()] return [x[0] for x in cursor.fetchall()]
def get_websites_smaller(self, size: int):
"""Get the websites with total size smaller than specified"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT Website.id FROM Website "
"INNER JOIN WebsitePath Path on Website.id = Path.website_id "
"INNER JOIN File F on Path.id = F.path_id "
"GROUP BY Website.id HAVING SUM(F.size) < ?", (size, ))
return cursor.fetchall()
def delete_website(self, website_id): def delete_website(self, website_id):
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:

View File

@ -4,12 +4,12 @@ import json
payload = json.dumps({ payload = json.dumps({
"website_id": 123, "website_id": 123,
"url": "http://124.158.108.137/ebooks/", "url": "https://frenchy.ga/",
"priority": 2, "priority": 2,
"callback_type": "", "callback_type": "",
"callback_args": "{}" "callback_args": "{}"
}) })
r = requests.post("http://localhost:5000/task/put", r = requests.post("http://localhost:5001/task/put",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
data=payload) data=payload)

View File

@ -9,39 +9,6 @@ CREATE TABLE Website (
last_modified INTEGER DEFAULT CURRENT_TIMESTAMP last_modified INTEGER DEFAULT CURRENT_TIMESTAMP
); );
CREATE TABLE WebsitePath (
id INTEGER PRIMARY KEY NOT NULL,
website_id INTEGER,
path TEXT,
FOREIGN KEY (website_id) REFERENCES Website(id)
);
CREATE TABLE FileType (
id INTEGER PRIMARY KEY NOT NULL,
mime TEXT
);
CREATE TABLE File (
id INTEGER PRIMARY KEY NOT NULL,
path_id INTEGER,
mime_id INTEGER,
name TEXT,
size INTEGER,
FOREIGN KEY (path_id) REFERENCES WebsitePath(id),
FOREIGN KEY (mime_id) REFERENCES FileType(id)
);
CREATE TABLE Queue (
id INTEGER PRIMARY KEY NOT NULL,
website_id INTEGER UNIQUE,
reddit_post_id TEXT,
reddit_comment_id TEXT,
priority INTEGER
);
CREATE TABLE Admin ( CREATE TABLE Admin (
username TEXT PRIMARY KEY NOT NULL, username TEXT PRIMARY KEY NOT NULL,
password TEXT password TEXT
@ -51,24 +18,3 @@ CREATE TABLE ApiToken (
token TEXT PRIMARY KEY NOT NULL, token TEXT PRIMARY KEY NOT NULL,
description TEXT description TEXT
); );
-- Full Text Index
CREATE VIRTUAL TABLE File_index USING fts5 (
name,
path,
tokenize=porter
);
CREATE TRIGGER after_File_index_insert AFTER INSERT ON File BEGIN
INSERT INTO File_index (rowid, name, path)
SELECT File.id, File.name, WebsitePath.path
FROM File
INNER JOIN WebsitePath on File.path_id = WebsitePath.id
WHERE File.id = new.id;
END;
CREATE TRIGGER after_File_index_delete AFTER DELETE ON File BEGIN
DELETE FROM File_index WHERE rowid = old.id;
END;

126
task.py
View File

@ -1,81 +1,91 @@
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
import os from crawl_server.database import Task, TaskResult
from database import Website import requests
from multiprocessing import Value, Process import json
from database import Database
from reddit_bot import RedditBot from reddit_bot import RedditBot
import praw import praw
class TaskManager: class CrawlServer:
headers = {
"Content-Type": "application/json"
}
def __init__(self, url):
self.url = url
def queue_task(self, task: Task) -> bool:
print("Sending task to crawl server " + self.url)
payload = json.dumps(task.to_json())
r = requests.post(self.url + "/task/put", headers=CrawlServer.headers, data=payload)
print(r)
return r.status_code == 200
def get_completed_tasks(self) -> list:
r = requests.get(self.url + "/task/completed")
return []
def get_queued_tasks(self) -> list:
r = requests.get(self.url + "/task/")
return [
Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"])
for t in json.loads(r.text)
]
def get_current_tasks(self):
r = requests.get(self.url + "/task/current")
return [
Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"])
for t in json.loads(r.text)
]
class TaskDispatcher:
def __init__(self): def __init__(self):
self.busy = Value("i", 0)
self.current_website = None
self.current_task = None
reddit = praw.Reddit('opendirectories-bot', reddit = praw.Reddit('opendirectories-bot',
user_agent='github.com/simon987/od-database v1.0 (by /u/Hexahedr_n)') user_agent='github.com/simon987/od-database v1.0 (by /u/Hexahedr_n)')
self.reddit_bot = RedditBot("crawled.txt", reddit) self.reddit_bot = RedditBot("crawled.txt", reddit)
self.db = Database("db.sqlite3")
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
scheduler.add_job(self.check_new_task, "interval", seconds=1) scheduler.add_job(self.check_completed_tasks, "interval", seconds=1)
scheduler.start() scheduler.start()
def check_new_task(self): # TODO load from config
if self.current_task is None: self.crawl_servers = [
task = self.db.dequeue() CrawlServer("http://localhost:5001"),
]
if task: def check_completed_tasks(self):
website_id, post_id, comment_id = task return self._get_available_crawl_server().get_completed_tasks()
website = self.db.get_website_by_id(website_id)
self.current_task = Process(target=self.execute_task,
args=(website, self.busy, post_id, comment_id))
self.current_website = website
self.current_task.start()
elif self.busy.value == 0: def dispatch_task(self, task: Task):
self.current_task.terminate() self._get_available_crawl_server().queue_task(task)
self.current_task = None
self.current_website = None
def execute_task(self, website: Website, busy: Value, post_id: str, comment_id: str): def _get_available_crawl_server(self) -> CrawlServer:
busy.value = 1 # TODO: Load balancing & health check for crawl servers
if os.path.exists("data.json"): return self.crawl_servers[0]
os.remove("data.json")
print("Started crawling task")
process = CrawlerProcess(get_project_settings())
process.crawl("od_links", base_url=website.url)
process.start()
print("Done crawling")
self.db.import_json("data.json", website) def get_queued_tasks(self) -> list:
os.remove("data.json")
print("Imported in SQLite3")
# TODO: Extract 'callbacks' for posts and comments in a function queued_tasks = []
if post_id:
# Reply to post
stats = self.db.get_website_stats(website.id)
comment = self.reddit_bot.get_comment({"": stats}, website.id)
print(comment)
if "total_size" in stats and stats["total_size"] > 10000000:
post = self.reddit_bot.reddit.submission(post_id)
self.reddit_bot.reply(post, comment)
pass
else:
self.reddit_bot.log_crawl(post_id)
elif comment_id: for server in self.crawl_servers:
# Reply to comment queued_tasks.extend(server.get_queued_tasks())
stats = self.db.get_website_stats(website.id)
comment = self.reddit_bot.get_comment({"There you go!": stats}, website.id)
print(comment)
reddit_comment = self.reddit_bot.reddit.comment(comment_id)
self.reddit_bot.reply(reddit_comment, comment)
busy.value = 0 return queued_tasks
print("Done crawling task")
def get_current_tasks(self) -> list:
current_tasks = []
for server in self.crawl_servers:
current_tasks.extend(server.get_current_tasks())
return current_tasks

View File

@ -14,5 +14,6 @@ CREATE TABLE TaskResult (
status_code TEXT, status_code TEXT,
file_count INT, file_count INT,
start_time INT, start_time INT,
end_time INT end_time INT,
indexed_time INT DEFAULT NULL
); );

View File

@ -71,15 +71,17 @@
<thead> <thead>
<tr> <tr>
<th>Url</th> <th>Url</th>
<th>Date added</th> <th>Priority</th>
<th>Task type</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{% for w in queue %} {% for task in queue %}
<tr> <tr>
<td title="{{ w.url }}">{{ w.url | truncate(70)}}</td> <td title="{{ task.url }}">{{ task.url | truncate(70)}}</td>
<td>{{ w.last_modified }} UTC</td> <td>{{ task.priority }}</td>
<td>{{ task.callback_type if task.callback_type else "NORMAL" }}</td>
</tr> </tr>
{% endfor %} {% endfor %}
</tbody> </tbody>