This commit is contained in:
2023-07-24 19:36:20 -04:00
parent f56cfb0f2f
commit 27188b6fa0
29 changed files with 1008 additions and 75 deletions

View File

@@ -37,7 +37,7 @@ int database_fts_get_max_path_depth(database_t *db) {
void database_fts_index(database_t *db) {
LOG_INFO("database_fts.c", "Creating content table.");
LOG_INFO("database_fts.c", "Creating content table");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -47,21 +47,12 @@ void database_fts_index(database_t *db) {
" document.json_data ->> 'path' as path,"
" mtime,"
" document.json_data ->> 'mime' as mime,"
" CASE"
" WHEN sc.json_data IS NULL THEN"
" json_set(document.json_data, "
" '$._id',document.id,"
" '$.size',document.size, "
" '$.mtime',document.mtime)"
" ELSE json_patch("
" json_set(document.json_data,"
" '$._id',document.id,"
" '$.size',document.size,"
" '$.mtime', document.mtime),"
" sc.json_data) END"
" FROM document"
" LEFT JOIN document_sidecar sc ON document.id = sc.id"
" GROUP BY document.id)"
" )"
" INSERT"
" INTO fts.document_index (id, index_id, size, name, path, mtime, mime, json_data)"
" SELECT * FROM docs WHERE true"
@@ -69,7 +60,16 @@ void database_fts_index(database_t *db) {
" size=excluded.size, mtime=excluded.mtime, mime=excluded.mime, json_data=excluded.json_data;",
NULL, NULL, NULL));
LOG_DEBUG("database_fts.c", "Deleting old documents.");
LOG_DEBUG("database_fts.c", "Copying embeddings");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
"REPLACE INTO fts.embedding (id, model_id, start, end, embedding)"
" SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL));
// TODO: delete old embeddings
LOG_DEBUG("database_fts.c", "Deleting old documents");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -144,7 +144,7 @@ void database_fts_index(database_t *db) {
"INSERT INTO path_index (path, index_id, count, depth) SELECT path, index_id, total, depth FROM path_tmp",
NULL, NULL, NULL));
LOG_DEBUG("database_fts.c", "Generating search index.");
LOG_DEBUG("database_fts.c", "Generating search index");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db, "INSERT INTO search(search) VALUES ('delete-all')",
@@ -157,7 +157,7 @@ void database_fts_index(database_t *db) {
}
void database_fts_optimize(database_t *db) {
LOG_INFO("database_fts.c", "Optimizing search index.");
LOG_INFO("database_fts.c", "Optimizing search index");
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
@@ -408,6 +408,8 @@ const char *get_sort_var(fts_sort_t sort) {
return "doc.name";
case FTS_SORT_ID:
return "doc.id";
case FTS_SORT_EMBEDDING:
return "cosine_sim(?7, ?8, emb.embedding)";
default:
return NULL;
}
@@ -459,11 +461,36 @@ char *get_after_where(char **after, fts_sort_t sort, int sort_asc) {
return "(sort_var, doc.ROWID) < (?3, ?4)";
}
int database_fts_get_model_size(database_t *db, int model_id) {
sqlite3_bind_int(db->fts_model_size, 1, model_id);
int ret = sqlite3_step(db->fts_model_size);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_DONE) {
return -1;
}
int size = sqlite3_column_int(db->fts_model_size, 0);
sqlite3_reset(db->fts_model_size);
return size;
}
cJSON *database_fts_search(database_t *db, const char *query, const char *path, long size_min,
long size_max, long date_min, long date_max, int page_size,
char **index_ids, char **mime_types, char **tags, int sort_asc,
fts_sort_t sort, int seed, char **after, int fetch_aggregations,
int highlight, int highlight_context_size) {
int highlight, int highlight_context_size, int model,
const float *embedding, int embedding_size) {
if (embedding) {
int model_embedding_size = database_fts_get_model_size(db, model);
if (model_embedding_size != embedding_size) {
LOG_WARNINGF("database_fts.c", "Received invalid embedding size for model %s: %d, expected %d",
model, embedding_size, model_embedding_size);
return NULL;
}
}
char path_glob[PATH_MAX * 2];
snprintf(path_glob, sizeof(path_glob), "%s/*", path);
@@ -502,6 +529,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"'$.content')";
}
const char *embedding_join = "";
if (embedding) {
embedding_join = "LEFT JOIN embedding emb ON emb.id = doc.id AND emb.model_id=?9";
}
char *sql;
char *agg_sql;
@@ -512,12 +544,14 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
" %s, %s as sort_var, doc.ROWID"
" FROM search"
" INNER JOIN document_index doc on doc.ROWID = search.ROWID"
" %s"
" WHERE %s"
" ORDER BY sort_var%s, doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : "DESC");
sort_asc ? "" : " DESC");
if (fetch_aggregations) {
asprintf(&agg_sql,
@@ -533,10 +567,12 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"SELECT"
" %s, %s as sort_var, doc.ROWID"
" FROM document_index doc"
" %s"
" WHERE %s"
" ORDER BY sort_var%s,doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : " DESC");
@@ -569,7 +605,6 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (tags) {
db->tag_array = tags;
}
if (size_min > 0) {
sqlite3_bind_int64(stmt, sqlite3_bind_parameter_index(stmt, "@size_min"), size_min);
}
@@ -602,6 +637,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (highlight) {
sqlite3_bind_int(stmt, 6, highlight_context_size);
}
if (embedding) {
sqlite3_bind_int(stmt, 7, embedding_size);
sqlite3_bind_blob(stmt, 8, embedding, (int) sizeof(float) * embedding_size, SQLITE_STATIC);
sqlite3_bind_int(stmt, 9, model);
}
cJSON *json = cJSON_CreateObject();
cJSON *hits_hits = cJSON_CreateArray();