mirror of
https://github.com/simon987/hexlib.git
synced 2025-04-04 02:12:59 +00:00
Add StatfulStreamProcessor
This commit is contained in:
parent
7349c9a5f1
commit
71cd00c063
@ -1,9 +1,65 @@
|
||||
from queue import Queue, Empty
|
||||
from multiprocessing import Process
|
||||
from multiprocessing import Queue as MPQueue
|
||||
from threading import Thread
|
||||
|
||||
from hexlib.misc import ichunks
|
||||
|
||||
def queue_iter(q: Queue, **get_args):
|
||||
|
||||
class StatefulStreamWorker:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def run(self, q: Queue):
|
||||
for chunk in queue_iter(q, timeout=3):
|
||||
self.process_chunk(chunk)
|
||||
|
||||
def process_chunk(self, chunk):
|
||||
for item in chunk:
|
||||
self.process(item)
|
||||
|
||||
def process(self, item) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StatefulStreamProcessor:
|
||||
def __init__(self, worker_factory, chunk_size=128, processes=1):
|
||||
self._chunk_size = 128
|
||||
self._queue = MPQueue(maxsize=chunk_size)
|
||||
self._process_count = processes
|
||||
self._processes = []
|
||||
self._factory = worker_factory
|
||||
self._workers = []
|
||||
|
||||
if processes > 1:
|
||||
for _ in range(processes):
|
||||
worker = self._factory()
|
||||
p = Process(target=worker.run, args=(self._queue,))
|
||||
p.start()
|
||||
self._processes.append(p)
|
||||
self._workers.append(worker)
|
||||
else:
|
||||
self._workers.append(self._factory())
|
||||
|
||||
def injest(self, iterable):
|
||||
|
||||
if self._process_count > 1:
|
||||
for chunk in ichunks(iterable, self._chunk_size):
|
||||
self._queue.put(chunk)
|
||||
|
||||
for p in self._processes:
|
||||
p.join()
|
||||
else:
|
||||
for item in iterable:
|
||||
self._workers[0].process(item)
|
||||
|
||||
def get_results(self):
|
||||
for worker in self._workers:
|
||||
yield worker.results()
|
||||
|
||||
|
||||
def queue_iter(q: Queue, joinable=True, **get_args):
|
||||
while True:
|
||||
try:
|
||||
task = q.get(**get_args)
|
||||
@ -12,7 +68,8 @@ def queue_iter(q: Queue, **get_args):
|
||||
break
|
||||
|
||||
yield task
|
||||
q.task_done()
|
||||
if joinable:
|
||||
q.task_done()
|
||||
except Empty:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
|
@ -1,4 +1,5 @@
|
||||
import atexit
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@ -33,6 +34,15 @@ def chunks(lst: list, chunk_len: int):
|
||||
yield lst[i:i + chunk_len]
|
||||
|
||||
|
||||
def ichunks(iterable, chunk_len: int):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, chunk_len))
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
def rate_limit(per_second):
|
||||
min_interval = 1.0 / float(per_second)
|
||||
|
||||
|
@ -60,7 +60,7 @@ SINGLE_QUOTES = ("’", "`")
|
||||
SINGLE_QUOTE_TRANS = str.maketrans("".join(SINGLE_QUOTES), "".join(repeat("'", len(SINGLE_QUOTES))))
|
||||
|
||||
PUNCTUATION = ".,;:\"!?/()|*=>"
|
||||
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, len(PUNCTUATION))
|
||||
PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION))
|
||||
|
||||
|
||||
def preprocess(text, lowercase=False, clean_html=False, remove_punctuation=False, remove_stopwords_en=False,
|
||||
|
Loading…
x
Reference in New Issue
Block a user