From f17f924e9d88998fe10e2587c532c9e9f4ceea98 Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 6 Jul 2020 17:14:50 +0800 Subject: [PATCH] fix bug in dag executor --- python/pipeline/dag.py | 43 ++++++++++++++++++++++--------------- python/pipeline/operator.py | 20 +++++++++++------ python/pipeline/profiler.py | 5 ++++- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/python/pipeline/dag.py b/python/pipeline/dag.py index 15c3317f..b1f853fd 100644 --- a/python/pipeline/dag.py +++ b/python/pipeline/dag.py @@ -15,6 +15,7 @@ import threading import multiprocessing import sys +import copy if sys.version_info.major == 2: import Queue elif sys.version_info.major == 3: @@ -58,8 +59,8 @@ class DAGExecutor(object): self._profiler = TimeProfiler() self._profiler.enable(use_profile) - self._dag = DAG(response_op, self._profiler, use_profile, - use_multithread, client_type, channel_size) + self._dag = DAG(response_op, use_profile, use_multithread, client_type, + channel_size) (in_channel, out_channel, pack_rpc_func, unpack_rpc_func) = self._dag.build() self._dag.start() @@ -75,8 +76,9 @@ class DAGExecutor(object): self._id_lock = threading.Lock() self._id_counter = 0 self._reset_max_id = 1000000000000000000 - self._cv = threading.Condition() - self._fetch_buffer = {} + self._cv_pool = {} + self._cv_for_cv_pool = threading.Condition() + self._fetch_buffer = None self._is_run = False self._recive_func = None @@ -115,6 +117,7 @@ class DAGExecutor(object): self._out_channel = out_channel def _recive_out_channel_func(self): + cv = None while self._is_run: channeldata_dict = self._out_channel.front(self.name) if len(channeldata_dict) != 1: @@ -125,18 +128,25 @@ class DAGExecutor(object): raise TypeError( self._log('data must be ChannelData type, but get {}'. format(type(channeldata)))) - with self._cv: - data_id = channeldata.id - self._fetch_buffer[data_id] = channeldata - self._cv.notify_all() + + data_id = channeldata.id + _LOGGER.debug("recive thread fetch data: {}".format(data_id)) + with self._cv_for_cv_pool: + cv = self._cv_pool[data_id] + with cv: + self._fetch_buffer = channeldata + cv.notify_all() def _get_channeldata_from_fetch_buffer(self, data_id): resp = None - with self._cv: - while data_id not in self._fetch_buffer: - self._cv.wait() - resp = self._fetch_buffer.pop(data_id) - self._cv.notify_all() + cv = threading.Condition() + with self._cv_for_cv_pool: + self._cv_pool[data_id] = cv + with cv: + cv.wait() + _LOGGER.debug("resp func get lock (data_id: {})".format(data_id)) + resp = copy.deepcopy(self._fetch_buffer) + # cv.notify_all() return resp def _pack_channeldata(self, rpc_request, data_id): @@ -204,8 +214,8 @@ class DAGExecutor(object): class DAG(object): - def __init__(self, response_op, profiler, use_profile, use_multithread, - client_type, channel_size): + def __init__(self, response_op, use_profile, use_multithread, client_type, + channel_size): self._response_op = response_op self._use_profile = use_profile self._use_multithread = use_multithread @@ -213,7 +223,6 @@ class DAG(object): self._client_type = client_type if not self._use_multithread: self._manager = multiprocessing.Manager() - self._profiler = profiler def get_use_ops(self, response_op): unique_names = set() @@ -413,7 +422,7 @@ class DAG(object): def start(self): self._threads_or_proces = [] for op in self._actual_ops: - op.init_profiler(self._profiler, self._use_profile) + op.use_profiler(self._use_profile) if self._use_multithread: self._threads_or_proces.extend( op.start_with_thread(self._client_type)) diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 7e515be9..9a569abd 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -20,6 +20,7 @@ from concurrent import futures import logging import func_timeout import os +import sys from numpy import * from .proto import pipeline_service_pb2 @@ -59,14 +60,14 @@ class Op(object): self._retry = max(1, retry) self._input = None self._outputs = [] - self._profiler = None + + self._use_profile = False # only for multithread self._for_init_op_lock = threading.Lock() self._succ_init_op = False - def init_profiler(self, profiler, use_profile): - self._profiler = profiler + def use_profiler(self, use_profile): self._use_profile = use_profile def _profiler_record(self, string): @@ -351,9 +352,8 @@ class Op(object): os._exit(-1) # init profiler - if not use_multithread: - self._profiler = TimeProfiler() - self._profiler.enable(self._use_profile) + self._profiler = TimeProfiler() + self._profiler.enable(self._use_profile) self._is_run = True while self._is_run: @@ -400,11 +400,17 @@ class Op(object): output_channels) continue + if self._use_profile: + profile_str = self._profiler.gen_profile_str() + sys.stderr.write(profile_str) + #TODO + #output_data.add_profile(profile_str) + # push data to channel (if run succ) #self._profiler_record("push#{}_0".format(op_info_prefix)) self._push_to_output_channels(output_data, output_channels) #self._profiler_record("push#{}_1".format(op_info_prefix)) - self._profiler.print_profile() + #self._profiler.print_profile() def _log(self, info): return "{} {}".format(self.name, info) diff --git a/python/pipeline/profiler.py b/python/pipeline/profiler.py index cbcf7c81..df497c48 100644 --- a/python/pipeline/profiler.py +++ b/python/pipeline/profiler.py @@ -50,6 +50,9 @@ class TimeProfiler(object): self._time_record.put((name, tag, timestamp)) def print_profile(self): + sys.stderr.write(self.gen_profile_str()) + + def gen_profile_str(self): if self._enable is False: return print_str = self._print_head @@ -64,7 +67,7 @@ class TimeProfiler(object): else: tmp[name] = (tag, timestamp) print_str = "\n{}\n".format(print_str) - sys.stderr.write(print_str) for name, item in tmp.items(): tag, timestamp = item self._time_record.put((name, tag, timestamp)) + return print_str -- GitLab