sist2/sist2-vue/src/ml/BertTokenizer.js
2023-04-23 12:53:27 -04:00

184 lines
5.5 KiB
JavaScript

import {zip, chunk} from "underscore";
const UNK_INDEX = 100;
const CLS_INDEX = 101;
const SEP_INDEX = 102;
const CONTINUING_SUBWORD_PREFIX = "##";
function isWhitespace(ch) {
return /\s/.test(ch);
}
function isInvalid(ch) {
return (ch.charCodeAt(0) === 0 || ch.charCodeAt(0) === 0xfffd);
}
const punctuations = '[~`!@#$%^&*(){}[];:"\'<,.>?/\\|-_+=';
/** To judge whether it's a punctuation. */
function isPunctuation(ch) {
return punctuations.indexOf(ch) !== -1;
}
export default class BertTokenizer {
vocab;
constructor(vocab) {
this.vocab = vocab;
}
tokenize(text) {
const charOriginalIndex = [];
const cleanedText = this.cleanText(text, charOriginalIndex);
const origTokens = cleanedText.split(' ');
let charCount = 0;
const tokens = origTokens.map((token) => {
token = token.toLowerCase();
const tokens = this.runSplitOnPunctuation(token, charCount, charOriginalIndex);
charCount += token.length + 1;
return tokens;
});
let flattenTokens = [];
for (let index = 0; index < tokens.length; index++) {
flattenTokens = flattenTokens.concat(tokens[index]);
}
return flattenTokens;
}
/* Performs invalid character removal and whitespace cleanup on text. */
cleanText(text, charOriginalIndex) {
text = text.replace(/\?/g, "").trim();
const stringBuilder = [];
let originalCharIndex = 0;
let newCharIndex = 0;
for (const ch of text) {
// Skip the characters that cannot be used.
if (isInvalid(ch)) {
originalCharIndex += ch.length;
continue;
}
if (isWhitespace(ch)) {
if (stringBuilder.length > 0 && stringBuilder[stringBuilder.length - 1] !== ' ') {
stringBuilder.push(' ');
charOriginalIndex[newCharIndex] = originalCharIndex;
originalCharIndex += ch.length;
} else {
originalCharIndex += ch.length;
continue;
}
} else {
stringBuilder.push(ch);
charOriginalIndex[newCharIndex] = originalCharIndex;
originalCharIndex += ch.length;
}
newCharIndex++;
}
return stringBuilder.join('');
}
/* Splits punctuation on a piece of text. */
runSplitOnPunctuation(text, count, charOriginalIndex) {
const tokens = [];
let startNewWord = true;
for (const ch of text) {
if (isPunctuation(ch)) {
tokens.push({text: ch, index: charOriginalIndex[count]});
count += ch.length;
startNewWord = true;
} else {
if (startNewWord) {
tokens.push({text: '', index: charOriginalIndex[count]});
startNewWord = false;
}
tokens[tokens.length - 1].text += ch;
count += ch.length;
}
}
return tokens;
}
encode(words) {
let outputTokens = [];
const wordIds = [];
for (let i = 0; i < words.length; i++) {
let chars = [...words[i].text];
let isUnknown = false;
let start = 0;
let subTokens = [];
while (start < chars.length) {
let end = chars.length;
let currentSubstring = null;
while (start < end) {
let substr = chars.slice(start, end).join('');
if (start > 0) {
substr = CONTINUING_SUBWORD_PREFIX + substr;
}
if (this.vocab.includes(substr)) {
currentSubstring = this.vocab.indexOf(substr);
break;
}
--end;
}
if (currentSubstring == null) {
isUnknown = true;
break;
}
subTokens.push(currentSubstring);
start = end;
}
if (isUnknown) {
outputTokens.push(UNK_INDEX);
wordIds.push(i);
} else {
subTokens.forEach(tok => {
outputTokens.push(tok);
wordIds.push(i)
});
}
}
return {tokens: outputTokens, wordIds};
}
encodeText(inputText, inputSize) {
const tokenized = this.tokenize(inputText);
const encoded = this.encode(tokenized);
const encodedTokenChunks = chunk(encoded.tokens, inputSize - 2);
const encodedWordIdChunks = chunk(encoded.wordIds, inputSize - 2);
const chunks = [];
zip(encodedTokenChunks, encodedWordIdChunks).forEach(([tokens, wordIds]) => {
const inputIds = [CLS_INDEX, ...tokens, SEP_INDEX];
const segmentIds = Array(inputIds.length).fill(0);
const inputMask = Array(inputIds.length).fill(1);
wordIds = [-1, ...wordIds, -1];
while (inputIds.length < inputSize) {
inputIds.push(0);
inputMask.push(0);
segmentIds.push(0);
wordIds.push(-1);
}
chunks.push({inputIds, inputMask, segmentIds, wordIds})
});
return {
inputChunks: chunks,
words: tokenized
};
}
}