Started working on post-crawl callbacks and basic auth for crawl servers

This commit is contained in:
Simon
2018-06-14 15:05:56 -04:00
parent 1bd58468eb
commit 83ca579ec7
13 changed files with 142 additions and 56 deletions

61
crawl_server/callbacks.py Normal file
View File

@@ -0,0 +1,61 @@
from crawl_server.database import Task
from crawl_server.reddit_bot import RedditBot
import praw
class PostCrawlCallback:
def __init__(self, task: Task):
self.task = task
def run(self):
raise NotImplementedError
class PostCrawlCallbackFactory:
@staticmethod
def get_callback(task: Task):
if task.callback_type == "reddit_post":
return RedditPostCallback(task)
elif task.callback_type == "reddit_comment":
return RedditCommentCallback(task)
elif task.callback_type == "discord":
return DiscordCallback(task)
class RedditCallback(PostCrawlCallback):
def __init__(self, task: Task):
super().__init__(task)
reddit = praw.Reddit('opendirectories-bot',
user_agent='github.com/simon987/od-database (by /u/Hexahedr_n)')
self.reddit_bot = RedditBot("crawled.txt", reddit)
def run(self):
raise NotImplementedError
class RedditPostCallback(RedditCallback):
def run(self):
print("Reddit post callback for task " + str(self.task))
pass
class RedditCommentCallback(RedditCallback):
def run(self):
print("Reddit comment callback for task " + str(self.task))
pass
class DiscordCallback(PostCrawlCallback):
def run(self):
print("Discord callback for task " + str(self.task))
pass

View File

@@ -41,9 +41,12 @@ class Task:
"callback_args": json.dumps(self.callback_args)
}
def __repr__(self):
def __str__(self):
return json.dumps(self.to_json())
def __repr__(self):
return self.__str__()
class TaskManagerDatabase:

View File

@@ -0,0 +1,88 @@
import os
import time
import praw
import humanfriendly
class RedditBot:
bottom_line = "^(Beep boop. I am a bot that calculates the file sizes & count of " \
"open directories posted in /r/opendirectories/)"
def __init__(self, log_file: str, reddit: praw.Reddit):
self.log_file = log_file
self.crawled = []
self.load_from_file()
self.reddit = reddit
def log_crawl(self, post_id):
self.load_from_file()
self.crawled.append(post_id)
with open(self.log_file, "w") as f:
for post_id in self.crawled:
f.write(post_id + "\n")
def has_crawled(self, post_id):
self.load_from_file()
return post_id in self.crawled
def load_from_file(self):
if not os.path.isfile(self.log_file):
self.crawled = []
else:
with open(self.log_file, "r") as f:
self.crawled = list(filter(None, f.read().split("\n")))
def reply(self, reddit_obj, comment: str):
while True:
try:
# Double check has_crawled
if not self.has_crawled(reddit_obj.id):
reddit_obj.reply(comment)
self.log_crawl(reddit_obj.id)
print("Reply to " + reddit_obj.id)
break
except Exception as e:
print("Waiting 5 minutes: " + str(e))
time.sleep(300)
continue
@staticmethod
def get_comment(stats: dict, website_id, message: str = ""):
comment = message + " \n" if len(message) > 0 else ""
for stat in stats:
comment += stat + " \n" if len(stat) > 0 else ""
comment += RedditBot.format_stats(stats[stat])
comment += "[Full Report](https://od-database.simon987.net/website/" + str(website_id) + "/)"
comment += " | [Link list](https://od-database.simon987.net/website/" + str(website_id) + "/links) \n"
comment += "*** \n"
comment += RedditBot.bottom_line
return comment
@staticmethod
def format_stats(stats):
result = " \n"
result += "File types | Count | Total Size\n"
result += ":-- | :-- | :-- \n"
counter = 0
for mime in stats["mime_stats"]:
result += mime[2]
result += " | " + str(mime[1])
result += " | " + humanfriendly.format_size(mime[0]) + " \n"
counter += 1
if counter >= 3:
break
result += "**Total** | **" + str(stats["total_count"]) + "** | **"
result += humanfriendly.format_size(stats["total_size"]) + "** \n\n"
return result

View File

@@ -36,7 +36,7 @@ class HttpDirectory(RemoteDirectory):
def __init__(self, url):
super().__init__(url)
self.parser = etree.HTMLParser(collect_ids=False)
self.parser = etree.HTMLParser(collect_ids=False, encoding="utf-8")
def list_dir(self, path) -> list:
results = []

View File

@@ -1,19 +1,33 @@
from flask import Flask, request, abort, Response, send_from_directory
from flask import Flask, request, abort, Response
from flask_httpauth import HTTPTokenAuth
import json
from crawl_server.task_manager import TaskManager, Task, TaskResult
from crawl_server.task_manager import TaskManager, Task
import os
import config
app = Flask(__name__)
auth = HTTPTokenAuth(scheme="Token")
tm = TaskManager("tm_db.sqlite3", 2)
tokens = [config.CRAWL_SERVER_TOKEN]
tm = TaskManager("tm_db.sqlite3", 8)
@auth.verify_token
def verify_token(token):
print(token)
if token in tokens:
return True
@app.route("/task/")
@auth.login_required
def get_tasks():
json_str = json.dumps([task.to_json() for task in tm.get_tasks()])
return Response(json_str, mimetype="application/json")
@app.route("/task/put", methods=["POST"])
@auth.login_required
def task_put():
if request.json:
@@ -34,12 +48,14 @@ def task_put():
@app.route("/task/completed", methods=["GET"])
@auth.login_required
def get_completed_tasks():
json_str = json.dumps([result.to_json() for result in tm.get_non_indexed_results()])
return json_str
@app.route("/task/current", methods=["GET"])
@auth.login_required
def get_current_tasks():
current_tasks = tm.get_current_tasks()
@@ -47,6 +63,7 @@ def get_current_tasks():
@app.route("/file_list/<int:website_id>/")
@auth.login_required
def get_file_list(website_id):
file_name = "./crawled/" + str(website_id) + ".json"
@@ -62,4 +79,4 @@ def get_file_list(website_id):
if __name__ == "__main__":
app.run(port=5001)
app.run(port=5002)

View File

@@ -4,11 +4,12 @@ from multiprocessing import Manager
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime
from crawl_server.crawler import RemoteDirectoryCrawler
from crawl_server.callbacks import PostCrawlCallbackFactory
class TaskManager:
def __init__(self, db_path, max_processes=4):
def __init__(self, db_path, max_processes=2):
self.db_path = db_path
self.db = TaskManagerDatabase(db_path)
self.pool = ProcessPoolExecutor(max_workers=max_processes)
@@ -53,7 +54,7 @@ class TaskManager:
print("Starting task " + task.url)
crawler = RemoteDirectoryCrawler(task.url, 30)
crawler = RemoteDirectoryCrawler(task.url, 100)
crawl_result = crawler.crawl_directory("./crawled/" + str(task.website_id) + ".json")
result.file_count = crawl_result.file_count
@@ -62,6 +63,11 @@ class TaskManager:
result.end_time = datetime.utcnow()
print("End task " + task.url)
callback = PostCrawlCallbackFactory.get_callback(task)
if callback:
callback.run()
print("Executed callback")
return result, db_path, current_tasks
@staticmethod