mirror of
https://github.com/simon987/sist2.git
synced 2025-12-17 17:29:07 +00:00
wip
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import BertTokenizer from "@/ml/BertTokenizer";
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
import axios from "axios";
|
||||
import {chunk as _chunk} from "underscore";
|
||||
import * as ort from "onnxruntime-web";
|
||||
import {argMax, downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
|
||||
|
||||
export default class BertNerModel {
|
||||
vocabUrl;
|
||||
@@ -29,7 +31,10 @@ export default class BertNerModel {
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
|
||||
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
|
||||
const buf = await downloadToBuffer(this.modelUrl, onProgress);
|
||||
|
||||
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
|
||||
}
|
||||
|
||||
alignLabels(labels, wordIds, words) {
|
||||
@@ -57,21 +62,28 @@ export default class BertNerModel {
|
||||
|
||||
async predict(text, callback) {
|
||||
this._previousWordId = null;
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize)
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize);
|
||||
|
||||
let i = 0;
|
||||
for (let chunk of encoded.inputChunks) {
|
||||
const rawResult = tf.tidy(() => this._model.execute({
|
||||
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
|
||||
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
|
||||
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
|
||||
}));
|
||||
|
||||
const labelIds = await tf.argMax(rawResult, -1);
|
||||
const labelIdsArray = await labelIds.array();
|
||||
const labels = labelIdsArray[0].map(id => this.id2label[id]);
|
||||
rawResult.dispose()
|
||||
const results = await this._model.run({
|
||||
input_ids: new ort.Tensor("int32", chunk.inputIds, [1, this.inputSize]),
|
||||
token_type_ids: new ort.Tensor("int32", chunk.segmentIds, [1, this.inputSize]),
|
||||
attention_mask: new ort.Tensor("int32", chunk.inputMask, [1, this.inputSize]),
|
||||
});
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
|
||||
const labelIds = _chunk(results["output"].data, this.id2label.length).map(argMax);
|
||||
const labels = labelIds.map(id => this.id2label[id]);
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words));
|
||||
|
||||
i += 1;
|
||||
|
||||
// give browser some time to repaint
|
||||
if (i % 2 === 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user