1
0
mirror of https://github.com/simon987/sist2.git synced 2025-04-08 13:06:47 +00:00

Rework user scripts, update DB schema to support embeddings

This commit is contained in:
simon987 2023-08-19 15:46:19 -04:00
parent 27188b6fa0
commit 857f3315c2
62 changed files with 1842 additions and 1250 deletions

@ -146,17 +146,17 @@ sist2 v3.0.7+ supports SQLite search backend. The SQLite search backend has
fewer features and generally comparable query performance for medium-size fewer features and generally comparable query performance for medium-size
indices, but it uses much less memory and is easier to set up. indices, but it uses much less memory and is easier to set up.
| | SQLite | Elasticsearch | | | SQLite | Elasticsearch |
|----------------------------------------------|:----------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------:| |----------------------------------------------|:---------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------:|
| Requires separate search engine installation | | ✓ | | Requires separate search engine installation | | ✓ |
| Memory footprint | ~20MB | >500MB | | Memory footprint | ~20MB | >500MB |
| Query syntax | [fts5](https://www.sqlite.org/fts5.html) | [query_string](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#query-string-syntax) | | Query syntax | [fts5](https://www.sqlite.org/fts5.html) | [query_string](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html#query-string-syntax) |
| Fuzzy search | | ✓ | | Fuzzy search | | ✓ |
| Media Types tree real-time updating | | ✓ | | Media Types tree real-time updating | | ✓ |
| Search in file `path` | | ✓ | | Search in file `path` | [WIP](https://github.com/simon987/sist2/issues/402) | ✓ |
| Manual tagging | ✓ | ✓ | | Manual tagging | | ✓ |
| User scripts | | ✓ | | User scripts | | ✓ |
| Media Type breakdown for search results | | ✓ | | Media Type breakdown for search results | | ✓ |
### NER ### NER
@ -206,7 +206,7 @@ docker run --rm --entrypoint cat my-sist2-image /root/sist2 > sist2-x64-linux
3. Install vcpkg dependencies 3. Install vcpkg dependencies
```bash ```bash
vcpkg install curl[core,openssl] sqlite3[core,fts5] cpp-jwt pcre cjson brotli libarchive[core,bzip2,libxml2,lz4,lzma,lzo] pthread tesseract libxml2 libmupdf gtest mongoose libmagic libraw gumbo ffmpeg[core,avcodec,avformat,swscale,swresample,webp,opus,mp3lame,vpx,ffprobe,zlib] vcpkg install openblas curl[core,openssl] sqlite3[core,fts5] cpp-jwt pcre cjson brotli libarchive[core,bzip2,libxml2,lz4,lzma,lzo] pthread tesseract libxml2 libmupdf gtest mongoose libmagic libraw gumbo ffmpeg[core,avcodec,avformat,swscale,swresample,webp,opus,mp3lame,vpx,ffprobe,zlib]
``` ```
4. Build 4. Build

@ -5,7 +5,6 @@ Usage: sist2 scan [OPTION]... PATH
or: sist2 index [OPTION]... INDEX or: sist2 index [OPTION]... INDEX
or: sist2 sqlite-index [OPTION]... INDEX or: sist2 sqlite-index [OPTION]... INDEX
or: sist2 web [OPTION]... INDEX... or: sist2 web [OPTION]... INDEX...
or: sist2 exec-script [OPTION]... INDEX
Lightning-fast file system indexer and search tool. Lightning-fast file system indexer and search tool.
@ -74,13 +73,6 @@ Web options
--dev Serve html & js files from disk (for development) --dev Serve html & js files from disk (for development)
--lang=<str> Default UI language. Can be changed by the user --lang=<str> Default UI language. Can be changed by the user
Exec-script options
--es-url=<str> Elasticsearch url. DEFAULT: http://localhost:9200
--es-insecure-ssl Do not verify SSL connections to Elasticsearch.
--es-index=<str> Elasticsearch index name. DEFAULT: sist2
--script-file=<str> Path to user script.
--async-script Execute user script asynchronously.
Made by simon987 <me@simon987.net>. Released under GPL-3.0 Made by simon987 <me@simon987.net>. Released under GPL-3.0
``` ```
@ -183,11 +175,6 @@ Using a version >=7.14.0 is recommended to enable the following features:
When using a legacy version of ES, a notice will be displayed next to the sist2 version in the web UI. When using a legacy version of ES, a notice will be displayed next to the sist2 version in the web UI.
If you don't care about the features above, you can ignore it or disable it in the configuration page. If you don't care about the features above, you can ignore it or disable it in the configuration page.
## exec-script
The `exec-script` command is used to execute a user script for an index that has already been imported to Elasticsearch with the `index` command. Note that the documents will not be reset to their default state before each execution as the `index` command does: if you make undesired changes to the documents by accident, you will need to run `index` again to revert to the original state.
# Tagging # Tagging
### Manual tagging ### Manual tagging

@ -1,18 +1,47 @@
## User scripts ## User scripts
*This document is under construction, more in-depth guide coming soon* User scripts are used to augment your sist2 index with additional metadata, neural network embeddings, tags etc.
Since version 3.2.0, user scripts are written in Python, and are ran against the sist2 index file. User scripts do not
need a connection to the search backend.
You can create a user script based on a template from the sist2-admin interface:
![sist2-admin-scripts](sist2-admin-scripts.png)
User scripts leverage the [sist2-python](https://github.com/simon987/sist2-python) library to interface with the
index file*. You can find sist2-python documentation and examples
here: [sist2-python.readthedocs.io](https://sist2-python.readthedocs.io/).
If you are not using the sist2-admin interface, you can run user scripts manually from the command line:
```
pip install git+https://github.com/simon987/sist2-python.git
python my_script.py /path/to/my_index.sist2
```
\* It is possible to manually update the index using raw SQL queries, but the database schema is not stable and
can change at any time; it is recommended to use the more stable sist2-python wrapper instead.
<hr>
<details>
<summary>Legacy user scripts (sist2 version < 3.2.0)</summary>
During the `index` step, you can use the `--script-file <script>` option to During the `index` step, you can use the `--script-file <script>` option to
modify documents or add user tags. This option is mainly used to modify documents or add user tags. This option is mainly used to
implement automatic tagging based on file attributes. implement automatic tagging based on file attributes.
The scripting language used The scripting language used
([Painless Scripting Language](https://www.elastic.co/guide/en/elasticsearch/painless/7.4/index.html)) ([Painless Scripting Language](https://www.elastic.co/guide/en/elasticsearch/painless/7.4/index.html))
is very similar to Java, but you should be able to create user scripts is very similar to Java, but you should be able to create user scripts
without programming experience at all if you're somewhat familiar with without programming experience at all if you're somewhat familiar with
regex. regex.
This is the base structure of the documents we're working with: This is the base structure of the documents we're working with:
```json ```json
{ {
"_id": "e171405c-fdb5-4feb-bb32-82637bc32084", "_id": "e171405c-fdb5-4feb-bb32-82637bc32084",
@ -34,7 +63,8 @@ This is the base structure of the documents we're working with:
**Example script** **Example script**
This script checks if the `genre` attribute exists, if it does This script checks if the `genre` attribute exists, if it does
it adds the `genre.<genre>` tag. it adds the `genre.<genre>` tag.
```Java ```Java
ArrayList tags = ctx._source.tag = new ArrayList(); ArrayList tags = ctx._source.tag = new ArrayList();
@ -47,21 +77,23 @@ You can use `.` to create a hierarchical tag tree:
![scripting/genre_example](genre_example.png) ![scripting/genre_example](genre_example.png)
To use regular expressions, you need to add this line in `/etc/elasticsearch/elasticsearch.yml` To use regular expressions, you need to add this line in `/etc/elasticsearch/elasticsearch.yml`
```yaml ```yaml
script.painless.regex.enabled: true script.painless.regex.enabled: true
``` ```
Or, if you're using docker add `-e "script.painless.regex.enabled=true"` Or, if you're using docker add `-e "script.painless.regex.enabled=true"`
**Tag color** **Tag color**
You can specify the color for an individual tag by appending an You can specify the color for an individual tag by appending an
hexadecimal color code (`#RRGGBBAA`) to the tag name. hexadecimal color code (`#RRGGBBAA`) to the tag name.
### Examples ### Examples
If `(20XX)` is in the file name, add the `year.<year>` tag: If `(20XX)` is in the file name, add the `year.<year>` tag:
```Java ```Java
ArrayList tags = ctx._source.tag = new ArrayList(); ArrayList tags = ctx._source.tag = new ArrayList();
@ -72,6 +104,7 @@ if (m.find()) {
``` ```
Use default *Calibre* folder structure to infer author. Use default *Calibre* folder structure to infer author.
```Java ```Java
ArrayList tags = ctx._source.tag = new ArrayList(); ArrayList tags = ctx._source.tag = new ArrayList();
@ -84,8 +117,9 @@ if (ctx._source.name.contains("-") && ctx._source.extension == "pdf") {
} }
``` ```
If the file matches a specific pattern `AAAA-000 fName1 lName1, <fName2 lName2>...`, add the `actress.<actress>` and If the file matches a specific pattern `AAAA-000 fName1 lName1, <fName2 lName2>...`, add the `actress.<actress>` and
`studio.<studio>` tag: `studio.<studio>` tag:
```Java ```Java
ArrayList tags = ctx._source.tag = new ArrayList(); ArrayList tags = ctx._source.tag = new ArrayList();
@ -102,16 +136,18 @@ if (m.find()) {
``` ```
Set the name of the last folder (`/path/to/<studio>/file.mp4`) to `studio.<studio>` tag Set the name of the last folder (`/path/to/<studio>/file.mp4`) to `studio.<studio>` tag
```Java ```Java
ArrayList tags = ctx._source.tag = new ArrayList(); ArrayList tags = ctx._source.tag = new ArrayList();
if (ctx._source.path != "") { if (ctx._source.path != "") {
String[] names = ctx._source.path.splitOnToken('/'); String[] names = ctx._source.path.splitOnToken('/');
tags.add("studio." + names[names.length-1]); tags.add("studio." + names[names.length-1]);
} }
``` ```
Parse `EXIF:F Number` tag Parse `EXIF:F Number` tag
```Java ```Java
if (ctx._source?.exif_fnumber != null) { if (ctx._source?.exif_fnumber != null) {
String[] values = ctx._source.exif_fnumber.splitOnToken(' '); String[] values = ctx._source.exif_fnumber.splitOnToken(' ');
@ -124,6 +160,7 @@ if (ctx._source?.exif_fnumber != null) {
``` ```
Display year and months from `EXIF:DateTime` tag Display year and months from `EXIF:DateTime` tag
```Java ```Java
if (ctx._source?.exif_datetime != null) { if (ctx._source?.exif_datetime != null) {
SimpleDateFormat parser = new SimpleDateFormat("yyyy:MM:dd HH:mm:ss"); SimpleDateFormat parser = new SimpleDateFormat("yyyy:MM:dd HH:mm:ss");
@ -140,3 +177,6 @@ if (ctx._source?.exif_datetime != null) {
} }
``` ```
</details>

@ -202,6 +202,46 @@
}, },
"modified_by": { "modified_by": {
"type": "text" "type": "text"
},
"emb.384.*": {
"type": "dense_vector",
"dims": 384
},
"emb.idx_384.*": {
"type": "dense_vector",
"dims": 384,
"index": true,
"similarity": "cosine"
},
"emb.idx_512.clip": {
"type": "dense_vector",
"dims": 512,
"index": true,
"similarity": "cosine"
},
"emb.512.*": {
"type": "dense_vector",
"dims": 512
},
"emb.idx_768.*": {
"type": "dense_vector",
"dims": 768,
"index": true,
"similarity": "cosine"
},
"emb.768.*": {
"type": "dense_vector",
"dims": 768
},
"emb.idx_1024.*": {
"type": "dense_vector",
"dims": 1024,
"index": true,
"similarity": "cosine"
},
"emb.1024.*": {
"type": "dense_vector",
"dims": 1024
} }
} }
} }

@ -0,0 +1,131 @@
import sqlite3
import orjson as json
import os
import string
from hashlib import md5
import random
from tqdm import tqdm
schema = """
CREATE TABLE thumbnail (
id TEXT NOT NULL CHECK (
length(id) = 32
),
num INTEGER NOT NULL,
data BLOB NOT NULL,
PRIMARY KEY(id, num)
) WITHOUT ROWID;
CREATE TABLE version (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TEXT NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
CREATE TABLE document (
id TEXT PRIMARY KEY NOT NULL CHECK (
length(id) = 32
),
marked INTEGER NOT NULL DEFAULT (1),
version INTEGER NOT NULL REFERENCES version(id),
mtime INTEGER NOT NULL,
size INTEGER NOT NULL,
json_data TEXT NOT NULL CHECK (
json_valid(json_data)
)
);
CREATE TABLE delete_list (
id TEXT PRIMARY KEY CHECK (
length(id) = 32
)
) WITHOUT ROWID;
CREATE TABLE tag (
id TEXT NOT NULL,
tag TEXT NOT NULL,
PRIMARY KEY (id, tag)
);
CREATE TABLE document_sidecar (
id TEXT PRIMARY KEY NOT NULL, json_data TEXT NOT NULL
) WITHOUT ROWID;
CREATE TABLE descriptor (
id TEXT NOT NULL, version_major INTEGER NOT NULL,
version_minor INTEGER NOT NULL, version_patch INTEGER NOT NULL,
root TEXT NOT NULL, name TEXT NOT NULL,
rewrite_url TEXT, timestamp INTEGER NOT NULL
);
CREATE TABLE stats_treemap (
path TEXT NOT NULL, size INTEGER NOT NULL
);
CREATE TABLE stats_size_agg (
bucket INTEGER NOT NULL, count INTEGER NOT NULL
);
CREATE TABLE stats_date_agg (
bucket INTEGER NOT NULL, count INTEGER NOT NULL
);
CREATE TABLE stats_mime_agg (
mime TEXT NOT NULL, size INTEGER NOT NULL,
count INTEGER NOT NULL
);
CREATE TABLE embedding (
id TEXT REFERENCES document(id),
model_id INTEGER NOT NULL references model(id),
start INTEGER NOT NULL,
end INTEGER,
embedding BLOB NOT NULL,
PRIMARY KEY (id, model_id, start)
);
CREATE TABLE model (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE CHECK (
length(name) < 16
),
url TEXT,
path TEXT NOT NULL UNIQUE,
size INTEGER NOT NULL,
type TEXT NOT NULL CHECK (
type IN ('flat', 'nested')
)
);
"""
content = "".join(random.choices(string.ascii_letters, k=500))
def gen_document():
return [
md5(random.randbytes(8)).hexdigest(),
json.dumps({
"content": content,
"mime": "image/jpeg",
"extension": "jpeg",
"name": "test",
"path": "",
})
]
if __name__ == "__main__":
DB_NAME = "big_index.sist2"
SIZE = 30_000_000
os.remove(DB_NAME)
db = sqlite3.connect(DB_NAME)
db.executescript(schema)
db.executescript("""
PRAGMA journal_mode = OFF;
PRAGMA synchronous = 0;
""")
for _ in tqdm(range(SIZE), total=SIZE):
db.execute(
"INSERT INTO document (id, version, mtime, size, json_data) VALUES (?, 1, 1000000, 10000, ?)",
gen_document()
)
# 1. Enable rowid from document
# 2. CREATE TABLE marked (
# id INTEGER PRIMARY KEY,
# marked int
# );
# 3. Set FK for document_sidecar, embedding, tag, thumbnail
# 4. Toggle FK if debug
db.commit()

@ -1,3 +1,3 @@
docker run --rm -it --name "sist2-dev-es"\ docker run --rm -it --name "sist2-dev-es3"\
-p 9200:9200 -e "discovery.type=single-node" \ -p 9200:9200 -e "discovery.type=single-node" \
-e "ES_JAVA_OPTS=-Xms8g -Xmx8g" elasticsearch:7.17.9 -e "ES_JAVA_OPTS=-Xms8g -Xmx8g" elasticsearch:7.17.9

@ -1,3 +1,3 @@
docker run --rm -it --name "sist2-dev-es"\ docker run --rm -it --name "sist2-dev-es3"\
-p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" \ -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" \
-e "ES_JAVA_OPTS=-Xms8g -Xmx8g" elasticsearch:8.7.0 -e "ES_JAVA_OPTS=-Xms8g -Xmx8g" elasticsearch:8.7.0

@ -7,7 +7,7 @@
to the <a href="https://github.com/simon987/sist2/issues/new/choose" target="_blank">issue tracker on to the <a href="https://github.com/simon987/sist2/issues/new/choose" target="_blank">issue tracker on
Github</a>. Thank you! Github</a>. Thank you!
</b-alert> </b-alert>
<router-view/> <router-view v-if="$store.state.sist2AdminInfo"/>
</b-container> </b-container>
</div> </div>
</template> </template>
@ -71,10 +71,12 @@ html, body {
.info-icon { .info-icon {
width: 1rem; width: 1rem;
min-width: 1rem;
margin-right: 0.2rem; margin-right: 0.2rem;
cursor: pointer; cursor: pointer;
line-height: 1rem; line-height: 1rem;
height: 1rem; height: 1rem;
min-height: 1rem;
background-image: url(); background-image: url();
filter: brightness(45%); filter: brightness(45%);
display: block; display: block;

@ -139,6 +139,38 @@ class Sist2AdminApi {
deleteTaskLogs(taskId) { deleteTaskLogs(taskId) {
return axios.post(`${this.baseUrl}/api/task/${taskId}/delete_logs`); return axios.post(`${this.baseUrl}/api/task/${taskId}/delete_logs`);
} }
getUserScripts() {
return axios.get(`${this.baseUrl}/api/user_script`);
}
getUserScript(name) {
return axios.get(`${this.baseUrl}/api/user_script/${name}`);
}
createUserScript(name, template) {
return axios.post(`${this.baseUrl}/api/user_script/${name}`, null, {
params: {
template: template
}
});
}
updateUserScript(name, data) {
return axios.put(`${this.baseUrl}/api/user_script/${name}`, data);
}
deleteUserScript(name) {
return axios.delete(`${this.baseUrl}/api/user_script/${name}`);
}
testUserScript(name, job) {
return axios.get(`${this.baseUrl}/api/user_script/${name}/run`, {
params: {
job: job
}
});
}
} }
export default new Sist2AdminApi() export default new Sist2AdminApi()

@ -0,0 +1,34 @@
<template>
<b-progress v-if="loading" striped animated value="100"></b-progress>
<span v-else-if="jobs.length === 0"></span>
<b-form-select v-else :options="jobs" text-field="name" value-field="name"
@change="$emit('change', $event)" :value="$t('selectJob')"></b-form-select>
</template>
<script>
import Sist2AdminApi from "@/Sist2AdminApi";
export default {
name: "JobSelect",
mounted() {
Sist2AdminApi.getJobs().then(resp => {
this._jobs = resp.data;
this.loading = false;
});
},
computed: {
jobs() {
return [
{name: this.$t("selectJob"), disabled: true},
...this._jobs.filter(job => job.index_path)
]
}
},
data() {
return {
loading: true,
_jobs: null
}
}
}
</script>

@ -0,0 +1,18 @@
<template>
<b-list-group-item action :to="`/userScript/${script.name}`">
<div class="d-flex w-100 justify-content-between">
<h5 class="mb-1">
{{ script.name }}
</h5>
</div>
</b-list-group-item>
</template>
<script>
export default {
name: "UserScriptListItem",
props: ["script"],
}
</script>

@ -0,0 +1,88 @@
<template>
<b-progress v-if="loading" striped animated value="100"></b-progress>
<b-row v-else>
<b-col cols="6">
<h5>Selected scripts</h5>
<b-list-group>
<b-list-group-item v-for="script in selectedScripts" :key="script"
button
@click="onRemoveScript(script)"
class="d-flex justify-content-between align-items-center">
{{ script }}
<b-button-group>
<b-button variant="light" @click.stop="moveUpScript(script)"></b-button>
<b-button variant="light" @click.stop="moveDownScript(script)"></b-button>
</b-button-group>
</b-list-group-item>
</b-list-group>
</b-col>
<b-col cols="6">
<h5>Available scripts</h5>
<b-list-group>
<b-list-group-item v-for="script in availableScripts" :key="script" button
@click="onSelectScript(script)">
{{ script }}
</b-list-group-item>
</b-list-group>
</b-col>
</b-row>
<!-- <b-checkbox-group v-else :options="scripts" stacked :checked="selectedScripts"-->
<!-- @input="$emit('change', $event)"></b-checkbox-group>-->
</template>
<script>
import Sist2AdminApi from "@/Sist2AdminApi";
export default {
name: "UserScriptPicker",
props: ["selectedScripts"],
data() {
return {
loading: true,
scripts: []
}
},
computed: {
availableScripts() {
return this.scripts.filter(script => !this.selectedScripts.includes(script))
}
},
mounted() {
Sist2AdminApi.getUserScripts().then(resp => {
this.scripts = resp.data.map(script => script.name);
this.loading = false;
});
},
methods: {
onSelectScript(name) {
this.selectedScripts.push(name);
this.$emit("change", this.selectedScripts)
},
onRemoveScript(name) {
this.selectedScripts.splice(this.selectedScripts.indexOf(name), 1);
this.$emit("change", this.selectedScripts);
},
moveUpScript(name) {
const index = this.selectedScripts.indexOf(name);
if (index > 0) {
this.selectedScripts.splice(index, 1);
this.selectedScripts.splice(index - 1, 0, name);
}
this.$emit("change", this.selectedScripts);
},
moveDownScript(name) {
const index = this.selectedScripts.indexOf(name);
if (index < this.selectedScripts.length - 1) {
this.selectedScripts.splice(index, 1);
this.selectedScripts.splice(index + 1, 0, name);
}
this.$emit("change", this.selectedScripts);
}
}
}
</script>
<style scoped>
</style>

@ -54,8 +54,18 @@ export default {
frontendTab: "Frontend", frontendTab: "Frontend",
backendTab: "Backend", backendTab: "Backend",
scripts: "User Scripts",
script: "User Script",
testScript: "Test/debug User Script",
newScriptName: "New script name",
scriptType: "Script type",
scriptCode: "Script code (Python)",
scriptOptions: "User scripts",
gitRepository: "Git repository URL",
extraArgs: "Extra command line arguments",
selectJobs: "Available jobs", selectJobs: "Available jobs",
selectJob: "Select a job",
webOptions: { webOptions: {
title: "Web options", title: "Web options",
lang: "UI Language", lang: "UI Language",

@ -6,12 +6,18 @@ import Tasks from "@/views/Tasks";
import Frontend from "@/views/Frontend"; import Frontend from "@/views/Frontend";
import Tail from "@/views/Tail"; import Tail from "@/views/Tail";
import SearchBackend from "@/views/SearchBackend.vue"; import SearchBackend from "@/views/SearchBackend.vue";
import UserScript from "@/views/UserScript.vue";
Vue.use(VueRouter); Vue.use(VueRouter);
const routes = [ const routes = [
{ {
path: "/", path: "/task",
name: "Tasks",
component: Tasks
},
{
path: "/:tab?",
name: "Home", name: "Home",
component: Home component: Home
}, },
@ -20,11 +26,6 @@ const routes = [
name: "Job", name: "Job",
component: Job component: Job
}, },
{
path: "/task/",
name: "Tasks",
component: Tasks
},
{ {
path: "/frontend/:name", path: "/frontend/:name",
name: "Frontend", name: "Frontend",
@ -35,6 +36,11 @@ const routes = [
name: "SearchBackend", name: "SearchBackend",
component: SearchBackend component: SearchBackend
}, },
{
path: "/userScript/:name",
name: "UserScript",
component: UserScript
},
{ {
path: "/log/:taskId", path: "/log/:taskId",
name: "Tail", name: "Tail",

@ -1,6 +1,6 @@
<template> <template>
<div> <div>
<b-tabs content-class="mt-3"> <b-tabs content-class="mt-3" v-model="tab" @input="onTabChange($event)">
<b-tab :title="$t('backendTab')"> <b-tab :title="$t('backendTab')">
<b-card> <b-card>
@ -25,7 +25,6 @@
<SearchBackendListItem v-for="backend in backends" <SearchBackendListItem v-for="backend in backends"
:key="backend.name" :backend="backend"></SearchBackendListItem> :key="backend.name" :backend="backend"></SearchBackendListItem>
</b-list-group> </b-list-group>
</b-card> </b-card>
<br/> <br/>
@ -36,12 +35,12 @@
<b-col> <b-col>
<b-input id="new-job" v-model="newJobName" :placeholder="$t('newJobName')"></b-input> <b-input id="new-job" v-model="newJobName" :placeholder="$t('newJobName')"></b-input>
<b-popover <b-popover
:show.sync="showHelp" :show.sync="showHelp"
target="new-job" target="new-job"
placement="top" placement="top"
triggers="manual" triggers="manual"
variant="primary" variant="primary"
:content="$t('newJobHelp')" :content="$t('newJobHelp')"
></b-popover> ></b-popover>
</b-col> </b-col>
<b-col> <b-col>
@ -59,6 +58,37 @@
</b-list-group> </b-list-group>
</b-card> </b-card>
</b-tab> </b-tab>
<b-tab :title="$t('scripts')">
<b-progress v-if="scriptsLoading" striped animated value="100"></b-progress>
<b-card v-else>
<b-card-title>{{ $t("scripts") }}</b-card-title>
<label>Select template</label>
<b-form-radio-group stacked :options="scriptTemplates" v-model="scriptTemplate"></b-form-radio-group>
<br>
<b-row>
<b-col>
<b-form-input v-model="newScriptName" :disabled="!scriptTemplate" :placeholder="$t('newScriptName')"></b-form-input>
</b-col>
<b-col>
<b-button variant="primary" @click="createScript()"
:disabled="!scriptNameValid(newScriptName)">
{{ $t("create") }}
</b-button>
</b-col>
</b-row>
<hr/>
<b-list-group>
<UserScriptListItem v-for="script in scripts"
:key="script.name" :script="script"></UserScriptListItem>
</b-list-group>
</b-card>
</b-tab>
<b-tab :title="$t('frontendTab')"> <b-tab :title="$t('frontendTab')">
<b-card> <b-card>
@ -96,10 +126,11 @@ import {formatBindAddress} from "@/util";
import Sist2AdminApi from "@/Sist2AdminApi"; import Sist2AdminApi from "@/Sist2AdminApi";
import FrontendListItem from "@/components/FrontendListItem"; import FrontendListItem from "@/components/FrontendListItem";
import SearchBackendListItem from "@/components/SearchBackendListItem.vue"; import SearchBackendListItem from "@/components/SearchBackendListItem.vue";
import UserScriptListItem from "@/components/UserScriptListItem.vue";
export default { export default {
name: "Jobs", name: "Jobs",
components: {SearchBackendListItem, JobListItem, FrontendListItem}, components: {UserScriptListItem, SearchBackendListItem, JobListItem, FrontendListItem},
data() { data() {
return { return {
jobsLoading: true, jobsLoading: true,
@ -115,11 +146,24 @@ export default {
backendsLoading: true, backendsLoading: true,
newBackendName: "", newBackendName: "",
showHelp: false scripts: [],
scriptTemplates: [],
newScriptName: "",
scriptTemplate: null,
scriptsLoading: true,
showHelp: false,
tab: 0
} }
}, },
mounted() { mounted() {
this.loading = true; this.loading = true;
if (this.$route.params.tab) {
console.log("mounted " + this.$route.params.tab)
window.setTimeout(() => {
this.tab = Math.round(Number(this.$route.params.tab));
}, 1)
}
this.reload(); this.reload();
}, },
methods: { methods: {
@ -144,11 +188,20 @@ export default {
return /^[a-zA-Z0-9-_,.; ]+$/.test(name); return /^[a-zA-Z0-9-_,.; ]+$/.test(name);
}, },
scriptNameValid(name) {
if (this.scripts.some(script => script.name === name)) {
return false;
}
if (name.length > 16) {
return false;
}
return /^[a-zA-Z0-9-_,.; ]+$/.test(name);
},
reload() { reload() {
Sist2AdminApi.getJobs().then(resp => { Sist2AdminApi.getJobs().then(resp => {
this.jobs = resp.data; this.jobs = resp.data;
this.jobsLoading = false; this.jobsLoading = false;
this.showHelp = this.jobs.length === 0; this.showHelp = this.jobs.length === 0;
}); });
Sist2AdminApi.getFrontends().then(resp => { Sist2AdminApi.getFrontends().then(resp => {
@ -159,6 +212,11 @@ export default {
this.backends = resp.data; this.backends = resp.data;
this.backendsLoading = false; this.backendsLoading = false;
}) })
Sist2AdminApi.getUserScripts().then(resp => {
this.scripts = resp.data;
this.scriptTemplates = this.$store.state.sist2AdminInfo.user_script_templates;
this.scriptsLoading = false;
})
}, },
createJob() { createJob() {
Sist2AdminApi.createJob(this.newJobName).then(this.reload); Sist2AdminApi.createJob(this.newJobName).then(this.reload);
@ -168,6 +226,14 @@ export default {
}, },
createBackend() { createBackend() {
Sist2AdminApi.createBackend(this.newBackendName).then(this.reload); Sist2AdminApi.createBackend(this.newBackendName).then(this.reload);
},
createScript() {
Sist2AdminApi.createUserScript(this.newScriptName, this.scriptTemplate).then(this.reload)
},
onTabChange(tab) {
if (this.$route.params.tab != tab) {
this.$router.push({params: {tab: tab}})
}
} }
} }
} }

@ -30,6 +30,13 @@
<SearchBackendSelect :value="job.index_options.search_backend" <SearchBackendSelect :value="job.index_options.search_backend"
@change="onBackendSelect($event)"></SearchBackendSelect> @change="onBackendSelect($event)"></SearchBackendSelect>
</b-card> </b-card>
<br/>
<h4>{{ $t("scriptOptions") }}</h4>
<b-card>
<UserScriptPicker :selected-scripts="job.user_scripts"
@change="onScriptChange($event)"></UserScriptPicker>
</b-card>
<br/> <br/>
@ -48,10 +55,12 @@ import ScanOptions from "@/components/ScanOptions";
import Sist2AdminApi from "@/Sist2AdminApi"; import Sist2AdminApi from "@/Sist2AdminApi";
import JobOptions from "@/components/JobOptions"; import JobOptions from "@/components/JobOptions";
import SearchBackendSelect from "@/components/SearchBackendSelect.vue"; import SearchBackendSelect from "@/components/SearchBackendSelect.vue";
import UserScriptPicker from "@/components/UserScriptPicker.vue";
export default { export default {
name: "Job", name: "Job",
components: { components: {
UserScriptPicker,
SearchBackendSelect, SearchBackendSelect,
ScanOptions, ScanOptions,
JobOptions JobOptions
@ -95,6 +104,10 @@ export default {
onBackendSelect(backend) { onBackendSelect(backend) {
this.job.index_options.search_backend = backend; this.job.index_options.search_backend = backend;
this.update(); this.update();
},
onScriptChange(scripts) {
this.job.user_scripts = scripts;
this.update();
} }
}, },
mounted() { mounted() {

@ -44,9 +44,6 @@
<label>{{ $t("backendOptions.batchSize") }}</label> <label>{{ $t("backendOptions.batchSize") }}</label>
<b-form-input v-model="backend.batch_size" type="number" min="1" @change="update()"></b-form-input> <b-form-input v-model="backend.batch_size" type="number" min="1" @change="update()"></b-form-input>
<label>{{ $t("backendOptions.script") }}</label>
<b-form-textarea v-model="backend.script" rows="6" @change="update()"></b-form-textarea>
</template> </template>
<template v-else> <template v-else>
<label>{{ $t("backendOptions.searchIndex") }}</label> <label>{{ $t("backendOptions.searchIndex") }}</label>

@ -92,6 +92,9 @@ export default {
if ("stderr" in message) { if ("stderr" in message) {
message.level = "ERROR"; message.level = "ERROR";
message.message = message["stderr"]; message.message = message["stderr"];
} else if ("stdout" in message) {
message.level = "INFO";
message.message = message["stdout"];
} else { } else {
message.level = "ADMIN"; message.level = "ADMIN";
message.message = message["sist2-admin"]; message.message = message["sist2-admin"];

@ -0,0 +1,117 @@
<template>
<b-progress v-if="loading" striped animated value="100"></b-progress>
<b-card v-else>
<b-card-title>
{{ $route.params.name }}
{{ $t("script") }}
</b-card-title>
<div class="mb-3">
<b-button variant="danger" @click="deleteScript()">{{ $t("delete") }}</b-button>
</div>
<b-card>
<h5>{{ $t("testScript") }}</h5>
<b-row>
<b-col cols="11">
<JobSelect @change="onJobSelect($event)"></JobSelect>
</b-col>
<b-col cols="1">
<b-button :disabled="!selectedTestJob" variant="primary" @click="testScript()">{{ $t("test") }}
</b-button>
</b-col>
</b-row>
</b-card>
<br/>
<label>{{ $t("scriptType") }}</label>
<b-form-select :options="['git', 'simple']" v-model="script.type" @change="update()"></b-form-select>
<template v-if="script.type === 'git'">
<label>{{ $t("gitRepository") }}</label>
<b-form-input v-model="script.git_repository" placeholder="https://github.com/example/example.git"
@change="update()"></b-form-input>
<label>{{ $t("extraArgs") }}</label>
<b-form-input v-model="script.extra_args" @change="update()" class="text-monospace"></b-form-input>
</template>
<template v-if="script.type === 'simple'">
<label>{{ $t("scriptCode") }}</label>
<p>Find sist2-python documentation <a href="https://sist2-python.readthedocs.io/" target="_blank">here</a></p>
<b-textarea rows="15" class="text-monospace" v-model="script.script" @change="update()" spellcheck="false"></b-textarea>
</template>
<template v-if="script.type === 'local'">
<!-- TODO-->
</template>
</b-card>
</template>
<script>
import Sist2AdminApi from "@/Sist2AdminApi";
import JobOptions from "@/components/JobOptions.vue";
import JobCheckboxGroup from "@/components/JobCheckboxGroup.vue";
import JobSelect from "@/components/JobSelect.vue";
export default {
name: "UserScript",
components: {JobSelect, JobCheckboxGroup, JobOptions},
data() {
return {
loading: true,
script: null,
selectedTestJob: null
}
},
methods: {
update() {
Sist2AdminApi.updateUserScript(this.name, this.script);
},
onJobSelect(job) {
this.selectedTestJob = job;
},
deleteScript() {
Sist2AdminApi.deleteUserScript(this.name)
.then(() => {
this.$router.push("/");
})
.catch(err => {
this.$bvToast.toast("Cannot delete user script " +
"because it is referenced by a job", {
title: "Error",
variant: "danger",
toaster: "b-toaster-bottom-right"
});
})
},
testScript() {
Sist2AdminApi.testUserScript(this.name, this.selectedTestJob)
.then(() => {
this.$bvToast.toast(this.$t("runJobConfirmation"), {
title: this.$t("runJobConfirmationTitle"),
variant: "success",
toaster: "b-toaster-bottom-right"
});
})
}
},
mounted() {
Sist2AdminApi.getUserScript(this.name).then(resp => {
this.script = resp.data;
this.loading = false;
});
},
computed: {
name() {
return this.$route.params.name;
},
},
}
</script>

@ -2,4 +2,6 @@ fastapi
git+https://github.com/simon987/hexlib.git git+https://github.com/simon987/hexlib.git
uvicorn uvicorn
websockets websockets
pycron pycron
GitPython
git+https://github.com/simon987/sist2-python.git

@ -18,12 +18,13 @@ from websockets.exceptions import ConnectionClosed
import cron import cron
from config import LOG_FOLDER, logger, WEBSERVER_PORT, DATA_FOLDER, SIST2_BINARY from config import LOG_FOLDER, logger, WEBSERVER_PORT, DATA_FOLDER, SIST2_BINARY
from jobs import Sist2Job, Sist2ScanTask, TaskQueue, Sist2IndexTask, JobStatus from jobs import Sist2Job, Sist2ScanTask, TaskQueue, Sist2IndexTask, JobStatus, Sist2UserScriptTask
from notifications import Subscribe, Notifications from notifications import Subscribe, Notifications
from sist2 import Sist2, Sist2SearchBackend from sist2 import Sist2, Sist2SearchBackend
from state import migrate_v1_to_v2, RUNNING_FRONTENDS, TESSERACT_LANGS, DB_SCHEMA_VERSION, migrate_v3_to_v4, \ from state import migrate_v1_to_v2, RUNNING_FRONTENDS, TESSERACT_LANGS, DB_SCHEMA_VERSION, migrate_v3_to_v4, \
get_log_files_to_remove, delete_log_file, create_default_search_backends get_log_files_to_remove, delete_log_file, create_default_search_backends
from web import Sist2Frontend from web import Sist2Frontend
from script import UserScript, SCRIPT_TEMPLATES
sist2 = Sist2(SIST2_BINARY, DATA_FOLDER) sist2 = Sist2(SIST2_BINARY, DATA_FOLDER)
db = PersistentState(dbfile=os.path.join(DATA_FOLDER, "state.db")) db = PersistentState(dbfile=os.path.join(DATA_FOLDER, "state.db"))
@ -52,7 +53,8 @@ async def home():
async def api(): async def api():
return { return {
"tesseract_langs": TESSERACT_LANGS, "tesseract_langs": TESSERACT_LANGS,
"logs_folder": LOG_FOLDER "logs_folder": LOG_FOLDER,
"user_script_templates": list(SCRIPT_TEMPLATES.keys())
} }
@ -113,8 +115,6 @@ async def update_job(name: str, new_job: Sist2Job):
async def update_frontend(name: str, frontend: Sist2Frontend): async def update_frontend(name: str, frontend: Sist2Frontend):
db["frontends"][name] = frontend db["frontends"][name] = frontend
# TODO: Check etag
return "ok" return "ok"
@ -150,9 +150,21 @@ def _run_job(job: Sist2Job):
db["jobs"][job.name] = job db["jobs"][job.name] = job
scan_task = Sist2ScanTask(job, f"Scan [{job.name}]") scan_task = Sist2ScanTask(job, f"Scan [{job.name}]")
index_task = Sist2IndexTask(job, f"Index [{job.name}]", depends_on=scan_task)
index_depends_on = scan_task
script_tasks = []
for script_name in job.user_scripts:
script = db["user_scripts"][script_name]
task = Sist2UserScriptTask(script, job, f"Script <{script_name}> [{job.name}]", depends_on=scan_task)
script_tasks.append(task)
index_depends_on = task
index_task = Sist2IndexTask(job, f"Index [{job.name}]", depends_on=index_depends_on)
task_queue.submit(scan_task) task_queue.submit(scan_task)
for task in script_tasks:
task_queue.submit(task)
task_queue.submit(index_task) task_queue.submit(index_task)
@ -167,6 +179,22 @@ async def run_job(name: str):
return "ok" return "ok"
@app.get("/api/user_script/{name:str}/run")
def run_user_script(name: str, job: str):
script = db["user_scripts"][name]
if not script:
raise HTTPException(status_code=404)
job = db["jobs"][job]
if not job:
raise HTTPException(status_code=404)
script_task = Sist2UserScriptTask(script, job, f"Script <{name}> [{job.name}]")
task_queue.submit(script_task)
return "ok"
@app.get("/api/job/{name:str}/logs_to_delete") @app.get("/api/job/{name:str}/logs_to_delete")
async def task_history(n: int, name: str): async def task_history(n: int, name: str):
return get_log_files_to_remove(db, name, n) return get_log_files_to_remove(db, name, n)
@ -239,7 +267,7 @@ def check_es_version(es_url: str, insecure: bool):
es_url = f"{url.scheme}://{url.hostname}:{url.port}" es_url = f"{url.scheme}://{url.hostname}:{url.port}"
else: else:
auth = None auth = None
r = requests.get(es_url, verify=insecure, auth=auth) r = requests.get(es_url, verify=not insecure, auth=auth)
except SSLError: except SSLError:
return { return {
"ok": False, "ok": False,
@ -375,6 +403,59 @@ def create_search_backend(name: str):
return backend return backend
@app.delete("/api/user_script/{name:str}")
def delete_user_script(name: str):
if db["user_scripts"][name] is None:
return HTTPException(status_code=404)
if any(name in job.user_scripts for job in db["jobs"]):
raise HTTPException(status_code=400, detail="in use (job)")
script: UserScript = db["user_scripts"][name]
script.delete_dir()
del db["user_scripts"][name]
return "ok"
@app.post("/api/user_script/{name:str}")
def create_user_script(name: str, template: str):
if db["user_scripts"][name] is not None:
return HTTPException(status_code=400, detail="already exists")
script = SCRIPT_TEMPLATES[template](name)
db["user_scripts"][name] = script
return script
@app.get("/api/user_script")
async def get_user_scripts():
return list(db["user_scripts"])
@app.get("/api/user_script/{name:str}")
async def get_user_script(name: str):
backend = db["user_scripts"][name]
if not backend:
raise HTTPException(status_code=404)
return backend
@app.put("/api/user_script/{name:str}")
async def update_user_script(name: str, script: UserScript):
previous_version: UserScript = db["user_scripts"][name]
if previous_version and previous_version.git_repository != script.git_repository:
script.force_clone = True
db["user_scripts"][name] = script
return "ok"
def tail(filepath: str, n: int): def tail(filepath: str, n: int):
with open(filepath) as file: with open(filepath) as file:
@ -479,7 +560,8 @@ if __name__ == '__main__':
migrate_v3_to_v4(db) migrate_v3_to_v4(db)
if db["sist2_admin"]["info"]["version"] != DB_SCHEMA_VERSION: if db["sist2_admin"]["info"]["version"] != DB_SCHEMA_VERSION:
raise Exception(f"Incompatible database version for {db.dbfile}") raise Exception(f"Incompatible database {db.dbfile}. "
f"Automatic migration is not available, please delete the database file to continue.")
start_frontends() start_frontends()
cron.initialize(db, _run_job) cron.initialize(db, _run_job)

@ -9,9 +9,11 @@ MAX_LOG_SIZE = 1 * 1024 * 1024
SIST2_BINARY = os.environ.get("SIST2_BINARY", "/root/sist2") SIST2_BINARY = os.environ.get("SIST2_BINARY", "/root/sist2")
DATA_FOLDER = os.environ.get("DATA_FOLDER", "/sist2-admin/") DATA_FOLDER = os.environ.get("DATA_FOLDER", "/sist2-admin/")
LOG_FOLDER = os.path.join(DATA_FOLDER, "logs") LOG_FOLDER = os.path.join(DATA_FOLDER, "logs")
SCRIPT_FOLDER = os.path.join(DATA_FOLDER, "scripts")
WEBSERVER_PORT = 8080 WEBSERVER_PORT = 8080
os.makedirs(LOG_FOLDER, exist_ok=True) os.makedirs(LOG_FOLDER, exist_ok=True)
os.makedirs(SCRIPT_FOLDER, exist_ok=True)
os.makedirs(DATA_FOLDER, exist_ok=True) os.makedirs(DATA_FOLDER, exist_ok=True)
logger = logging.Logger("sist2-admin") logger = logging.Logger("sist2-admin")

@ -1,13 +1,18 @@
import json import json
import logging import logging
import os.path import os.path
import shlex
import signal import signal
import uuid import uuid
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from io import TextIOWrapper
from logging import FileHandler from logging import FileHandler
from subprocess import Popen
import subprocess
from threading import Lock, Thread from threading import Lock, Thread
from time import sleep from time import sleep
from typing import List
from uuid import uuid4, UUID from uuid import uuid4, UUID
from hexlib.db import PersistentState from hexlib.db import PersistentState
@ -18,6 +23,7 @@ from notifications import Notifications
from sist2 import ScanOptions, IndexOptions, Sist2 from sist2 import ScanOptions, IndexOptions, Sist2
from state import RUNNING_FRONTENDS, get_log_files_to_remove, delete_log_file from state import RUNNING_FRONTENDS, get_log_files_to_remove, delete_log_file
from web import Sist2Frontend from web import Sist2Frontend
from script import UserScript
class JobStatus(Enum): class JobStatus(Enum):
@ -32,6 +38,8 @@ class Sist2Job(BaseModel):
scan_options: ScanOptions scan_options: ScanOptions
index_options: IndexOptions index_options: IndexOptions
user_scripts: List[str] = []
cron_expression: str cron_expression: str
schedule_enabled: bool = False schedule_enabled: bool = False
@ -182,7 +190,7 @@ class Sist2IndexTask(Sist2Task):
duration = self.ended - self.started duration = self.ended - self.started
ok = return_code == 0 ok = return_code in (0, 1)
if ok: if ok:
self.restart_running_frontends(db, sist2) self.restart_running_frontends(db, sist2)
@ -231,6 +239,65 @@ class Sist2IndexTask(Sist2Task):
self._logger.info(json.dumps({"sist2-admin": f"Restart frontend {pid=} {frontend_name=}"})) self._logger.info(json.dumps({"sist2-admin": f"Restart frontend {pid=} {frontend_name=}"}))
class Sist2UserScriptTask(Sist2Task):
def __init__(self, user_script: UserScript, job: Sist2Job, display_name: str, depends_on: Sist2Task = None):
super().__init__(job, display_name, depends_on=depends_on.id if depends_on else None)
self.user_script = user_script
def run(self, sist2: Sist2, db: PersistentState):
super().run(sist2, db)
try:
self.user_script.setup(self.log_callback)
except Exception as e:
logger.error(f"Setup for {self.user_script.name} failed: ")
logger.exception(e)
self.log_callback({"sist2-admin": f"Setup for {self.user_script.name} failed: {e}"})
return -1
executable = self.user_script.get_executable()
index_path = os.path.join(DATA_FOLDER, self.job.index_path)
extra_args = self.user_script.extra_args
args = [
executable,
index_path,
*shlex.split(extra_args)
]
self.log_callback({"sist2-admin": f"Starting user script with {executable=}, {index_path=}, {extra_args=}"})
proc = Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=self.user_script.script_dir())
self.pid = proc.pid
t_stderr = Thread(target=self._consume_logs, args=(self.log_callback, proc, "stderr", False))
t_stderr.start()
self._consume_logs(self.log_callback, proc, "stdout", True)
self.ended = datetime.utcnow()
return 0
@staticmethod
def _consume_logs(logs_cb, proc, stream, wait):
pipe_wrapper = TextIOWrapper(getattr(proc, stream), encoding="utf8", errors="ignore")
try:
for line in pipe_wrapper:
if line.strip() == "":
continue
if line.startswith("$PROGRESS"):
progress = json.loads(line[len("$PROGRESS "):])
logs_cb({"progress": progress})
continue
logs_cb({stream: line})
finally:
if wait:
proc.wait()
pipe_wrapper.close()
class TaskQueue: class TaskQueue:
def __init__(self, sist2: Sist2, db: PersistentState, notifications: Notifications): def __init__(self, sist2: Sist2, db: PersistentState, notifications: Notifications):
self._lock = Lock() self._lock = Lock()

@ -0,0 +1,126 @@
import os
import shutil
import stat
import subprocess
from enum import Enum
from git import Repo
from pydantic import BaseModel
from config import SCRIPT_FOLDER
class ScriptType(Enum):
LOCAL = "local"
SIMPLE = "simple"
GIT = "git"
def set_executable(file):
os.chmod(file, os.stat(file).st_mode | stat.S_IEXEC)
def _initialize_git_repository(url, path, log_cb, force_clone):
log_cb({"sist2-admin": f"Cloning {url}"})
if force_clone or not os.path.exists(os.path.join(path, ".git")):
if force_clone:
shutil.rmtree(path, ignore_errors=True)
Repo.clone_from(url, path)
else:
repo = Repo(path)
repo.remote("origin").pull()
setup_script = os.path.join(path, "setup.sh")
if setup_script:
log_cb({"sist2-admin": f"Executing setup script {setup_script}"})
set_executable(setup_script)
result = subprocess.run([setup_script], cwd=path, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
for line in result.stdout.split(b"\n"):
if line:
log_cb({"stdout": line.decode()})
log_cb({"stdout": f"Executed setup script {setup_script}, return code = {result.returncode}"})
if result.returncode != 0:
raise Exception("Error when running setup script!")
log_cb({"sist2-admin": f"Initialized git repository in {path}"})
class UserScript(BaseModel):
name: str
type: ScriptType
git_repository: str = None
force_clone: bool = False
script: str = None
extra_args: str = ""
def script_dir(self):
return os.path.join(SCRIPT_FOLDER, self.name)
def setup(self, log_cb):
os.makedirs(self.script_dir(), exist_ok=True)
if self.type == ScriptType.GIT:
_initialize_git_repository(self.git_repository, self.script_dir(), log_cb, self.force_clone)
self.force_clone = False
elif self.type == ScriptType.SIMPLE:
self._setup_simple()
set_executable(self.get_executable())
def _setup_simple(self):
with open(self.get_executable(), "w") as f:
f.write(
"#!/bin/bash\n"
"python run.py \"$@\""
)
with open(os.path.join(self.script_dir(), "run.py"), "w") as f:
f.write(self.script)
def get_executable(self):
return os.path.join(self.script_dir(), "run.sh")
def delete_dir(self):
shutil.rmtree(self.script_dir(), ignore_errors=True)
SCRIPT_TEMPLATES = {
"CLIP - Generate embeddings to predict the most relevant image based on the text prompt": lambda name: UserScript(
name=name,
type=ScriptType.GIT,
git_repository="https://github.com/simon987/sist2-script-clip",
extra_args="--num-tags=1 --tags-file=general.txt --color=#dcd7ff"
),
"Whisper - Speech to text with OpenAI Whisper": lambda name: UserScript(
name=name,
type=ScriptType.GIT,
git_repository="https://github.com/simon987/sist2-script-whisper",
extra_args="--model=base --num-threads=4 --color=#51da4c --tag"
),
"Hamburger - Simple script example": lambda name: UserScript(
name=name,
type=ScriptType.SIMPLE,
script=
'from sist2 import Sist2Index\n'
'import sys\n'
'\n'
'index = Sist2Index(sys.argv[1])\n'
'for doc in index.document_iter():\n'
' doc.json_data["tag"] = ["hamburger.#00FF00"]\n'
' index.update_document(doc)\n'
'\n'
'index.sync_tag_table()\n'
'index.commit()\n'
'\n'
'print("Done!")\n'
),
"(Blank)": lambda name: UserScript(
name=name,
type=ScriptType.SIMPLE,
script=""
)
}

@ -41,8 +41,6 @@ class Sist2SearchBackend(BaseModel):
es_insecure_ssl: bool = False es_insecure_ssl: bool = False
es_index: str = "sist2" es_index: str = "sist2"
threads: int = 1 threads: int = 1
script: str = ""
script_file: str = None
batch_size: int = 70 batch_size: int = 70
@staticmethod @staticmethod
@ -74,8 +72,6 @@ class IndexOptions(BaseModel):
f"--es-index={search_backend.es_index}", f"--es-index={search_backend.es_index}",
f"--batch-size={search_backend.batch_size}"] f"--batch-size={search_backend.batch_size}"]
if search_backend.script_file:
args.append(f"--script-file={search_backend.script_file}")
if search_backend.es_insecure_ssl: if search_backend.es_insecure_ssl:
args.append(f"--es-insecure-ssl") args.append(f"--es-insecure-ssl")
if self.incremental_index: if self.incremental_index:
@ -249,13 +245,6 @@ class Sist2:
def index(self, options: IndexOptions, search_backend: Sist2SearchBackend, logs_cb): def index(self, options: IndexOptions, search_backend: Sist2SearchBackend, logs_cb):
if search_backend.script and search_backend.backend_type == SearchBackendType("elasticsearch"):
with NamedTemporaryFile("w", prefix="sist2-admin", suffix=".painless", delete=False) as f:
f.write(search_backend.script)
search_backend.script_file = f.name
else:
search_backend.script_file = None
args = [ args = [
self.bin_path, self.bin_path,
*options.args(search_backend), *options.args(search_backend),

@ -14,7 +14,7 @@ RUNNING_FRONTENDS: Dict[str, int] = {}
TESSERACT_LANGS = get_tesseract_langs() TESSERACT_LANGS = get_tesseract_langs()
DB_SCHEMA_VERSION = "4" DB_SCHEMA_VERSION = "5"
from pydantic import BaseModel from pydantic import BaseModel

File diff suppressed because it is too large Load Diff

@ -1,6 +1,6 @@
{ {
"name": "sist2", "name": "sist2",
"version": "2.11.0", "version": "1.0.0",
"private": true, "private": true,
"scripts": { "scripts": {
"serve": "vue-cli-service serve", "serve": "vue-cli-service serve",
@ -9,7 +9,6 @@
"dependencies": { "dependencies": {
"@auth0/auth0-spa-js": "^2.0.2", "@auth0/auth0-spa-js": "^2.0.2",
"@egjs/vue-infinitegrid": "3.3.0", "@egjs/vue-infinitegrid": "3.3.0",
"@tensorflow/tfjs": "^4.4.0",
"axios": "^0.25.0", "axios": "^0.25.0",
"bootstrap-vue": "^2.21.2", "bootstrap-vue": "^2.21.2",
"core-js": "^3.6.5", "core-js": "^3.6.5",
@ -45,8 +44,8 @@
"portal-vue": "^2.1.7", "portal-vue": "^2.1.7",
"sass": "^1.26.11", "sass": "^1.26.11",
"sass-loader": "^10.0.2", "sass-loader": "^10.0.2",
"typescript": "~4.1.5", "typescript": "^4.9.5",
"vue-cli-plugin-bootstrap-vue": "~0.7.0", "vue-cli-plugin-bootstrap-vue": "~0.8.2",
"vue-template-compiler": "^2.6.11" "vue-template-compiler": "^2.6.11"
}, },
"browserslist": [ "browserslist": [

@ -308,15 +308,21 @@ html, body {
.info-icon { .info-icon {
width: 1rem; width: 1rem;
min-width: 1rem;
margin-right: 0.2rem; margin-right: 0.2rem;
cursor: pointer; cursor: pointer;
line-height: 1rem; line-height: 1rem;
height: 1rem; height: 1rem;
min-height: 1rem;
background-image: url(); background-image: url();
filter: brightness(45%); filter: brightness(45%);
display: block; display: block;
} }
.theme-black .info-icon {
filter: brightness(80%);
}
.tabs { .tabs {
margin-top: 10px; margin-top: 10px;
} }

@ -25,6 +25,7 @@ export interface Index {
id: string id: string
idPrefix: string idPrefix: string
timestamp: number timestamp: number
models: []
} }
export interface EsHit { export interface EsHit {
@ -117,6 +118,15 @@ class Sist2Api {
return this.sist2Info.searchBackend; return this.sist2Info.searchBackend;
} }
models() {
const allModels = this.sist2Info.indices
.map(idx => idx.models)
.flat();
return allModels
.filter((v, i, a) => a.findIndex(v2 => (v2.id === v.id)) === i)
}
getSist2Info(): Promise<any> { getSist2Info(): Promise<any> {
return axios.get(`${this.baseUrl}i`).then(resp => { return axios.get(`${this.baseUrl}i`).then(resp => {
const indices = resp.data.indices as Index[]; const indices = resp.data.indices as Index[];
@ -127,7 +137,8 @@ class Sist2Api {
name: idx.name, name: idx.name,
timestamp: idx.timestamp, timestamp: idx.timestamp,
version: idx.version, version: idx.version,
idPrefix: getIdPrefix(indices, idx.id) models: idx.models,
idPrefix: getIdPrefix(indices, idx.id),
} as Index; } as Index;
}); });
@ -618,6 +629,15 @@ class Sist2Api {
} }
} }
if ("knn" in query) {
query.query = {
bool: {
must: []
}
};
delete query.knn;
}
if ("function_score" in query.query) { if ("function_score" in query.query) {
query.query = query.query.function_score.query; query.query = query.query.function_score.query;
} }
@ -702,6 +722,11 @@ class Sist2Api {
return result; return result;
}); });
} }
getEmbeddings(indexId, docId, modelId) {
return axios.post(`${this.baseUrl}/e/${indexId}/${docId}/${modelId.toString().padStart(3, '0')}`)
.then(resp => (resp.data));
}
} }
export default new Sist2Api(""); export default new Sist2Api("");

@ -1,5 +1,5 @@
import store from "./store"; import store from "./store";
import {EsHit, Index} from "@/Sist2Api"; import sist2Api, {EsHit, Index} from "@/Sist2Api";
const SORT_MODES = { const SORT_MODES = {
score: { score: {
@ -79,8 +79,10 @@ class Sist2ElasticsearchQuery {
const selectedIndexIds = getters.selectedIndices.map((idx: Index) => idx.id) const selectedIndexIds = getters.selectedIndices.map((idx: Index) => idx.id)
const selectedMimeTypes = getters.selectedMimeTypes; const selectedMimeTypes = getters.selectedMimeTypes;
const selectedTags = getters.selectedTags; const selectedTags = getters.selectedTags;
const sortMode = getters.embedding ? "score" : getters.sortMode;
const legacyES = store.state.sist2Info.esVersionLegacy; const legacyES = store.state.sist2Info.esVersionLegacy;
const hasKnn = store.state.sist2Info.esVersionHasKnn;
const filters = [ const filters = [
{terms: {index: selectedIndexIds}} {terms: {index: selectedIndexIds}}
@ -162,14 +164,14 @@ class Sist2ElasticsearchQuery {
const q = { const q = {
_source: { _source: {
excludes: ["content", "_tie"] excludes: ["content", "_tie", "emb.*"]
}, },
query: { query: {
bool: { bool: {
filter: filters, filter: filters,
} }
}, },
sort: SORT_MODES[getters.sortMode].mode, sort: SORT_MODES[sortMode].mode,
size: size, size: size,
} as any; } as any;
@ -181,14 +183,57 @@ class Sist2ElasticsearchQuery {
} }
if (!empty && !blankSearch) { if (!empty && !blankSearch) {
q.query.bool.must = query; if (getters.embedding) {
filters.push(query)
} else {
q.query.bool.must = query;
}
}
if (getters.embedding) {
delete q.query;
const field = "emb." + sist2Api.models().find(m => m.id == getters.embeddingsModel).path;
if (hasKnn) {
// Use knn (8.8+)
q.knn = {
field: field,
query_vector: getters.embedding,
k: 600,
num_candidates: 600,
filter: filters
}
} else {
// Use brute-force as a fallback
filters.push({exists: {field: field}});
q.query = {
function_score: {
query: {
bool: {
must: filters,
}
},
script_score: {
script: {
source: `cosineSimilarity(params.query_vector, "${field}") + 1.0`,
params: {query_vector: getters.embedding}
}
}
}
}
}
} }
if (after) { if (after) {
q.search_after = [SORT_MODES[getters.sortMode].key(after), after["_id"]]; q.search_after = [SORT_MODES[sortMode].key(after), after["_id"]];
} }
if (getters.optHighlight) { if (getters.optHighlight && !getters.embedding) {
q.highlight = { q.highlight = {
pre_tags: ["<mark>"], pre_tags: ["<mark>"],
post_tags: ["</mark>"], post_tags: ["</mark>"],
@ -214,7 +259,7 @@ class Sist2ElasticsearchQuery {
} }
} }
if (getters.sortMode === "random") { if (sortMode === "random") {
q.query = { q.query = {
function_score: { function_score: {
query: { query: {

@ -103,7 +103,7 @@ class Sist2ElasticsearchQuery {
q["highlightContextSize"] = Number(getters.optFragmentSize); q["highlightContextSize"] = Number(getters.optFragmentSize);
} }
if (getters.embeddingText) { if (getters.embedding) {
q["model"] = getters.embeddingsModel; q["model"] = getters.embeddingsModel;
q["embedding"] = getters.embedding; q["embedding"] = getters.embedding;
q["sort"] = "embedding"; q["sort"] = "embedding";

@ -45,7 +45,8 @@ export default {
items.push( items.push(
{key: "esVersion", value: this.$store.state.sist2Info.esVersion}, {key: "esVersion", value: this.$store.state.sist2Info.esVersion},
{key: "esVersionSupported", value: this.$store.state.sist2Info.esVersionSupported}, {key: "esVersionSupported", value: this.$store.state.sist2Info.esVersionSupported},
{key: "esVersionLegacy", value: this.$store.state.sist2Info.esVersionLegacy} {key: "esVersionLegacy", value: this.$store.state.sist2Info.esVersionLegacy},
{key: "esVersionHasKnn", value: this.$store.state.sist2Info.esVersionHasKnn},
); );
} }

@ -24,6 +24,7 @@
<!-- Title line --> <!-- Title line -->
<div style="display: flex"> <div style="display: flex">
<span class="info-icon" @click="onInfoClick()"></span> <span class="info-icon" @click="onInfoClick()"></span>
<MLIcon v-if="doc._source.embedding" clickable @click="onEmbeddingClick()"></MLIcon>
<DocFileTitle :doc="doc"></DocFileTitle> <DocFileTitle :doc="doc"></DocFileTitle>
</div> </div>
@ -49,10 +50,12 @@ import DocInfoModal from "@/components/DocInfoModal.vue";
import ContentDiv from "@/components/ContentDiv.vue"; import ContentDiv from "@/components/ContentDiv.vue";
import FullThumbnail from "@/components/FullThumbnail"; import FullThumbnail from "@/components/FullThumbnail";
import FeaturedFieldsLine from "@/components/FeaturedFieldsLine"; import FeaturedFieldsLine from "@/components/FeaturedFieldsLine";
import MLIcon from "@/components/icons/MlIcon.vue";
import Sist2Api from "@/Sist2Api";
export default { export default {
components: {FeaturedFieldsLine, FullThumbnail, ContentDiv, DocInfoModal, DocFileTitle, TagContainer}, components: {MLIcon, FeaturedFieldsLine, FullThumbnail, ContentDiv, DocInfoModal, DocFileTitle, TagContainer},
props: ["doc", "width"], props: ["doc", "width"],
data() { data() {
return { return {
@ -71,6 +74,13 @@ export default {
onInfoClick() { onInfoClick() {
this.showInfo = true; this.showInfo = true;
}, },
onEmbeddingClick() {
Sist2Api.getEmbeddings(this.doc._source.index, this.doc._id, this.$store.state.embeddingsModel).then(embeddings => {
this.$store.commit("setEmbeddingText", "");
this.$store.commit("setEmbedding", embeddings);
this.$store.commit("setEmbeddingDoc", this.doc);
})
},
async onThumbnailClick() { async onThumbnailClick() {
this.$store.commit("setUiLightboxSlide", this.doc._seq); this.$store.commit("setUiLightboxSlide", this.doc._seq);
await this.$store.dispatch("showLightbox"); await this.$store.dispatch("showLightbox");

@ -1,63 +1,70 @@
<template> <template>
<a :href="`f/${doc._source.index}/${doc._id}`" class="file-title-anchor" target="_blank"> <a :href="`f/${doc._source.index}/${doc._id}`"
<div class="file-title" :title="doc._source.path + '/' + doc._source.name + ext(doc)" :class="doc._source.embedding ? 'file-title-anchor-with-embedding' : 'file-title-anchor'" target="_blank">
v-html="fileName() + ext(doc)"></div> <div class="file-title" :title="doc._source.path + '/' + doc._source.name + ext(doc)"
</a> v-html="fileName() + ext(doc)"></div>
</a>
</template> </template>
<script> <script>
import {ext} from "@/util"; import {ext} from "@/util";
export default { export default {
name: "DocFileTitle", name: "DocFileTitle",
props: ["doc"], props: ["doc"],
methods: { methods: {
ext: ext, ext: ext,
fileName() { fileName() {
if (!this.doc.highlight) { if (!this.doc.highlight) {
return this.doc._source.name; return this.doc._source.name;
} }
if (this.doc.highlight["name.nGram"]) { if (this.doc.highlight["name.nGram"]) {
return this.doc.highlight["name.nGram"]; return this.doc.highlight["name.nGram"];
} }
if (this.doc.highlight.name) { if (this.doc.highlight.name) {
return this.doc.highlight.name; return this.doc.highlight.name;
} }
return this.doc._source.name; return this.doc._source.name;
}
} }
}
} }
</script> </script>
<style scoped> <style scoped>
.file-title-anchor { .file-title-anchor {
max-width: calc(100% - 1.2rem); max-width: calc(100% - 1.2rem);
}
.file-title-anchor-with-embedding {
max-width: calc(100% - 2.2rem);
} }
.file-title { .file-title {
width: 100%; width: 100%;
line-height: 1rem; max-width: 100%;
height: 1.1rem; line-height: 1rem;
white-space: nowrap; height: 1.1rem;
text-overflow: ellipsis; white-space: nowrap;
overflow: hidden; text-overflow: ellipsis;
font-size: 16px; overflow: hidden;
font-family: "Source Sans Pro", sans-serif; font-size: 16px;
font-weight: bold; font-family: "Source Sans Pro", sans-serif;
font-weight: bold;
} }
.theme-black .file-title { .theme-black .file-title {
color: #ddd; color: #ddd;
} }
.theme-black .file-title:hover { .theme-black .file-title:hover {
color: #fff; color: #fff;
} }
.theme-light .file-title { .theme-light .file-title {
color: black; color: black;
} }
.doc-card .file-title { .doc-card .file-title {
font-size: 12px; font-size: 12px;
} }
</style> </style>

@ -1,63 +1,64 @@
<template> <template>
<b-list-group-item class="flex-column align-items-start mb-2" :class="{'sub-document': doc._props.isSubDocument}" <b-list-group-item class="flex-column align-items-start mb-2" :class="{'sub-document': doc._props.isSubDocument}"
@mouseenter="onTnEnter()" @mouseleave="onTnLeave()"> @mouseenter="onTnEnter()" @mouseleave="onTnLeave()">
<!-- Info modal--> <!-- Info modal-->
<DocInfoModal :show="showInfo" :doc="doc" @close="showInfo = false"></DocInfoModal> <DocInfoModal :show="showInfo" :doc="doc" @close="showInfo = false"></DocInfoModal>
<div class="media ml-2"> <div class="media ml-2">
<!-- Thumbnail--> <!-- Thumbnail-->
<div v-if="doc._props.hasThumbnail" class="align-self-start mr-2 wrapper-sm"> <div v-if="doc._props.hasThumbnail" class="align-self-start mr-2 wrapper-sm">
<div class="img-wrapper"> <div class="img-wrapper">
<div v-if="doc._props.isPlayableVideo" class="play"> <div v-if="doc._props.isPlayableVideo" class="play">
<svg viewBox="0 0 494.942 494.942" xmlns="http://www.w3.org/2000/svg"> <svg viewBox="0 0 494.942 494.942" xmlns="http://www.w3.org/2000/svg">
<path d="m35.353 0 424.236 247.471-424.236 247.471z"/> <path d="m35.353 0 424.236 247.471-424.236 247.471z"/>
</svg> </svg>
</div> </div>
<img v-if="doc._props.isPlayableImage || doc._props.isPlayableVideo" <img v-if="doc._props.isPlayableImage || doc._props.isPlayableVideo"
:src="(doc._props.isGif && hover) ? `f/${doc._source.index}/${doc._id}` : `t/${doc._source.index}/${doc._id}`" :src="(doc._props.isGif && hover) ? `f/${doc._source.index}/${doc._id}` : `t/${doc._source.index}/${doc._id}`"
alt="" alt=""
class="pointer fit-sm" @click="onThumbnailClick()"> class="pointer fit-sm" @click="onThumbnailClick()">
<img v-else :src="`t/${doc._source.index}/${doc._id}`" alt="" <img v-else :src="`t/${doc._source.index}/${doc._id}`" alt=""
class="fit-sm"> class="fit-sm">
</div> </div>
</div> </div>
<div v-else class="file-icon-wrapper" style=""> <div v-else class="file-icon-wrapper" style="">
<FileIcon></FileIcon> <FileIcon></FileIcon>
</div> </div>
<!-- Doc line--> <!-- Doc line-->
<div class="doc-line ml-3"> <div class="doc-line ml-3">
<div style="display: flex"> <div style="display: flex">
<span class="info-icon" @click="showInfo = true"></span> <span class="info-icon" @click="showInfo = true"></span>
<DocFileTitle :doc="doc"></DocFileTitle> <MLIcon v-if="doc._source.embedding" clickable @click="onEmbeddingClick()"></MLIcon>
</div> <DocFileTitle :doc="doc"></DocFileTitle>
</div>
<!-- Content highlight --> <!-- Content highlight -->
<ContentDiv :doc="doc"></ContentDiv> <ContentDiv :doc="doc"></ContentDiv>
<div class="path-row"> <div class="path-row">
<div class="path-line" v-html="path()"></div> <div class="path-line" v-html="path()"></div>
<TagContainer :hit="doc"></TagContainer> <TagContainer :hit="doc"></TagContainer>
</div> </div>
<div v-if="doc._source.pages || doc._source.author" class="path-row text-muted"> <div v-if="doc._source.pages || doc._source.author" class="path-row text-muted">
<span v-if="doc._source.pages">{{ doc._source.pages }} {{ <span v-if="doc._source.pages">{{ doc._source.pages }} {{
doc._source.pages > 1 ? $t("pages") : $t("page") doc._source.pages > 1 ? $t("pages") : $t("page")
}}</span> }}</span>
<span v-if="doc._source.author && doc._source.pages" class="mx-1">-</span> <span v-if="doc._source.author && doc._source.pages" class="mx-1">-</span>
<span v-if="doc._source.author">{{ doc._source.author }}</span> <span v-if="doc._source.author">{{ doc._source.author }}</span>
</div> </div>
<!-- Featured line --> <!-- Featured line -->
<div style="display: flex"> <div style="display: flex">
<FeaturedFieldsLine :doc="doc"></FeaturedFieldsLine> <FeaturedFieldsLine :doc="doc"></FeaturedFieldsLine>
</div>
</div>
</div> </div>
</div> </b-list-group-item>
</div>
</b-list-group-item>
</template> </template>
<script> <script>
@ -67,131 +68,140 @@ import DocInfoModal from "@/components/DocInfoModal";
import ContentDiv from "@/components/ContentDiv"; import ContentDiv from "@/components/ContentDiv";
import FileIcon from "@/components/icons/FileIcon"; import FileIcon from "@/components/icons/FileIcon";
import FeaturedFieldsLine from "@/components/FeaturedFieldsLine"; import FeaturedFieldsLine from "@/components/FeaturedFieldsLine";
import MLIcon from "@/components/icons/MlIcon.vue";
import Sist2Api from "@/Sist2Api";
export default { export default {
name: "DocListItem", name: "DocListItem",
components: {FileIcon, ContentDiv, DocInfoModal, DocFileTitle, TagContainer, FeaturedFieldsLine}, components: {MLIcon, FileIcon, ContentDiv, DocInfoModal, DocFileTitle, TagContainer, FeaturedFieldsLine},
props: ["doc"], props: ["doc"],
data() { data() {
return { return {
hover: false, hover: false,
showInfo: false showInfo: false
} }
},
methods: {
async onThumbnailClick() {
this.$store.commit("setUiLightboxSlide", this.doc._seq);
await this.$store.dispatch("showLightbox");
}, },
path() { methods: {
if (!this.doc.highlight) { async onThumbnailClick() {
return this.doc._source.path + "/" this.$store.commit("setUiLightboxSlide", this.doc._seq);
} await this.$store.dispatch("showLightbox");
if (this.doc.highlight["path.text"]) { },
return this.doc.highlight["path.text"] + "/" onEmbeddingClick() {
} Sist2Api.getEmbeddings(this.doc._source.index, this.doc._id, this.$store.state.embeddingsModel).then(embeddings => {
this.$store.commit("setEmbeddingText", "");
this.$store.commit("setEmbedding", embeddings);
this.$store.commit("setEmbeddingDoc", this.doc);
})
},
path() {
if (!this.doc.highlight) {
return this.doc._source.path + "/"
}
if (this.doc.highlight["path.text"]) {
return this.doc.highlight["path.text"] + "/"
}
if (this.doc.highlight["path.nGram"]) { if (this.doc.highlight["path.nGram"]) {
return this.doc.highlight["path.nGram"] + "/" return this.doc.highlight["path.nGram"] + "/"
} }
return this.doc._source.path + "/" return this.doc._source.path + "/"
}, },
onTnEnter() { onTnEnter() {
this.hover = true; this.hover = true;
}, },
onTnLeave() { onTnLeave() {
this.hover = false; this.hover = false;
}, },
} }
} }
</script> </script>
<style scoped> <style scoped>
.sub-document { .sub-document {
background: #AB47BC1F !important; background: #AB47BC1F !important;
} }
.theme-black .sub-document { .theme-black .sub-document {
background: #37474F !important; background: #37474F !important;
} }
.list-group { .list-group {
margin-top: 1em; margin-top: 1em;
} }
.list-group-item { .list-group-item {
padding: .25rem 0.5rem; padding: .25rem 0.5rem;
box-shadow: 0 0.125rem 0.25rem rgb(0 0 0 / 8%) !important; box-shadow: 0 0.125rem 0.25rem rgb(0 0 0 / 8%) !important;
border-radius: 0; border-radius: 0;
border: none; border: none;
} }
.path-row { .path-row {
display: -ms-flexbox; display: -ms-flexbox;
display: flex; display: flex;
-ms-flex-align: start; -ms-flex-align: start;
align-items: flex-start; align-items: flex-start;
} }
.path-line { .path-line {
color: #808080; color: #808080;
text-overflow: ellipsis; text-overflow: ellipsis;
overflow: hidden; overflow: hidden;
white-space: nowrap; white-space: nowrap;
margin-right: 0.3em; margin-right: 0.3em;
} }
.theme-black .path-line { .theme-black .path-line {
color: #bbb; color: #bbb;
} }
.play { .play {
position: absolute; position: absolute;
width: 18px; width: 18px;
height: 18px; height: 18px;
left: 50%; left: 50%;
top: 50%; top: 50%;
transform: translate(-50%, -50%); transform: translate(-50%, -50%);
pointer-events: none; pointer-events: none;
} }
.play svg { .play svg {
fill: rgba(0, 0, 0, 0.7); fill: rgba(0, 0, 0, 0.7);
} }
.list-group-item .img-wrapper { .list-group-item .img-wrapper {
width: 88px; width: 88px;
height: 88px; height: 88px;
position: relative; position: relative;
} }
.fit-sm { .fit-sm {
max-height: 100%; max-height: 100%;
max-width: 100%; max-width: 100%;
width: auto; width: auto;
height: auto; height: auto;
position: absolute; position: absolute;
top: 0; top: 0;
bottom: 0; bottom: 0;
left: 0; left: 0;
right: 0; right: 0;
margin: auto; margin: auto;
/*box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.12);*/ /*box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.12);*/
} }
.doc-line { .doc-line {
max-width: calc(100% - 88px - 1.5rem); max-width: calc(100% - 88px - 1.5rem);
flex: 1; flex: 1;
vertical-align: middle; vertical-align: middle;
margin-top: auto; margin-top: auto;
margin-bottom: auto; margin-bottom: auto;
} }
.file-icon-wrapper { .file-icon-wrapper {
width: calc(88px + .5rem); width: calc(88px + .5rem);
height: 88px; height: 88px;
position: relative; position: relative;
} }
</style> </style>

@ -1,29 +1,38 @@
<template> <template>
<div> <div>
<b-progress v-if="modelLoading" :value="modelLoadingProgress" max="1" class="mb-1" variant="warning" <b-progress v-if="modelLoading && [0, 1].includes(modelLoadingProgress)" max="1" class="mb-1" variant="primary"
striped animated :value="1">
</b-progress>
<b-progress v-else-if="modelLoading" :value="modelLoadingProgress" max="1" class="mb-1" variant="warning"
show-progress> show-progress>
</b-progress> </b-progress>
<b-input-group> <div style="display: flex">
<b-form-input :value="embeddingText" <b-select :options="modelOptions()" class="mr-2 input-prepend" :value="modelName"
:placeholder="$t('embeddingsSearchPlaceholder')" @change="onModelChange($event)"></b-select>
@input="onInput($event)"
:disabled="modelLoading" <b-input-group>
></b-form-input> <b-form-input :value="embeddingText"
:placeholder="$store.state.embeddingDoc ? ' ' : $t('embeddingsSearchPlaceholder')"
@input="onInput($event)"
:disabled="modelLoading"
:style="{'pointer-events': $store.state.embeddingDoc ? 'none' : undefined}"
></b-form-input>
<b-badge v-if="$store.state.embeddingDoc" pill variant="primary" class="overlay-badge" href="#"
@click="onBadgeClick()">{{ docName }}
</b-badge>
<template #prepend>
</template>
<template #append>
<b-input-group-text>
<MLIcon class="ml-append" big></MLIcon>
</b-input-group-text>
</template>
</b-input-group>
</div>
<!-- TODO: dropdown of available models-->
<!-- <template #prepend>-->
<!-- <b-input-group-text>-->
<!-- <b-form-checkbox :checked="fuzzy" title="Toggle fuzzy searching" @change="setFuzzy($event)">-->
<!-- {{ $t("searchBar.fuzzy") }}-->
<!-- </b-form-checkbox>-->
<!-- </b-input-group-text>-->
<!-- </template>-->
<template #append>
<b-input-group-text>
<MLIcon></MLIcon>
</b-input-group-text>
</template>
</b-input-group>
</div> </div>
</template> </template>
@ -32,6 +41,7 @@ import {mapGetters, mapMutations} from "vuex";
import {CLIPTransformerModel} from "@/ml/CLIPTransformerModel" import {CLIPTransformerModel} from "@/ml/CLIPTransformerModel"
import _debounce from "lodash/debounce"; import _debounce from "lodash/debounce";
import MLIcon from "@/components/icons/MlIcon.vue"; import MLIcon from "@/components/icons/MlIcon.vue";
import Sist2AdminApi from "@/Sist2Api";
export default { export default {
components: {MLIcon}, components: {MLIcon},
@ -40,7 +50,8 @@ export default {
modelLoading: false, modelLoading: false,
modelLoadingProgress: 0, modelLoadingProgress: 0,
modelLoaded: false, modelLoaded: false,
model: null model: null,
modelName: null
} }
}, },
computed: { computed: {
@ -49,9 +60,18 @@ export default {
embeddingText: "embeddingText", embeddingText: "embeddingText",
fuzzy: "fuzzy", fuzzy: "fuzzy",
}), }),
docName() {
const ext = this.$store.state.embeddingDoc._source.extension;
return this.$store.state.embeddingDoc._source.name +
(ext ? "." + ext : "")
}
}, },
mounted() { mounted() {
this.onInput = _debounce(this._onInput, 300, {leading: false}); // Set default model
this.modelName = Sist2AdminApi.models()[0].name;
this.onModelChange(this.modelName);
this.onInput = _debounce(this._onInput, 450, {leading: false});
}, },
methods: { methods: {
...mapMutations({ ...mapMutations({
@ -61,11 +81,6 @@ export default {
}), }),
async loadModel() { async loadModel() {
this.modelLoading = true; this.modelLoading = true;
this.model = new CLIPTransformerModel(
// TODO: add a config for this (?)
"https://github.com/simon987/sist2-models/raw/main/clip/models/clip-vit-base-patch32-q8.onnx",
"https://github.com/simon987/sist2-models/raw/main/clip/models/tokenizer.json",
);
await this.model.init(async progress => { await this.model.init(async progress => {
this.modelLoadingProgress = progress; this.modelLoadingProgress = progress;
@ -74,26 +89,67 @@ export default {
this.modelLoaded = true; this.modelLoaded = true;
}, },
async _onInput(text) { async _onInput(text) {
if (!this.modelLoaded) { try {
await this.loadModel();
this.setEmbeddingModel(1); // TODO if (!this.modelLoaded) {
await this.loadModel();
}
if (text.length === 0) {
this.setEmbeddingText("");
this.setEmbedding(null);
return;
}
const embeddings = await this.model.predict(text);
this.setEmbeddingText(text);
this.setEmbedding(embeddings);
} catch (e) {
alert(e)
} }
if (text.length === 0) {
this.setEmbeddingText("");
this.setEmbedding(null);
return;
}
const embeddings = await this.model.predict(text);
this.setEmbeddingText(text);
this.setEmbedding(embeddings);
}, },
mounted() { modelOptions() {
return Sist2AdminApi.models().map(model => model.name);
},
onModelChange(name) {
this.modelLoaded = false;
this.modelLoadingProgress = 0;
const modelInfo = Sist2AdminApi.models().find(m => m.name === name);
if (modelInfo.name === "CLIP") {
const tokenizerUrl = new URL("./tokenizer.json", modelInfo.url).href;
this.model = new CLIPTransformerModel(modelInfo.url, tokenizerUrl)
this.setEmbeddingModel(modelInfo.id);
} else {
throw new Error("Unknown model: " + name);
}
},
onBadgeClick() {
this.$store.commit("setEmbedding", null);
this.$store.commit("setEmbeddingDoc", null);
} }
} }
} }
</script> </script>
<style> <style>
.overlay-badge {
position: absolute;
z-index: 1;
left: 0.375rem;
top: 8px;
line-height: 1.1rem;
overflow: hidden;
max-width: 200px;
text-overflow: ellipsis;
}
.input-prepend {
max-width: 100px;
}
.theme-black .ml-append {
filter: brightness(0.95) !important;
}
</style> </style>

@ -1,42 +1,46 @@
<template> <template>
<div v-if="isMobile"> <div v-if="isMobile">
<b-form-select <b-form-select
:value="selectedIndicesIds" :value="selectedIndicesIds"
@change="onSelect($event)" @change="onSelect($event)"
:options="indices" multiple :select-size="6" text-field="name" :options="indices" multiple :select-size="6" text-field="name"
value-field="id"></b-form-select> value-field="id"></b-form-select>
</div> </div>
<div v-else> <div v-else>
<div class="d-flex justify-content-between align-content-center"> <div class="d-flex justify-content-between align-content-center">
<span> <span>
{{ selectedIndices.length }} {{ selectedIndices.length }}
{{ selectedIndices.length === 1 ? $t("indexPicker.selectedIndex") : $t("indexPicker.selectedIndices") }} {{ selectedIndices.length === 1 ? $t("indexPicker.selectedIndex") : $t("indexPicker.selectedIndices") }}
</span> </span>
<div> <div>
<b-button variant="link" @click="selectAll()"> {{ $t("indexPicker.selectAll") }}</b-button> <b-button variant="link" @click="selectAll()"> {{ $t("indexPicker.selectAll") }}</b-button>
<b-button variant="link" @click="selectNone()"> {{ $t("indexPicker.selectNone") }}</b-button> <b-button variant="link" @click="selectNone()"> {{ $t("indexPicker.selectNone") }}</b-button>
</div> </div>
</div>
<b-list-group id="index-picker-desktop" class="unselectable">
<b-list-group-item
v-for="idx in indices"
@click="toggleIndex(idx, $event)"
@click.shift="shiftClick(idx, $event)"
class="d-flex justify-content-between align-items-center list-group-item-action pointer"
:class="{active: lastClickIndex === idx}"
>
<div class="d-flex">
<b-checkbox style="pointer-events: none" :checked="isSelected(idx)"></b-checkbox>
{{ idx.name }}
<span class="text-muted timestamp-text ml-2">{{ formatIdxDate(idx.timestamp) }}</span>
</div> </div>
<b-badge class="version-badge">v{{ idx.version }}</b-badge>
</b-list-group-item> <b-list-group id="index-picker-desktop" class="unselectable">
</b-list-group> <b-list-group-item
</div> v-for="idx in indices"
@click="toggleIndex(idx, $event)"
@click.shift="shiftClick(idx, $event)"
class="d-flex justify-content-between align-items-center list-group-item-action pointer"
:class="{active: lastClickIndex === idx}"
>
<div class="d-flex">
<b-checkbox style="pointer-events: none" :checked="isSelected(idx)"></b-checkbox>
{{ idx.name }}
<div style="vertical-align: center; margin-left: 5px">
<MLIcon small style="top: -1px; position: relative"></MLIcon>
</div>
<span class="text-muted timestamp-text ml-2"
style="top: 1px; position: relative">{{ formatIdxDate(idx.timestamp) }}</span>
</div>
<b-badge class="version-badge">v{{ idx.version }}</b-badge>
</b-list-group-item>
</b-list-group>
</div>
</template> </template>
<script lang="ts"> <script lang="ts">
@ -44,148 +48,150 @@ import SmallBadge from "./SmallBadge.vue"
import {mapActions, mapGetters} from "vuex"; import {mapActions, mapGetters} from "vuex";
import Vue from "vue"; import Vue from "vue";
import {format} from "date-fns"; import {format} from "date-fns";
import MLIcon from "@/components/icons/MlIcon.vue";
export default Vue.extend({ export default Vue.extend({
components: { components: {
SmallBadge MLIcon,
}, SmallBadge
data() {
return {
loading: true,
lastClickIndex: null
}
},
computed: {
...mapGetters([
"indices", "selectedIndices"
]),
selectedIndicesIds() {
return this.selectedIndices.map(idx => idx.id)
}, },
isMobile() { data() {
return window.innerWidth <= 650; return {
} loading: true,
}, lastClickIndex: null
methods: {
...mapActions({
setSelectedIndices: "setSelectedIndices"
}),
shiftClick(index, e) {
if (this.lastClickIndex === null) {
return;
}
const select = this.isSelected(this.lastClickIndex);
let leftBoundary = this.indices.indexOf(this.lastClickIndex);
let rightBoundary = this.indices.indexOf(index);
if (rightBoundary < leftBoundary) {
let tmp = leftBoundary;
leftBoundary = rightBoundary;
rightBoundary = tmp;
}
for (let i = leftBoundary; i <= rightBoundary; i++) {
if (select) {
if (!this.isSelected(this.indices[i])) {
this.setSelectedIndices([this.indices[i], ...this.selectedIndices]);
}
} else {
this.setSelectedIndices(this.selectedIndices.filter(idx => idx !== this.indices[i]));
} }
}
}, },
selectAll() { computed: {
this.setSelectedIndices(this.indices); ...mapGetters([
"indices", "selectedIndices"
]),
selectedIndicesIds() {
return this.selectedIndices.map(idx => idx.id)
},
isMobile() {
return window.innerWidth <= 650;
}
}, },
selectNone() { methods: {
this.setSelectedIndices([]); ...mapActions({
}, setSelectedIndices: "setSelectedIndices"
onSelect(value) { }),
this.setSelectedIndices(this.indices.filter(idx => value.includes(idx.id))); shiftClick(index, e) {
}, if (this.lastClickIndex === null) {
formatIdxDate(timestamp: number): string { return;
return format(new Date(timestamp * 1000), "yyyy-MM-dd"); }
},
toggleIndex(index, e) {
if (e.shiftKey) {
return;
}
this.lastClickIndex = index; const select = this.isSelected(this.lastClickIndex);
if (this.isSelected(index)) {
this.setSelectedIndices(this.selectedIndices.filter(idx => idx.id != index.id)); let leftBoundary = this.indices.indexOf(this.lastClickIndex);
} else { let rightBoundary = this.indices.indexOf(index);
this.setSelectedIndices([index, ...this.selectedIndices]);
} if (rightBoundary < leftBoundary) {
let tmp = leftBoundary;
leftBoundary = rightBoundary;
rightBoundary = tmp;
}
for (let i = leftBoundary; i <= rightBoundary; i++) {
if (select) {
if (!this.isSelected(this.indices[i])) {
this.setSelectedIndices([this.indices[i], ...this.selectedIndices]);
}
} else {
this.setSelectedIndices(this.selectedIndices.filter(idx => idx !== this.indices[i]));
}
}
},
selectAll() {
this.setSelectedIndices(this.indices);
},
selectNone() {
this.setSelectedIndices([]);
},
onSelect(value) {
this.setSelectedIndices(this.indices.filter(idx => value.includes(idx.id)));
},
formatIdxDate(timestamp: number): string {
return format(new Date(timestamp * 1000), "yyyy-MM-dd");
},
toggleIndex(index, e) {
if (e.shiftKey) {
return;
}
this.lastClickIndex = index;
if (this.isSelected(index)) {
this.setSelectedIndices(this.selectedIndices.filter(idx => idx.id != index.id));
} else {
this.setSelectedIndices([index, ...this.selectedIndices]);
}
},
isSelected(index) {
return this.selectedIndices.find(idx => idx.id == index.id) != null;
}
}, },
isSelected(index) {
return this.selectedIndices.find(idx => idx.id == index.id) != null;
}
},
}) })
</script> </script>
<style scoped> <style scoped>
.timestamp-text { .timestamp-text {
line-height: 24px; line-height: 24px;
font-size: 80%; font-size: 80%;
} }
.theme-black .version-badge { .theme-black .version-badge {
color: #eee !important; color: #eee !important;
background: none; background: none;
} }
.version-badge { .version-badge {
color: #222 !important; color: #222 !important;
background: none; background: none;
} }
.list-group-item { .list-group-item {
padding: 0.2em 0.4em; padding: 0.2em 0.4em;
} }
#index-picker-desktop { #index-picker-desktop {
overflow-y: auto; overflow-y: auto;
max-height: 132px; max-height: 132px;
} }
.btn-link:focus { .btn-link:focus {
box-shadow: none; box-shadow: none;
} }
.unselectable { .unselectable {
user-select: none; user-select: none;
-ms-user-select: none; -ms-user-select: none;
-moz-user-select: none; -moz-user-select: none;
-webkit-user-select: none; -webkit-user-select: none;
} }
.list-group-item.active { .list-group-item.active {
z-index: 2; z-index: 2;
background-color: inherit; background-color: inherit;
color: inherit; color: inherit;
} }
.theme-black .list-group-item { .theme-black .list-group-item {
border: 1px solid rgba(255,255,255, 0.1); border: 1px solid rgba(255, 255, 255, 0.1);
} }
.theme-black .list-group-item:first-child { .theme-black .list-group-item:first-child {
border: 1px solid rgba(255,255,255, 0.05); border: 1px solid rgba(255, 255, 255, 0.05);
} }
.theme-black .list-group-item.active { .theme-black .list-group-item.active {
z-index: 2; z-index: 2;
background-color: inherit; background-color: inherit;
color: inherit; color: inherit;
border: 1px solid rgba(255,255,255, 0.3); border: 1px solid rgba(255, 255, 255, 0.3);
border-radius: 0; border-radius: 0;
} }
.theme-black .list-group { .theme-black .list-group {
border-radius: 0; border-radius: 0;
} }
</style> </style>

@ -1,5 +1,5 @@
<template> <template>
<b-dropdown variant="primary" :disabled="$store.getters.embeddingText !== ''"> <b-dropdown variant="primary" :disabled="$store.getters.embedding !== null">
<b-dropdown-item :class="{'dropdown-active': sort === 'score'}" @click="onSelect('score')">{{ <b-dropdown-item :class="{'dropdown-active': sort === 'score'}" @click="onSelect('score')">{{
$t("sort.relevance") $t("sort.relevance")
}} }}

@ -1,6 +1,6 @@
<template> <template>
<svg height="20px" width="20px" xmlns="http://www.w3.org/2000/svg" <svg class="ml-icon" :class="{'m-icon': 1, 'ml-icon-big': big, 'ml-icon-clickable': clickable}" xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 512 512" xml:space="preserve"> viewBox="0 0 512 512" xml:space="preserve" fill="currentColor" stroke="currentColor" @click="$emit('click')">
<g> <g>
<path class="st0" d="M167.314,14.993C167.314,6.712,160.602,0,152.332,0h-5.514c-8.27,0-14.982,6.712-14.982,14.993v41.466h35.478 <path class="st0" d="M167.314,14.993C167.314,6.712,160.602,0,152.332,0h-5.514c-8.27,0-14.982,6.712-14.982,14.993v41.466h35.478
V14.993z"/> V14.993z"/>
@ -42,9 +42,35 @@
<script> <script>
export default { export default {
name: "MLIcon" name: "MLIcon",
props: {
"big": Boolean,
"clickable": Boolean
}
} }
</script> </script>
<style scoped> <style scoped>
.ml-icon-clickable {
cursor: pointer;
}
.ml-icon-big {
width: 24px !important;
height: 24px !important;
}
.ml-icon {
width: 1rem;
min-width: 1rem;
margin-right: 0.2rem;
line-height: 1rem;
height: 1rem;
min-height: 1rem;
filter: brightness(45%);
}
.theme-black .ml-icon {
filter: brightness(80%);
}
</style> </style>

@ -1,4 +1,3 @@
import '@babel/polyfill'
import 'mutationobserver-shim' import 'mutationobserver-shim'
import Vue from 'vue' import Vue from 'vue'
import './plugins/bootstrap-vue' import './plugins/bootstrap-vue'

@ -2,6 +2,7 @@ import * as ort from "onnxruntime-web";
import {BPETokenizer} from "@/ml/BPETokenizer"; import {BPETokenizer} from "@/ml/BPETokenizer";
import axios from "axios"; import axios from "axios";
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils"; import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
import ModelStore from "@/ml/ModelStore";
export class CLIPTransformerModel { export class CLIPTransformerModel {
@ -21,9 +22,17 @@ export class CLIPTransformerModel {
async loadModel(onProgress) { async loadModel(onProgress) {
ort.env.wasm.wasmPaths = ORT_WASM_PATHS; ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
const buf = await downloadToBuffer(this._modelUrl, onProgress); ort.env.wasm.numThreads = 2;
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]}); let buf = await ModelStore.get(this._modelUrl);
if (!buf) {
buf = await downloadToBuffer(this._modelUrl, onProgress);
await ModelStore.set(this._modelUrl, buf);
}
this._model = await ort.InferenceSession.create(buf.buffer, {
executionProviders: ["wasm"],
});
} }
async loadTokenizer() { async loadTokenizer() {
@ -34,11 +43,11 @@ export class CLIPTransformerModel {
async predict(text) { async predict(text) {
const tokenized = this._tokenizer.encode(text); const tokenized = this._tokenizer.encode(text);
const feeds = { const inputs = {
input_ids: new ort.Tensor("int32", tokenized, [1, 77]) input_ids: new ort.Tensor("int32", tokenized, [1, 77])
}; };
const results = await this._model.run(feeds); const results = await this._model.run(inputs);
return Array.from( return Array.from(
Object.values(results) Object.values(results)

@ -0,0 +1,67 @@
class ModelStore {
_ok;
_db;
_resolve;
_loadingPromise;
constructor() {
const request = window.indexedDB.open("ModelStore", 1);
request.onerror = () => {
this._ok = false;
}
request.onupgradeneeded = event => {
const db = event.target.result;
db.createObjectStore("models");
}
request.onsuccess = () => {
this._ok = true;
this._db = request.result;
this._resolve();
}
this._loadingPromise = new Promise(resolve => this._resolve = resolve);
}
async get(key) {
await this._loadingPromise;
const req = this._db.transaction(["models"], "readwrite")
.objectStore("models")
.get(key);
return new Promise(resolve => {
req.onsuccess = event => {
resolve(event.target.result);
};
req.onerror = event => {
console.log("ERROR:");
console.log(event);
resolve(null);
};
});
}
async set(key, val) {
await this._loadingPromise;
const req = this._db.transaction(["models"], "readwrite")
.objectStore("models")
.put(val, key);
return new Promise(resolve => {
req.onsuccess = () => {
resolve(true);
};
req.onerror = () => {
resolve(false);
};
});
}
}
export default new ModelStore();

@ -17,7 +17,6 @@ export async function downloadToBuffer(url, onProgress) {
break; break;
} }
console.log(`Sending ${value.length} bytes into ${buf.length} at offset ${cursor} (${buf.length - cursor} free)`)
buf.set(value, cursor); buf.set(value, cursor);
cursor += value.length; cursor += value.length;

@ -25,6 +25,7 @@ export default new Vuex.Store({
searchText: "", searchText: "",
embeddingText: "", embeddingText: "",
embedding: null, embedding: null,
embeddingDoc: null,
pathText: "", pathText: "",
sortMode: "score", sortMode: "score",
@ -133,7 +134,8 @@ export default new Vuex.Store({
setDateBoundsMax: (state, val) => state.dateBoundsMax = val, setDateBoundsMax: (state, val) => state.dateBoundsMax = val,
setSearchText: (state, val) => state.searchText = val, setSearchText: (state, val) => state.searchText = val,
setEmbeddingText: (state, val) => state.embeddingText = val, setEmbeddingText: (state, val) => state.embeddingText = val,
setEmbedding: (state, val) => state.embedding= val, setEmbedding: (state, val) => state.embedding = val,
setEmbeddingDoc: (state, val) => state.embeddingDoc = val,
setFuzzy: (state, val) => state.fuzzy = val, setFuzzy: (state, val) => state.fuzzy = val,
setLastQueryResult: (state, val) => state.lastQueryResults = val, setLastQueryResult: (state, val) => state.lastQueryResults = val,
setFirstQueryResult: (state, val) => state.firstQueryResults = val, setFirstQueryResult: (state, val) => state.firstQueryResults = val,

@ -13,7 +13,7 @@
<b-card v-show="!uiLoading && !showEsConnectionError" id="search-panel"> <b-card v-show="!uiLoading && !showEsConnectionError" id="search-panel">
<SearchBar @show-help="showHelp=true"></SearchBar> <SearchBar @show-help="showHelp=true"></SearchBar>
<EmbeddingsSearchBar class="mt-3"></EmbeddingsSearchBar> <EmbeddingsSearchBar v-if="hasEmbeddings" class="mt-3"></EmbeddingsSearchBar>
<b-row> <b-row>
<b-col style="height: 70px;" sm="6"> <b-col style="height: 70px;" sm="6">
<SizeSlider></SizeSlider> <SizeSlider></SizeSlider>
@ -172,6 +172,12 @@ export default Vue.extend({
setDateBoundsMax: "setDateBoundsMax", setDateBoundsMax: "setDateBoundsMax",
setTags: "setTags", setTags: "setTags",
}), }),
hasEmbeddings() {
if (!this.loading) {
return false;
}
return Sist2Api.models().some();
},
showErrorToast() { showErrorToast() {
this.$bvToast.toast( this.$bvToast.toast(
this.$t("toast.esConnErr"), this.$t("toast.esConnErr"),
@ -203,6 +209,7 @@ export default Vue.extend({
await this.handleSearch(resp); await this.handleSearch(resp);
this.searchBusy = false; this.searchBusy = false;
}).catch(err => { }).catch(err => {
console.log(err)
if (err.response.status === 500 && this.$store.state.optQueryMode === "advanced") { if (err.response.status === 500 && this.$store.state.optQueryMode === "advanced") {
this.showSyntaxErrorToast(); this.showSyntaxErrorToast();
} else { } else {

@ -1,3 +1,5 @@
const TerserPlugin = require("terser-webpack-plugin");
module.exports = { module.exports = {
filenameHashing: false, filenameHashing: false,
productionSourceMap: false, productionSourceMap: false,
@ -6,5 +8,19 @@ module.exports = {
index: { index: {
entry: "src/main.js" entry: "src/main.js"
} }
},
configureWebpack: config => {
config.optimization.minimizer = [new TerserPlugin({
terserOptions: {
compress: {
passes: 2,
module: true,
hoist_funs: true,
// https://github.com/microsoft/onnxruntime/issues/16984
unused: false,
},
mangle: true,
}
})]
} }
} }

@ -38,11 +38,6 @@ scan_args_t *scan_args_create() {
return args; return args;
} }
exec_args_t *exec_args_create() {
exec_args_t *args = calloc(sizeof(exec_args_t), 1);
return args;
}
void scan_args_destroy(scan_args_t *args) { void scan_args_destroy(scan_args_t *args) {
if (args->name != NULL) { if (args->name != NULL) {
free(args->name); free(args->name);
@ -74,17 +69,9 @@ void web_args_destroy(web_args_t *args) {
free(args); free(args);
} }
void exec_args_destroy(exec_args_t *args) {
if (args->index_path != NULL) {
free(args->index_path);
}
free(args);
}
void sqlite_index_args_destroy(sqlite_index_args_t *args) { void sqlite_index_args_destroy(sqlite_index_args_t *args) {
// TODO free(args->index_path);
free(args);
} }
int scan_args_validate(scan_args_t *args, int argc, const char **argv) { int scan_args_validate(scan_args_t *args, int argc, const char **argv) {
@ -626,43 +613,3 @@ web_args_t *web_args_create() {
web_args_t *args = calloc(sizeof(web_args_t), 1); web_args_t *args = calloc(sizeof(web_args_t), 1);
return args; return args;
} }
int exec_args_validate(exec_args_t *args, int argc, const char **argv) {
if (argc < 2) {
fprintf(stderr, "Required positional argument: PATH.\n");
return 1;
}
char *index_path = abspath(argv[1]);
if (index_path == NULL) {
LOG_FATALF("cli.c", "Invalid index PATH argument. File not found: %s", argv[1]);
} else {
args->index_path = index_path;
}
if (args->es_url == NULL) {
args->es_url = DEFAULT_ES_URL;
}
if (args->es_index == NULL) {
args->es_index = DEFAULT_ES_INDEX;
}
if (args->script_path == NULL) {
LOG_FATAL("cli.c", "--script-file argument is required");
}
if (load_external_file(args->script_path, &args->script) != 0) {
return 1;
}
LOG_DEBUGF("cli.c", "arg script_path=%s", args->script_path);
char log_buf[5000];
strncpy(log_buf, args->script, sizeof(log_buf));
*(log_buf + sizeof(log_buf) - 1) = '\0';
LOG_DEBUGF("cli.c", "arg script=%s", log_buf);
return 0;
}

@ -102,16 +102,6 @@ typedef struct web_args {
search_backend_t search_backend; search_backend_t search_backend;
} web_args_t; } web_args_t;
typedef struct exec_args {
char *es_url;
char *es_index;
int es_insecure_ssl;
char *index_path;
const char *script_path;
int async_script;
char *script;
} exec_args_t;
index_args_t *index_args_create(); index_args_t *index_args_create();
sqlite_index_args_t *sqlite_index_args_create(); sqlite_index_args_t *sqlite_index_args_create();
@ -128,12 +118,6 @@ int sqlite_index_args_validate(sqlite_index_args_t *args, int argc, const char *
int web_args_validate(web_args_t *args, int argc, const char **argv); int web_args_validate(web_args_t *args, int argc, const char **argv);
exec_args_t *exec_args_create();
void exec_args_destroy(exec_args_t *args);
int exec_args_validate(exec_args_t *args, int argc, const char **argv);
void sqlite_index_args_destroy(sqlite_index_args_t *args); void sqlite_index_args_destroy(sqlite_index_args_t *args);

@ -180,6 +180,10 @@ void database_open(database_t *db) {
db->db, "SELECT * FROM model", -1, db->db, "SELECT * FROM model", -1,
&db->get_models, NULL)); &db->get_models, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db, "SELECT embedding FROM embedding WHERE id=? AND model_id=? AND start=0", -1,
&db->get_embedding, NULL));
// Create functions // Create functions
sqlite3_create_function( sqlite3_create_function(
db->db, db->db,
@ -194,11 +198,11 @@ void database_open(database_t *db) {
sqlite3_create_function( sqlite3_create_function(
db->db, db->db,
"embedding_to_json", "emb_to_json",
5, 1,
SQLITE_UTF8, SQLITE_UTF8,
NULL, NULL,
embedding_to_json_func, emb_to_json_func,
NULL, NULL,
NULL NULL
); );
@ -494,29 +498,31 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
sqlite3_stmt *stmt; sqlite3_stmt *stmt;
sqlite3_prepare_v2(db->db, "WITH doc (j) AS (SELECT CASE" CRASH_IF_NOT_SQLITE_OK(
" WHEN sc.json_data IS NULL THEN" sqlite3_prepare_v2(
" CASE" db->db,
" WHEN t.tag IS NULL THEN" "WITH doc (j) AS (SELECT CASE"
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime)" " WHEN emb.embedding IS NULL THEN"
" ELSE" " json_set(document.json_data, "
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))" " '$._id', document.id, "
" END" " '$.size', document.size, "
" ELSE" " '$.mtime', document.mtime, "
" CASE" " '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)))"
" WHEN t.tag IS NULL THEN" " ELSE"
" json_patch(json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime), sc.json_data)" " json_set(document.json_data,"
" ELSE" " '$._id', document.id,"
// This will overwrite any tags specified in the sidecar file! " '$.size', document.size,"
// TODO: concatenate the two arrays? " '$.mtime', document.mtime,"
" json_set(json_patch(document.json_data, sc.json_data), '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))" " '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)),"
" END" " '$.emb', json_group_object(m.path, json(emb_to_json(emb.embedding))),"
" END" " '$.embedding', 1)"
" FROM document" " END"
" LEFT JOIN document_sidecar sc ON document.id = sc.id" " FROM document"
" LEFT JOIN tag t ON document.id = t.id" " LEFT JOIN embedding emb ON document.id = emb.id"
" GROUP BY document.id)" " LEFT JOIN model m ON emb.model_id = m.id"
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc", -1, &stmt, NULL); " GROUP BY document.id)"
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc",
-1, &stmt, NULL));
database_iterator_t *iter = malloc(sizeof(database_iterator_t)); database_iterator_t *iter = malloc(sizeof(database_iterator_t));
@ -526,6 +532,13 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
return iter; return iter;
} }
void remove_tag_if_null(cJSON *doc) {
cJSON *tags = cJSON_GetObjectItem(doc, "tag");
if (tags != NULL && cJSON_IsNull(cJSON_GetArrayItem(tags, 0))) {
cJSON_DeleteItemFromObject(doc, "tag");
}
}
cJSON *database_document_iter(database_iterator_t *iter) { cJSON *database_document_iter(database_iterator_t *iter) {
if (iter->stmt == NULL) { if (iter->stmt == NULL) {
@ -537,7 +550,12 @@ cJSON *database_document_iter(database_iterator_t *iter) {
if (ret == SQLITE_ROW) { if (ret == SQLITE_ROW) {
const char *json_string = (const char *) sqlite3_column_text(iter->stmt, 0); const char *json_string = (const char *) sqlite3_column_text(iter->stmt, 0);
return cJSON_Parse(json_string);
cJSON *doc = cJSON_Parse(json_string);
remove_tag_if_null(doc);
return doc;
} }
if (ret != SQLITE_DONE) { if (ret != SQLITE_DONE) {

@ -85,6 +85,7 @@ typedef struct database {
sqlite3_stmt *write_thumbnail_stmt; sqlite3_stmt *write_thumbnail_stmt;
sqlite3_stmt *get_document; sqlite3_stmt *get_document;
sqlite3_stmt *get_models; sqlite3_stmt *get_models;
sqlite3_stmt *get_embedding;
sqlite3_stmt *delete_tag_stmt; sqlite3_stmt *delete_tag_stmt;
sqlite3_stmt *write_tag_stmt; sqlite3_stmt *write_tag_stmt;
@ -144,6 +145,8 @@ void database_write_document(database_t *db, document_t *doc, const char *json_d
database_iterator_t *database_create_document_iterator(database_t *db); database_iterator_t *database_create_document_iterator(database_t *db);
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
cJSON *database_document_iter(database_iterator_t *); cJSON *database_document_iter(database_iterator_t *);
#define database_document_iter_foreach(element, iter) \ #define database_document_iter_foreach(element, iter) \
@ -235,8 +238,10 @@ cJSON *database_get_document(database_t *db, char *doc_id);
void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv); void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
cJSON *database_get_models(database_t *db); cJSON *database_get_models(database_t *db);
int database_fts_get_model_size(database_t *db, int model_id);
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id);
#endif #endif

@ -1,5 +1,6 @@
#include <openblas/cblas.h> #include <openblas/cblas.h>
#include "database.h" #include "database.h"
#include "src/ctx.h"
static float cosine_sim(int n, const float *a, const float *b) { static float cosine_sim(int n, const float *a, const float *b) {
@ -33,38 +34,6 @@ void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
sqlite3_result_double(ctx, result); sqlite3_result_double(ctx, result);
} }
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
// emb, type, start, end, size
if (argc != 5) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
const float *embedding = sqlite3_value_blob(argv[0]);
const char *type = (const char *) sqlite3_value_text(argv[1]);
int size = sqlite3_value_int(argv[4]);
if (strcmp(type, "flat") == 0) {
cJSON *json = cJSON_CreateFloatArray(embedding, size);
char *json_str = cJSON_PrintBuffered(json, size * 22, FALSE);
cJSON_Delete(json);
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
free(json_str);
} else {
int start = sqlite3_value_int(argv[2]);
int end = sqlite3_value_int(argv[3]);
sqlite3_result_error(ctx, "Nested embeddings not implemented yet", -1);
}
}
cJSON *database_get_models(database_t *db) { cJSON *database_get_models(database_t *db) {
cJSON *json = cJSON_CreateArray(); cJSON *json = cJSON_CreateArray();
sqlite3_stmt *stmt = db->get_models; sqlite3_stmt *stmt = db->get_models;
@ -72,7 +41,12 @@ cJSON *database_get_models(database_t *db) {
int ret; int ret;
do { do {
ret = sqlite3_step(stmt); ret = sqlite3_step(stmt);
CRASH_IF_STMT_FAIL(ret); if (ret == SQLITE_BUSY) {
// Database is busy (probably scanning)
LOG_WARNING("database_embeddings.c",
"Database is busy, could not fetch list of models");
break;
}
if (ret == SQLITE_DONE) { if (ret == SQLITE_DONE) {
break; break;
@ -82,7 +56,7 @@ cJSON *database_get_models(database_t *db) {
cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0)); cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0));
cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1)); cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1));
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_int64(stmt, 2)); cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_text(stmt, 2));
cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3)); cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3));
cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4)); cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4));
cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5)); cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5));
@ -90,5 +64,44 @@ cJSON *database_get_models(database_t *db) {
cJSON_AddItemToArray(json, row); cJSON_AddItemToArray(json, row);
} while (TRUE); } while (TRUE);
sqlite3_reset(stmt);
return json; return json;
} }
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id) {
sqlite3_bind_text(db->get_embedding, 1, doc_id, -1, SQLITE_STATIC);
sqlite3_bind_int(db->get_embedding, 2, model_id);
int ret = sqlite3_step(db->get_embedding);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_DONE) {
sqlite3_reset(db->get_embedding);
return NULL;
}
float *embedding = (float *) sqlite3_column_blob(db->get_embedding, 0);
size_t size = sqlite3_column_bytes(db->get_embedding, 0) / sizeof(float);
cJSON *json = cJSON_CreateFloatArray(embedding, (int) size);
sqlite3_reset(db->get_embedding);
return json;
}
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
if (argc != 1) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
float *embedding = (float *) sqlite3_value_blob(argv[0]);
int size = sqlite3_value_bytes(argv[0]) / 4;
cJSON *json = cJSON_CreateFloatArray(embedding, size);
char *json_str = cJSON_PrintUnformatted(json);
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
free(json_str);
cJSON_Delete(json);
}

@ -67,6 +67,11 @@ void database_fts_index(database_t *db) {
"REPLACE INTO fts.embedding (id, model_id, start, end, embedding)" "REPLACE INTO fts.embedding (id, model_id, start, end, embedding)"
" SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL)); " SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
"INSERT INTO fts.model (id, size)"
" SELECT id, size FROM model WHERE TRUE ON CONFLICT (id) DO NOTHING", NULL, NULL, NULL));
// TODO: delete old embeddings // TODO: delete old embeddings
LOG_DEBUG("database_fts.c", "Deleting old documents"); LOG_DEBUG("database_fts.c", "Deleting old documents");
@ -518,20 +523,15 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
const char *json_object_sql; const char *json_object_sql;
if (highlight && query_where != NULL) { if (highlight && query_where != NULL) {
json_object_sql = "json_remove(json_set(doc.json_data," json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
"'$.index', doc.index_id," "'$.index', doc.index_id,"
"'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END),"
"'$._highlight.name', snippet(search, 0, '<mark>', '</mark>', '', ?6)," "'$._highlight.name', snippet(search, 0, '<mark>', '</mark>', '', ?6),"
"'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6))," "'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6))";
"'$.content')";
} else { } else {
json_object_sql = "json_remove(json_set(doc.json_data," json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
"'$.index', doc.index_id)," "'$.index', doc.index_id,"
"'$.content')"; "'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END))";
}
const char *embedding_join = "";
if (embedding) {
embedding_join = "LEFT JOIN embedding emb ON emb.id = doc.id AND emb.model_id=?9";
} }
char *sql; char *sql;
@ -544,12 +544,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
" %s, %s as sort_var, doc.ROWID" " %s, %s as sort_var, doc.ROWID"
" FROM search" " FROM search"
" INNER JOIN document_index doc on doc.ROWID = search.ROWID" " INNER JOIN document_index doc on doc.ROWID = search.ROWID"
" %s" " LEFT JOIN embedding emb on emb.id = doc.id"
" WHERE %s" " WHERE %s"
" ORDER BY sort_var%s, doc.ROWID" " ORDER BY sort_var%s, doc.ROWID"
" LIMIT ?2", " LIMIT ?2",
json_object_sql, get_sort_var(sort), json_object_sql, get_sort_var(sort),
embedding_join,
where, where,
sort_asc ? "" : " DESC"); sort_asc ? "" : " DESC");
@ -567,12 +566,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"SELECT" "SELECT"
" %s, %s as sort_var, doc.ROWID" " %s, %s as sort_var, doc.ROWID"
" FROM document_index doc" " FROM document_index doc"
" %s" " LEFT JOIN embedding emb on emb.id = doc.id"
" WHERE %s" " WHERE %s"
" ORDER BY sort_var%s,doc.ROWID" " ORDER BY sort_var%s,doc.ROWID"
" LIMIT ?2", " LIMIT ?2",
json_object_sql, get_sort_var(sort), json_object_sql, get_sort_var(sort),
embedding_join,
where, where,
sort_asc ? "" : " DESC"); sort_asc ? "" : " DESC");
@ -624,7 +622,7 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (after_where) { if (after_where) {
if (sort == FTS_SORT_NAME || sort == FTS_SORT_ID) { if (sort == FTS_SORT_NAME || sort == FTS_SORT_ID) {
sqlite3_bind_text(stmt, 3, after[0], -1, SQLITE_STATIC); sqlite3_bind_text(stmt, 3, after[0], -1, SQLITE_STATIC);
} else if (sort == FTS_SORT_SCORE) { } else if (sort == FTS_SORT_SCORE || sort == FTS_SORT_EMBEDDING) {
sqlite3_bind_double(stmt, 3, strtod(after[0], NULL)); sqlite3_bind_double(stmt, 3, strtod(after[0], NULL));
} else { } else {
sqlite3_bind_int64(stmt, 3, strtol(after[0], NULL, 10)); sqlite3_bind_int64(stmt, 3, strtol(after[0], NULL, 10));

@ -49,12 +49,8 @@ const char *FtsDatabaseSchema =
");" ");"
"" ""
"CREATE TABLE IF NOT EXISTS model (" "CREATE TABLE IF NOT EXISTS model ("
" id INTEGER PRIMARY KEY," " id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 )," " size INTEGER NOT NULL"
" url TEXT,"
" path TEXT NOT NULL UNIQUE,"
" size INTEGER NOT NULL,"
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
");" ");"
"" ""
"CREATE TRIGGER IF NOT EXISTS tag_write_trigger" "CREATE TRIGGER IF NOT EXISTS tag_write_trigger"
@ -183,5 +179,14 @@ const char *IndexDatabaseSchema =
" end INTEGER," " end INTEGER,"
" embedding BLOB NOT NULL," " embedding BLOB NOT NULL,"
" PRIMARY KEY (id, model_id, start)" " PRIMARY KEY (id, model_id, start)"
");"
""
"CREATE TABLE model ("
" id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
" url TEXT,"
" path TEXT NOT NULL UNIQUE,"
" size INTEGER NOT NULL,"
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
");"; ");";

@ -98,61 +98,6 @@ void index_json(cJSON *document, const char doc_id[SIST_DOC_ID_LEN]) {
free(bulk_line); free(bulk_line);
} }
void execute_update_script(const char *script, int async, const char index_id[SIST_INDEX_ID_LEN]) {
if (Indexer == NULL) {
Indexer = create_indexer(IndexCtx.es_url, IndexCtx.es_index);
}
cJSON *body = cJSON_CreateObject();
cJSON *script_obj = cJSON_AddObjectToObject(body, "script");
cJSON_AddStringToObject(script_obj, "lang", "painless");
cJSON_AddStringToObject(script_obj, "source", script);
cJSON *query = cJSON_AddObjectToObject(body, "query");
cJSON *term_obj = cJSON_AddObjectToObject(query, "term");
cJSON_AddStringToObject(term_obj, "index", index_id);
char *str = cJSON_PrintUnformatted(body);
char url[4096];
if (async) {
snprintf(url, sizeof(url), "%s/%s/_update_by_query?wait_for_completion=false", Indexer->es_url,
Indexer->es_index);
} else {
snprintf(url, sizeof(url), "%s/%s/_update_by_query", Indexer->es_url, Indexer->es_index);
}
response_t *r = web_post(url, str, IndexCtx.es_insecure_ssl);
if (!async) {
LOG_INFOF("elastic.c", "Executed user script <%d>", r->status_code);
}
cJSON *resp = cJSON_Parse(r->body);
cJSON_free(str);
cJSON_Delete(body);
free_response(r);
cJSON *error = cJSON_GetObjectItem(resp, "error");
if (error != NULL) {
char *error_str = cJSON_Print(error);
LOG_ERRORF("elastic.c", "User script error: \n%s", error_str);
cJSON_free(error_str);
}
if (async) {
cJSON *task = cJSON_GetObjectItem(resp, "task");
if (task == NULL) {
LOG_FATALF("elastic.c", "FIXME: Could not get task id: %s", r->body);
}
LOG_INFOF("elastic.c", "User script queued: %s/_tasks/%s", Indexer->es_url, task->valuestring);
}
cJSON_Delete(resp);
}
void *create_bulk_buffer(int max, int *count, size_t *buf_len, int legacy) { void *create_bulk_buffer(int max, int *count, size_t *buf_len, int legacy) {
es_bulk_line_t *line = Indexer->line_head; es_bulk_line_t *line = Indexer->line_head;
*count = 0; *count = 0;
@ -403,7 +348,7 @@ es_indexer_t *create_indexer(const char *url, const char *index) {
return indexer; return indexer;
} }
void finish_indexer(char *script, int async_script, char *index_id) { void finish_indexer(char *index_id) {
char url[4096]; char url[4096];
@ -412,16 +357,6 @@ void finish_indexer(char *script, int async_script, char *index_id) {
LOG_INFOF("elastic.c", "Refresh index <%d>", r->status_code); LOG_INFOF("elastic.c", "Refresh index <%d>", r->status_code);
free_response(r); free_response(r);
if (script != NULL) {
execute_update_script(script, async_script, index_id);
free(script);
snprintf(url, sizeof(url), "%s/%s/_refresh", IndexCtx.es_url, IndexCtx.es_index);
r = web_post(url, "", IndexCtx.es_insecure_ssl);
LOG_INFOF("elastic.c", "Refresh index <%d>", r->status_code);
free_response(r);
}
snprintf(url, sizeof(url), "%s/%s/_forcemerge", IndexCtx.es_url, IndexCtx.es_index); snprintf(url, sizeof(url), "%s/%s/_forcemerge", IndexCtx.es_url, IndexCtx.es_index);
r = web_post(url, "", IndexCtx.es_insecure_ssl); r = web_post(url, "", IndexCtx.es_insecure_ssl);
LOG_INFOF("elastic.c", "Merge index <%d>", r->status_code); LOG_INFOF("elastic.c", "Merge index <%d>", r->status_code);

@ -24,6 +24,8 @@ typedef struct {
#define IS_SUPPORTED_ES_VERSION(es_version) ((es_version) != NULL && VERSION_GE((es_version), 6, 8) && VERSION_LT((es_version), 9, 0)) #define IS_SUPPORTED_ES_VERSION(es_version) ((es_version) != NULL && VERSION_GE((es_version), 6, 8) && VERSION_LT((es_version), 9, 0))
#define IS_LEGACY_VERSION(es_version) ((es_version) != NULL && VERSION_LT((es_version), 7, 14)) #define IS_LEGACY_VERSION(es_version) ((es_version) != NULL && VERSION_LT((es_version), 7, 14))
#define HAS_KNN(es_version) ((es_version) != NULL && VERSION_GE((es_version), 8, 0))
__always_inline __always_inline
static const char *format_es_version(es_version_t *version) { static const char *format_es_version(es_version_t *version) {
@ -51,7 +53,7 @@ void delete_document(const char *document_id);
es_indexer_t *create_indexer(const char *url, const char *index); es_indexer_t *create_indexer(const char *url, const char *index);
void elastic_cleanup(); void elastic_cleanup();
void finish_indexer(char *script, int async_script, char *index_id); void finish_indexer(char *index_id);
void elastic_init(int force_reset, const char* user_mappings, const char* user_settings); void elastic_init(int force_reset, const char* user_mappings, const char* user_settings);
@ -61,6 +63,4 @@ char *elastic_get_status();
es_version_t *elastic_get_version(const char *es_url, int insecure); es_version_t *elastic_get_version(const char *es_url, int insecure);
void execute_update_script(const char *script, int async, const char index_id[SIST_INDEX_ID_LEN]);
#endif #endif

@ -24,7 +24,6 @@ static const char *const usage[] = {
"sist2 index [OPTION]... INDEX", "sist2 index [OPTION]... INDEX",
"sist2 sqlite-index [OPTION]... INDEX", "sist2 sqlite-index [OPTION]... INDEX",
"sist2 web [OPTION]... INDEX...", "sist2 web [OPTION]... INDEX...",
"sist2 exec-script [OPTION]... INDEX",
NULL, NULL,
}; };
@ -349,7 +348,7 @@ void sist2_index(index_args_t *args) {
tpool_destroy(IndexCtx.pool); tpool_destroy(IndexCtx.pool);
if (IndexCtx.needs_es_connection) { if (IndexCtx.needs_es_connection) {
finish_indexer(args->script, args->async_script, desc->id); finish_indexer(desc->id);
} }
free(desc); free(desc);
} }
@ -370,25 +369,6 @@ void sist2_sqlite_index(sqlite_index_args_t *args) {
database_close(search_db, FALSE); database_close(search_db, FALSE);
} }
void sist2_exec_script(exec_args_t *args) {
LogCtx.verbose = TRUE;
IndexCtx.es_url = args->es_url;
IndexCtx.es_index = args->es_index;
IndexCtx.es_insecure_ssl = args->es_insecure_ssl;
IndexCtx.needs_es_connection = TRUE;
database_t *db = database_create(args->index_path, INDEX_DATABASE);
database_open(db);
index_descriptor_t *desc = database_read_index_descriptor(db);
LOG_DEBUGF("main.c", "Index version %s", desc->version);
execute_update_script(args->script, args->async_script, desc->id);
free(args->script);
database_close(db, FALSE);
}
void sist2_web(web_args_t *args) { void sist2_web(web_args_t *args) {
WebCtx.es_url = args->es_url; WebCtx.es_url = args->es_url;
@ -464,15 +444,9 @@ int set_to_negative_if_value_is_zero(UNUSED(struct argparse *self), const struct
int main(int argc, const char *argv[]) { int main(int argc, const char *argv[]) {
setlocale(LC_ALL, ""); setlocale(LC_ALL, "");
// database_t *db = database_create("clip.sist2", INDEX_DATABASE);
// database_open(db);
// database_test(db);
// exit(0);
scan_args_t *scan_args = scan_args_create(); scan_args_t *scan_args = scan_args_create();
index_args_t *index_args = index_args_create(); index_args_t *index_args = index_args_create();
web_args_t *web_args = web_args_create(); web_args_t *web_args = web_args_create();
exec_args_t *exec_args = exec_args_create();
sqlite_index_args_t *sqlite_index_args = sqlite_index_args_create(); sqlite_index_args_t *sqlite_index_args = sqlite_index_args_create();
int arg_version = 0; int arg_version = 0;
@ -481,7 +455,6 @@ int main(int argc, const char *argv[]) {
int common_es_insecure_ssl = 0; int common_es_insecure_ssl = 0;
char *common_es_index = NULL; char *common_es_index = NULL;
char *common_script_path = NULL; char *common_script_path = NULL;
int common_async_script = 0;
int common_threads = 0; int common_threads = 0;
int common_optimize_database = 0; int common_optimize_database = 0;
char *common_search_index = NULL; char *common_search_index = NULL;
@ -556,7 +529,6 @@ int main(int argc, const char *argv[]) {
OPT_STRING(0, "script-file", &common_script_path, "Path to user script."), OPT_STRING(0, "script-file", &common_script_path, "Path to user script."),
OPT_STRING(0, "mappings-file", &index_args->es_mappings_path, "Path to Elasticsearch mappings."), OPT_STRING(0, "mappings-file", &index_args->es_mappings_path, "Path to Elasticsearch mappings."),
OPT_STRING(0, "settings-file", &index_args->es_settings_path, "Path to Elasticsearch settings."), OPT_STRING(0, "settings-file", &index_args->es_settings_path, "Path to Elasticsearch settings."),
OPT_BOOLEAN(0, "async-script", &common_async_script, "Execute user script asynchronously."),
OPT_INTEGER(0, "batch-size", &index_args->batch_size, "Index batch size. DEFAULT: 70"), OPT_INTEGER(0, "batch-size", &index_args->batch_size, "Index batch size. DEFAULT: 70"),
OPT_BOOLEAN('f', "force-reset", &index_args->force_reset, "Reset Elasticsearch mappings and settings."), OPT_BOOLEAN('f', "force-reset", &index_args->force_reset, "Reset Elasticsearch mappings and settings."),
@ -567,7 +539,6 @@ int main(int argc, const char *argv[]) {
OPT_STRING(0, "es-url", &common_es_url, "Elasticsearch url. DEFAULT: http://localhost:9200"), OPT_STRING(0, "es-url", &common_es_url, "Elasticsearch url. DEFAULT: http://localhost:9200"),
OPT_BOOLEAN(0, "es-insecure-ssl", &common_es_insecure_ssl, OPT_BOOLEAN(0, "es-insecure-ssl", &common_es_insecure_ssl,
"Do not verify SSL connections to Elasticsearch."), "Do not verify SSL connections to Elasticsearch."),
// TODO: change arg name (?)
OPT_STRING(0, "search-index", &common_search_index, "Path to SQLite search index."), OPT_STRING(0, "search-index", &common_search_index, "Path to SQLite search index."),
OPT_STRING(0, "es-index", &common_es_index, "Elasticsearch index name. DEFAULT: sist2"), OPT_STRING(0, "es-index", &common_es_index, "Elasticsearch index name. DEFAULT: sist2"),
OPT_STRING(0, "bind", &web_args->listen_address, OPT_STRING(0, "bind", &web_args->listen_address,
@ -583,14 +554,6 @@ int main(int argc, const char *argv[]) {
OPT_BOOLEAN(0, "dev", &web_args->dev, "Serve html & js files from disk (for development)"), OPT_BOOLEAN(0, "dev", &web_args->dev, "Serve html & js files from disk (for development)"),
OPT_STRING(0, "lang", &web_args->lang, "Default UI language. Can be changed by the user"), OPT_STRING(0, "lang", &web_args->lang, "Default UI language. Can be changed by the user"),
OPT_GROUP("Exec-script options"),
OPT_STRING(0, "es-url", &common_es_url, "Elasticsearch url. DEFAULT: http://localhost:9200"),
OPT_BOOLEAN(0, "es-insecure-ssl", &common_es_insecure_ssl,
"Do not verify SSL connections to Elasticsearch."),
OPT_STRING(0, "es-index", &common_es_index, "Elasticsearch index name. DEFAULT: sist2"),
OPT_STRING(0, "script-file", &common_script_path, "Path to user script."),
OPT_BOOLEAN(0, "async-script", &common_async_script, "Execute user script asynchronously."),
OPT_END(), OPT_END(),
}; };
@ -614,22 +577,16 @@ int main(int argc, const char *argv[]) {
web_args->es_url = common_es_url; web_args->es_url = common_es_url;
index_args->es_url = common_es_url; index_args->es_url = common_es_url;
exec_args->es_url = common_es_url;
web_args->es_index = common_es_index; web_args->es_index = common_es_index;
index_args->es_index = common_es_index; index_args->es_index = common_es_index;
exec_args->es_index = common_es_index;
web_args->es_insecure_ssl = common_es_insecure_ssl; web_args->es_insecure_ssl = common_es_insecure_ssl;
index_args->es_insecure_ssl = common_es_insecure_ssl; index_args->es_insecure_ssl = common_es_insecure_ssl;
exec_args->es_insecure_ssl = common_es_insecure_ssl;
index_args->script_path = common_script_path; index_args->script_path = common_script_path;
exec_args->script_path = common_script_path;
index_args->threads = common_threads; index_args->threads = common_threads;
scan_args->threads = common_threads; scan_args->threads = common_threads;
exec_args->async_script = common_async_script;
index_args->async_script = common_async_script;
scan_args->optimize_database = common_optimize_database; scan_args->optimize_database = common_optimize_database;
@ -671,14 +628,6 @@ int main(int argc, const char *argv[]) {
} }
sist2_web(web_args); sist2_web(web_args);
} else if (strcmp(argv[0], "exec-script") == 0) {
int err = exec_args_validate(exec_args, argc, argv);
if (err != 0) {
goto end;
}
sist2_exec_script(exec_args);
} else { } else {
argparse_usage(&argparse); argparse_usage(&argparse);
LOG_FATALF("main.c", "Invalid command: '%s'\n", argv[0]); LOG_FATALF("main.c", "Invalid command: '%s'\n", argv[0]);
@ -689,7 +638,6 @@ int main(int argc, const char *argv[]) {
scan_args_destroy(scan_args); scan_args_destroy(scan_args);
index_args_destroy(index_args); index_args_destroy(index_args);
web_args_destroy(web_args); web_args_destroy(web_args);
exec_args_destroy(exec_args);
sqlite_index_args_destroy(sqlite_index_args); sqlite_index_args_destroy(sqlite_index_args);
return 0; return 0;

@ -16,6 +16,7 @@ typedef struct {
typedef struct tpool { typedef struct tpool {
pthread_t threads[256]; pthread_t threads[256];
void *start_thread_args[256];
int num_threads; int num_threads;
int print_progress; int print_progress;
@ -293,6 +294,8 @@ void tpool_destroy(tpool_t *pool) {
void *_; void *_;
pthread_join(thread, &_); pthread_join(thread, &_);
} }
free(pool->start_thread_args[i]);
} }
pthread_mutex_destroy(&pool->shm->ipc_ctx.mutex); pthread_mutex_destroy(&pool->shm->ipc_ctx.mutex);
@ -320,6 +323,7 @@ tpool_t *tpool_create(int thread_cnt, int print_progress) {
pool->shm->waiting = FALSE; pool->shm->waiting = FALSE;
pool->shm->job_type = JOB_UNDEFINED; pool->shm->job_type = JOB_UNDEFINED;
memset(pool->threads, 0, sizeof(pool->threads)); memset(pool->threads, 0, sizeof(pool->threads));
memset(pool->start_thread_args, 0, sizeof(pool->start_thread_args));
pool->print_progress = print_progress; pool->print_progress = print_progress;
sprintf(pool->shm->ipc_database_filepath, "/dev/shm/sist2-ipc-%d.sqlite", getpid()); sprintf(pool->shm->ipc_database_filepath, "/dev/shm/sist2-ipc-%d.sqlite", getpid());
@ -361,6 +365,7 @@ void tpool_start(tpool_t *pool) {
arg->pool = pool; arg->pool = pool;
pthread_create(&pool->threads[i], NULL, tpool_worker, arg); pthread_create(&pool->threads[i], NULL, tpool_worker, arg);
pool->start_thread_args[i] = arg;
} }
// Only open the database when all workers are done initializing // Only open the database when all workers are done initializing

@ -36,9 +36,52 @@ static struct mg_http_serve_opts IndexServeOpts = {
.ssi_pattern = NULL, .ssi_pattern = NULL,
.root_dir = NULL, .root_dir = NULL,
.mime_types = "", .mime_types = "",
.extra_headers = HTTP_SERVER_HEADER "Cross-Origin-Embedder-Policy: require-corp\r\nCross-Origin-Opener-Policy: same-origin\r\n" .extra_headers = HTTP_SERVER_HEADER HTTP_CROSS_ORIGIN_HEADERS
}; };
void get_embedding(struct mg_connection *nc, struct mg_http_message *hm) {
if (WebCtx.search_backend == ES_SEARCH_BACKEND && WebCtx.es_version != NULL && !HAS_KNN(WebCtx.es_version)) {
LOG_WARNINGF("serve.c",
"Your Elasticsearch version (%d.%d.%d) does not support approximate kNN search and will"
" fallback to a brute-force search. Please install ES 8.x.x+ for better search performance.",
WebCtx.es_version->major, WebCtx.es_version->minor, WebCtx.es_version->patch);
}
if (hm->uri.len != SIST_INDEX_ID_LEN + SIST_DOC_ID_LEN + 2 + 4) {
LOG_DEBUGF("serve.c", "Invalid thumbnail path: %.*s", (int) hm->uri.len, hm->uri.ptr);
HTTP_REPLY_NOT_FOUND
return;
}
char doc_id[SIST_DOC_ID_LEN];
char index_id[SIST_INDEX_ID_LEN];
memcpy(index_id, hm->uri.ptr + 3, SIST_INDEX_ID_LEN);
*(index_id + SIST_INDEX_ID_LEN - 1) = '\0';
memcpy(doc_id, hm->uri.ptr + 3 + SIST_INDEX_ID_LEN, SIST_DOC_ID_LEN);
*(doc_id + SIST_DOC_ID_LEN - 1) = '\0';
int model_id = (int) strtol(hm->uri.ptr + SIST_INDEX_ID_LEN + SIST_DOC_ID_LEN + 3, NULL, 10);
database_t *db = web_get_database(index_id);
if (db == NULL) {
LOG_DEBUGF("serve.c", "Could not get database for index: %s", index_id);
HTTP_REPLY_NOT_FOUND
return;
}
cJSON *json = database_get_embedding(db, doc_id, model_id);
if (json == NULL) {
HTTP_REPLY_NOT_FOUND
return;
}
mg_send_json(nc, json);
cJSON_Delete(json);
}
void stats_files(struct mg_connection *nc, struct mg_http_message *hm) { void stats_files(struct mg_connection *nc, struct mg_http_message *hm) {
if (hm->uri.len != SIST_INDEX_ID_LEN + 7) { if (hm->uri.len != SIST_INDEX_ID_LEN + 7) {
@ -316,6 +359,7 @@ void index_info(struct mg_connection *nc) {
cJSON_AddBoolToObject(json, "esVersionSupported", IS_SUPPORTED_ES_VERSION(WebCtx.es_version)); cJSON_AddBoolToObject(json, "esVersionSupported", IS_SUPPORTED_ES_VERSION(WebCtx.es_version));
cJSON_AddBoolToObject(json, "esVersionLegacy", IS_LEGACY_VERSION(WebCtx.es_version)); cJSON_AddBoolToObject(json, "esVersionLegacy", IS_LEGACY_VERSION(WebCtx.es_version));
cJSON_AddBoolToObject(json, "esVersionHasKnn", HAS_KNN(WebCtx.es_version));
cJSON_AddStringToObject(json, "lang", WebCtx.lang); cJSON_AddStringToObject(json, "lang", WebCtx.lang);
cJSON_AddBoolToObject(json, "auth0Enabled", WebCtx.auth0_enabled); cJSON_AddBoolToObject(json, "auth0Enabled", WebCtx.auth0_enabled);
@ -708,6 +752,9 @@ static void ev_router(struct mg_connection *nc, int ev, void *ev_data, UNUSED(vo
return; return;
} }
tag(nc, hm); tag(nc, hm);
} else if (mg_http_match_uri(hm, "/e/*/*/*")) {
get_embedding(nc, hm);
return;
} else { } else {
HTTP_REPLY_NOT_FOUND HTTP_REPLY_NOT_FOUND
} }

@ -262,8 +262,8 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
: DEFAULT_HIGHLIGHT_CONTEXT_SIZE; : DEFAULT_HIGHLIGHT_CONTEXT_SIZE;
req->model = req_model.val ? req_model.val->valueint : 0; req->model = req_model.val ? req_model.val->valueint : 0;
req->embedding = req_model.val req->embedding = req_model.val
? get_float_buffer(req_embedding.val, &req->embedding_size) ? get_float_buffer(req_embedding.val, &req->embedding_size)
: NULL; : NULL;
cJSON_Delete(json); cJSON_Delete(json);

@ -3,7 +3,7 @@
void web_serve_asset_index_html(struct mg_connection *nc) { void web_serve_asset_index_html(struct mg_connection *nc) {
web_send_headers(nc, 200, sizeof(index_html), "Content-Type: text/html"); web_send_headers(nc, 200, sizeof(index_html), HTTP_CROSS_ORIGIN_HEADERS "Content-Type: text/html");
mg_send(nc, index_html, sizeof(index_html)); mg_send(nc, index_html, sizeof(index_html));
} }

@ -7,6 +7,8 @@
#include <mongoose.h> #include <mongoose.h>
#define HTTP_SERVER_HEADER "Server: sist2/" VERSION "\r\n" #define HTTP_SERVER_HEADER "Server: sist2/" VERSION "\r\n"
// See https://web.dev/coop-coep/
#define HTTP_CROSS_ORIGIN_HEADERS "Cross-Origin-Embedder-Policy: require-corp\r\nCross-Origin-Opener-Policy: same-origin\r\n"
index_t *web_get_index_by_id(const char *index_id); index_t *web_get_index_by_id(const char *index_id);