mirror of
https://github.com/simon987/sist2.git
synced 2025-12-12 15:08:53 +00:00
Add NER support
This commit is contained in:
77
sist2-vue/src/ml/BertNerModel.js
Normal file
77
sist2-vue/src/ml/BertNerModel.js
Normal file
@@ -0,0 +1,77 @@
|
||||
import BertTokenizer from "@/ml/BertTokenizer";
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
import axios from "axios";
|
||||
|
||||
export default class BertNerModel {
|
||||
vocabUrl;
|
||||
modelUrl;
|
||||
|
||||
id2label;
|
||||
_tokenizer;
|
||||
_model;
|
||||
inputSize = 128;
|
||||
|
||||
_previousWordId = null;
|
||||
|
||||
constructor(vocabUrl, modelUrl, id2label) {
|
||||
this.vocabUrl = vocabUrl;
|
||||
this.modelUrl = modelUrl;
|
||||
this.id2label = id2label;
|
||||
}
|
||||
|
||||
async init(onProgress) {
|
||||
await Promise.all([this.loadTokenizer(), this.loadModel(onProgress)]);
|
||||
}
|
||||
|
||||
async loadTokenizer() {
|
||||
const vocab = (await axios.get(this.vocabUrl)).data;
|
||||
this._tokenizer = new BertTokenizer(vocab);
|
||||
}
|
||||
|
||||
async loadModel(onProgress) {
|
||||
this._model = await tf.loadGraphModel(this.modelUrl, {onProgress});
|
||||
}
|
||||
|
||||
alignLabels(labels, wordIds, words) {
|
||||
const result = [];
|
||||
|
||||
for (let i = 0; i < this.inputSize; i++) {
|
||||
const label = labels[i];
|
||||
const wordId = wordIds[i];
|
||||
|
||||
if (wordId === -1) {
|
||||
continue;
|
||||
}
|
||||
if (wordId === this._previousWordId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
result.push({
|
||||
word: words[wordId].text, wordIndex: words[wordId].index, label: label
|
||||
});
|
||||
this._previousWordId = wordId;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async predict(text, callback) {
|
||||
this._previousWordId = null;
|
||||
const encoded = this._tokenizer.encodeText(text, this.inputSize)
|
||||
|
||||
for (let chunk of encoded.inputChunks) {
|
||||
const rawResult = tf.tidy(() => this._model.execute({
|
||||
input_ids: tf.tensor2d(chunk.inputIds, [1, this.inputSize], "int32"),
|
||||
token_type_ids: tf.tensor2d(chunk.segmentIds, [1, this.inputSize], "int32"),
|
||||
attention_mask: tf.tensor2d(chunk.inputMask, [1, this.inputSize], "int32"),
|
||||
}));
|
||||
|
||||
const labelIds = await tf.argMax(rawResult, -1);
|
||||
const labelIdsArray = await labelIds.array();
|
||||
const labels = labelIdsArray[0].map(id => this.id2label[id]);
|
||||
rawResult.dispose()
|
||||
|
||||
callback(this.alignLabels(labels, chunk.wordIds, encoded.words))
|
||||
}
|
||||
}
|
||||
}
|
||||
184
sist2-vue/src/ml/BertTokenizer.js
Normal file
184
sist2-vue/src/ml/BertTokenizer.js
Normal file
@@ -0,0 +1,184 @@
|
||||
import {zip, chunk} from "underscore";
|
||||
|
||||
const UNK_INDEX = 100;
|
||||
const CLS_INDEX = 101;
|
||||
const SEP_INDEX = 102;
|
||||
const CONTINUING_SUBWORD_PREFIX = "##";
|
||||
|
||||
function isWhitespace(ch) {
|
||||
return /\s/.test(ch);
|
||||
}
|
||||
|
||||
function isInvalid(ch) {
|
||||
return (ch.charCodeAt(0) === 0 || ch.charCodeAt(0) === 0xfffd);
|
||||
}
|
||||
|
||||
const punctuations = '[~`!@#$%^&*(){}[];:"\'<,.>?/\\|-_+=';
|
||||
|
||||
/** To judge whether it's a punctuation. */
|
||||
function isPunctuation(ch) {
|
||||
return punctuations.indexOf(ch) !== -1;
|
||||
}
|
||||
|
||||
export default class BertTokenizer {
|
||||
vocab;
|
||||
|
||||
constructor(vocab) {
|
||||
this.vocab = vocab;
|
||||
}
|
||||
|
||||
tokenize(text) {
|
||||
const charOriginalIndex = [];
|
||||
const cleanedText = this.cleanText(text, charOriginalIndex);
|
||||
const origTokens = cleanedText.split(' ');
|
||||
|
||||
let charCount = 0;
|
||||
const tokens = origTokens.map((token) => {
|
||||
token = token.toLowerCase();
|
||||
const tokens = this.runSplitOnPunctuation(token, charCount, charOriginalIndex);
|
||||
charCount += token.length + 1;
|
||||
return tokens;
|
||||
});
|
||||
|
||||
let flattenTokens = [];
|
||||
for (let index = 0; index < tokens.length; index++) {
|
||||
flattenTokens = flattenTokens.concat(tokens[index]);
|
||||
}
|
||||
return flattenTokens;
|
||||
}
|
||||
|
||||
/* Performs invalid character removal and whitespace cleanup on text. */
|
||||
cleanText(text, charOriginalIndex) {
|
||||
text = text.replace(/\?/g, "").trim();
|
||||
|
||||
const stringBuilder = [];
|
||||
let originalCharIndex = 0;
|
||||
let newCharIndex = 0;
|
||||
|
||||
for (const ch of text) {
|
||||
// Skip the characters that cannot be used.
|
||||
if (isInvalid(ch)) {
|
||||
originalCharIndex += ch.length;
|
||||
continue;
|
||||
}
|
||||
if (isWhitespace(ch)) {
|
||||
if (stringBuilder.length > 0 && stringBuilder[stringBuilder.length - 1] !== ' ') {
|
||||
stringBuilder.push(' ');
|
||||
charOriginalIndex[newCharIndex] = originalCharIndex;
|
||||
originalCharIndex += ch.length;
|
||||
} else {
|
||||
originalCharIndex += ch.length;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
stringBuilder.push(ch);
|
||||
charOriginalIndex[newCharIndex] = originalCharIndex;
|
||||
originalCharIndex += ch.length;
|
||||
}
|
||||
newCharIndex++;
|
||||
}
|
||||
return stringBuilder.join('');
|
||||
}
|
||||
|
||||
/* Splits punctuation on a piece of text. */
|
||||
runSplitOnPunctuation(text, count, charOriginalIndex) {
|
||||
const tokens = [];
|
||||
let startNewWord = true;
|
||||
for (const ch of text) {
|
||||
if (isPunctuation(ch)) {
|
||||
tokens.push({text: ch, index: charOriginalIndex[count]});
|
||||
count += ch.length;
|
||||
startNewWord = true;
|
||||
} else {
|
||||
if (startNewWord) {
|
||||
tokens.push({text: '', index: charOriginalIndex[count]});
|
||||
startNewWord = false;
|
||||
}
|
||||
tokens[tokens.length - 1].text += ch;
|
||||
count += ch.length;
|
||||
}
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
encode(words) {
|
||||
let outputTokens = [];
|
||||
const wordIds = [];
|
||||
|
||||
for (let i = 0; i < words.length; i++) {
|
||||
let chars = [...words[i].text];
|
||||
|
||||
let isUnknown = false;
|
||||
let start = 0;
|
||||
let subTokens = [];
|
||||
|
||||
while (start < chars.length) {
|
||||
let end = chars.length;
|
||||
let currentSubstring = null;
|
||||
while (start < end) {
|
||||
let substr = chars.slice(start, end).join('');
|
||||
|
||||
if (start > 0) {
|
||||
substr = CONTINUING_SUBWORD_PREFIX + substr;
|
||||
}
|
||||
if (this.vocab.includes(substr)) {
|
||||
currentSubstring = this.vocab.indexOf(substr);
|
||||
break;
|
||||
}
|
||||
|
||||
--end;
|
||||
}
|
||||
if (currentSubstring == null) {
|
||||
isUnknown = true;
|
||||
break;
|
||||
}
|
||||
subTokens.push(currentSubstring);
|
||||
start = end;
|
||||
}
|
||||
|
||||
if (isUnknown) {
|
||||
outputTokens.push(UNK_INDEX);
|
||||
wordIds.push(i);
|
||||
} else {
|
||||
subTokens.forEach(tok => {
|
||||
outputTokens.push(tok);
|
||||
wordIds.push(i)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {tokens: outputTokens, wordIds};
|
||||
}
|
||||
|
||||
encodeText(inputText, inputSize) {
|
||||
|
||||
const tokenized = this.tokenize(inputText);
|
||||
const encoded = this.encode(tokenized);
|
||||
|
||||
const encodedTokenChunks = chunk(encoded.tokens, inputSize - 2);
|
||||
const encodedWordIdChunks = chunk(encoded.wordIds, inputSize - 2);
|
||||
|
||||
const chunks = [];
|
||||
|
||||
zip(encodedTokenChunks, encodedWordIdChunks).forEach(([tokens, wordIds]) => {
|
||||
const inputIds = [CLS_INDEX, ...tokens, SEP_INDEX];
|
||||
const segmentIds = Array(inputIds.length).fill(0);
|
||||
const inputMask = Array(inputIds.length).fill(1);
|
||||
wordIds = [-1, ...wordIds, -1];
|
||||
|
||||
while (inputIds.length < inputSize) {
|
||||
inputIds.push(0);
|
||||
inputMask.push(0);
|
||||
segmentIds.push(0);
|
||||
wordIds.push(-1);
|
||||
}
|
||||
|
||||
chunks.push({inputIds, inputMask, segmentIds, wordIds})
|
||||
});
|
||||
|
||||
return {
|
||||
inputChunks: chunks,
|
||||
words: tokenized
|
||||
};
|
||||
}
|
||||
}
|
||||
43
sist2-vue/src/ml/modelsRepo.js
Normal file
43
sist2-vue/src/ml/modelsRepo.js
Normal file
@@ -0,0 +1,43 @@
|
||||
import axios from "axios";
|
||||
|
||||
class ModelsRepo {
|
||||
_repositories;
|
||||
data = {};
|
||||
|
||||
async init(repositories) {
|
||||
this._repositories = repositories;
|
||||
|
||||
const data = await Promise.all(this._repositories.map(this._loadRepository));
|
||||
|
||||
data.forEach(models => {
|
||||
models.forEach(model => {
|
||||
this.data[model.name] = model;
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
async _loadRepository(repository) {
|
||||
const data = (await axios.get(repository)).data;
|
||||
data.forEach(model => {
|
||||
model["modelUrl"] = new URL(model["modelPath"], repository).href;
|
||||
model["vocabUrl"] = new URL(model["vocabPath"], repository).href;
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
getOptions() {
|
||||
return Object.values(this.data).map(model => ({
|
||||
text: `${model.name} (${Math.round(model.size / (1024*1024))}MB)`,
|
||||
value: model.name
|
||||
}));
|
||||
}
|
||||
|
||||
getDefaultModel() {
|
||||
if (Object.values(this.data).length === 0) {
|
||||
return null;
|
||||
}
|
||||
return Object.values(this.data).find(model => model.default).name;
|
||||
}
|
||||
}
|
||||
|
||||
export default new ModelsRepo();
|
||||
Reference in New Issue
Block a user