mirror of
https://github.com/simon987/sist2.git
synced 2025-12-18 01:39:05 +00:00
Rework user scripts, update DB schema to support embeddings
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#include <openblas/cblas.h>
|
||||
#include "database.h"
|
||||
#include "src/ctx.h"
|
||||
|
||||
|
||||
static float cosine_sim(int n, const float *a, const float *b) {
|
||||
@@ -33,38 +34,6 @@ void cosine_sim_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
sqlite3_result_double(ctx, result);
|
||||
}
|
||||
|
||||
void embedding_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
|
||||
// emb, type, start, end, size
|
||||
|
||||
if (argc != 5) {
|
||||
sqlite3_result_error(ctx, "Invalid parameters", -1);
|
||||
}
|
||||
|
||||
const float *embedding = sqlite3_value_blob(argv[0]);
|
||||
const char *type = (const char *) sqlite3_value_text(argv[1]);
|
||||
|
||||
int size = sqlite3_value_int(argv[4]);
|
||||
|
||||
if (strcmp(type, "flat") == 0) {
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, size);
|
||||
|
||||
char *json_str = cJSON_PrintBuffered(json, size * 22, FALSE);
|
||||
|
||||
cJSON_Delete(json);
|
||||
|
||||
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
|
||||
free(json_str);
|
||||
|
||||
} else {
|
||||
int start = sqlite3_value_int(argv[2]);
|
||||
int end = sqlite3_value_int(argv[3]);
|
||||
|
||||
sqlite3_result_error(ctx, "Nested embeddings not implemented yet", -1);
|
||||
}
|
||||
}
|
||||
|
||||
cJSON *database_get_models(database_t *db) {
|
||||
cJSON *json = cJSON_CreateArray();
|
||||
sqlite3_stmt *stmt = db->get_models;
|
||||
@@ -72,7 +41,12 @@ cJSON *database_get_models(database_t *db) {
|
||||
int ret;
|
||||
do {
|
||||
ret = sqlite3_step(stmt);
|
||||
CRASH_IF_STMT_FAIL(ret);
|
||||
if (ret == SQLITE_BUSY) {
|
||||
// Database is busy (probably scanning)
|
||||
LOG_WARNING("database_embeddings.c",
|
||||
"Database is busy, could not fetch list of models");
|
||||
break;
|
||||
}
|
||||
|
||||
if (ret == SQLITE_DONE) {
|
||||
break;
|
||||
@@ -82,7 +56,7 @@ cJSON *database_get_models(database_t *db) {
|
||||
|
||||
cJSON_AddNumberToObject(row, "id", sqlite3_column_int(stmt, 0));
|
||||
cJSON_AddStringToObject(row, "name", (const char *) sqlite3_column_text(stmt, 1));
|
||||
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_int64(stmt, 2));
|
||||
cJSON_AddStringToObject(row, "url", (const char *) sqlite3_column_text(stmt, 2));
|
||||
cJSON_AddStringToObject(row, "path", (const char *) sqlite3_column_text(stmt, 3));
|
||||
cJSON_AddNumberToObject(row, "size", sqlite3_column_int(stmt, 4));
|
||||
cJSON_AddStringToObject(row, "type", (const char *) sqlite3_column_text(stmt, 5));
|
||||
@@ -90,5 +64,44 @@ cJSON *database_get_models(database_t *db) {
|
||||
cJSON_AddItemToArray(json, row);
|
||||
} while (TRUE);
|
||||
|
||||
sqlite3_reset(stmt);
|
||||
|
||||
return json;
|
||||
}
|
||||
|
||||
cJSON *database_get_embedding(database_t *db, char *doc_id, int model_id) {
|
||||
|
||||
sqlite3_bind_text(db->get_embedding, 1, doc_id, -1, SQLITE_STATIC);
|
||||
sqlite3_bind_int(db->get_embedding, 2, model_id);
|
||||
int ret = sqlite3_step(db->get_embedding);
|
||||
CRASH_IF_STMT_FAIL(ret);
|
||||
|
||||
if (ret == SQLITE_DONE) {
|
||||
sqlite3_reset(db->get_embedding);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
float *embedding = (float *) sqlite3_column_blob(db->get_embedding, 0);
|
||||
size_t size = sqlite3_column_bytes(db->get_embedding, 0) / sizeof(float);
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, (int) size);
|
||||
sqlite3_reset(db->get_embedding);
|
||||
|
||||
return json;
|
||||
}
|
||||
|
||||
void emb_to_json_func(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
|
||||
if (argc != 1) {
|
||||
sqlite3_result_error(ctx, "Invalid parameters", -1);
|
||||
}
|
||||
|
||||
float *embedding = (float *) sqlite3_value_blob(argv[0]);
|
||||
int size = sqlite3_value_bytes(argv[0]) / 4;
|
||||
|
||||
cJSON *json = cJSON_CreateFloatArray(embedding, size);
|
||||
char *json_str = cJSON_PrintUnformatted(json);
|
||||
|
||||
sqlite3_result_text(ctx, json_str, -1, SQLITE_TRANSIENT);
|
||||
free(json_str);
|
||||
cJSON_Delete(json);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user