Rework user scripts, update DB schema to support embeddings

This commit is contained in:
2023-08-19 15:46:19 -04:00
parent 27188b6fa0
commit 857f3315c2
62 changed files with 1842 additions and 1250 deletions

View File

@@ -180,6 +180,10 @@ void database_open(database_t *db) {
db->db, "SELECT * FROM model", -1,
&db->get_models, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_prepare_v2(
db->db, "SELECT embedding FROM embedding WHERE id=? AND model_id=? AND start=0", -1,
&db->get_embedding, NULL));
// Create functions
sqlite3_create_function(
db->db,
@@ -194,11 +198,11 @@ void database_open(database_t *db) {
sqlite3_create_function(
db->db,
"embedding_to_json",
5,
"emb_to_json",
1,
SQLITE_UTF8,
NULL,
embedding_to_json_func,
emb_to_json_func,
NULL,
NULL
);
@@ -494,29 +498,31 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
sqlite3_stmt *stmt;
sqlite3_prepare_v2(db->db, "WITH doc (j) AS (SELECT CASE"
" WHEN sc.json_data IS NULL THEN"
" CASE"
" WHEN t.tag IS NULL THEN"
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime)"
" ELSE"
" json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))"
" END"
" ELSE"
" CASE"
" WHEN t.tag IS NULL THEN"
" json_patch(json_set(document.json_data, '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime), sc.json_data)"
" ELSE"
// This will overwrite any tags specified in the sidecar file!
// TODO: concatenate the two arrays?
" json_set(json_patch(document.json_data, sc.json_data), '$._id', document.id, '$.size', document.size, '$.mtime', document.mtime, '$.tag', json_group_array(t.tag))"
" END"
" END"
" FROM document"
" LEFT JOIN document_sidecar sc ON document.id = sc.id"
" LEFT JOIN tag t ON document.id = t.id"
" GROUP BY document.id)"
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc", -1, &stmt, NULL);
CRASH_IF_NOT_SQLITE_OK(
sqlite3_prepare_v2(
db->db,
"WITH doc (j) AS (SELECT CASE"
" WHEN emb.embedding IS NULL THEN"
" json_set(document.json_data, "
" '$._id', document.id, "
" '$.size', document.size, "
" '$.mtime', document.mtime, "
" '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)))"
" ELSE"
" json_set(document.json_data,"
" '$._id', document.id,"
" '$.size', document.size,"
" '$.mtime', document.mtime,"
" '$.tag', json_group_array((SELECT tag FROM tag WHERE document.id = tag.id)),"
" '$.emb', json_group_object(m.path, json(emb_to_json(emb.embedding))),"
" '$.embedding', 1)"
" END"
" FROM document"
" LEFT JOIN embedding emb ON document.id = emb.id"
" LEFT JOIN model m ON emb.model_id = m.id"
" GROUP BY document.id)"
" SELECT json_set(j, '$.index', (SELECT id FROM descriptor)) FROM doc",
-1, &stmt, NULL));
database_iterator_t *iter = malloc(sizeof(database_iterator_t));
@@ -526,6 +532,13 @@ database_iterator_t *database_create_document_iterator(database_t *db) {
return iter;
}
void remove_tag_if_null(cJSON *doc) {
cJSON *tags = cJSON_GetObjectItem(doc, "tag");
if (tags != NULL && cJSON_IsNull(cJSON_GetArrayItem(tags, 0))) {
cJSON_DeleteItemFromObject(doc, "tag");
}
}
cJSON *database_document_iter(database_iterator_t *iter) {
if (iter->stmt == NULL) {
@@ -537,7 +550,12 @@ cJSON *database_document_iter(database_iterator_t *iter) {
if (ret == SQLITE_ROW) {
const char *json_string = (const char *) sqlite3_column_text(iter->stmt, 0);
return cJSON_Parse(json_string);
cJSON *doc = cJSON_Parse(json_string);
remove_tag_if_null(doc);
return doc;
}
if (ret != SQLITE_DONE) {

View File

@@ -85,6 +85,7 @@ typedef struct database {
sqlite3_stmt *write_thumbnail_stmt;
sqlite3_stmt *get_document;
sqlite3_stmt *get_models;
sqlite3_stmt *get_embedding;
sqlite3_stmt *delete_tag_stmt;
sqlite3_stmt *write_tag_stmt;
@@ -144,6 +145,8 @@ void database_write_document(database_t *db, document_t *doc, const char *json_d
database_iterator_t *database_create_document_iterator(database_t *db);
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
cJSON *database_document_iter(database_iterator_t *);
#define database_document_iter_foreach(element, iter) \
@@ -235,8 +238,10 @@ cJSON *database_get_document(database_t *db, char *doc_id);
void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv);
cJSON *database_get_models(database_t *db);
int database_fts_get_model_size(database_t *db, int model_id);
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id);
#endif

View File

@@ -1,5 +1,6 @@
#include <openblas/cblas.h>
#include "database.h"
#include "src/ctx.h"
static float cosine_sim(int n, const float *a, const float *b) {
@@ -33,38 +34,6 @@ void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
sqlite3_result_double(ctx, result);
}
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
// emb, type, start, end, size
if (argc != 5) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
const float *embedding = sqlite3_value_blob(argv[0]);
const char *type = (const char *) sqlite3_value_text(argv[1]);
int size = sqlite3_value_int(argv[4]);
if (strcmp(type, "flat") == 0) {
cJSON *json = cJSON_CreateFloatArray(embedding, size);
char *json_str = cJSON_PrintBuffered(json, size * 22, FALSE);
cJSON_Delete(json);
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
free(json_str);
} else {
int start = sqlite3_value_int(argv[2]);
int end = sqlite3_value_int(argv[3]);
sqlite3_result_error(ctx, "Nested embeddings not implemented yet", -1);
}
}
cJSON *database_get_models(database_t *db) {
cJSON *json = cJSON_CreateArray();
sqlite3_stmt *stmt = db->get_models;
@@ -72,7 +41,12 @@ cJSON *database_get_models(database_t *db) {
int ret;
do {
ret = sqlite3_step(stmt);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_BUSY) {
// Database is busy (probably scanning)
LOG_WARNING("database_embeddings.c",
"Database is busy, could not fetch list of models");
break;
}
if (ret == SQLITE_DONE) {
break;
@@ -82,7 +56,7 @@ cJSON *database_get_models(database_t *db) {
cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0));
cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1));
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_int64(stmt, 2));
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_text(stmt, 2));
cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3));
cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4));
cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5));
@@ -90,5 +64,44 @@ cJSON *database_get_models(database_t *db) {
cJSON_AddItemToArray(json, row);
} while (TRUE);
sqlite3_reset(stmt);
return json;
}
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id) {
sqlite3_bind_text(db->get_embedding, 1, doc_id, -1, SQLITE_STATIC);
sqlite3_bind_int(db->get_embedding, 2, model_id);
int ret = sqlite3_step(db->get_embedding);
CRASH_IF_STMT_FAIL(ret);
if (ret == SQLITE_DONE) {
sqlite3_reset(db->get_embedding);
return NULL;
}
float *embedding = (float *) sqlite3_column_blob(db->get_embedding, 0);
size_t size = sqlite3_column_bytes(db->get_embedding, 0) / sizeof(float);
cJSON *json = cJSON_CreateFloatArray(embedding, (int) size);
sqlite3_reset(db->get_embedding);
return json;
}
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
if (argc != 1) {
sqlite3_result_error(ctx, "Invalid parameters", -1);
}
float *embedding = (float *) sqlite3_value_blob(argv[0]);
int size = sqlite3_value_bytes(argv[0]) / 4;
cJSON *json = cJSON_CreateFloatArray(embedding, size);
char *json_str = cJSON_PrintUnformatted(json);
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
free(json_str);
cJSON_Delete(json);
}

