mirror of
https://github.com/simon987/sist2.git
synced 2025-04-21 11:16:46 +00:00
77 lines
2.2 KiB
JavaScript
77 lines
2.2 KiB
JavaScript
import BertTokenizer from "@/ml/BertTokenizer";
|
|
import * as tf from "@tensorflow/tfjs";
|
|
import axios from "axios";
|
|
|
|
export default class BertNerModel {
|
|
vocabUrl;
|
|
modelUrl;
|
|
|
|
id2label;
|
|
_tokenizer;
|
|
_model;
|
|
inputSize = 128;
|
|
|
|
_previousWordId = null;
|
|
|
|
constructor(vocabUrl, modelUrl, id2label) {
|
|
this.vocabUrl = vocabUrl;
|
|
this.modelUrl = modelUrl;
|
|
this.id2label = id2label;
|
|
}
|
|
|
|
async init(onProgress) {
|
|
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
|
}
|
|
|
|
async loadTokenizer() {
|
|
const vocab = (await axios.get(this.vocabUrl)).data;
|
|
this._tokenizer = new BertTokenizer(vocab);
|
|
}
|
|
|
|
async loadModel(onProgress) {
|
|
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
|
|
}
|
|
|
|
alignLabels(labels, wordIds, words) {
|
|
const result = [];
|
|
|
|
for (let i = 0; i < this.inputSize; i++) {
|
|
const label = labels[i];
|
|
const wordId = wordIds[i];
|
|
|
|
if (wordId === -1) {
|
|
continue;
|
|
}
|
|
if (wordId === this._previousWordId) {
|
|
continue;
|
|
}
|
|
|
|
result.push({
|
|
word: words[wordId].text, wordIndex: words[wordId].index, label: label
|
|
});
|
|
this._previousWordId = wordId;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
async predict(text, callback) {
|
|
this._previousWordId = null;
|
|
const encoded = this._tokenizer.encodeText(text, this.inputSize)
|
|
|
|
for (let chunk of encoded.inputChunks) {
|
|
const rawResult = tf.tidy(() => this._model.execute({
|
|
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
|
|
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
|
|
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
|
|
}));
|
|
|
|
const labelIds = await tf.argMax(rawResult, -1);
|
|
const labelIdsArray = await labelIds.array();
|
|
const labels = labelIdsArray[0].map(id => this.id2label[id]);
|
|
rawResult.dispose()
|
|
|
|
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
|
|
}
|
|
}
|
|
} |