This commit is contained in:
2023-07-24 19:36:20 -04:00
parent f56cfb0f2f
commit 27188b6fa0
29 changed files with 1008 additions and 75 deletions

View File

@@ -163,7 +163,8 @@ void database_open(database_t *db) {
&db->write_document_sidecar_stmt, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db,
"REPLACE INTO document (id, mtime, size, json_data, version) VALUES (?, ?, ?, ?, (SELECT max(id) FROM version));", -1,
"REPLACE INTO document (id, mtime, size, json_data, version) VALUES (?, ?, ?, ?, (SELECT max(id) FROM version));",
-1,
&db->write_document_stmt, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db,
@@ -175,6 +176,10 @@ void database_open(database_t *db) {
db->db, "SELECT json_data FROM document WHERE id=?", -1,
&db->get_document, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db, "SELECT * FROM model", -1,
&db->get_models, NULL));
// Create functions
sqlite3_create_function(
db->db,
@@ -186,6 +191,17 @@ void database_open(database_t *db) {
NULL,
NULL
);
sqlite3_create_function(
db->db,
"embedding_to_json",
5,
SQLITE_UTF8,
NULL,
embedding_to_json_func,
NULL,
NULL
);
} else if (db->type == IPC_CONSUMER_DATABASE) {
sqlite3_create_function(
@@ -248,6 +264,10 @@ void database_open(database_t *db) {
db->db, "SELECT tag, count(*) FROM tag GROUP BY tag", -1,
&db->fts_get_tags, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db, "SELECT size FROM model WHERE id=?", -1,
&db->fts_model_size, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db, "SELECT path, count FROM path_index"
" WHERE (index_id=?1 OR ?1 IS NULL) AND depth BETWEEN ? AND ?"
@@ -302,6 +322,17 @@ void database_open(database_t *db) {
NULL,
NULL
);
sqlite3_create_function(
db->db,
"cosine_sim",
3,
SQLITE_UTF8,
NULL,
cosine_sim_func,
NULL,
NULL
);
}
if (db->type == FTS_DATABASE || db->type == INDEX_DATABASE) {
@@ -463,8 +494,6 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
sqlite3_stmt *stmt;
// TODO optimization: remove mtime, size, _id from json_data
sqlite3_prepare_v2(db->db, "WITH doc (j) AS (SELECT CASE"
" WHEN sc.json_data IS NULL THEN"
" CASE"
@@ -800,4 +829,4 @@ cJSON *database_get_document(database_t *db, char *doc_id) {
void database_increment_version(database_t *db) {
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db, "INSERT INTO version DEFAULT VALUES", NULL, NULL, NULL));
}
}

View File

@@ -41,6 +41,7 @@ typedef enum {
FTS_SORT_RANDOM,
FTS_SORT_NAME,
FTS_SORT_ID,
FTS_SORT_EMBEDDING
} fts_sort_t;
typedef struct {
@@ -83,6 +84,7 @@ typedef struct database {
sqlite3_stmt *write_document_sidecar_stmt;
sqlite3_stmt *write_thumbnail_stmt;
sqlite3_stmt *get_document;
sqlite3_stmt *get_models;
sqlite3_stmt *delete_tag_stmt;
sqlite3_stmt *write_tag_stmt;
@@ -100,6 +102,8 @@ typedef struct database {
sqlite3_stmt *fts_get_document;
sqlite3_stmt *fts_suggest_tag;
sqlite3_stmt *fts_get_tags;
sqlite3_stmt *fts_model_size;
char **tag_array;
@@ -210,7 +214,8 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
long size_max, long date_min, long date_max, int page_size,
char **index_ids, char **mime_types, char **tags, int sort_asc,
fts_sort_t sort, int seed, char **after, int fetch_aggregations,
int highlight, int highlight_context_size);
int highlight, int highlight_context_size, int model,
const float *embedding, int embedding_size);
void database_write_tag(database_t *db, char *doc_id, char *tag);
@@ -228,4 +233,10 @@ cJSON *database_fts_get_tags(database_t *db);
cJSON *database_get_document(database_t *db, char *doc_id);
void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
cJSON *database_get_models(database_t *db);
#endif

View File

@@ -0,0 +1,94 @@
#include <openblas/cblas.h>
#include "database.h"
static float cosine_sim(int n, const float *a, const float *b) {
float dot_product = cblas_sdot(n, a, 1, b, 1);
float norm_a = cblas_snrm2(n, a, 1);
float norm_b = cblas_snrm2(n, b, 1);
return dot_product / (norm_a * norm_b);
}
void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
if (argc != 3) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
int n = sqlite3_value_int(argv[0]);
const float *a = sqlite3_value_blob(argv[1]);
const float *b = sqlite3_value_blob(argv[2]);
if (a == NULL || b == NULL) {
sqlite3_result_double(ctx, -1);
return;
}
float result = cosine_sim(n, a, b);
if (result != result) {
result = -1;
}
sqlite3_result_double(ctx, result);
}
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
// emb, type, start, end, size
if (argc != 5) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
const float *embedding = sqlite3_value_blob(argv[0]);
const char *type = (const char *) sqlite3_value_text(argv[1]);
int size = sqlite3_value_int(argv[4]);
if (strcmp(type, "flat") == 0) {
cJSON *json = cJSON_CreateFloatArray(embedding, size);
char *json_str = cJSON_PrintBuffered(json, size * 22, FALSE);
cJSON_Delete(json);
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
free(json_str);
} else {
int start = sqlite3_value_int(argv[2]);
int end = sqlite3_value_int(argv[3]);
sqlite3_result_error(ctx, "Nested embeddings not implemented yet", -1);
}
}
cJSON *database_get_models(database_t *db) {
cJSON *json = cJSON_CreateArray();
sqlite3_stmt *stmt = db->get_models;
int ret;
do {
ret = sqlite3_step(stmt);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_DONE) {
break;
}
cJSON *row = cJSON_CreateObject();
cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0));
cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1));
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_int64(stmt, 2));
cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3));
cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4));
cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5));
cJSON_AddItemToArray(json, row);
} while (TRUE);
return json;
}

