mirror of
https://github.com/simon987/sist2.git
synced 2025-12-18 09:49:06 +00:00
Add NER support
This commit is contained in:
77
sist2-vue/src/ml/BertNerModel.js
Normal file
77
sist2-vue/src/ml/BertNerModel.js
Normal file
@@ -0,0 +1,77 @@
|
||||
import BertTokenizer from "@/ml/BertTokenizer";
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
import axios from "axios";
|
||||
|
||||
export default class BertNerModel {
|
||||
vocabUrl;
|
||||
modelUrl;
|
||||
|
||||
id2label;
|
||||
_tokenizer;
|
||||
_model;
|
||||
inputSize = 128;
|
||||
|
||||
_previousWordId = null;
|
||||
|
||||
constructor(vocabUrl, modelUrl, id2label) {
|
||||
this.vocabUrl = vocabUrl;
|
||||
this.modelUrl = modelUrl;
|
||||
this.id2label = id2label;
|
||||
}
|
||||
|
||||
async init(onProgress) {
|
||||
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
||||
}
|
||||
|
||||
async loadTokenizer() {
|
||||
const vocab = (await axios.get(this.vocabUrl)).data;
|
||||
this._tokenizer = new BertTokenizer(vocab);
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
|
||||
}
|
||||
|
||||
alignLabels(labels, wordIds, words) {
|
||||
const result = [];
|
||||
|
||||
for (let i = 0; i < this.inputSize; i++) {
|
||||
const label = labels[i];
|
||||
const wordId = wordIds[i];
|
||||
|
||||
if (wordId === -1) {
|
||||
continue;
|
||||
}
|
||||
if (wordId === this._previousWordId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push({
|
||||
word: words[wordId].text, wordIndex: words[wordId].index, label: label
|
||||
});
|
||||
this._previousWordId = wordId;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async predict(text, callback) {
|
||||
this._previousWordId = null;
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize)
|
||||
|
||||
for (let chunk of encoded.inputChunks) {
|
||||
const rawResult = tf.tidy(() => this._model.execute({
|
||||
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
|
||||
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
|
||||
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
|
||||
}));
|
||||
|
||||
const labelIds = await tf.argMax(rawResult, -1);
|
||||
const labelIdsArray = await labelIds.array();
|
||||
const labels = labelIdsArray[0].map(id => this.id2label[id]);
|
||||
rawResult.dispose()
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user