mirror of
https://github.com/simon987/sist2.git
synced 2025-12-12 06:58:54 +00:00
Rework user scripts, update DB schema to support embeddings
This commit is contained in:
57
src/cli.c
57
src/cli.c
@@ -38,11 +38,6 @@ scan_args_t *scan_args_create() {
|
||||
return args;
|
||||
}
|
||||
|
||||
exec_args_t *exec_args_create() {
|
||||
exec_args_t *args = calloc(sizeof(exec_args_t), 1);
|
||||
return args;
|
||||
}
|
||||
|
||||
void scan_args_destroy(scan_args_t *args) {
|
||||
if (args->name != NULL) {
|
||||
free(args->name);
|
||||
@@ -74,17 +69,9 @@ void web_args_destroy(web_args_t *args) {
|
||||
free(args);
|
||||
}
|
||||
|
||||
void exec_args_destroy(exec_args_t *args) {
|
||||
|
||||
if (args->index_path != NULL) {
|
||||
free(args->index_path);
|
||||
}
|
||||
|
||||
free(args);
|
||||
}
|
||||
|
||||
void sqlite_index_args_destroy(sqlite_index_args_t *args) {
|
||||
// TODO
|
||||
free(args->index_path);
|
||||
free(args);
|
||||
}
|
||||
|
||||
int scan_args_validate(scan_args_t *args, int argc, const char **argv) {
|
||||
@@ -626,43 +613,3 @@ web_args_t *web_args_create() {
|
||||
web_args_t *args = calloc(sizeof(web_args_t), 1);
|
||||
return args;
|
||||
}
|
||||
|
||||
int exec_args_validate(exec_args_t *args, int argc, const char **argv) {
|
||||
|
||||
if (argc < 2) {
|
||||
fprintf(stderr, "Required positional argument: PATH.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
char *index_path = abspath(argv[1]);
|
||||
if (index_path == NULL) {
|
||||
LOG_FATALF("cli.c", "Invalid index PATH argument. File not found: %s", argv[1]);
|
||||
} else {
|
||||
args->index_path = index_path;
|
||||
}
|
||||
|
||||
if (args->es_url == NULL) {
|
||||
args->es_url = DEFAULT_ES_URL;
|
||||
}
|
||||
|
||||
if (args->es_index == NULL) {
|
||||
args->es_index = DEFAULT_ES_INDEX;
|
||||
}
|
||||
|
||||
if (args->script_path == NULL) {
|
||||
LOG_FATAL("cli.c", "--script-file argument is required");
|
||||
}
|
||||
|
||||
if (load_external_file(args->script_path, &args->script) != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG_DEBUGF("cli.c", "arg script_path=%s", args->script_path);
|
||||
|
||||
char log_buf[5000];
|
||||
strncpy(log_buf, args->script, sizeof(log_buf));
|
||||
*(log_buf + sizeof(log_buf) - 1) = '\0';
|
||||
LOG_DEBUGF("cli.c", "arg script=%s", log_buf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
16
src/cli.h
16
src/cli.h
@@ -102,16 +102,6 @@ typedef struct web_args {
|
||||
search_backend_t search_backend;
|
||||
} web_args_t;
|
||||
|
||||
typedef struct exec_args {
|
||||
char *es_url;
|
||||
char *es_index;
|
||||
int es_insecure_ssl;
|
||||
char *index_path;
|
||||
const char *script_path;
|
||||
int async_script;
|
||||
char *script;
|
||||
} exec_args_t;
|
||||
|
||||
index_args_t *index_args_create();
|
||||
|
||||
sqlite_index_args_t *sqlite_index_args_create();
|
||||
@@ -128,12 +118,6 @@ int sqlite_index_args_validate(sqlite_index_args_t *args, int argc, const char *
|
||||
|
||||
int web_args_validate(web_args_t *args, int argc, const char **argv);
|
||||
|
||||
exec_args_t *exec_args_create();
|
||||
|
||||
void exec_args_destroy(exec_args_t *args);
|
||||
|
||||
int exec_args_validate(exec_args_t *args, int argc, const char **argv);
|
||||
|
||||
void sqlite_index_args_destroy(sqlite_index_args_t *args);
|
||||
|
||||
|
||||
|
||||
@@ -180,6 +180,10 @@ void database_open(database_t *db) {
|
||||
db->db, "SELECT * FROM model", -1,
|
||||
&db->get_models, NULL));
|
||||
|
||||
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
|
||||
db->db, "SELECT embedding FROM embedding WHERE id=? AND model_id=? AND start=0", -1,
|
||||
&db->get_embedding, NULL));
|
||||
|
||||
// Create functions
|
||||
sqlite3_create_function(
|
||||
db->db,
|
||||
@@ -194,11 +198,11 @@ void database_open(database_t *db) {
|
||||
|
||||
sqlite3_create_function(
|
||||
db->db,
|
||||
"embedding_to_json",
|
||||
5,
|
||||
"emb_to_json",
|
||||
1,
|
||||
SQLITE_UTF8,
|
||||
NULL,
|
||||
embedding_to_json_func,
|
||||
emb_to_json_func,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
@@ -494,29 +498,31 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
|
||||
|
||||
sqlite3_stmt *stmt;
|
||||
|
||||
sqlite3_prepare_v2(db->db, "WITH doc (j) AS (SELECT CASE"
|
||||
" WHEN sc.json_data IS NULL THEN"
|
||||
" CASE"
|
||||
" WHEN t.tag IS NULL THEN"
|
||||
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime)"
|
||||
" ELSE"
|
||||
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))"
|
||||
" END"
|
||||
" ELSE"
|
||||
" CASE"
|
||||
" WHEN t.tag IS NULL THEN"
|
||||
" json_patch(json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime), sc.json_data)"
|
||||
" ELSE"
|
||||
// This will overwrite any tags specified in the sidecar file!
|
||||
// TODO: concatenate the two arrays?
|
||||
" json_set(json_patch(document.json_data, sc.json_data), '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))"
|
||||
" END"
|
||||
" END"
|
||||
" FROM document"
|
||||
" LEFT JOIN document_sidecar sc ON document.id = sc.id"
|
||||
" LEFT JOIN tag t ON document.id = t.id"
|
||||
" GROUP BY document.id)"
|
||||
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc", -1, &stmt, NULL);
|
||||
CRASH_IF_NOT_SQLITE_OK(
|
||||
sqlite3_prepare_v2(
|
||||
db->db,
|
||||
"WITH doc (j) AS (SELECT CASE"
|
||||
" WHEN emb.embedding IS NULL THEN"
|
||||
" json_set(document.json_data, "
|
||||
" '$._id', document.id, "
|
||||
" '$.size', document.size, "
|
||||
" '$.mtime', document.mtime, "
|
||||
" '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)))"
|
||||
" ELSE"
|
||||
" json_set(document.json_data,"
|
||||
" '$._id', document.id,"
|
||||
" '$.size', document.size,"
|
||||
" '$.mtime', document.mtime,"
|
||||
" '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)),"
|
||||
" '$.emb', json_group_object(m.path, json(emb_to_json(emb.embedding))),"
|
||||
" '$.embedding', 1)"
|
||||
" END"
|
||||
" FROM document"
|
||||
" LEFT JOIN embedding emb ON document.id = emb.id"
|
||||
" LEFT JOIN model m ON emb.model_id = m.id"
|
||||
" GROUP BY document.id)"
|
||||
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc",
|
||||
-1, &stmt, NULL));
|
||||
|
||||
database_iterator_t *iter = malloc(sizeof(database_iterator_t));
|
||||
|
||||
@@ -526,6 +532,13 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
|
||||
return iter;
|
||||
}
|
||||
|
||||
void remove_tag_if_null(cJSON *doc) {
|
||||
cJSON *tags = cJSON_GetObjectItem(doc, "tag");
|
||||
if (tags != NULL && cJSON_IsNull(cJSON_GetArrayItem(tags, 0))) {
|
||||
cJSON_DeleteItemFromObject(doc, "tag");
|
||||
}
|
||||
}
|
||||
|
||||
cJSON *database_document_iter(database_iterator_t *iter) {
|
||||
|
||||
if (iter->stmt == NULL) {
|
||||
@@ -537,7 +550,12 @@ cJSON *database_document_iter(database_iterator_t *iter) {
|
||||
|
||||
if (ret == SQLITE_ROW) {
|
||||
const char *json_string = (const char *) sqlite3_column_text(iter->stmt, 0);
|
||||
return cJSON_Parse(json_string);
|
||||
|
||||
cJSON *doc = cJSON_Parse(json_string);
|
||||
|
||||
remove_tag_if_null(doc);
|
||||
|
||||
return doc;
|
||||
}
|
||||
|
||||
if (ret != SQLITE_DONE) {
|
||||
|
||||
@@ -85,6 +85,7 @@ typedef struct database {
|
||||
sqlite3_stmt *write_thumbnail_stmt;
|
||||
sqlite3_stmt *get_document;
|
||||
sqlite3_stmt *get_models;
|
||||
sqlite3_stmt *get_embedding;
|
||||
|
||||
sqlite3_stmt *delete_tag_stmt;
|
||||
sqlite3_stmt *write_tag_stmt;
|
||||
@@ -144,6 +145,8 @@ void database_write_document(database_t *db, document_t *doc, const char *json_d
|
||||
|
||||
database_iterator_t *database_create_document_iterator(database_t *db);
|
||||
|
||||
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
|
||||
|
||||
cJSON *database_document_iter(database_iterator_t *);
|
||||
|
||||
#define database_document_iter_foreach(element, iter) \
|
||||
@@ -235,8 +238,10 @@ cJSON *database_get_document(database_t *db, char *doc_id);
|
||||
|
||||
void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
|
||||
|
||||
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
|
||||
|
||||
cJSON *database_get_models(database_t *db);
|
||||
|
||||
int database_fts_get_model_size(database_t *db, int model_id);
|
||||
|
||||
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id);
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <openblas/cblas.h>
|
||||
#include "database.h"
|
||||
#include "src/ctx.h"
|
||||
|
||||
|
||||
static float cosine_sim(int n, const float *a, const float *b) {
|
||||
@@ -33,38 +34,6 @@ void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
sqlite3_result_double(ctx, result);
|
||||
}
|
||||
|
||||
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
|
||||
// emb, type, start, end, size
|
||||
|
||||
if (argc != 5) {
|
||||
sqlite3_result_error(ctx, "Invalid parameters", -1);
|
||||
}
|
||||
|
||||
const float *embedding = sqlite3_value_blob(argv[0]);
|
||||
const char *type = (const char *) sqlite3_value_text(argv[1]);
|
||||
|
||||
int size = sqlite3_value_int(argv[4]);
|
||||
|
||||
if (strcmp(type, "flat") == 0) {
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, size);
|
||||
|
||||
char *json_str = cJSON_PrintBuffered(json, size * 22, FALSE);
|
||||
|
||||
cJSON_Delete(json);
|
||||
|
||||
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
|
||||
free(json_str);
|
||||
|
||||
} else {
|
||||
int start = sqlite3_value_int(argv[2]);
|
||||
int end = sqlite3_value_int(argv[3]);
|
||||
|
||||
sqlite3_result_error(ctx, "Nested embeddings not implemented yet", -1);
|
||||
}
|
||||
}
|
||||
|
||||
cJSON *database_get_models(database_t *db) {
|
||||
cJSON *json = cJSON_CreateArray();
|
||||
sqlite3_stmt *stmt = db->get_models;
|
||||
@@ -72,7 +41,12 @@ cJSON *database_get_models(database_t *db) {
|
||||
int ret;
|
||||
do {
|
||||
ret = sqlite3_step(stmt);
|
||||
CRASH_IF_STMT_FAIL(ret);
|
||||
if (ret == SQLITE_BUSY) {
|
||||
// Database is busy (probably scanning)
|
||||
LOG_WARNING("database_embeddings.c",
|
||||
"Database is busy, could not fetch list of models");
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret == SQLITE_DONE) {
|
||||
break;
|
||||
@@ -82,7 +56,7 @@ cJSON *database_get_models(database_t *db) {
|
||||
|
||||
cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0));
|
||||
cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1));
|
||||
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_int64(stmt, 2));
|
||||
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_text(stmt, 2));
|
||||
cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3));
|
||||
cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4));
|
||||
cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5));
|
||||
@@ -90,5 +64,44 @@ cJSON *database_get_models(database_t *db) {
|
||||
cJSON_AddItemToArray(json, row);
|
||||
} while (TRUE);
|
||||
|
||||
sqlite3_reset(stmt);
|
||||
|
||||
return json;
|
||||
}
|
||||
|
||||
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id) {
|
||||
|
||||
sqlite3_bind_text(db->get_embedding, 1, doc_id, -1, SQLITE_STATIC);
|
||||
sqlite3_bind_int(db->get_embedding, 2, model_id);
|
||||
int ret = sqlite3_step(db->get_embedding);
|
||||
CRASH_IF_STMT_FAIL(ret);
|
||||
|
||||
if (ret == SQLITE_DONE) {
|
||||
sqlite3_reset(db->get_embedding);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
float *embedding = (float *) sqlite3_column_blob(db->get_embedding, 0);
|
||||
size_t size = sqlite3_column_bytes(db->get_embedding, 0) / sizeof(float);
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, (int) size);
|
||||
sqlite3_reset(db->get_embedding);
|
||||
|
||||
return json;
|
||||
}
|
||||
|
||||
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
if (argc != 1) {
|
||||
sqlite3_result_error(ctx, "Invalid parameters", -1);
|
||||
}
|
||||
|
||||
float *embedding = (float *) sqlite3_value_blob(argv[0]);
|
||||
int size = sqlite3_value_bytes(argv[0]) / 4;
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, size);
|
||||
char *json_str = cJSON_PrintUnformatted(json);
|
||||
|
||||
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
|
||||
free(json_str);
|
||||
cJSON_Delete(json);
|
||||
}
|
||||
|
||||
@@ -67,6 +67,11 @@ void database_fts_index(database_t *db) {
|
||||
"REPLACE INTO fts.embedding (id, model_id, start, end, embedding)"
|
||||
" SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL));
|
||||
|
||||
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
|
||||
db->db,
|
||||
"INSERT INTO fts.model (id, size)"
|
||||
" SELECT id, size FROM model WHERE TRUE ON CONFLICT (id) DO NOTHING", NULL, NULL, NULL));
|
||||
|
||||
// TODO: delete old embeddings
|
||||
|
||||
LOG_DEBUG("database_fts.c", "Deleting old documents");
|
||||
@@ -518,20 +523,15 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
|
||||
|
||||
const char *json_object_sql;
|
||||
if (highlight && query_where != NULL) {
|
||||
json_object_sql = "json_remove(json_set(doc.json_data,"
|
||||
json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
|
||||
"'$.index', doc.index_id,"
|
||||
"'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END),"
|
||||
"'$._highlight.name', snippet(search, 0, '<mark>', '</mark>', '', ?6),"
|
||||
"'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6)),"
|
||||
"'$.content')";
|
||||
"'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6))";
|
||||
} else {
|
||||
json_object_sql = "json_remove(json_set(doc.json_data,"
|
||||
"'$.index', doc.index_id),"
|
||||
"'$.content')";
|
||||
}
|
||||
|
||||
const char *embedding_join = "";
|
||||
if (embedding) {
|
||||
embedding_join = "LEFT JOIN embedding emb ON emb.id = doc.id AND emb.model_id=?9";
|
||||
json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
|
||||
"'$.index', doc.index_id,"
|
||||
"'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END))";
|
||||
}
|
||||
|
||||
char *sql;
|
||||
@@ -544,12 +544,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
|
||||
" %s, %s as sort_var, doc.ROWID"
|
||||
" FROM search"
|
||||
" INNER JOIN document_index doc on doc.ROWID = search.ROWID"
|
||||
" %s"
|
||||
" LEFT JOIN embedding emb on emb.id = doc.id"
|
||||
" WHERE %s"
|
||||
" ORDER BY sort_var%s, doc.ROWID"
|
||||
" LIMIT ?2",
|
||||
json_object_sql, get_sort_var(sort),
|
||||
embedding_join,
|
||||
where,
|
||||
sort_asc ? "" : " DESC");
|
||||
|
||||
@@ -567,12 +566,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
|
||||
"SELECT"
|
||||
" %s, %s as sort_var, doc.ROWID"
|
||||
" FROM document_index doc"
|
||||
" %s"
|
||||
" LEFT JOIN embedding emb on emb.id = doc.id"
|
||||
" WHERE %s"
|
||||
" ORDER BY sort_var%s,doc.ROWID"
|
||||
" LIMIT ?2",
|
||||
json_object_sql, get_sort_var(sort),
|
||||
embedding_join,
|
||||
where,
|
||||
sort_asc ? "" : " DESC");
|
||||
|
||||
@@ -624,7 +622,7 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
|
||||
if (after_where) {
|
||||
if (sort == FTS_SORT_NAME || sort == FTS_SORT_ID) {
|
||||
sqlite3_bind_text(stmt, 3, after[0], -1, SQLITE_STATIC);
|
||||
} else if (sort == FTS_SORT_SCORE) {
|
||||
} else if (sort == FTS_SORT_SCORE || sort == FTS_SORT_EMBEDDING) {
|
||||
sqlite3_bind_double(stmt, 3, strtod(after[0], NULL));
|
||||
} else {
|
||||
sqlite3_bind_int64(stmt, 3, strtol(after[0], NULL, 10));
|
||||
|
||||
@@ -49,12 +49,8 @@ const char *FtsDatabaseSchema =
|
||||
");"
|
||||
""
|
||||
"CREATE TABLE IF NOT EXISTS model ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
|
||||
" url TEXT,"
|
||||
" path TEXT NOT NULL UNIQUE,"
|
||||
" size INTEGER NOT NULL,"
|
||||
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
|
||||
" id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
|
||||
" size INTEGER NOT NULL"
|
||||
");"
|
||||
""
|
||||
"CREATE TRIGGER IF NOT EXISTS tag_write_trigger"
|
||||
@@ -183,5 +179,14 @@ const char *IndexDatabaseSchema =
|
||||
" end INTEGER,"
|
||||
" embedding BLOB NOT NULL,"
|
||||
" PRIMARY KEY (id, model_id, start)"
|
||||
");"
|
||||
""
|
||||
"CREATE TABLE model ("
|
||||
" id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
|
||||
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
|
||||
" url TEXT,"
|
||||
" path TEXT NOT NULL UNIQUE,"
|
||||
" size INTEGER NOT NULL,"
|
||||
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
|
||||
");";
|
||||
|
||||
|
||||
@@ -98,61 +98,6 @@ void index_json(cJSON *document, const char doc_id[SIST_DOC_ID_LEN]) {
|
||||
free(bulk_line);
|
||||
}
|
||||
|
||||
void execute_update_script(const char *script, int async, const char index_id[SIST_INDEX_ID_LEN]) {
|
||||
|
||||
if (Indexer == NULL) {
|
||||
Indexer = create_indexer(IndexCtx.es_url, IndexCtx.es_index);
|
||||
}
|
||||
|
||||
cJSON *body = cJSON_CreateObject();
|
||||
cJSON *script_obj = cJSON_AddObjectToObject(body, "script");
|
||||
cJSON_AddStringToObject(script_obj, "lang", "painless");
|
||||
cJSON_AddStringToObject(script_obj, "source", script);
|
||||
|
||||
cJSON *query = cJSON_AddObjectToObject(body, "query");
|
||||
cJSON *term_obj = cJSON_AddObjectToObject(query, "term");
|
||||
cJSON_AddStringToObject(term_obj, "index", index_id);
|
||||
|
||||
char *str = cJSON_PrintUnformatted(body);
|
||||
|
||||
char url[4096];
|
||||
if (async) {
|
||||
snprintf(url, sizeof(url), "%s/%s/_update_by_query?wait_for_completion=false", Indexer->es_url,
|
||||
Indexer->es_index);
|
||||
} else {
|
||||
snprintf(url, sizeof(url), "%s/%s/_update_by_query", Indexer->es_url, Indexer->es_index);
|
||||
}
|
||||
response_t *r = web_post(url, str, IndexCtx.es_insecure_ssl);
|
||||
if (!async) {
|
||||
LOG_INFOF("elastic.c", "Executed user script <%d>", r->status_code);
|
||||
}
|
||||
cJSON *resp = cJSON_Parse(r->body);
|
||||
|
||||
cJSON_free(str);
|
||||
cJSON_Delete(body);
|
||||
free_response(r);
|
||||
|
||||
cJSON *error = cJSON_GetObjectItem(resp, "error");
|
||||
if (error != NULL) {
|
||||
char *error_str = cJSON_Print(error);
|
||||
|
||||
LOG_ERRORF("elastic.c", "User script error: \n%s", error_str);
|
||||
cJSON_free(error_str);
|
||||
}
|
||||
|
||||
if (async) {
|
||||
cJSON *task = cJSON_GetObjectItem(resp, "task");
|
||||
|
||||
if (task == NULL) {
|
||||
LOG_FATALF("elastic.c", "FIXME: Could not get task id: %s", r->body);
|
||||
}
|
||||
|
||||
LOG_INFOF("elastic.c", "User script queued: %s/_tasks/%s", Indexer->es_url, task->valuestring);
|
||||
}
|
||||
|
||||
cJSON_Delete(resp);
|
||||
}
|
||||
|
||||
void *create_bulk_buffer(int max, int *count, size_t *buf_len, int legacy) {
|
||||
es_bulk_line_t *line = Indexer->line_head;
|
||||
*count = 0;
|
||||
@@ -403,7 +348,7 @@ es_indexer_t *create_indexer(const char *url, const char *index) {
|
||||
return indexer;
|
||||
}
|
||||
|
||||
void finish_indexer(char *script, int async_script, char *index_id) {
|
||||
void finish_indexer(char *index_id) {
|
||||
|
||||
char url[4096];
|
||||
|
||||
@@ -412,16 +357,6 @@ void finish_indexer(char *script, int async_script, char *index_id) {
|
||||
LOG_INFOF("elastic.c", "Refresh index <%d>", r->status_code);
|
||||
free_response(r);
|
||||
|
||||
if (script != NULL) {
|
||||
execute_update_script(script, async_script, index_id);
|
||||
free(script);
|
||||
|
||||
snprintf(url, sizeof(url), "%s/%s/_refresh", IndexCtx.es_url, IndexCtx.es_index);
|
||||
r = web_post(url, "", IndexCtx.es_insecure_ssl);
|
||||
LOG_INFOF("elastic.c", "Refresh index <%d>", r->status_code);
|
||||
free_response(r);
|
||||
}
|
||||
|
||||
snprintf(url, sizeof(url), "%s/%s/_forcemerge", IndexCtx.es_url, IndexCtx.es_index);
|
||||
r = web_post(url, "", IndexCtx.es_insecure_ssl);
|
||||
LOG_INFOF("elastic.c", "Merge index <%d>", r->status_code);
|
||||
|
||||
@@ -24,6 +24,8 @@ typedef struct {
|
||||
|
||||
#define IS_SUPPORTED_ES_VERSION(es_version) ((es_version) != NULL && VERSION_GE((es_version), 6, 8) && VERSION_LT((es_version), 9, 0))
|
||||
#define IS_LEGACY_VERSION(es_version) ((es_version) != NULL && VERSION_LT((es_version), 7, 14))
|
||||
#define HAS_KNN(es_version) ((es_version) != NULL && VERSION_GE((es_version), 8, 0))
|
||||
|
||||
|
||||
__always_inline
|
||||
static const char *format_es_version(es_version_t *version) {
|
||||
@@ -51,7 +53,7 @@ void delete_document(const char *document_id);
|
||||
es_indexer_t *create_indexer(const char *url, const char *index);
|
||||
|
||||
void elastic_cleanup();
|
||||
void finish_indexer(char *script, int async_script, char *index_id);
|
||||
void finish_indexer(char *index_id);
|
||||
|
||||
void elastic_init(int force_reset, const char* user_mappings, const char* user_settings);
|
||||
|
||||
@@ -61,6 +63,4 @@ char *elastic_get_status();
|
||||
|
||||
es_version_t *elastic_get_version(const char *es_url, int insecure);
|
||||
|
||||
void execute_update_script(const char *script, int async, const char index_id[SIST_INDEX_ID_LEN]);
|
||||
|
||||
#endif
|
||||
|
||||
54
src/main.c
54
src/main.c
@@ -24,7 +24,6 @@ static const char *const usage[] = {
|
||||
"sist2 index [OPTION]... INDEX",
|
||||
"sist2 sqlite-index [OPTION]... INDEX",
|
||||
"sist2 web [OPTION]... INDEX...",
|
||||
"sist2 exec-script [OPTION]... INDEX",
|
||||
NULL,
|
||||
};
|
||||
|
||||
@@ -349,7 +348,7 @@ void sist2_index(index_args_t *args) {
|
||||
tpool_destroy(IndexCtx.pool);
|
||||
|
||||
if (IndexCtx.needs_es_connection) {
|
||||
finish_indexer(args->script, args->async_script, desc->id);
|
||||
finish_indexer(desc->id);
|
||||
}
|
||||
free(desc);
|
||||
}
|
||||
@@ -370,25 +369,6 @@ void sist2_sqlite_index(sqlite_index_args_t *args) {
|
||||
database_close(search_db, FALSE);
|
||||
}
|
||||
|
||||
void sist2_exec_script(exec_args_t *args) {
|
||||
LogCtx.verbose = TRUE;
|
||||
|
||||
IndexCtx.es_url = args->es_url;
|
||||
IndexCtx.es_index = args->es_index;
|
||||
IndexCtx.es_insecure_ssl = args->es_insecure_ssl;
|
||||
IndexCtx.needs_es_connection = TRUE;
|
||||
|
||||
database_t *db = database_create(args->index_path, INDEX_DATABASE);
|
||||
database_open(db);
|
||||
|
||||
index_descriptor_t *desc = database_read_index_descriptor(db);
|
||||
LOG_DEBUGF("main.c", "Index version %s", desc->version);
|
||||
|
||||
execute_update_script(args->script, args->async_script, desc->id);
|
||||
free(args->script);
|
||||
database_close(db, FALSE);
|
||||
}
|
||||
|
||||
void sist2_web(web_args_t *args) {
|
||||
|
||||
WebCtx.es_url = args->es_url;
|
||||
@@ -464,15 +444,9 @@ int set_to_negative_if_value_is_zero(UNUSED(struct argparse *self), const struct
|
||||
int main(int argc, const char *argv[]) {
|
||||
setlocale(LC_ALL, "");
|
||||
|
||||
// database_t *db = database_create("clip.sist2", INDEX_DATABASE);
|
||||
// database_open(db);
|
||||
// database_test(db);
|
||||
// exit(0);
|
||||
|
||||
scan_args_t *scan_args = scan_args_create();
|
||||
index_args_t *index_args = index_args_create();
|
||||
web_args_t *web_args = web_args_create();
|
||||
exec_args_t *exec_args = exec_args_create();
|
||||
sqlite_index_args_t *sqlite_index_args = sqlite_index_args_create();
|
||||
|
||||
int arg_version = 0;
|
||||
@@ -481,7 +455,6 @@ int main(int argc, const char *argv[]) {
|
||||
int common_es_insecure_ssl = 0;
|
||||
char *common_es_index = NULL;
|
||||
char *common_script_path = NULL;
|
||||
int common_async_script = 0;
|
||||
int common_threads = 0;
|
||||
int common_optimize_database = 0;
|
||||
char *common_search_index = NULL;
|
||||
@@ -556,7 +529,6 @@ int main(int argc, const char *argv[]) {
|
||||
OPT_STRING(0, "script-file", &common_script_path, "Path to user script."),
|
||||
OPT_STRING(0, "mappings-file", &index_args->es_mappings_path, "Path to Elasticsearch mappings."),
|
||||
OPT_STRING(0, "settings-file", &index_args->es_settings_path, "Path to Elasticsearch settings."),
|
||||
OPT_BOOLEAN(0, "async-script", &common_async_script, "Execute user script asynchronously."),
|
||||
OPT_INTEGER(0, "batch-size", &index_args->batch_size, "Index batch size. DEFAULT: 70"),
|
||||
OPT_BOOLEAN('f', "force-reset", &index_args->force_reset, "Reset Elasticsearch mappings and settings."),
|
||||
|
||||
@@ -567,7 +539,6 @@ int main(int argc, const char *argv[]) {
|
||||
OPT_STRING(0, "es-url", &common_es_url, "Elasticsearch url. DEFAULT: http://localhost:9200"),
|
||||
OPT_BOOLEAN(0, "es-insecure-ssl", &common_es_insecure_ssl,
|
||||
"Do not verify SSL connections to Elasticsearch."),
|
||||
// TODO: change arg name (?)
|
||||
OPT_STRING(0, "search-index", &common_search_index, "Path to SQLite search index."),
|
||||
OPT_STRING(0, "es-index", &common_es_index, "Elasticsearch index name. DEFAULT: sist2"),
|
||||
OPT_STRING(0, "bind", &web_args->listen_address,
|
||||
@@ -583,14 +554,6 @@ int main(int argc, const char *argv[]) {
|
||||
OPT_BOOLEAN(0, "dev", &web_args->dev, "Serve html & js files from disk (for development)"),
|
||||
OPT_STRING(0, "lang", &web_args->lang, "Default UI language. Can be changed by the user"),
|
||||
|
||||
OPT_GROUP("Exec-script options"),
|
||||
OPT_STRING(0, "es-url", &common_es_url, "Elasticsearch url. DEFAULT: http://localhost:9200"),
|
||||
OPT_BOOLEAN(0, "es-insecure-ssl", &common_es_insecure_ssl,
|
||||
"Do not verify SSL connections to Elasticsearch."),
|
||||
OPT_STRING(0, "es-index", &common_es_index, "Elasticsearch index name. DEFAULT: sist2"),
|
||||
OPT_STRING(0, "script-file", &common_script_path, "Path to user script."),
|
||||
OPT_BOOLEAN(0, "async-script", &common_async_script, "Execute user script asynchronously."),
|
||||
|
||||
OPT_END(),
|
||||
};
|
||||
|
||||
@@ -614,22 +577,16 @@ int main(int argc, const char *argv[]) {
|
||||
|
||||
web_args->es_url = common_es_url;
|
||||
index_args->es_url = common_es_url;
|
||||
exec_args->es_url = common_es_url;
|
||||
|
||||
web_args->es_index = common_es_index;
|
||||
index_args->es_index = common_es_index;
|
||||
exec_args->es_index = common_es_index;
|
||||
|
||||
web_args->es_insecure_ssl = common_es_insecure_ssl;
|
||||
index_args->es_insecure_ssl = common_es_insecure_ssl;
|
||||
exec_args->es_insecure_ssl = common_es_insecure_ssl;
|
||||
|
||||
index_args->script_path = common_script_path;
|
||||
exec_args->script_path = common_script_path;
|
||||
index_args->threads = common_threads;
|
||||
scan_args->threads = common_threads;
|
||||
exec_args->async_script = common_async_script;
|
||||
index_args->async_script = common_async_script;
|
||||
|
||||
scan_args->optimize_database = common_optimize_database;
|
||||
|
||||
@@ -671,14 +628,6 @@ int main(int argc, const char *argv[]) {
|
||||
}
|
||||
sist2_web(web_args);
|
||||
|
||||
} else if (strcmp(argv[0], "exec-script") == 0) {
|
||||
|
||||
int err = exec_args_validate(exec_args, argc, argv);
|
||||
if (err != 0) {
|
||||
goto end;
|
||||
}
|
||||
sist2_exec_script(exec_args);
|
||||
|
||||
} else {
|
||||
argparse_usage(&argparse);
|
||||
LOG_FATALF("main.c", "Invalid command: '%s'\n", argv[0]);
|
||||
@@ -689,7 +638,6 @@ int main(int argc, const char *argv[]) {
|
||||
scan_args_destroy(scan_args);
|
||||
index_args_destroy(index_args);
|
||||
web_args_destroy(web_args);
|
||||
exec_args_destroy(exec_args);
|
||||
sqlite_index_args_destroy(sqlite_index_args);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -16,6 +16,7 @@ typedef struct {
|
||||
|
||||
typedef struct tpool {
|
||||
pthread_t threads[256];
|
||||
void *start_thread_args[256];
|
||||
int num_threads;
|
||||
|
||||
int print_progress;
|
||||
@@ -293,6 +294,8 @@ void tpool_destroy(tpool_t *pool) {
|
||||
void *_;
|
||||
pthread_join(thread, &_);
|
||||
}
|
||||
|
||||
free(pool->start_thread_args[i]);
|
||||
}
|
||||
|
||||
pthread_mutex_destroy(&pool->shm->ipc_ctx.mutex);
|
||||
@@ -320,6 +323,7 @@ tpool_t *tpool_create(int thread_cnt, int print_progress) {
|
||||
pool->shm->waiting = FALSE;
|
||||
pool->shm->job_type = JOB_UNDEFINED;
|
||||
memset(pool->threads, 0, sizeof(pool->threads));
|
||||
memset(pool->start_thread_args, 0, sizeof(pool->start_thread_args));
|
||||
pool->print_progress = print_progress;
|
||||
sprintf(pool->shm->ipc_database_filepath, "/dev/shm/sist2-ipc-%d.sqlite", getpid());
|
||||
|
||||
@@ -361,6 +365,7 @@ void tpool_start(tpool_t *pool) {
|
||||
arg->pool = pool;
|
||||
|
||||
pthread_create(&pool->threads[i], NULL, tpool_worker, arg);
|
||||
pool->start_thread_args[i] = arg;
|
||||
}
|
||||
|
||||
// Only open the database when all workers are done initializing
|
||||
|
||||
@@ -36,9 +36,52 @@ static struct mg_http_serve_opts IndexServeOpts = {
|
||||
.ssi_pattern = NULL,
|
||||
.root_dir = NULL,
|
||||
.mime_types = "",
|
||||
.extra_headers = HTTP_SERVER_HEADER "Cross-Origin-Embedder-Policy: require-corp\r\nCross-Origin-Opener-Policy: same-origin\r\n"
|
||||
.extra_headers = HTTP_SERVER_HEADER HTTP_CROSS_ORIGIN_HEADERS
|
||||
};
|
||||
|
||||
void get_embedding(struct mg_connection *nc, struct mg_http_message *hm) {
|
||||
|
||||
if (WebCtx.search_backend == ES_SEARCH_BACKEND && WebCtx.es_version != NULL && !HAS_KNN(WebCtx.es_version)) {
|
||||
LOG_WARNINGF("serve.c",
|
||||
"Your Elasticsearch version (%d.%d.%d) does not support approximate kNN search and will"
|
||||
" fallback to a brute-force search. Please install ES 8.x.x+ for better search performance.",
|
||||
WebCtx.es_version->major, WebCtx.es_version->minor, WebCtx.es_version->patch);
|
||||
}
|
||||
|
||||
if (hm->uri.len != SIST_INDEX_ID_LEN + SIST_DOC_ID_LEN + 2 + 4) {
|
||||
LOG_DEBUGF("serve.c", "Invalid thumbnail path: %.*s", (int) hm->uri.len, hm->uri.ptr);
|
||||
HTTP_REPLY_NOT_FOUND
|
||||
return;
|
||||
}
|
||||
|
||||
char doc_id[SIST_DOC_ID_LEN];
|
||||
char index_id[SIST_INDEX_ID_LEN];
|
||||
|
||||
memcpy(index_id, hm->uri.ptr + 3, SIST_INDEX_ID_LEN);
|
||||
*(index_id + SIST_INDEX_ID_LEN - 1) = '\0';
|
||||
memcpy(doc_id, hm->uri.ptr + 3 + SIST_INDEX_ID_LEN, SIST_DOC_ID_LEN);
|
||||
*(doc_id + SIST_DOC_ID_LEN - 1) = '\0';
|
||||
|
||||
int model_id = (int) strtol(hm->uri.ptr + SIST_INDEX_ID_LEN + SIST_DOC_ID_LEN + 3, NULL, 10);
|
||||
|
||||
database_t *db = web_get_database(index_id);
|
||||
if (db == NULL) {
|
||||
LOG_DEBUGF("serve.c", "Could not get database for index: %s", index_id);
|
||||
HTTP_REPLY_NOT_FOUND
|
||||
return;
|
||||
}
|
||||
|
||||
cJSON *json = database_get_embedding(db, doc_id, model_id);
|
||||
|
||||
if (json == NULL) {
|
||||
HTTP_REPLY_NOT_FOUND
|
||||
return;
|
||||
}
|
||||
|
||||
mg_send_json(nc, json);
|
||||
cJSON_Delete(json);
|
||||
}
|
||||
|
||||
void stats_files(struct mg_connection *nc, struct mg_http_message *hm) {
|
||||
|
||||
if (hm->uri.len != SIST_INDEX_ID_LEN + 7) {
|
||||
@@ -316,6 +359,7 @@ void index_info(struct mg_connection *nc) {
|
||||
|
||||
cJSON_AddBoolToObject(json, "esVersionSupported", IS_SUPPORTED_ES_VERSION(WebCtx.es_version));
|
||||
cJSON_AddBoolToObject(json, "esVersionLegacy", IS_LEGACY_VERSION(WebCtx.es_version));
|
||||
cJSON_AddBoolToObject(json, "esVersionHasKnn", HAS_KNN(WebCtx.es_version));
|
||||
cJSON_AddStringToObject(json, "lang", WebCtx.lang);
|
||||
|
||||
cJSON_AddBoolToObject(json, "auth0Enabled", WebCtx.auth0_enabled);
|
||||
@@ -708,6 +752,9 @@ static void ev_router(struct mg_connection *nc, int ev, void *ev_data, UNUSED(vo
|
||||
return;
|
||||
}
|
||||
tag(nc, hm);
|
||||
} else if (mg_http_match_uri(hm, "/e/*/*/*")) {
|
||||
get_embedding(nc, hm);
|
||||
return;
|
||||
} else {
|
||||
HTTP_REPLY_NOT_FOUND
|
||||
}
|
||||
|
||||
@@ -262,8 +262,8 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
|
||||
: DEFAULT_HIGHLIGHT_CONTEXT_SIZE;
|
||||
req->model = req_model.val ? req_model.val->valueint : 0;
|
||||
req->embedding = req_model.val
|
||||
? get_float_buffer(req_embedding.val, &req->embedding_size)
|
||||
: NULL;
|
||||
? get_float_buffer(req_embedding.val, &req->embedding_size)
|
||||
: NULL;
|
||||
|
||||
cJSON_Delete(json);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
void web_serve_asset_index_html(struct mg_connection *nc) {
|
||||
web_send_headers(nc, 200, sizeof(index_html), "Content-Type: text/html");
|
||||
web_send_headers(nc, 200, sizeof(index_html), HTTP_CROSS_ORIGIN_HEADERS "Content-Type: text/html");
|
||||
mg_send(nc, index_html, sizeof(index_html));
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <mongoose.h>
|
||||
|
||||
#define HTTP_SERVER_HEADER "Server: sist2/" VERSION "\r\n"
|
||||
// See https://web.dev/coop-coep/
|
||||
#define HTTP_CROSS_ORIGIN_HEADERS "Cross-Origin-Embedder-Policy: require-corp\r\nCross-Origin-Opener-Policy: same-origin\r\n"
|
||||
|
||||
index_t *web_get_index_by_id(const char *index_id);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user