mirror of
https://github.com/simon987/sist2.git
synced 2025-04-19 10:16:42 +00:00
58 lines
1.6 KiB
JavaScript
58 lines
1.6 KiB
JavaScript
import * as ort from "onnxruntime-web";
|
|
import {BPETokenizer} from "@/ml/BPETokenizer";
|
|
import axios from "axios";
|
|
import {downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
|
import ModelStore from "@/ml/ModelStore";
|
|
|
|
export class CLIPTransformerModel {
|
|
|
|
_modelUrl = null;
|
|
_tokenizerUrl = null;
|
|
_model = null;
|
|
_tokenizer = null;
|
|
|
|
constructor(modelUrl, tokenizerUrl) {
|
|
this._modelUrl = modelUrl;
|
|
this._tokenizerUrl = tokenizerUrl;
|
|
}
|
|
|
|
async init(onProgress) {
|
|
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
|
}
|
|
|
|
async loadModel(onProgress) {
|
|
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
|
ort.env.wasm.numThreads = 2;
|
|
|
|
let buf = await ModelStore.get(this._modelUrl);
|
|
if (!buf) {
|
|
buf = await downloadToBuffer(this._modelUrl, onProgress);
|
|
await ModelStore.set(this._modelUrl, buf);
|
|
}
|
|
|
|
this._model = await ort.InferenceSession.create(buf.buffer, {
|
|
executionProviders: ["wasm"],
|
|
});
|
|
}
|
|
|
|
async loadTokenizer() {
|
|
const resp = await axios.get(this._tokenizerUrl);
|
|
this._tokenizer = new BPETokenizer(resp.data.encoder, resp.data.bpe_ranks)
|
|
}
|
|
|
|
async predict(text) {
|
|
const tokenized = this._tokenizer.encode(text);
|
|
|
|
const inputs = {
|
|
input_ids: new ort.Tensor("int32", tokenized, [1, 77])
|
|
};
|
|
|
|
const results = await this._model.run(inputs);
|
|
|
|
return Array.from(
|
|
Object.values(results)
|
|
.find(result => result.size === 512).data
|
|
);
|
|
}
|
|
}
|