diff --git a/app.py b/app.py index 97fa422..045f814 100644 --- a/app.py +++ b/app.py @@ -8,7 +8,7 @@ from flask_recaptcha import ReCaptcha import od_util import config from flask_caching import Cache -from task import TaskDispatcher, Task +from task import TaskDispatcher, Task, CrawlServer from search.search import ElasticSearchEngine app = Flask(__name__) @@ -349,8 +349,9 @@ def admin_dashboard(): tokens = db.get_tokens() blacklist = db.get_blacklist() + crawl_servers = db.get_crawl_servers() - return render_template("dashboard.html", api_tokens=tokens, blacklist=blacklist) + return render_template("dashboard.html", api_tokens=tokens, blacklist=blacklist, crawl_servers=crawl_servers) else: return abort(403) @@ -416,6 +417,37 @@ def admin_crawl_logs(): return abort(403) +@app.route("/crawl_server/add", methods=["POST"]) +def admin_add_crawl_server(): + if "username" in session: + + server = CrawlServer( + request.form.get("url"), + request.form.get("name"), + request.form.get("slots"), + request.form.get("token") + ) + + db.add_crawl_server(server) + flash("Added crawl server", "success") + return redirect("/dashboard") + + else: + return abort(403) + + +@app.route("/crawl_server//delete") +def admin_delete_crawl_server(server_id): + if "username" in session: + + db.remove_crawl_server(server_id) + flash("Deleted crawl server", "success") + return redirect("/dashboard") + + else: + abort(403) + + if __name__ == '__main__': if config.USE_SSL: context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) diff --git a/crawl_server/crawler.py b/crawl_server/crawler.py index 358786a..4e7af05 100644 --- a/crawl_server/crawler.py +++ b/crawl_server/crawler.py @@ -84,46 +84,48 @@ class RemoteDirectoryCrawler: self.crawled_paths = list() def crawl_directory(self, out_file: str) -> CrawlResult: - try: - directory = RemoteDirectoryFactory.get_directory(self.url) - path, root_listing = directory.list_dir("") - self.crawled_paths.append(path) - directory.close() - except TimeoutError: - return CrawlResult(0, "timeout") + try: + directory = RemoteDirectoryFactory.get_directory(self.url) + path, root_listing = directory.list_dir("") + self.crawled_paths.append(path) + directory.close() + except TimeoutError: + return CrawlResult(0, "timeout") - in_q = Queue(maxsize=0) - files_q = Queue(maxsize=0) - for f in root_listing: - if f.is_dir: - in_q.put(os.path.join(f.path, f.name, "")) - else: - files_q.put(f) + in_q = Queue(maxsize=0) + files_q = Queue(maxsize=0) + for f in root_listing: + if f.is_dir: + in_q.put(os.path.join(f.path, f.name, "")) + else: + files_q.put(f) - threads = [] - for i in range(self.max_threads): - worker = Thread(target=RemoteDirectoryCrawler._process_listings, args=(self, self.url, in_q, files_q)) - threads.append(worker) - worker.start() + threads = [] + for i in range(self.max_threads): + worker = Thread(target=RemoteDirectoryCrawler._process_listings, args=(self, self.url, in_q, files_q)) + threads.append(worker) + worker.start() - files_written = [] # Pass array to worker to get result - file_writer_thread = Thread(target=RemoteDirectoryCrawler._log_to_file, args=(files_q, out_file, files_written)) - file_writer_thread.start() + files_written = [] # Pass array to worker to get result + file_writer_thread = Thread(target=RemoteDirectoryCrawler._log_to_file, args=(files_q, out_file, files_written)) + file_writer_thread.start() - in_q.join() - files_q.join() - print("Done") + in_q.join() + files_q.join() + print("Done") - # Kill threads - for _ in threads: - in_q.put(None) - for t in threads: - t.join() - files_q.put(None) - file_writer_thread.join() + # Kill threads + for _ in threads: + in_q.put(None) + for t in threads: + t.join() + files_q.put(None) + file_writer_thread.join() - return CrawlResult(files_written[0], "success") + return CrawlResult(files_written[0], "success") + except Exception as e: + return CrawlResult(0, str(e) + " \nType:" + str(type(e))) def _process_listings(self, url: str, in_q: Queue, files_q: Queue): diff --git a/crawl_server/server.py b/crawl_server/server.py index b7b65fc..99fd03a 100644 --- a/crawl_server/server.py +++ b/crawl_server/server.py @@ -7,15 +7,14 @@ import config app = Flask(__name__) auth = HTTPTokenAuth(scheme="Token") -tokens = [config.CRAWL_SERVER_TOKEN] +token = config.CRAWL_SERVER_TOKEN tm = TaskManager("tm_db.sqlite3", 32) @auth.verify_token -def verify_token(token): - if token in tokens: - return True +def verify_token(provided_token): + return token == provided_token @app.route("/task/") @@ -99,4 +98,4 @@ def get_stats(): if __name__ == "__main__": - app.run(port=5001, host="0.0.0.0") + app.run(port=config.CRAWL_SERVER_PORT, host="0.0.0.0", ssl_context="adhoc") diff --git a/database.py b/database.py index 88ec7f5..ffc8993 100644 --- a/database.py +++ b/database.py @@ -4,6 +4,7 @@ from urllib.parse import urlparse import os import bcrypt import uuid +import task class InvalidQueryException(Exception): @@ -277,6 +278,33 @@ class Database: cursor.execute("SELECT * FROM BlacklistedWebsite") return [BlacklistedWebsite(r[0], r[1]) for r in cursor.fetchall()] + def add_crawl_server(self, server: task.CrawlServer): + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + cursor.execute("INSERT INTO CrawlServer (url, name, slots, token) VALUES (?,?,?,?)", + (server.url, server.name, server.slots, server.token)) + conn.commit() + + def remove_crawl_server(self, server_id): + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + cursor.execute("DELETE FROM CrawlServer WHERE id=?", (server_id, )) + conn.commit() + + def get_crawl_servers(self) -> list: + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + cursor.execute("SELECT url, name, slots, token, id FROM CrawlServer") + + return [task.CrawlServer(r[0], r[1], r[2], r[3], r[4]) for r in cursor.fetchall()] + + diff --git a/init_script.sql b/init_script.sql index 9b24f92..0b07f12 100644 --- a/init_script.sql +++ b/init_script.sql @@ -23,3 +23,11 @@ CREATE TABLE BlacklistedWebsite ( id INTEGER PRIMARY KEY NOT NULL, url TEXT ); + +CREATE TABLE CrawlServer ( + id INTEGER PRIMARY KEY NOT NULL, + url TEXT, + name TEXT, + token TEXT, + slots INTEGER +) diff --git a/requirements.txt b/requirements.txt index 2aebe99..fadfad1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ python-dateutil flask_httpauth ujson timeout_decorator +urllib3 \ No newline at end of file diff --git a/search/search.py b/search/search.py index 2410ce1..459a4cc 100644 --- a/search/search.py +++ b/search/search.py @@ -168,7 +168,7 @@ class ElasticSearchEngine(SearchEngine): "path": {"pre_tags": [""], "post_tags": [""]} } }, - "size": per_page, "from": page * per_page}, index=self.index_name) + "size": per_page, "from": min(page * per_page, 10000 - per_page)}, index=self.index_name) return page diff --git a/task.py b/task.py index 5366acc..1f1c830 100644 --- a/task.py +++ b/task.py @@ -1,31 +1,41 @@ +import random + from apscheduler.schedulers.background import BackgroundScheduler from search.search import ElasticSearchEngine from crawl_server.database import Task, TaskResult import requests from requests.exceptions import ConnectionError import json -import config -from database import Database +import database +from concurrent.futures import ThreadPoolExecutor +import urllib3 + +urllib3.disable_warnings() class CrawlServer: - headers = { - "Content-Type": "application/json", - "Authorization": "Token " + config.CRAWL_SERVER_TOKEN, - } - - def __init__(self, url, name): + def __init__(self, url, name, slots, token, server_id=None): self.url = url self.name = name + self.slots = slots + self.used_slots = 0 + self.token = token + self.id = server_id + + def _generate_headers(self): + return { + "Content-Type": "application/json", + "Authorization": "Token " + self.token, + } def queue_task(self, task: Task) -> bool: print("Sending task to crawl server " + self.url) try: payload = json.dumps(task.to_json()) - r = requests.post(self.url + "/task/put", headers=CrawlServer.headers, data=payload) - print(r) + r = requests.post(self.url + "/task/put", headers=self._generate_headers(), data=payload, verify=False) + print(r) # TODO: If the task could not be added, fallback to another server return r.status_code == 200 except ConnectionError: return False @@ -33,40 +43,63 @@ class CrawlServer: def fetch_completed_tasks(self) -> list: try: - r = requests.get(self.url + "/task/completed", headers=CrawlServer.headers) + r = requests.get(self.url + "/task/completed", headers=self._generate_headers(), verify=False) + if r.status_code != 200: + print("Problem while fetching completed tasks for '" + self.name + "': " + str(r.status_code)) + print(r.text) + return [] return [ TaskResult(r["status_code"], r["file_count"], r["start_time"], r["end_time"], r["website_id"]) for r in json.loads(r.text)] except ConnectionError: - print("Crawl server cannot be reached " + self.url) + print("Crawl server cannot be reached @ " + self.url) return [] - def fetch_queued_tasks(self) -> list: + def fetch_queued_tasks(self): try: - r = requests.get(self.url + "/task/", headers=CrawlServer.headers) + r = requests.get(self.url + "/task/", headers=self._generate_headers(), verify=False) + + if r.status_code != 200: + print("Problem while fetching queued tasks for '" + self.name + "' " + str(r.status_code)) + print(r.text) + return None + return [ Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"]) for t in json.loads(r.text) ] except ConnectionError: - return [] + return None def fetch_current_tasks(self): try: - r = requests.get(self.url + "/task/current", headers=CrawlServer.headers) + r = requests.get(self.url + "/task/current", headers=self._generate_headers(), verify=False) + + if r.status_code != 200: + print("Problem while fetching current tasks for '" + self.name + "' " + str(r.status_code)) + print(r.text) + return None + return [ Task(t["website_id"], t["url"], t["priority"], t["callback_type"], t["callback_args"]) for t in json.loads(r.text) ] except ConnectionError: - return [] + return None def fetch_website_files(self, website_id) -> str: try: - r = requests.get(self.url + "/file_list/" + str(website_id) + "/", stream=True, headers=CrawlServer.headers) + r = requests.get(self.url + "/file_list/" + str(website_id) + "/", stream=True, + headers=self._generate_headers(), verify=False) + + if r.status_code != 200: + print("Problem while fetching website files for '" + self.name + "': " + str(r.status_code)) + print(r.text) + return "" + for line in r.iter_lines(chunk_size=1024 * 256): yield line except ConnectionError: @@ -75,7 +108,8 @@ class CrawlServer: def free_website_files(self, website_id) -> bool: try: - r = requests.get(self.url + "/file_list/" + str(website_id) + "/free", headers=CrawlServer.headers) + r = requests.get(self.url + "/file_list/" + str(website_id) + "/free", headers=self._generate_headers(), + verify=False) return r.status_code == 200 except ConnectionError as e: print(e) @@ -84,16 +118,29 @@ class CrawlServer: def fetch_crawl_logs(self): try: - r = requests.get(self.url + "/task/logs/", headers=CrawlServer.headers) + r = requests.get(self.url + "/task/logs/", headers=self._generate_headers(), verify=False) + + if r.status_code != 200: + print("Problem while fetching crawl logs for '" + self.name + "': " + str(r.status_code)) + print(r.text) + return [] + return [ - TaskResult(r["status_code"], r["file_count"], r["start_time"], r["end_time"], r["website_id"], r["indexed_time"]) + TaskResult(r["status_code"], r["file_count"], r["start_time"], + r["end_time"], r["website_id"], r["indexed_time"]) for r in json.loads(r.text)] except ConnectionError: return [] def fetch_stats(self): try: - r = requests.get(self.url + "/stats/", headers=CrawlServer.headers) + r = requests.get(self.url + "/stats/", headers=self._generate_headers(), verify=False) + + if r.status_code != 200: + print("Problem while fetching stats for '" + self.name + "': " + str(r.status_code)) + print(r.text) + return [] + return json.loads(r.text) except ConnectionError: return {} @@ -107,16 +154,11 @@ class TaskDispatcher: scheduler.start() self.search = ElasticSearchEngine("od-database") - self.db = Database("db.sqlite3") - - # TODO load from config - self.crawl_servers = [ - CrawlServer("http://localhost:5001", "OVH_VPS_SSD2 #1"), - ] + self.db = database.Database("db.sqlite3") def check_completed_tasks(self): - for server in self.crawl_servers: + for server in self.db.get_crawl_servers(): for task in server.fetch_completed_tasks(): print("Completed task") # All files are overwritten @@ -135,24 +177,63 @@ class TaskDispatcher: self._get_available_crawl_server().queue_task(task) def _get_available_crawl_server(self) -> CrawlServer: - # TODO: Load balancing & health check for crawl servers - return self.crawl_servers[0] + + queued_tasks_by_server = self._get_current_tasks_by_server() + server_with_most_free_slots = None + most_free_slots = 0 + + for server in queued_tasks_by_server: + free_slots = server.slots - len(queued_tasks_by_server[server]) + if free_slots > most_free_slots: + server_with_most_free_slots = server + most_free_slots = free_slots + + if server_with_most_free_slots: + print("Dispatching task to '" + + server_with_most_free_slots.name + "' " + + str(most_free_slots) + " free out of " + str(server_with_most_free_slots.slots)) + + return self.db.get_crawl_servers()[0] def get_queued_tasks(self) -> list: - queued_tasks = [] + queued_tasks_by_server = self._get_current_tasks_by_server() + for queued_tasks in queued_tasks_by_server.values(): + for task in queued_tasks: + yield task - for server in self.crawl_servers: - queued_tasks.extend(server.fetch_queued_tasks()) + def _get_queued_tasks_by_server(self) -> dict: + + queued_tasks = dict() + pool = ThreadPoolExecutor(max_workers=10) + crawl_servers = self.db.get_crawl_servers() + responses = list(pool.map(lambda server: server.fetch_queued_tasks())) + pool.shutdown() + + for i, server in enumerate(crawl_servers): + if responses[i] is not None: + queued_tasks[server] = responses[i] return queued_tasks - def get_current_tasks(self) -> list: - # TODO mem cache this + def get_current_tasks(self): - current_tasks = [] - for server in self.crawl_servers: - current_tasks.extend(server.fetch_current_tasks()) + current_tasks_by_server = self._get_current_tasks_by_server() + for current_tasks in current_tasks_by_server.values(): + for task in current_tasks: + yield task + + def _get_current_tasks_by_server(self) -> dict: + + current_tasks = dict() + pool = ThreadPoolExecutor(max_workers=10) + crawl_servers = self.db.get_crawl_servers() + responses = list(pool.map(lambda s: s.fetch_current_tasks(), crawl_servers)) + pool.shutdown() + + for i, server in enumerate(crawl_servers): + if responses[i] is not None: + current_tasks[server] = responses[i] return current_tasks @@ -160,7 +241,7 @@ class TaskDispatcher: task_logs = dict() - for server in self.crawl_servers: + for server in self.db.get_crawl_servers(): task_logs[server.name] = server.fetch_crawl_logs() return task_logs @@ -169,11 +250,9 @@ class TaskDispatcher: stats = dict() - for server in self.crawl_servers: + for server in self.db.get_crawl_servers(): server_stats = server.fetch_stats() if server_stats: stats[server.name] = server_stats return stats - - diff --git a/templates/dashboard.html b/templates/dashboard.html index efcd5bf..d05a236 100644 --- a/templates/dashboard.html +++ b/templates/dashboard.html @@ -7,6 +7,48 @@
Dashboard
+

Crawl servers

+ + + + + + + + + + + {% for server in crawl_servers %} + + + + + + + {% endfor %} + +
UrlNameSlotsAction
{{ server.url }}{{ server.name }}{{ server.slots }}Delete
+
+
+
+ +
+
+ +
+
+ +
+
+ +
+
+ +
+
+
+
+

API Keys