mirror of
https://github.com/simon987/od-database.git
synced 2025-12-11 14:08:51 +00:00
Started working on post-crawl callbacks and basic auth for crawl servers
This commit is contained in:
61
crawl_server/callbacks.py
Normal file
61
crawl_server/callbacks.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from crawl_server.database import Task
|
||||
from crawl_server.reddit_bot import RedditBot
|
||||
import praw
|
||||
|
||||
|
||||
class PostCrawlCallback:
|
||||
|
||||
def __init__(self, task: Task):
|
||||
self.task = task
|
||||
|
||||
def run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PostCrawlCallbackFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_callback(task: Task):
|
||||
|
||||
if task.callback_type == "reddit_post":
|
||||
return RedditPostCallback(task)
|
||||
|
||||
elif task.callback_type == "reddit_comment":
|
||||
return RedditCommentCallback(task)
|
||||
|
||||
elif task.callback_type == "discord":
|
||||
return DiscordCallback(task)
|
||||
|
||||
|
||||
class RedditCallback(PostCrawlCallback):
|
||||
|
||||
def __init__(self, task: Task):
|
||||
super().__init__(task)
|
||||
|
||||
reddit = praw.Reddit('opendirectories-bot',
|
||||
user_agent='github.com/simon987/od-database (by /u/Hexahedr_n)')
|
||||
self.reddit_bot = RedditBot("crawled.txt", reddit)
|
||||
|
||||
def run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedditPostCallback(RedditCallback):
|
||||
|
||||
def run(self):
|
||||
print("Reddit post callback for task " + str(self.task))
|
||||
pass
|
||||
|
||||
|
||||
class RedditCommentCallback(RedditCallback):
|
||||
|
||||
def run(self):
|
||||
print("Reddit comment callback for task " + str(self.task))
|
||||
pass
|
||||
|
||||
|
||||
class DiscordCallback(PostCrawlCallback):
|
||||
|
||||
def run(self):
|
||||
print("Discord callback for task " + str(self.task))
|
||||
pass
|
||||
@@ -41,9 +41,12 @@ class Task:
|
||||
"callback_args": json.dumps(self.callback_args)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
def __str__(self):
|
||||
return json.dumps(self.to_json())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class TaskManagerDatabase:
|
||||
|
||||
|
||||
88
crawl_server/reddit_bot.py
Normal file
88
crawl_server/reddit_bot.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import time
|
||||
import praw
|
||||
import humanfriendly
|
||||
|
||||
|
||||
class RedditBot:
|
||||
|
||||
bottom_line = "^(Beep boop. I am a bot that calculates the file sizes & count of " \
|
||||
"open directories posted in /r/opendirectories/)"
|
||||
|
||||
def __init__(self, log_file: str, reddit: praw.Reddit):
|
||||
|
||||
self.log_file = log_file
|
||||
|
||||
self.crawled = []
|
||||
self.load_from_file()
|
||||
self.reddit = reddit
|
||||
|
||||
def log_crawl(self, post_id):
|
||||
|
||||
self.load_from_file()
|
||||
self.crawled.append(post_id)
|
||||
|
||||
with open(self.log_file, "w") as f:
|
||||
for post_id in self.crawled:
|
||||
f.write(post_id + "\n")
|
||||
|
||||
def has_crawled(self, post_id):
|
||||
self.load_from_file()
|
||||
return post_id in self.crawled
|
||||
|
||||
def load_from_file(self):
|
||||
if not os.path.isfile(self.log_file):
|
||||
self.crawled = []
|
||||
else:
|
||||
with open(self.log_file, "r") as f:
|
||||
self.crawled = list(filter(None, f.read().split("\n")))
|
||||
|
||||
def reply(self, reddit_obj, comment: str):
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Double check has_crawled
|
||||
if not self.has_crawled(reddit_obj.id):
|
||||
reddit_obj.reply(comment)
|
||||
self.log_crawl(reddit_obj.id)
|
||||
print("Reply to " + reddit_obj.id)
|
||||
break
|
||||
except Exception as e:
|
||||
print("Waiting 5 minutes: " + str(e))
|
||||
time.sleep(300)
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def get_comment(stats: dict, website_id, message: str = ""):
|
||||
comment = message + " \n" if len(message) > 0 else ""
|
||||
|
||||
for stat in stats:
|
||||
comment += stat + " \n" if len(stat) > 0 else ""
|
||||
comment += RedditBot.format_stats(stats[stat])
|
||||
|
||||
comment += "[Full Report](https://od-database.simon987.net/website/" + str(website_id) + "/)"
|
||||
comment += " | [Link list](https://od-database.simon987.net/website/" + str(website_id) + "/links) \n"
|
||||
comment += "*** \n"
|
||||
comment += RedditBot.bottom_line
|
||||
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def format_stats(stats):
|
||||
|
||||
result = " \n"
|
||||
result += "File types | Count | Total Size\n"
|
||||
result += ":-- | :-- | :-- \n"
|
||||
counter = 0
|
||||
for mime in stats["mime_stats"]:
|
||||
result += mime[2]
|
||||
result += " | " + str(mime[1])
|
||||
result += " | " + humanfriendly.format_size(mime[0]) + " \n"
|
||||
|
||||
counter += 1
|
||||
if counter >= 3:
|
||||
break
|
||||
|
||||
result += "**Total** | **" + str(stats["total_count"]) + "** | **"
|
||||
result += humanfriendly.format_size(stats["total_size"]) + "** \n\n"
|
||||
return result
|
||||
@@ -36,7 +36,7 @@ class HttpDirectory(RemoteDirectory):
|
||||
|
||||
def __init__(self, url):
|
||||
super().__init__(url)
|
||||
self.parser = etree.HTMLParser(collect_ids=False)
|
||||
self.parser = etree.HTMLParser(collect_ids=False, encoding="utf-8")
|
||||
|
||||
def list_dir(self, path) -> list:
|
||||
results = []
|
||||
|
||||
@@ -1,19 +1,33 @@
|
||||
from flask import Flask, request, abort, Response, send_from_directory
|
||||
from flask import Flask, request, abort, Response
|
||||
from flask_httpauth import HTTPTokenAuth
|
||||
import json
|
||||
from crawl_server.task_manager import TaskManager, Task, TaskResult
|
||||
from crawl_server.task_manager import TaskManager, Task
|
||||
import os
|
||||
import config
|
||||
app = Flask(__name__)
|
||||
auth = HTTPTokenAuth(scheme="Token")
|
||||
|
||||
tm = TaskManager("tm_db.sqlite3", 2)
|
||||
tokens = [config.CRAWL_SERVER_TOKEN]
|
||||
|
||||
tm = TaskManager("tm_db.sqlite3", 8)
|
||||
|
||||
|
||||
@auth.verify_token
|
||||
def verify_token(token):
|
||||
print(token)
|
||||
if token in tokens:
|
||||
return True
|
||||
|
||||
|
||||
@app.route("/task/")
|
||||
@auth.login_required
|
||||
def get_tasks():
|
||||
json_str = json.dumps([task.to_json() for task in tm.get_tasks()])
|
||||
return Response(json_str, mimetype="application/json")
|
||||
|
||||
|
||||
@app.route("/task/put", methods=["POST"])
|
||||
@auth.login_required
|
||||
def task_put():
|
||||
|
||||
if request.json:
|
||||
@@ -34,12 +48,14 @@ def task_put():
|
||||
|
||||
|
||||
@app.route("/task/completed", methods=["GET"])
|
||||
@auth.login_required
|
||||
def get_completed_tasks():
|
||||
json_str = json.dumps([result.to_json() for result in tm.get_non_indexed_results()])
|
||||
return json_str
|
||||
|
||||
|
||||
@app.route("/task/current", methods=["GET"])
|
||||
@auth.login_required
|
||||
def get_current_tasks():
|
||||
|
||||
current_tasks = tm.get_current_tasks()
|
||||
@@ -47,6 +63,7 @@ def get_current_tasks():
|
||||
|
||||
|
||||
@app.route("/file_list/<int:website_id>/")
|
||||
@auth.login_required
|
||||
def get_file_list(website_id):
|
||||
|
||||
file_name = "./crawled/" + str(website_id) + ".json"
|
||||
@@ -62,4 +79,4 @@ def get_file_list(website_id):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5001)
|
||||
app.run(port=5002)
|
||||
|
||||
@@ -4,11 +4,12 @@ from multiprocessing import Manager
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from datetime import datetime
|
||||
from crawl_server.crawler import RemoteDirectoryCrawler
|
||||
from crawl_server.callbacks import PostCrawlCallbackFactory
|
||||
|
||||
|
||||
class TaskManager:
|
||||
|
||||
def __init__(self, db_path, max_processes=4):
|
||||
def __init__(self, db_path, max_processes=2):
|
||||
self.db_path = db_path
|
||||
self.db = TaskManagerDatabase(db_path)
|
||||
self.pool = ProcessPoolExecutor(max_workers=max_processes)
|
||||
@@ -53,7 +54,7 @@ class TaskManager:
|
||||
|
||||
print("Starting task " + task.url)
|
||||
|
||||
crawler = RemoteDirectoryCrawler(task.url, 30)
|
||||
crawler = RemoteDirectoryCrawler(task.url, 100)
|
||||
crawl_result = crawler.crawl_directory("./crawled/" + str(task.website_id) + ".json")
|
||||
|
||||
result.file_count = crawl_result.file_count
|
||||
@@ -62,6 +63,11 @@ class TaskManager:
|
||||
result.end_time = datetime.utcnow()
|
||||
print("End task " + task.url)
|
||||
|
||||
callback = PostCrawlCallbackFactory.get_callback(task)
|
||||
if callback:
|
||||
callback.run()
|
||||
print("Executed callback")
|
||||
|
||||
return result, db_path, current_tasks
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user