View File

@@ -67,6 +67,11 @@ void database_fts_index(database_t *db) {
"REPLACE INTO fts.embedding (id, model_id, start, end, embedding)"
" SELECT id, model_id, start, end, embedding FROM embedding", NULL, NULL, NULL));
CRASH_IF_NOT_SQLITE_OK(sqlite3_exec(
db->db,
"INSERT INTO fts.model (id, size)"
" SELECT id, size FROM model WHERE TRUE ON CONFLICT (id) DO NOTHING", NULL, NULL, NULL));
// TODO: delete old embeddings
LOG_DEBUG("database_fts.c", "Deleting old documents");
@@ -518,20 +523,15 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
const char *json_object_sql;
if (highlight && query_where != NULL) {
json_object_sql = "json_remove(json_set(doc.json_data,"
json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
"'$.index', doc.index_id,"
"'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END),"
"'$._highlight.name', snippet(search, 0, '<mark>', '</mark>', '', ?6),"
"'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6)),"
"'$.content')";
"'$._highlight.content', snippet(search, 1, '<mark>', '</mark>', '', ?6))";
} else {
json_object_sql = "json_remove(json_set(doc.json_data,"
"'$.index', doc.index_id),"
"'$.content')";
}
const char *embedding_join = "";
if (embedding) {
embedding_join = "LEFT JOIN embedding emb ON emb.id = doc.id AND emb.model_id=?9";
json_object_sql = "json_set(json_remove(doc.json_data, '$.content'),"
"'$.index', doc.index_id,"
"'$.embedding', (CASE WHEN emb.id IS NOT NULL THEN 1 ELSE 0 END))";
}
char *sql;
@@ -544,12 +544,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
" %s, %s as sort_var, doc.ROWID"
" FROM search"
" INNER JOIN document_index doc on doc.ROWID = search.ROWID"
" %s"
" LEFT JOIN embedding emb on emb.id = doc.id"
" WHERE %s"
" ORDER BY sort_var%s, doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : " DESC");
@@ -567,12 +566,11 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
"SELECT"
" %s, %s as sort_var, doc.ROWID"
" FROM document_index doc"
" %s"
" LEFT JOIN embedding emb on emb.id = doc.id"
" WHERE %s"
" ORDER BY sort_var%s,doc.ROWID"
" LIMIT ?2",
json_object_sql, get_sort_var(sort),
embedding_join,
where,
sort_asc ? "" : " DESC");
@@ -624,7 +622,7 @@ cJSON *database_fts_search(database_t *db, const char *query, const char *path,
if (after_where) {
if (sort == FTS_SORT_NAME || sort == FTS_SORT_ID) {
sqlite3_bind_text(stmt, 3, after[0], -1, SQLITE_STATIC);
} else if (sort == FTS_SORT_SCORE) {
} else if (sort == FTS_SORT_SCORE || sort == FTS_SORT_EMBEDDING) {
sqlite3_bind_double(stmt, 3, strtod(after[0], NULL));
} else {
sqlite3_bind_int64(stmt, 3, strtol(after[0], NULL, 10));

View File

@@ -49,12 +49,8 @@ const char *FtsDatabaseSchema =
");"
""
"CREATE TABLE IF NOT EXISTS model ("
" id INTEGER PRIMARY KEY,"
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
" url TEXT,"
" path TEXT NOT NULL UNIQUE,"
" size INTEGER NOT NULL,"
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
" id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
" size INTEGER NOT NULL"
");"
""
"CREATE TRIGGER IF NOT EXISTS tag_write_trigger"
@@ -183,5 +179,14 @@ const char *IndexDatabaseSchema =
" end INTEGER,"
" embedding BLOB NOT NULL,"
" PRIMARY KEY (id, model_id, start)"
");"
""
"CREATE TABLE model ("
" id INTEGER PRIMARY KEY CHECK (id > 0 AND id < 1000),"
" name TEXT NOT NULL UNIQUE CHECK ( length(name) < 16 ),"
" url TEXT,"
" path TEXT NOT NULL UNIQUE,"
" size INTEGER NOT NULL,"
" type TEXT NOT NULL CHECK ( type IN ('flat', 'nested') )"
");";