unbreak statefulstreamworker

This commit is contained in:
simon987 2021-09-23 19:14:11 -04:00
parent c560cc2010
commit 9bd1f4b799
2 changed files with 12 additions and 9 deletions

View File

@ -1,6 +1,6 @@
from queue import Queue, Empty
from multiprocessing import Process from multiprocessing import Process
from multiprocessing import Queue as MPQueue from multiprocessing import Queue as MPQueue
from queue import Queue, Empty
from threading import Thread from threading import Thread
from hexlib.misc import ichunks from hexlib.misc import ichunks
@ -11,10 +11,12 @@ class StatefulStreamWorker:
def __init__(self): def __init__(self):
pass pass
def run(self, q: Queue): def run(self, q: Queue, q_out: Queue):
for chunk in queue_iter(q, joinable=False, timeout=3): for chunk in queue_iter(q, joinable=False, timeout=3):
self._process_chunk(chunk) self._process_chunk(chunk)
q_out.put(self.results())
def _process_chunk(self, chunk): def _process_chunk(self, chunk):
for item in chunk: for item in chunk:
self.process(item) self.process(item)
@ -30,6 +32,7 @@ class StatefulStreamProcessor:
def __init__(self, worker_factory, chunk_size=128, processes=1): def __init__(self, worker_factory, chunk_size=128, processes=1):
self._chunk_size = 128 self._chunk_size = 128
self._queue = MPQueue(maxsize=chunk_size) self._queue = MPQueue(maxsize=chunk_size)
self._queue_out = MPQueue()
self._process_count = processes self._process_count = processes
self._processes = [] self._processes = []
self._factory = worker_factory self._factory = worker_factory
@ -38,8 +41,9 @@ class StatefulStreamProcessor:
if processes > 1: if processes > 1:
for _ in range(processes): for _ in range(processes):
worker = self._factory() worker = self._factory()
p = Process(target=worker.run, args=(self._queue,)) p = Process(target=worker.run, args=(self._queue, self._queue_out))
p.start() p.start()
self._processes.append(p) self._processes.append(p)
self._workers.append(worker) self._workers.append(worker)
else: else:
@ -50,16 +54,15 @@ class StatefulStreamProcessor:
if self._process_count > 1: if self._process_count > 1:
for chunk in ichunks(iterable, self._chunk_size): for chunk in ichunks(iterable, self._chunk_size):
self._queue.put(chunk) self._queue.put(chunk)
for p in self._processes:
p.join()
else: else:
for item in iterable: for item in iterable:
self._workers[0].process(item) self._workers[0].process(item)
def get_results(self): def get_results(self):
for worker in self._workers: for _ in range(self._process_count):
yield worker.results() yield self._queue_out.get()
for p in self._processes:
p.join()
def queue_iter(q: Queue, joinable=True, **get_args): def queue_iter(q: Queue, joinable=True, **get_args):

View File

@ -2,7 +2,7 @@ from setuptools import setup
setup( setup(
name="hexlib", name="hexlib",
version="1.55", version="1.56",
description="Misc utility methods", description="Misc utility methods",
author="simon987", author="simon987",
author_email="me@simon987.net", author_email="me@simon987.net",