mirror of
https://github.com/simon987/sist2.git
synced 2025-12-12 15:08:53 +00:00
Rework user scripts, update DB schema to support embeddings
This commit is contained in:
@@ -18,12 +18,13 @@ from websockets.exceptions import ConnectionClosed
|
||||
|
||||
import cron
|
||||
from config import LOG_FOLDER, logger, WEBSERVER_PORT, DATA_FOLDER, SIST2_BINARY
|
||||
from jobs import Sist2Job, Sist2ScanTask, TaskQueue, Sist2IndexTask, JobStatus
|
||||
from jobs import Sist2Job, Sist2ScanTask, TaskQueue, Sist2IndexTask, JobStatus, Sist2UserScriptTask
|
||||
from notifications import Subscribe, Notifications
|
||||
from sist2 import Sist2, Sist2SearchBackend
|
||||
from state import migrate_v1_to_v2, RUNNING_FRONTENDS, TESSERACT_LANGS, DB_SCHEMA_VERSION, migrate_v3_to_v4, \
|
||||
get_log_files_to_remove, delete_log_file, create_default_search_backends
|
||||
from web import Sist2Frontend
|
||||
from script import UserScript, SCRIPT_TEMPLATES
|
||||
|
||||
sist2 = Sist2(SIST2_BINARY, DATA_FOLDER)
|
||||
db = PersistentState(dbfile=os.path.join(DATA_FOLDER, "state.db"))
|
||||
@@ -52,7 +53,8 @@ async def home():
|
||||
async def api():
|
||||
return {
|
||||
"tesseract_langs": TESSERACT_LANGS,
|
||||
"logs_folder": LOG_FOLDER
|
||||
"logs_folder": LOG_FOLDER,
|
||||
"user_script_templates": list(SCRIPT_TEMPLATES.keys())
|
||||
}
|
||||
|
||||
|
||||
@@ -113,8 +115,6 @@ async def update_job(name: str, new_job: Sist2Job):
|
||||
async def update_frontend(name: str, frontend: Sist2Frontend):
|
||||
db["frontends"][name] = frontend
|
||||
|
||||
# TODO: Check etag
|
||||
|
||||
return "ok"
|
||||
|
||||
|
||||
@@ -150,9 +150,21 @@ def _run_job(job: Sist2Job):
|
||||
db["jobs"][job.name] = job
|
||||
|
||||
scan_task = Sist2ScanTask(job, f"Scan [{job.name}]")
|
||||
index_task = Sist2IndexTask(job, f"Index [{job.name}]", depends_on=scan_task)
|
||||
|
||||
index_depends_on = scan_task
|
||||
script_tasks = []
|
||||
for script_name in job.user_scripts:
|
||||
script = db["user_scripts"][script_name]
|
||||
|
||||
task = Sist2UserScriptTask(script, job, f"Script <{script_name}> [{job.name}]", depends_on=scan_task)
|
||||
script_tasks.append(task)
|
||||
index_depends_on = task
|
||||
|
||||
index_task = Sist2IndexTask(job, f"Index [{job.name}]", depends_on=index_depends_on)
|
||||
|
||||
task_queue.submit(scan_task)
|
||||
for task in script_tasks:
|
||||
task_queue.submit(task)
|
||||
task_queue.submit(index_task)
|
||||
|
||||
|
||||
@@ -167,6 +179,22 @@ async def run_job(name: str):
|
||||
return "ok"
|
||||
|
||||
|
||||
@app.get("/api/user_script/{name:str}/run")
|
||||
def run_user_script(name: str, job: str):
|
||||
script = db["user_scripts"][name]
|
||||
if not script:
|
||||
raise HTTPException(status_code=404)
|
||||
job = db["jobs"][job]
|
||||
if not job:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
script_task = Sist2UserScriptTask(script, job, f"Script <{name}> [{job.name}]")
|
||||
|
||||
task_queue.submit(script_task)
|
||||
|
||||
return "ok"
|
||||
|
||||
|
||||
@app.get("/api/job/{name:str}/logs_to_delete")
|
||||
async def task_history(n: int, name: str):
|
||||
return get_log_files_to_remove(db, name, n)
|
||||
@@ -239,7 +267,7 @@ def check_es_version(es_url: str, insecure: bool):
|
||||
es_url = f"{url.scheme}://{url.hostname}:{url.port}"
|
||||
else:
|
||||
auth = None
|
||||
r = requests.get(es_url, verify=insecure, auth=auth)
|
||||
r = requests.get(es_url, verify=not insecure, auth=auth)
|
||||
except SSLError:
|
||||
return {
|
||||
"ok": False,
|
||||
@@ -375,6 +403,59 @@ def create_search_backend(name: str):
|
||||
return backend
|
||||
|
||||
|
||||
@app.delete("/api/user_script/{name:str}")
|
||||
def delete_user_script(name: str):
|
||||
if db["user_scripts"][name] is None:
|
||||
return HTTPException(status_code=404)
|
||||
|
||||
if any(name in job.user_scripts for job in db["jobs"]):
|
||||
raise HTTPException(status_code=400, detail="in use (job)")
|
||||
|
||||
script: UserScript = db["user_scripts"][name]
|
||||
script.delete_dir()
|
||||
|
||||
del db["user_scripts"][name]
|
||||
|
||||
return "ok"
|
||||
|
||||
|
||||
@app.post("/api/user_script/{name:str}")
|
||||
def create_user_script(name: str, template: str):
|
||||
if db["user_scripts"][name] is not None:
|
||||
return HTTPException(status_code=400, detail="already exists")
|
||||
|
||||
script = SCRIPT_TEMPLATES[template](name)
|
||||
db["user_scripts"][name] = script
|
||||
|
||||
return script
|
||||
|
||||
|
||||
@app.get("/api/user_script")
|
||||
async def get_user_scripts():
|
||||
return list(db["user_scripts"])
|
||||
|
||||
|
||||
@app.get("/api/user_script/{name:str}")
|
||||
async def get_user_script(name: str):
|
||||
backend = db["user_scripts"][name]
|
||||
if not backend:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@app.put("/api/user_script/{name:str}")
|
||||
async def update_user_script(name: str, script: UserScript):
|
||||
previous_version: UserScript = db["user_scripts"][name]
|
||||
|
||||
if previous_version and previous_version.git_repository != script.git_repository:
|
||||
script.force_clone = True
|
||||
|
||||
db["user_scripts"][name] = script
|
||||
|
||||
return "ok"
|
||||
|
||||
|
||||
def tail(filepath: str, n: int):
|
||||
with open(filepath) as file:
|
||||
|
||||
@@ -479,7 +560,8 @@ if __name__ == '__main__':
|
||||
migrate_v3_to_v4(db)
|
||||
|
||||
if db["sist2_admin"]["info"]["version"] != DB_SCHEMA_VERSION:
|
||||
raise Exception(f"Incompatible database version for {db.dbfile}")
|
||||
raise Exception(f"Incompatible database {db.dbfile}. "
|
||||
f"Automatic migration is not available, please delete the database file to continue.")
|
||||
|
||||
start_frontends()
|
||||
cron.initialize(db, _run_job)
|
||||
|
||||
@@ -9,9 +9,11 @@ MAX_LOG_SIZE = 1 * 1024 * 1024
|
||||
SIST2_BINARY = os.environ.get("SIST2_BINARY", "/root/sist2")
|
||||
DATA_FOLDER = os.environ.get("DATA_FOLDER", "/sist2-admin/")
|
||||
LOG_FOLDER = os.path.join(DATA_FOLDER, "logs")
|
||||
SCRIPT_FOLDER = os.path.join(DATA_FOLDER, "scripts")
|
||||
WEBSERVER_PORT = 8080
|
||||
|
||||
os.makedirs(LOG_FOLDER, exist_ok=True)
|
||||
os.makedirs(SCRIPT_FOLDER, exist_ok=True)
|
||||
os.makedirs(DATA_FOLDER, exist_ok=True)
|
||||
|
||||
logger = logging.Logger("sist2-admin")
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shlex
|
||||
import signal
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import TextIOWrapper
|
||||
from logging import FileHandler
|
||||
from subprocess import Popen
|
||||
import subprocess
|
||||
from threading import Lock, Thread
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from hexlib.db import PersistentState
|
||||
@@ -18,6 +23,7 @@ from notifications import Notifications
|
||||
from sist2 import ScanOptions, IndexOptions, Sist2
|
||||
from state import RUNNING_FRONTENDS, get_log_files_to_remove, delete_log_file
|
||||
from web import Sist2Frontend
|
||||
from script import UserScript
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
@@ -32,6 +38,8 @@ class Sist2Job(BaseModel):
|
||||
scan_options: ScanOptions
|
||||
index_options: IndexOptions
|
||||
|
||||
user_scripts: List[str] = []
|
||||
|
||||
cron_expression: str
|
||||
schedule_enabled: bool = False
|
||||
|
||||
@@ -182,7 +190,7 @@ class Sist2IndexTask(Sist2Task):
|
||||
|
||||
duration = self.ended - self.started
|
||||
|
||||
ok = return_code == 0
|
||||
ok = return_code in (0, 1)
|
||||
|
||||
if ok:
|
||||
self.restart_running_frontends(db, sist2)
|
||||
@@ -231,6 +239,65 @@ class Sist2IndexTask(Sist2Task):
|
||||
self._logger.info(json.dumps({"sist2-admin": f"Restart frontend {pid=} {frontend_name=}"}))
|
||||
|
||||
|
||||
class Sist2UserScriptTask(Sist2Task):
|
||||
|
||||
def __init__(self, user_script: UserScript, job: Sist2Job, display_name: str, depends_on: Sist2Task = None):
|
||||
super().__init__(job, display_name, depends_on=depends_on.id if depends_on else None)
|
||||
self.user_script = user_script
|
||||
|
||||
def run(self, sist2: Sist2, db: PersistentState):
|
||||
super().run(sist2, db)
|
||||
|
||||
try:
|
||||
self.user_script.setup(self.log_callback)
|
||||
except Exception as e:
|
||||
logger.error(f"Setup for {self.user_script.name} failed: ")
|
||||
logger.exception(e)
|
||||
self.log_callback({"sist2-admin": f"Setup for {self.user_script.name} failed: {e}"})
|
||||
return -1
|
||||
|
||||
executable = self.user_script.get_executable()
|
||||
index_path = os.path.join(DATA_FOLDER, self.job.index_path)
|
||||
extra_args = self.user_script.extra_args
|
||||
|
||||
args = [
|
||||
executable,
|
||||
index_path,
|
||||
*shlex.split(extra_args)
|
||||
]
|
||||
|
||||
self.log_callback({"sist2-admin": f"Starting user script with {executable=}, {index_path=}, {extra_args=}"})
|
||||
|
||||
proc = Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.user_script.script_dir())
|
||||
self.pid = proc.pid
|
||||
|
||||
t_stderr = Thread(target=self._consume_logs, args=(self.log_callback, proc, "stderr", False))
|
||||
t_stderr.start()
|
||||
|
||||
self._consume_logs(self.log_callback, proc, "stdout", True)
|
||||
|
||||
self.ended = datetime.utcnow()
|
||||
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _consume_logs(logs_cb, proc, stream, wait):
|
||||
pipe_wrapper = TextIOWrapper(getattr(proc, stream), encoding="utf8", errors="ignore")
|
||||
try:
|
||||
for line in pipe_wrapper:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
if line.startswith("$PROGRESS"):
|
||||
progress = json.loads(line[len("$PROGRESS "):])
|
||||
logs_cb({"progress": progress})
|
||||
continue
|
||||
logs_cb({stream: line})
|
||||
finally:
|
||||
if wait:
|
||||
proc.wait()
|
||||
pipe_wrapper.close()
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
def __init__(self, sist2: Sist2, db: PersistentState, notifications: Notifications):
|
||||
self._lock = Lock()
|
||||
|
||||
126
sist2-admin/sist2_admin/script.py
Normal file
126
sist2-admin/sist2_admin/script.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
|
||||
from git import Repo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import SCRIPT_FOLDER
|
||||
|
||||
|
||||
class ScriptType(Enum):
|
||||
LOCAL = "local"
|
||||
SIMPLE = "simple"
|
||||
GIT = "git"
|
||||
|
||||
|
||||
def set_executable(file):
|
||||
os.chmod(file, os.stat(file).st_mode | stat.S_IEXEC)
|
||||
|
||||
|
||||
def _initialize_git_repository(url, path, log_cb, force_clone):
|
||||
log_cb({"sist2-admin": f"Cloning {url}"})
|
||||
|
||||
if force_clone or not os.path.exists(os.path.join(path, ".git")):
|
||||
if force_clone:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
Repo.clone_from(url, path)
|
||||
else:
|
||||
repo = Repo(path)
|
||||
repo.remote("origin").pull()
|
||||
|
||||
setup_script = os.path.join(path, "setup.sh")
|
||||
if setup_script:
|
||||
log_cb({"sist2-admin": f"Executing setup script {setup_script}"})
|
||||
|
||||
set_executable(setup_script)
|
||||
result = subprocess.run([setup_script], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
for line in result.stdout.split(b"\n"):
|
||||
if line:
|
||||
log_cb({"stdout": line.decode()})
|
||||
|
||||
log_cb({"stdout": f"Executed setup script {setup_script}, return code = {result.returncode}"})
|
||||
|
||||
if result.returncode != 0:
|
||||
raise Exception("Error when running setup script!")
|
||||
|
||||
log_cb({"sist2-admin": f"Initialized git repository in {path}"})
|
||||
|
||||
|
||||
class UserScript(BaseModel):
|
||||
name: str
|
||||
type: ScriptType
|
||||
git_repository: str = None
|
||||
force_clone: bool = False
|
||||
script: str = None
|
||||
extra_args: str = ""
|
||||
|
||||
def script_dir(self):
|
||||
return os.path.join(SCRIPT_FOLDER, self.name)
|
||||
|
||||
def setup(self, log_cb):
|
||||
os.makedirs(self.script_dir(), exist_ok=True)
|
||||
|
||||
if self.type == ScriptType.GIT:
|
||||
_initialize_git_repository(self.git_repository, self.script_dir(), log_cb, self.force_clone)
|
||||
self.force_clone = False
|
||||
elif self.type == ScriptType.SIMPLE:
|
||||
self._setup_simple()
|
||||
|
||||
set_executable(self.get_executable())
|
||||
|
||||
def _setup_simple(self):
|
||||
with open(self.get_executable(), "w") as f:
|
||||
f.write(
|
||||
"#!/bin/bash\n"
|
||||
"python run.py \"$@\""
|
||||
)
|
||||
|
||||
with open(os.path.join(self.script_dir(), "run.py"), "w") as f:
|
||||
f.write(self.script)
|
||||
|
||||
def get_executable(self):
|
||||
return os.path.join(self.script_dir(), "run.sh")
|
||||
|
||||
def delete_dir(self):
|
||||
shutil.rmtree(self.script_dir(), ignore_errors=True)
|
||||
|
||||
|
||||
SCRIPT_TEMPLATES = {
|
||||
"CLIP - Generate embeddings to predict the most relevant image based on the text prompt": lambda name: UserScript(
|
||||
name=name,
|
||||
type=ScriptType.GIT,
|
||||
git_repository="https://github.com/simon987/sist2-script-clip",
|
||||
extra_args="--num-tags=1 --tags-file=general.txt --color=#dcd7ff"
|
||||
),
|
||||
"Whisper - Speech to text with OpenAI Whisper": lambda name: UserScript(
|
||||
name=name,
|
||||
type=ScriptType.GIT,
|
||||
git_repository="https://github.com/simon987/sist2-script-whisper",
|
||||
extra_args="--model=base --num-threads=4 --color=#51da4c --tag"
|
||||
),
|
||||
"Hamburger - Simple script example": lambda name: UserScript(
|
||||
name=name,
|
||||
type=ScriptType.SIMPLE,
|
||||
script=
|
||||
'from sist2 import Sist2Index\n'
|
||||
'import sys\n'
|
||||
'\n'
|
||||
'index = Sist2Index(sys.argv[1])\n'
|
||||
'for doc in index.document_iter():\n'
|
||||
' doc.json_data["tag"] = ["hamburger.#00FF00"]\n'
|
||||
' index.update_document(doc)\n'
|
||||
'\n'
|
||||
'index.sync_tag_table()\n'
|
||||
'index.commit()\n'
|
||||
'\n'
|
||||
'print("Done!")\n'
|
||||
),
|
||||
"(Blank)": lambda name: UserScript(
|
||||
name=name,
|
||||
type=ScriptType.SIMPLE,
|
||||
script=""
|
||||
)
|
||||
}
|
||||
@@ -41,8 +41,6 @@ class Sist2SearchBackend(BaseModel):
|
||||
es_insecure_ssl: bool = False
|
||||
es_index: str = "sist2"
|
||||
threads: int = 1
|
||||
script: str = ""
|
||||
script_file: str = None
|
||||
batch_size: int = 70
|
||||
|
||||
@staticmethod
|
||||
@@ -74,8 +72,6 @@ class IndexOptions(BaseModel):
|
||||
f"--es-index={search_backend.es_index}",
|
||||
f"--batch-size={search_backend.batch_size}"]
|
||||
|
||||
if search_backend.script_file:
|
||||
args.append(f"--script-file={search_backend.script_file}")
|
||||
if search_backend.es_insecure_ssl:
|
||||
args.append(f"--es-insecure-ssl")
|
||||
if self.incremental_index:
|
||||
@@ -249,13 +245,6 @@ class Sist2:
|
||||
|
||||
def index(self, options: IndexOptions, search_backend: Sist2SearchBackend, logs_cb):
|
||||
|
||||
if search_backend.script and search_backend.backend_type == SearchBackendType("elasticsearch"):
|
||||
with NamedTemporaryFile("w", prefix="sist2-admin", suffix=".painless", delete=False) as f:
|
||||
f.write(search_backend.script)
|
||||
search_backend.script_file = f.name
|
||||
else:
|
||||
search_backend.script_file = None
|
||||
|
||||
args = [
|
||||
self.bin_path,
|
||||
*options.args(search_backend),
|
||||
|
||||
@@ -14,7 +14,7 @@ RUNNING_FRONTENDS: Dict[str, int] = {}
|
||||
|
||||
TESSERACT_LANGS = get_tesseract_langs()
|
||||
|
||||
DB_SCHEMA_VERSION = "4"
|
||||
DB_SCHEMA_VERSION = "5"
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user