mirror of
https://github.com/simon987/sist2.git
synced 2025-04-16 00:46:43 +00:00
316 lines
9.2 KiB
Python
316 lines
9.2 KiB
Python
import json
|
|
import logging
|
|
import os.path
|
|
import shutil
|
|
import signal
|
|
import uuid
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from hashlib import md5
|
|
from logging import FileHandler
|
|
from threading import Lock, Thread
|
|
from time import sleep
|
|
from uuid import uuid4, UUID
|
|
|
|
from hexlib.db import PersistentState
|
|
from pydantic import BaseModel, validator
|
|
|
|
from config import logger, LOG_FOLDER
|
|
from notifications import Notifications
|
|
from sist2 import ScanOptions, IndexOptions, Sist2, Sist2Index
|
|
from state import RUNNING_FRONTENDS
|
|
from web import Sist2Frontend
|
|
|
|
|
|
class JobStatus(Enum):
|
|
CREATED = "created"
|
|
STARTED = "started"
|
|
INDEXED = "indexed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class Sist2Job(BaseModel):
|
|
name: str
|
|
scan_options: ScanOptions
|
|
index_options: IndexOptions
|
|
|
|
cron_expression: str
|
|
schedule_enabled: bool = False
|
|
|
|
previous_index: str = None
|
|
last_index: str = None
|
|
last_index_date: datetime = None
|
|
status: JobStatus = JobStatus("created")
|
|
last_modified: datetime
|
|
etag: str = None
|
|
do_full_scan: bool = False
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
@staticmethod
|
|
def create_default(name: str):
|
|
return Sist2Job(
|
|
name=name,
|
|
scan_options=ScanOptions(path="/"),
|
|
index_options=IndexOptions(),
|
|
last_modified=datetime.now(),
|
|
cron_expression="0 0 * * *"
|
|
)
|
|
|
|
@validator("etag", always=True)
|
|
def validate_etag(cls, value, values):
|
|
s = values["name"] + values["scan_options"].json() + values["index_options"].json() + values["cron_expression"]
|
|
return md5(s.encode()).hexdigest()
|
|
|
|
|
|
class Sist2TaskProgress:
|
|
|
|
def __init__(self, done: int = 0, count: int = 0, index_size: int = 0, tn_size: int = 0, waiting: bool = False):
|
|
self.done = done
|
|
self.count = count
|
|
self.index_size = index_size
|
|
self.store_size = tn_size
|
|
self.waiting = waiting
|
|
|
|
def percent(self):
|
|
return (self.done / self.count) if self.count else 0
|
|
|
|
|
|
class Sist2Task:
|
|
|
|
def __init__(self, job: Sist2Job, display_name: str, depends_on: uuid.UUID = None):
|
|
self.job = job
|
|
self.display_name = display_name
|
|
|
|
self.progress = Sist2TaskProgress()
|
|
self.id = uuid4()
|
|
self.pid = None
|
|
self.started = None
|
|
self.ended = None
|
|
self.depends_on = depends_on
|
|
|
|
self._logger = logging.Logger(name=f"{self.id}")
|
|
self._logger.addHandler(FileHandler(os.path.join(LOG_FOLDER, f"sist2-{self.id}.log")))
|
|
|
|
def json(self):
|
|
return {
|
|
"id": self.id,
|
|
"job": self.job,
|
|
"display_name": self.display_name,
|
|
"progress": self.progress,
|
|
"started": self.started,
|
|
"ended": self.ended,
|
|
"depends_on": self.depends_on,
|
|
}
|
|
|
|
def log_callback(self, log_json):
|
|
|
|
if "progress" in log_json:
|
|
self.progress = Sist2TaskProgress(**log_json["progress"])
|
|
elif self._logger:
|
|
self._logger.info(json.dumps(log_json))
|
|
|
|
def run(self, sist2: Sist2, db: PersistentState):
|
|
self.started = datetime.now()
|
|
|
|
logger.info(f"Started task {self.display_name}")
|
|
|
|
|
|
class Sist2ScanTask(Sist2Task):
|
|
|
|
def run(self, sist2: Sist2, db: PersistentState):
|
|
super().run(sist2, db)
|
|
|
|
self.job.scan_options.name = self.job.name
|
|
|
|
if self.job.last_index and os.path.exists(self.job.last_index) and not self.job.do_full_scan:
|
|
self.job.scan_options.incremental = self.job.last_index
|
|
else:
|
|
self.job.scan_options.incremental = None
|
|
|
|
def set_pid(pid):
|
|
self.pid = pid
|
|
|
|
return_code = sist2.scan(self.job.scan_options, logs_cb=self.log_callback, set_pid_cb=set_pid)
|
|
self.ended = datetime.now()
|
|
|
|
if return_code != 0:
|
|
self._logger.error(json.dumps({"sist2-admin": f"Process returned non-zero exit code ({return_code})"}))
|
|
logger.info(f"Task {self.display_name} failed ({return_code})")
|
|
else:
|
|
index = Sist2Index(self.job.scan_options.output)
|
|
|
|
# Save latest index
|
|
self.job.previous_index = self.job.last_index
|
|
|
|
self.job.last_index = index.path
|
|
self.job.last_index_date = datetime.now()
|
|
self.job.do_full_scan = False
|
|
db["jobs"][self.job.name] = {"job": self.job}
|
|
self._logger.info(json.dumps({"sist2-admin": f"Save last_index={self.job.last_index}"}))
|
|
|
|
logger.info(f"Completed {self.display_name} ({return_code=})")
|
|
|
|
return return_code
|
|
|
|
|
|
class Sist2IndexTask(Sist2Task):
|
|
|
|
def __init__(self, job: Sist2Job, display_name: str, depends_on: Sist2Task):
|
|
super().__init__(job, display_name, depends_on=depends_on.id)
|
|
|
|
def run(self, sist2: Sist2, db: PersistentState):
|
|
super().run(sist2, db)
|
|
|
|
self.job.index_options.path = self.job.scan_options.output
|
|
|
|
return_code = sist2.index(self.job.index_options, logs_cb=self.log_callback)
|
|
self.ended = datetime.now()
|
|
|
|
duration = self.ended - self.started
|
|
|
|
ok = return_code == 0
|
|
|
|
if ok:
|
|
# Remove old index
|
|
if self.job.previous_index is not None:
|
|
self._logger.info(json.dumps({"sist2-admin": f"Remove {self.job.previous_index=}"}))
|
|
try:
|
|
shutil.rmtree(self.job.previous_index)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
self.restart_running_frontends(db, sist2)
|
|
|
|
# Update status
|
|
self.job.status = JobStatus("indexed") if ok else JobStatus("failed")
|
|
db["jobs"][self.job.name] = {"job": self.job}
|
|
|
|
self._logger.info(json.dumps({"sist2-admin": f"Sist2Scan task finished {return_code=}, {duration=}"}))
|
|
|
|
logger.info(f"Completed {self.display_name} ({return_code=})")
|
|
|
|
return return_code
|
|
|
|
def restart_running_frontends(self, db: PersistentState, sist2: Sist2):
|
|
for frontend_name, pid in RUNNING_FRONTENDS.items():
|
|
frontend = db["frontends"][frontend_name]["frontend"]
|
|
frontend: Sist2Frontend
|
|
|
|
os.kill(pid, signal.SIGTERM)
|
|
try:
|
|
os.wait()
|
|
except ChildProcessError:
|
|
pass
|
|
|
|
frontend.web_options.indices = map(lambda j: db["jobs"][j]["job"].last_index, frontend.jobs)
|
|
|
|
pid = sist2.web(frontend.web_options, frontend.name)
|
|
RUNNING_FRONTENDS[frontend_name] = pid
|
|
|
|
self._logger.info(json.dumps({"sist2-admin": f"Restart frontend {pid=} {frontend_name=}"}))
|
|
|
|
|
|
class TaskQueue:
|
|
def __init__(self, sist2: Sist2, db: PersistentState, notifications: Notifications):
|
|
self._lock = Lock()
|
|
|
|
self._sist2 = sist2
|
|
self._db = db
|
|
self._notifications = notifications
|
|
|
|
self._tasks = {}
|
|
self._queue = []
|
|
self._sem = 0
|
|
|
|
self._thread = Thread(target=self._check_new_task, daemon=True)
|
|
self._thread.start()
|
|
|
|
def _tasks_failed(self):
|
|
done = set()
|
|
|
|
for row in self._db["task_done"].sql("WHERE return_code != 0"):
|
|
done.add(uuid.UUID(row["id"]))
|
|
|
|
return done
|
|
|
|
def _tasks_done(self):
|
|
|
|
done = set()
|
|
|
|
for row in self._db["task_done"]:
|
|
done.add(uuid.UUID(row["id"]))
|
|
|
|
return done
|
|
|
|
def _check_new_task(self):
|
|
while True:
|
|
with self._lock:
|
|
for task in list(self._queue):
|
|
task: Sist2Task
|
|
|
|
if self._sem >= 1:
|
|
break
|
|
|
|
if not task.depends_on or task.depends_on in self._tasks_done():
|
|
self._queue.remove(task)
|
|
|
|
if task.depends_on in self._tasks_failed():
|
|
# The task which we depend on failed, continue
|
|
continue
|
|
|
|
self._sem += 1
|
|
|
|
t = Thread(target=self._run_task, args=(task,))
|
|
|
|
self._tasks[task.id] = {
|
|
"task": task,
|
|
"thread": t,
|
|
}
|
|
|
|
t.start()
|
|
break
|
|
sleep(1)
|
|
|
|
def tasks(self):
|
|
return list(map(lambda t: t["task"], self._tasks.values()))
|
|
|
|
def kill_task(self, task_id):
|
|
|
|
task = self._tasks.get(UUID(task_id))
|
|
|
|
if task:
|
|
pid = task["task"].pid
|
|
logger.info(f"Killing task {task_id} (pid={pid})")
|
|
os.kill(pid, signal.SIGTERM)
|
|
return True
|
|
|
|
return False
|
|
|
|
def _run_task(self, task: Sist2Task):
|
|
task_result = task.run(self._sist2, self._db)
|
|
|
|
with self._lock:
|
|
del self._tasks[task.id]
|
|
self._sem -= 1
|
|
|
|
self._db["task_done"][task.id] = {
|
|
"ended": task.ended,
|
|
"started": task.started,
|
|
"name": task.display_name,
|
|
"return_code": task_result
|
|
}
|
|
if isinstance(task, Sist2IndexTask):
|
|
self._notifications.notify({
|
|
"message": "notifications.indexCompleted",
|
|
"job": task.job.name
|
|
})
|
|
|
|
def submit(self, task: Sist2Task):
|
|
|
|
logger.info(f"Submitted task to queue {task.display_name}")
|
|
|
|
with self._lock:
|
|
self._queue.append(task)
|