diff --git a/hexlib/concurrency.py b/hexlib/concurrency.py index d4d851e..34d3219 100644 --- a/hexlib/concurrency.py +++ b/hexlib/concurrency.py @@ -1,9 +1,65 @@ from queue import Queue, Empty from multiprocessing import Process +from multiprocessing import Queue as MPQueue from threading import Thread +from hexlib.misc import ichunks -def queue_iter(q: Queue, **get_args): + +class StatefulStreamWorker: + + def __init__(self): + pass + + def run(self, q: Queue): + for chunk in queue_iter(q, timeout=3): + self.process_chunk(chunk) + + def process_chunk(self, chunk): + for item in chunk: + self.process(item) + + def process(self, item) -> None: + raise NotImplementedError + + +class StatefulStreamProcessor: + def __init__(self, worker_factory, chunk_size=128, processes=1): + self._chunk_size = 128 + self._queue = MPQueue(maxsize=chunk_size) + self._process_count = processes + self._processes = [] + self._factory = worker_factory + self._workers = [] + + if processes > 1: + for _ in range(processes): + worker = self._factory() + p = Process(target=worker.run, args=(self._queue,)) + p.start() + self._processes.append(p) + self._workers.append(worker) + else: + self._workers.append(self._factory()) + + def injest(self, iterable): + + if self._process_count > 1: + for chunk in ichunks(iterable, self._chunk_size): + self._queue.put(chunk) + + for p in self._processes: + p.join() + else: + for item in iterable: + self._workers[0].process(item) + + def get_results(self): + for worker in self._workers: + yield worker.results() + + +def queue_iter(q: Queue, joinable=True, **get_args): while True: try: task = q.get(**get_args) @@ -12,7 +68,8 @@ def queue_iter(q: Queue, **get_args): break yield task - q.task_done() + if joinable: + q.task_done() except Empty: break except KeyboardInterrupt: diff --git a/hexlib/misc.py b/hexlib/misc.py index 91cbc04..0b341cd 100644 --- a/hexlib/misc.py +++ b/hexlib/misc.py @@ -1,4 +1,5 @@ import atexit +import itertools import os import sys import time @@ -33,6 +34,15 @@ def chunks(lst: list, chunk_len: int): yield lst[i:i + chunk_len] +def ichunks(iterable, chunk_len: int): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, chunk_len)) + if not chunk: + break + yield chunk + + def rate_limit(per_second): min_interval = 1.0 / float(per_second) diff --git a/hexlib/text.py b/hexlib/text.py index 16d5b5f..86009ba 100644 --- a/hexlib/text.py +++ b/hexlib/text.py @@ -60,7 +60,7 @@ SINGLE_QUOTES = ("’", "`") SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", len(SINGLE_QUOTES)))) PUNCTUATION = ".,;:\"!?/()|*=>" -PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, len(PUNCTUATION)) +PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION)) def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_stopwords_en=False, diff --git a/setup.py b/setup.py index 6465cb2..56d3d27 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup( name="hexlib", - version="1.52", + version="1.53", description="Misc utility methods", author="simon987", author_email="me@simon987.net",