diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index 8fb3d2cc843b23d2c856fbeee748a330b7ccc126..81c54cbb03278ca4f717c299a1d3c340713dc245 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -174,23 +174,30 @@ class Op(object): p = multiprocessing.Process( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type)) + self._get_output_channels(), client_type, False)) p.start() proces.append(p) return proces def start_with_thread(self, client_type): + # load user resources + try: + self.init_op() + except Exception as e: + _LOGGER.error(log(e)) + os._exit(-1) + threads = [] for concurrency_idx in range(self.concurrency): t = threading.Thread( target=self._run, args=(concurrency_idx, self._get_input_channel(), - self._get_output_channels(), client_type)) + self._get_output_channels(), client_type, True)) t.start() threads.append(t) return threads - def load_user_resources(self): + def init_op(self): pass def _run_preprocess(self, parsed_data, data_id, log_func): @@ -309,8 +316,8 @@ class Op(object): data_id=data_id) return output_data, error_channeldata - def _run(self, concurrency_idx, input_channel, output_channels, - client_type): + def _run(self, concurrency_idx, input_channel, output_channels, client_type, + use_multithread): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str) @@ -329,13 +336,18 @@ class Op(object): self._server_endpoints, self._fetch_names) if client is not None: client_predict_handler = client.predict - - # load user resources - self.load_user_resources() except Exception as e: _LOGGER.error(log(e)) os._exit(-1) + if not use_multithread: + # load user resources + try: + self.init_op() + except Exception as e: + _LOGGER.error(log(e)) + os._exit(-1) + self._is_run = True while self._is_run: self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid)) @@ -399,7 +411,7 @@ class RequestOp(Op): name="#G", input_ops=[], concurrency=concurrency) # load user resources try: - self.load_user_resources() + self.init_op() except Exception as e: _LOGGER.error(log(e)) os._exit(-1) @@ -424,7 +436,7 @@ class ResponseOp(Op): name="#R", input_ops=input_ops, concurrency=concurrency) # load user resources try: - self.load_user_resources() + self.init_op() except Exception as e: _LOGGER.error(log(e)) os._exit(-1) @@ -490,8 +502,8 @@ class VirtualOp(Op): channel.add_producer(op_name) self._outputs.append(channel) - def _run(self, concurrency_idx, input_channel, output_channels, - client_type): + def _run(self, concurrency_idx, input_channel, output_channels, client_type, + use_multithread): def get_log_func(op_info_prefix): def log_func(info_str): return "{} {}".format(op_info_prefix, info_str)