View File

@@ -37,7 +37,7 @@ int database_fts_get_max_path_depth(database_t *db) {
void database_fts_index(database_t *db) {
LOG_INFO("database_fts.c", "Creating content table.");
LOG_INFO("database_fts.c", "Creating content table");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -47,21 +47,12 @@ void database_fts_index(database_t *db) {
" document.json_data ->> 'path' as path,"
" mtime,"
" document.json_data ->> 'mime' as mime,"
" CASE"
" WHEN sc.json_data IS NULL THEN"
" json_set(document.json_data, "
" '$._id',document.id,"
" '$.size',document.size, "
" '$.mtime',document.mtime)"
" ELSE json_patch("
" json_set(document.json_data,"
" '$._id',document.id,"
" '$.size',document.size,"
" '$.mtime', document.mtime),"
" sc.json_data) END"
" FROM document"
" LEFT JOIN document_sidecar sc ON document.id = sc.id"
" GROUP BY document.id)"
" )"
" INSERT"
" INTO fts.document_index (id, index_id, size, name, path, mtime, mime, json_data)"
" SELECT * FROM docs WHERE true"
@@ -69,7 +60,16 @@ void database_fts_index(database_t *db) {
" size=excluded.size, mtime=excluded.mtime, mime=excluded.mime, json_data=excluded.json_data;",
NULL, NULL, NULL));
LOG_DEBUG("database_fts.c", "Deleting old documents.");
LOG_DEBUG("database_fts.c", "Copying embeddings");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
"REPLACE INTO fts.embedding (id, model_id, start, end, embedding)"
" SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL));
// TODO: delete old embeddings
LOG_DEBUG("database_fts.c", "Deleting old documents");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -144,7 +144,7 @@ void database_fts_index(database_t *db) {
"INSERT INTO path_index (path, index_id, count, depth) SELECT path, index_id, total, depth FROM path_tmp",
NULL, NULL, NULL));
LOG_DEBUG("database_fts.c", "Generating search index.");
LOG_DEBUG("database_fts.c", "Generating search index");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db, "INSERT INTO search(search) VALUES ('delete-all')",
@@ -157,7 +157,7 @@ void database_fts_index(database_t *db) {
}
void database_fts_optimize(database_t *db) {
LOG_INFO("database_fts.c", "Optimizing search index.");
LOG_INFO("database_fts.c", "Optimizing search index");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -408,6 +408,8 @@ const char *get_sort_var(fts_sort_t sort) {
return "doc.name";
case FTS_SORT_ID:
return "doc.id";
case FTS_SORT_EMBEDDING:
return "cosine_sim(?7, ?8, emb.embedding)";
default:
return NULL;
}
@@ -459,11 +461,36 @@ char *get_after_where(char **after, fts_sort_t sort, int sort_asc) {
return "(sort_var, doc.ROWID) < (?3, ?4)";
}
int database_fts_get_model_size(database_t *db, int model_id) {
sqlite3_bind_int(db->fts_model_size, 1, model_id);
int ret = sqlite3_step(db->fts_model_size);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_DONE) {
return -1;
}
int size = sqlite3_column_int(db->fts_model_size, 0);
sqlite3_reset(db->fts_model_size);
return size;
}
cJSON *database_fts_search(database_t *db, const char *query, const char *path, long size_min,
long size_max, long date_min, long date_max, int page_size,
char **index_ids, char **mime_types, char **tags, int sort_asc,
fts_sort_t sort, int seed, char **after, int fetch_aggregations,
int highlight, int highlight_context_size) {
int highlight, int highlight_context_size, int model,
const float *embedding, int embedding_size) {
if (embedding) {
int model_embedding_size = database_fts_get_model_size(db, model);
if (model_embedding_size != embedding_size) {
LOG_WARNINGF("database_fts.c", "Received invalid embedding size for model %s: %d, expected %d",
model, embedding_size, model_embedding_size);
return NULL;
}
}
char path_glob[PATH_MAX * 2];
snprintf(path_glob, sizeof(path_glob), "%s/*", path);
@@ -502,6 +529,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"'$.content')";
}
const char *embedding_join = "";
if (embedding) {
embedding_join = "LEFT JOIN embedding emb ON emb.id = doc.id AND emb.model_id=?9";
}
char *sql;
char *agg_sql;
@@ -512,12 +544,14 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
" %s, %s as sort_var, doc.ROWID"
" FROM search"
" INNER JOIN document_index doc on doc.ROWID = search.ROWID"
" %s"
" WHERE %s"
" ORDER BY sort_var%s, doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : "DESC");
sort_asc ? "" : " DESC");
if (fetch_aggregations) {
asprintf(&agg_sql,
@@ -533,10 +567,12 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"SELECT"
" %s, %s as sort_var, doc.ROWID"
" FROM document_index doc"
" %s"
" WHERE %s"
" ORDER BY sort_var%s,doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : " DESC");
@@ -569,7 +605,6 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (tags) {
db->tag_array = tags;
}
if (size_min > 0) {
sqlite3_bind_int64(stmt, sqlite3_bind_parameter_index(stmt, "@size_min"), size_min);
}
@@ -602,6 +637,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (highlight) {
sqlite3_bind_int(stmt, 6, highlight_context_size);
}
if (embedding) {
sqlite3_bind_int(stmt, 7, embedding_size);
sqlite3_bind_blob(stmt, 8, embedding, (int) sizeof(float) * embedding_size, SQLITE_STATIC);
sqlite3_bind_int(stmt, 9, model);
}
cJSON *json = cJSON_CreateObject();
cJSON *hits_hits = cJSON_CreateArray();

View File

@@ -38,6 +38,25 @@ const char *FtsDatabaseSchema =
");"
"CREATE INDEX IF NOT EXISTS tag_tag_idx ON tag(tag);"
"CREATE INDEX IF NOT EXISTS tag_id_idx ON tag(id);"
""
"CREATE TABLE IF NOT EXISTS embedding ("
" id TEXT REFERENCES document(id),"
" model_id INTEGER NOT NULL REFERENCES model(id),"
" start INTEGER NOT NULL,"
" end INTEGER,"
" embedding BLOB NOT NULL,"
" PRIMARY KEY (id, model_id, start)"
");"
""
"CREATE TABLE IF NOT EXISTS model ("
" id INTEGER PRIMARY KEY,"
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
" url TEXT,"
" path TEXT NOT NULL UNIQUE,"
" size INTEGER NOT NULL,"
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
");"
""
"CREATE TRIGGER IF NOT EXISTS tag_write_trigger"
" AFTER INSERT ON tag"
" BEGIN"
@@ -155,5 +174,14 @@ const char *IndexDatabaseSchema =
" mime TEXT NOT NULL,"
" size INTEGER NOT NULL,"
" count INTEGER NOT NULL"
");"
""
"CREATE TABLE embedding ("
" id TEXT REFERENCES document(id),"
" model_id INTEGER NOT NULL references model(id),"
" start INTEGER NOT NULL,"
" end INTEGER,"
" embedding BLOB NOT NULL,"
" PRIMARY KEY (id, model_id, start)"
");";

View File

@@ -1,6 +1,7 @@
#ifndef WALK_H
#define WALK_H
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 500
int walk_directory_tree(const char *);

View File

@@ -321,6 +321,8 @@ void sist2_index(index_args_t *args) {
strcpy(doc_id, cJSON_GetObjectItem(json, "_id")->valuestring);
cJSON_DeleteItemFromObject(json, "_id");
// TODO: delete tag if empty
if (args->print) {
print_json(json, doc_id);
} else {
@@ -462,6 +464,11 @@ int set_to_negative_if_value_is_zero(UNUSED(struct argparse *self), const struct
int main(int argc, const char *argv[]) {
setlocale(LC_ALL, "");
// database_t *db = database_create("clip.sist2", INDEX_DATABASE);
// database_open(db);
// database_test(db);
// exit(0);
scan_args_t *scan_args = scan_args_create();
index_args_t *index_args = index_args_create();
web_args_t *web_args = web_args_create();

View File

@@ -87,7 +87,7 @@ static void buf2hex(const unsigned char *buf, size_t buflen, char *hex_string) {
*s = '\0';
}
static void md5_hexdigest(void *data, size_t size, char *output) {
static void md5_hexdigest(const void *data, size_t size, char *output) {
EVP_MD_CTX *md_ctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(md_ctx, EVP_md5(), NULL);
@@ -120,7 +120,7 @@ struct timespec timespec_add(struct timespec ts1, long usec);
#define pthread_cond_timedwait_ms(cond, mutex, delay_ms) do {\
struct timespec now; \
clock_gettime(CLOCK_REALTIME, &now); \
struct timespec end_time = timespec_add(now, MILLISECOND * delay_ms); \
struct timespec end_time = timespec_add(now, MILLISECOND * (delay_ms)); \
pthread_cond_timedwait(cond, mutex, &end_time); \
} while (0)

View File

@@ -28,7 +28,15 @@ static struct mg_http_serve_opts DefaultServeOpts = {
.fs = NULL,
.ssi_pattern = NULL,
.root_dir = NULL,
.mime_types = ""
.mime_types = HTTP_SERVER_HEADER
};
static struct mg_http_serve_opts IndexServeOpts = {
.fs = NULL,
.ssi_pattern = NULL,
.root_dir = NULL,
.mime_types = "",
.extra_headers = HTTP_SERVER_HEADER "Cross-Origin-Embedder-Policy: require-corp\r\nCross-Origin-Opener-Policy: same-origin\r\n"
};
void stats_files(struct mg_connection *nc, struct mg_http_message *hm) {
@@ -67,7 +75,7 @@ void stats_files(struct mg_connection *nc, struct mg_http_message *hm) {
void serve_index_html(struct mg_connection *nc, struct mg_http_message *hm) {
if (WebCtx.dev) {
mg_http_serve_file(nc, hm, "sist2-vue/dist/index.html", &DefaultServeOpts);
mg_http_serve_file(nc, hm, "sist2-vue/dist/index.html", &IndexServeOpts);
} else {
web_serve_asset_index_html(nc);
}
@@ -334,6 +342,9 @@ void index_info(struct mg_connection *nc) {
cJSON_AddStringToObject(idx_json, "rewriteUrl", idx->desc.rewrite_url);
cJSON_AddNumberToObject(idx_json, "timestamp", (double) idx->desc.timestamp);
cJSON_AddItemToArray(arr, idx_json);
cJSON *models = database_get_models(idx->db);
cJSON_AddItemToObject(idx_json, "models", models);
}
if (WebCtx.search_backend == SQLITE_SEARCH_BACKEND) {

View File

@@ -32,6 +32,9 @@ typedef struct {
int fetch_aggregations;
int highlight;
int highlight_context_size;
int model;
float *embedding;
int embedding_size;
} fts_search_req_t;
fts_sort_t get_sort_mode(const cJSON *req_sort) {
@@ -45,11 +48,27 @@ fts_sort_t get_sort_mode(const cJSON *req_sort) {
return FTS_SORT_RANDOM;
} else if (strcmp(req_sort->valuestring, "name") == 0) {
return FTS_SORT_NAME;
} else if (strcmp(req_sort->valuestring, "embedding") == 0) {
return FTS_SORT_EMBEDDING;
}
return FTS_SORT_INVALID;
}
float *get_float_buffer(cJSON *arr, int *size) {
*size = cJSON_GetArraySize(arr);
float *floats = malloc(sizeof(float) * *size);
cJSON *elem;
int i = 0;
cJSON_ArrayForEach(elem, arr) {
floats[i] = (float) elem->valuedouble;
i += 1;
}
return floats;
}
static json_value get_json_string(cJSON *object, const char *name) {
@@ -89,6 +108,25 @@ static json_value get_json_bool(cJSON *object, const char *name) {
return (json_value) {item, FALSE};
}
static json_value get_json_float_array(cJSON *object, const char *name) {
cJSON *item = cJSON_GetObjectItem(object, name);
if (item == NULL || cJSON_IsNull(item)) {
return (json_value) {NULL, FALSE};
}
if (!cJSON_IsArray(item) || cJSON_GetArraySize(item) == 0) {
return (json_value) {NULL, TRUE};
}
cJSON *elem;
cJSON_ArrayForEach(elem, item) {
if (!cJSON_IsNumber(elem)) {
return (json_value) {NULL, TRUE};
}
}
return (json_value) {item, FALSE};
}
static json_value get_json_array(cJSON *object, const char *name) {
cJSON *item = cJSON_GetObjectItem(object, name);
if (item == NULL || cJSON_IsNull(item)) {
@@ -131,7 +169,7 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
json_value req_query, req_path, req_size_min, req_size_max, req_date_min, req_date_max, req_page_size,
req_index_ids, req_mime_types, req_tags, req_sort_asc, req_sort, req_seed, req_after,
req_fetch_aggregations, req_highlight, req_highlight_context_size;
req_fetch_aggregations, req_highlight, req_highlight_context_size, req_embedding, req_model;
if (!cJSON_IsObject(json) ||
(req_query = get_json_string(json, "query")).invalid ||
@@ -150,6 +188,8 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
(req_mime_types = get_json_array(json, "mimeTypes")).invalid ||
(req_highlight = get_json_bool(json, "highlight")).invalid ||
(req_highlight_context_size = get_json_number(json, "highlightContextSize")).invalid ||
(req_embedding = get_json_float_array(json, "embedding")).invalid ||
(req_model = get_json_number(json, "model")).invalid ||
(req_tags = get_json_array(json, "tags")).invalid) {
cJSON_Delete(json);
return NULL;
@@ -190,7 +230,11 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
cJSON_Delete(json);
return NULL;
}
if (req_highlight_context_size.val->valueint < 0) {
if (req_highlight_context_size.val && req_highlight_context_size.val->valueint < 0) {
cJSON_Delete(json);
return NULL;
}
if (req_model.val && !req_embedding.val || !req_model.val && req_embedding.val) {
cJSON_Delete(json);
return NULL;
}
@@ -216,6 +260,10 @@ fts_search_req_t *get_search_req(struct mg_http_message *hm) {
req->highlight_context_size = req_highlight_context_size.val
? req_highlight_context_size.val->valueint
: DEFAULT_HIGHLIGHT_CONTEXT_SIZE;
req->model = req_model.val ? req_model.val->valueint : 0;
req->embedding = req_model.val
? get_float_buffer(req_embedding.val, &req->embedding_size)
: NULL;
cJSON_Delete(json);
@@ -238,6 +286,10 @@ void destroy_search_req(fts_search_req_t *req) {
destroy_array(req->mime_types);
destroy_array(req->tags);
if (req->embedding) {
free(req->embedding);
}
free(req);
}
@@ -331,7 +383,13 @@ void fts_search(struct mg_connection *nc, struct mg_http_message *hm) {
req->page_size, req->index_ids, req->mime_types,
req->tags, req->sort_asc, req->sort, req->seed,
req->after, req->fetch_aggregations, req->highlight,
req->highlight_context_size);
req->highlight_context_size, req->model,
req->embedding, req->embedding_size);
if (json == NULL) {
HTTP_REPLY_BAD_REQUEST
return;
}
destroy_search_req(req);
mg_send_json(nc, json);