diff --git a/hexlib/text.py b/hexlib/text.py index af6d3b9..77688f9 100644 --- a/hexlib/text.py +++ b/hexlib/text.py @@ -43,9 +43,20 @@ def _transform_bigram(ngram_seq, ngrams): yield ngram[0] +def _transform_trigram(ngram_seq, ngrams): + for ngram in ngram_seq: + if ngram in ngrams: + yield "_".join(ngram) + + ngram_seq.__next__() + ngram_seq.__next__() + else: + yield ngram[0] + + def preprocess(text, lowercase=False, clean_html=False, strip=False, remove_punctuation=False, remove_stopwords_en=False, lemmatize=False, fix_single_quotes=False, strip_quotes=False, - remove_urls=False, bigrams: set = None, remove_numbers=False): + remove_urls=False, bigrams: set = None, trigrams: set = None, remove_numbers=False): if lowercase: text = text.lower() @@ -81,6 +92,12 @@ def preprocess(text, lowercase=False, clean_html=False, strip=False, remove_punc words.append("*") text = " ".join(_transform_bigram(nltk.bigrams(words), bigrams)) + if trigrams: + words = text.split(" ") + words.append("*") + words.append("*") + text = " ".join(_transform_trigram(nltk.trigrams(words), trigrams)) + if remove_stopwords_en or lemmatize or remove_numbers: words = text.split(" ") diff --git a/setup.py b/setup.py index 9e62f82..7f1cf17 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( name="hexlib", - version="1.49", + version="1.50", description="Misc utility methods", author="simon987", author_email="me@simon987.net", diff --git a/test/test_text.py b/test/test_text.py index 503f628..54546bd 100644 --- a/test/test_text.py +++ b/test/test_text.py @@ -245,6 +245,20 @@ class TestText(TestCase): self.assertEqual(cleaned, expected) + def test_trigrams(self): + text = "x A b c d e f g h" + cleaned = preprocess( + text, + lowercase=True, + trigrams={ + ("a", "b", "c"), + ("e", "f", "g"), + } + ) + expected = "x a_b_c d e_f_g h" + + self.assertEqual(cleaned, expected) + def test_remove_numbers(self): text = "Hello1 test1124test 12 1 1111111 world" cleaned = preprocess(