This commit is contained in:
2023-07-24 19:36:20 -04:00
parent f56cfb0f2f
commit 27188b6fa0
29 changed files with 1008 additions and 75 deletions

View File

@@ -1,6 +1,8 @@
import BertTokenizer from "@/ml/BertTokenizer";
import * as tf from "@tensorflow/tfjs";
import axios from "axios";
import {chunk as _chunk} from "underscore";
import * as ort from "onnxruntime-web";
import {argMax, downloadToBuffer, ORT_WASM_PATHS} from "@/ml/mlUtils";
export default class BertNerModel {
vocabUrl;
@@ -29,7 +31,10 @@ export default class BertNerModel {
}
async loadModel(onProgress) {
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
ort.env.wasm.wasmPaths = ORT_WASM_PATHS;
const buf = await downloadToBuffer(this.modelUrl, onProgress);
this._model = await ort.InferenceSession.create(buf.buffer, {executionProviders: ["wasm"]});
}
alignLabels(labels, wordIds, words) {
@@ -57,21 +62,28 @@ export default class BertNerModel {
async predict(text, callback) {
this._previousWordId = null;
const encoded = this._tokenizer.encodeText(text, this.inputSize)
const encoded = this._tokenizer.encodeText(text, this.inputSize);
let i = 0;
for (let chunk of encoded.inputChunks) {
const rawResult = tf.tidy(() => this._model.execute({
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
}));
const labelIds = await tf.argMax(rawResult, -1);
const labelIdsArray = await labelIds.array();
const labels = labelIdsArray[0].map(id => this.id2label[id]);
rawResult.dispose()
const results = await this._model.run({
input_ids: new ort.Tensor("int32", chunk.inputIds, [1, this.inputSize]),
token_type_ids: new ort.Tensor("int32", chunk.segmentIds, [1, this.inputSize]),
attention_mask: new ort.Tensor("int32", chunk.inputMask, [1, this.inputSize]),
});
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
const labelIds = _chunk(results["output"].data, this.id2label.length).map(argMax);
const labels = labelIds.map(id => this.id2label[id]);
callback(this.alignLabels(labels, chunk.wordIds, encoded.words));
i += 1;
// give browser some time to repaint
if (i % 2 === 0) {
await new Promise(resolve => setTimeout(resolve, 0));
}
}
}
}