提交 f17f924e 编写于 作者: B barrierye

fix bug in dag executor

上级 c045eebb
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import threading import threading
import multiprocessing import multiprocessing
import sys import sys
import copy
if sys.version_info.major == 2: if sys.version_info.major == 2:
import Queue import Queue
elif sys.version_info.major == 3: elif sys.version_info.major == 3:
...@@ -58,8 +59,8 @@ class DAGExecutor(object): ...@@ -58,8 +59,8 @@ class DAGExecutor(object):
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(use_profile) self._profiler.enable(use_profile)
self._dag = DAG(response_op, self._profiler, use_profile, self._dag = DAG(response_op, use_profile, use_multithread, client_type,
use_multithread, client_type, channel_size) channel_size)
(in_channel, out_channel, pack_rpc_func, (in_channel, out_channel, pack_rpc_func,
unpack_rpc_func) = self._dag.build() unpack_rpc_func) = self._dag.build()
self._dag.start() self._dag.start()
...@@ -75,8 +76,9 @@ class DAGExecutor(object): ...@@ -75,8 +76,9 @@ class DAGExecutor(object):
self._id_lock = threading.Lock() self._id_lock = threading.Lock()
self._id_counter = 0 self._id_counter = 0
self._reset_max_id = 1000000000000000000 self._reset_max_id = 1000000000000000000
self._cv = threading.Condition() self._cv_pool = {}
self._fetch_buffer = {} self._cv_for_cv_pool = threading.Condition()
self._fetch_buffer = None
self._is_run = False self._is_run = False
self._recive_func = None self._recive_func = None
...@@ -115,6 +117,7 @@ class DAGExecutor(object): ...@@ -115,6 +117,7 @@ class DAGExecutor(object):
self._out_channel = out_channel self._out_channel = out_channel
def _recive_out_channel_func(self): def _recive_out_channel_func(self):
cv = None
while self._is_run: while self._is_run:
channeldata_dict = self._out_channel.front(self.name) channeldata_dict = self._out_channel.front(self.name)
if len(channeldata_dict) != 1: if len(channeldata_dict) != 1:
...@@ -125,18 +128,25 @@ class DAGExecutor(object): ...@@ -125,18 +128,25 @@ class DAGExecutor(object):
raise TypeError( raise TypeError(
self._log('data must be ChannelData type, but get {}'. self._log('data must be ChannelData type, but get {}'.
format(type(channeldata)))) format(type(channeldata))))
with self._cv:
data_id = channeldata.id data_id = channeldata.id
self._fetch_buffer[data_id] = channeldata _LOGGER.debug("recive thread fetch data: {}".format(data_id))
self._cv.notify_all() with self._cv_for_cv_pool:
cv = self._cv_pool[data_id]
with cv:
self._fetch_buffer = channeldata
cv.notify_all()
def _get_channeldata_from_fetch_buffer(self, data_id): def _get_channeldata_from_fetch_buffer(self, data_id):
resp = None resp = None
with self._cv: cv = threading.Condition()
while data_id not in self._fetch_buffer: with self._cv_for_cv_pool:
self._cv.wait() self._cv_pool[data_id] = cv
resp = self._fetch_buffer.pop(data_id) with cv:
self._cv.notify_all() cv.wait()
_LOGGER.debug("resp func get lock (data_id: {})".format(data_id))
resp = copy.deepcopy(self._fetch_buffer)
# cv.notify_all()
return resp return resp
def _pack_channeldata(self, rpc_request, data_id): def _pack_channeldata(self, rpc_request, data_id):
...@@ -204,8 +214,8 @@ class DAGExecutor(object): ...@@ -204,8 +214,8 @@ class DAGExecutor(object):
class DAG(object): class DAG(object):
def __init__(self, response_op, profiler, use_profile, use_multithread, def __init__(self, response_op, use_profile, use_multithread, client_type,
client_type, channel_size): channel_size):
self._response_op = response_op self._response_op = response_op
self._use_profile = use_profile self._use_profile = use_profile
self._use_multithread = use_multithread self._use_multithread = use_multithread
...@@ -213,7 +223,6 @@ class DAG(object): ...@@ -213,7 +223,6 @@ class DAG(object):
self._client_type = client_type self._client_type = client_type
if not self._use_multithread: if not self._use_multithread:
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()
self._profiler = profiler
def get_use_ops(self, response_op): def get_use_ops(self, response_op):
unique_names = set() unique_names = set()
...@@ -413,7 +422,7 @@ class DAG(object): ...@@ -413,7 +422,7 @@ class DAG(object):
def start(self): def start(self):
self._threads_or_proces = [] self._threads_or_proces = []
for op in self._actual_ops: for op in self._actual_ops:
op.init_profiler(self._profiler, self._use_profile) op.use_profiler(self._use_profile)
if self._use_multithread: if self._use_multithread:
self._threads_or_proces.extend( self._threads_or_proces.extend(
op.start_with_thread(self._client_type)) op.start_with_thread(self._client_type))
......
...@@ -20,6 +20,7 @@ from concurrent import futures ...@@ -20,6 +20,7 @@ from concurrent import futures
import logging import logging
import func_timeout import func_timeout
import os import os
import sys
from numpy import * from numpy import *
from .proto import pipeline_service_pb2 from .proto import pipeline_service_pb2
...@@ -59,14 +60,14 @@ class Op(object): ...@@ -59,14 +60,14 @@ class Op(object):
self._retry = max(1, retry) self._retry = max(1, retry)
self._input = None self._input = None
self._outputs = [] self._outputs = []
self._profiler = None
self._use_profile = False
# only for multithread # only for multithread
self._for_init_op_lock = threading.Lock() self._for_init_op_lock = threading.Lock()
self._succ_init_op = False self._succ_init_op = False
def init_profiler(self, profiler, use_profile): def use_profiler(self, use_profile):
self._profiler = profiler
self._use_profile = use_profile self._use_profile = use_profile
def _profiler_record(self, string): def _profiler_record(self, string):
...@@ -351,7 +352,6 @@ class Op(object): ...@@ -351,7 +352,6 @@ class Op(object):
os._exit(-1) os._exit(-1)
# init profiler # init profiler
if not use_multithread:
self._profiler = TimeProfiler() self._profiler = TimeProfiler()
self._profiler.enable(self._use_profile) self._profiler.enable(self._use_profile)
...@@ -400,11 +400,17 @@ class Op(object): ...@@ -400,11 +400,17 @@ class Op(object):
output_channels) output_channels)
continue continue
if self._use_profile:
profile_str = self._profiler.gen_profile_str()
sys.stderr.write(profile_str)
#TODO
#output_data.add_profile(profile_str)
# push data to channel (if run succ) # push data to channel (if run succ)
#self._profiler_record("push#{}_0".format(op_info_prefix)) #self._profiler_record("push#{}_0".format(op_info_prefix))
self._push_to_output_channels(output_data, output_channels) self._push_to_output_channels(output_data, output_channels)
#self._profiler_record("push#{}_1".format(op_info_prefix)) #self._profiler_record("push#{}_1".format(op_info_prefix))
self._profiler.print_profile() #self._profiler.print_profile()
def _log(self, info): def _log(self, info):
return "{} {}".format(self.name, info) return "{} {}".format(self.name, info)
......
...@@ -50,6 +50,9 @@ class TimeProfiler(object): ...@@ -50,6 +50,9 @@ class TimeProfiler(object):
self._time_record.put((name, tag, timestamp)) self._time_record.put((name, tag, timestamp))
def print_profile(self): def print_profile(self):
sys.stderr.write(self.gen_profile_str())
def gen_profile_str(self):
if self._enable is False: if self._enable is False:
return return
print_str = self._print_head print_str = self._print_head
...@@ -64,7 +67,7 @@ class TimeProfiler(object): ...@@ -64,7 +67,7 @@ class TimeProfiler(object):
else: else:
tmp[name] = (tag, timestamp) tmp[name] = (tag, timestamp)
print_str = "\n{}\n".format(print_str) print_str = "\n{}\n".format(print_str)
sys.stderr.write(print_str)
for name, item in tmp.items(): for name, item in tmp.items():
tag, timestamp = item tag, timestamp = item
self._time_record.put((name, tag, timestamp)) self._time_record.put((name, tag, timestamp))
return print_str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册