sist2/sist2-vue/src/ml/BertNerModel.js
2023-04-23 12:53:27 -04:00

77 lines
2.2 KiB
JavaScript

import BertTokenizer from "@/ml/BertTokenizer";
import * as tf from "@tensorflow/tfjs";
import axios from "axios";
export default class BertNerModel {
vocabUrl;
modelUrl;
id2label;
_tokenizer;
_model;
inputSize = 128;
_previousWordId = null;
constructor(vocabUrl, modelUrl, id2label) {
this.vocabUrl = vocabUrl;
this.modelUrl = modelUrl;
this.id2label = id2label;
}
async init(onProgress) {
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
}
async loadTokenizer() {
const vocab = (await axios.get(this.vocabUrl)).data;
this._tokenizer = new BertTokenizer(vocab);
}
async loadModel(onProgress) {
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
}
alignLabels(labels, wordIds, words) {
const result = [];
for (let i = 0; i < this.inputSize; i++) {
const label = labels[i];
const wordId = wordIds[i];
if (wordId === -1) {
continue;
}
if (wordId === this._previousWordId) {
continue;
}
result.push({
word: words[wordId].text, wordIndex: words[wordId].index, label: label
});
this._previousWordId = wordId;
}
return result;
}
async predict(text, callback) {
this._previousWordId = null;
const encoded = this._tokenizer.encodeText(text, this.inputSize)
for (let chunk of encoded.inputChunks) {
const rawResult = tf.tidy(() => this._model.execute({
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
}));
const labelIds = await tf.argMax(rawResult, -1);
const labelIdsArray = await labelIds.array();
const labels = labelIdsArray[0].map(id => this.id2label[id]);
rawResult.dispose()
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
}
}
}