operator.py 29.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
from numpy import *
26

B
barrierye 已提交
27
from .proto import pipeline_service_pb2
B
barrierye 已提交
28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
29 30
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
31
from .util import NameGenerator
B
barriery 已提交
32
from .profiler import UnsafeTimeProfiler as TimeProfiler
33

W
wangjiawei04 已提交
34
_LOGGER = logging.getLogger()
B
barrierye 已提交
35 36
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
37 38 39

class Op(object):
    def __init__(self,
B
barrierye 已提交
40
                 name=None,
D
dongdaxiang 已提交
41 42
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
43 44
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
45 46
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
47 48
                 retry=1,
                 batch_size=1,
B
bug fix  
barriery 已提交
49
                 auto_batching_timeout=None):
B
barrierye 已提交
50
        if name is None:
B
barrierye 已提交
51
            name = _op_name_gen.next()
52
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
53
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
54
        self.set_input_ops(input_ops)
B
barrierye 已提交
55 56

        self._server_endpoints = server_endpoints
57
        self.with_serving = False
B
barrierye 已提交
58
        if len(self._server_endpoints) != 0:
59
            self.with_serving = True
B
barrierye 已提交
60 61 62
        self._client_config = client_config
        self._fetch_names = fetch_list

63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
67

B
barriery 已提交
68
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
69
        self._auto_batching_timeout = auto_batching_timeout
B
barriery 已提交
70 71 72
        if self._auto_batching_timeout is not None:
            if self._auto_batching_timeout <= 0 or self._batch_size == 1:
                self._auto_batching_timeout = None
73 74
            else:
                self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
B
barriery 已提交
75

B
barrierye 已提交
76
        self._server_use_profile = False
77

B
barrierye 已提交
78 79
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
80
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
81
        self._succ_init_op = False
B
barrierye 已提交
82
        self._succ_close_op = False
B
barrierye 已提交
83

B
barriery 已提交
84
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
85
        if self._batch_size != 1:
B
barriery 已提交
86 87
            _LOGGER.warn("Op({}) reset batch_size=1 (original: {})"
                         .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
88 89
            self._batch_size = 1
        if self._auto_batching_timeout != None:
B
barriery 已提交
90 91 92
            _LOGGER.warn(
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
93
            self._auto_batching_timeout = None
B
barriery 已提交
94

B
barrierye 已提交
95
    def use_profiler(self, use_profile):
B
barrierye 已提交
96
        self._server_use_profile = use_profile
97

B
barrierye 已提交
98 99
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
100
        if self.with_serving == False:
B
barriery 已提交
101
            _LOGGER.info("Op({}) no client".format(self.name))
B
barrierye 已提交
102
            return None
B
barriery 已提交
103 104
        _LOGGER.info("Op({}) service endpoints: {}".format(self.name,
                                                           server_endpoints))
B
barriery 已提交
105
        _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
106
        if client_type == 'brpc':
B
barriery 已提交
107 108
            _LOGGER.debug("Op({}) client_config: {}".format(self.name,
                                                            client_config))
B
barrierye 已提交
109 110
            client = Client()
            client.load_client_config(client_config)
111
        elif client_type == 'grpc':
B
barrierye 已提交
112
            client = MultiLangClient()
113 114
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
115
        client.connect(server_endpoints)
116
        self._fetch_names = fetch_names
B
barrierye 已提交
117
        return client
118 119 120 121 122 123 124 125 126 127 128 129 130 131

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
132

133 134 135 136 137 138 139
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
140

141
    def clean_input_channel(self):
B
barrierye 已提交
142 143 144 145
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
146

147 148 149 150 151 152 153
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
154

155
    def clean_output_channels(self):
B
barrierye 已提交
156 157 158 159 160
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
161
    def preprocess(self, input_dicts):
B
barrierye 已提交
162
        # multiple previous Op
B
barrierye 已提交
163
        if len(input_dicts) != 1:
164
            raise NotImplementedError(
B
barrierye 已提交
165
                'this Op has multiple previous inputs. Please override this func.'
166
            )
D
dongdaxiang 已提交
167

B
barrierye 已提交
168 169
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
170

B
bug fix  
barriery 已提交
171 172
    def process(self, feed_batch):
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
173 174 175
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
176
        call_result = self.client.predict(
B
bug fix  
barriery 已提交
177
            feed=feed_batch, fetch=self._fetch_names)
178 179
        return call_result

W
wangjiawei04 已提交
180
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
181
        return fetch_dict
D
dongdaxiang 已提交
182

B
barrierye 已提交
183
    def _parse_channeldata(self, channeldata_dict):
184
        data_id, error_channeldata = None, None
B
barrierye 已提交
185
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
186 187 188 189
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
190
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
191 192 193 194 195 196

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
197
            if client_need_profile:
B
barrierye 已提交
198
                profile_set |= data.profile_data_set
B
barrierye 已提交
199
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
200
                profile_set)
B
barrierye 已提交
201 202 203 204 205

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
206
                                 profile_str=None,
B
barrierye 已提交
207
                                 client_need_profile=False,
B
barrierye 已提交
208
                                 profile_set=None):
209 210
        if name is None:
            name = self.name
B
barrierye 已提交
211

B
barriery 已提交
212
        # add profile into channeldata
B
barrierye 已提交
213
        if client_need_profile and profile_set is not None:
B
barriery 已提交
214 215
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
216
            data.add_profile(profile_set)
B
barrierye 已提交
217

B
barriery 已提交
218 219 220
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
221
    def start_with_process(self, client_type):
222
        proces = []
B
barrierye 已提交
223
        for concurrency_idx in range(self.concurrency):
224 225
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
226
                args=(concurrency_idx, self._get_input_channel(),
227
                      self._get_output_channels(), client_type, False))
228 229 230 231
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
232
    def start_with_thread(self, client_type):
233
        threads = []
B
barrierye 已提交
234
        for concurrency_idx in range(self.concurrency):
235 236
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
237
                args=(concurrency_idx, self._get_input_channel(),
238
                      self._get_output_channels(), client_type, True))
B
barriery 已提交
239 240 241
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
242 243 244 245
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
246
    def init_op(self):
B
barrierye 已提交
247 248
        pass

249
    def _run_preprocess(self, parsed_data_dict, log_func):
B
barriery 已提交
250
        _LOGGER.debug(log_func("try to run preprocess"))
251 252 253 254 255 256 257 258
        preped_data_dict = {}
        err_channeldata_dict = {}
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except NotImplementedError as e:
                # preprocess function not implemented
B
barriery 已提交
259 260
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
261 262 263 264 265 266
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                    error_info=error_info,
                    data_id=data_id)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
267 268
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
269 270 271 272 273 274
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
275 276
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
277 278 279 280 281 282 283 284 285
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
286
        _LOGGER.debug(log_func("succ run preprocess"))
287 288 289
        return preped_data_dict, err_channeldata_dict

    def _run_process(self, preped_data_dict, log_func):
B
barriery 已提交
290
        _LOGGER.debug(log_func("try to run process"))
291 292
        midped_data_dict = {}
        err_channeldata_dict = {}
293
        if self.with_serving:
294
            data_ids = preped_data_dict.keys()
B
bug fix  
barriery 已提交
295 296
            feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
            midped_batch = None
297 298 299
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
bug fix  
barriery 已提交
300
                    midped_batch = self.process(feed_batch)
301 302
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
303
                    error_info = log_func("process batch failed: {}".format(e))
304 305 306 307
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
308
                        midped_batch = func_timeout.func_timeout(
B
bug fix  
barriery 已提交
309
                            self._timeout, self.process, args=(feed_batch, ))
310 311 312 313 314 315 316
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
B
barriery 已提交
317
                                log_func("PaddleService timeout, retry({}/{})"
B
barriery 已提交
318
                                         .format(i + 1, self._retry)))
319 320
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
321 322
                        error_info = log_func("process batch failed: {}".format(
                            e))
323 324 325 326 327
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
328 329
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
330
                        ecode=ecode, error_info=error_info, data_id=data_id)
331
            elif midped_batch is None:
332
                # op client return None
B
barriery 已提交
333 334
                error_info = log_func(
                    "predict failed. pls check the server side.")
B
barriery 已提交
335
                _LOGGER.error(error_info)
336 337
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
338 339 340
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
341 342 343 344
            else:
                # transform np format to dict format
                for idx, data_id in enumerate(data_ids):
                    midped_data_dict[data_id] = {
B
barriery 已提交
345 346
                        k: v[idx]
                        for k, v in midped_batch.items()
347
                    }
348
        else:
349
            midped_data_dict = preped_data_dict
B
barriery 已提交
350
        _LOGGER.debug(log_func("succ run process"))
351 352 353
        return midped_data_dict, err_channeldata_dict

    def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
B
barriery 已提交
354
        _LOGGER.debug(log_func("try to run postprocess"))
355 356
        postped_data_dict = {}
        err_channeldata_dict = {}
B
bug fix  
barriery 已提交
357
        for data_id, midped_data in midped_data_dict.items():
358 359
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
360 361
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
362
            except Exception as e:
B
barriery 已提交
363
                error_info = log_func("postprocess data[{}] failed: {}"
B
barriery 已提交
364
                                      .format(data_id, e))
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
                _LOGGER.error(error_info)
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
                    error_info = log_func("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
398
        _LOGGER.debug(log_func("succ run postprocess"))
399
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
400 401 402

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
                                 timeout, log_func):
B
barriery 已提交
403 404
        while True:
            batch = []
B
barriery 已提交
405
            _LOGGER.debug(
B
barriery 已提交
406
                log_func("Auto-batching expect size: {}; timeout(s): {}".format(
B
barriery 已提交
407
                    batch_size, timeout)))
B
barriery 已提交
408 409 410 411 412 413 414 415 416 417
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
418
                                _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
419
                                break
B
barriery 已提交
420 421
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
422 423 424 425
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
426
                        _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
427
                        break
B
barriery 已提交
428 429
            _LOGGER.debug(
                log_func("Auto-batching actual size: {}".format(len(batch))))
B
barriery 已提交
430
            yield batch
431

432 433 434 435
    def _parse_channeldata_batch(self, batch, output_channels):
        parsed_data_dict = {}
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
436
        for channeldata_dict in batch:
437 438 439 440 441 442 443 444 445 446
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
447 448
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
449 450

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
451 452 453

    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
             is_thread_op):
B
barrierye 已提交
454 455 456
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)
B
barriery 已提交
457

B
barrierye 已提交
458 459
            return log_func

460
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
461
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
462
        tid = threading.current_thread().ident
B
barrierye 已提交
463

B
barrierye 已提交
464
        # init op
B
barriery 已提交
465
        profiler = None
B
barrierye 已提交
466
        try:
B
barriery 已提交
467 468
            profiler = self._initialize(is_thread_op, client_type,
                                        concurrency_idx)
B
barrierye 已提交
469
        except Exception as e:
B
barriery 已提交
470
            _LOGGER.error(log("init op failed: {}".format(e)))
B
barrierye 已提交
471
            os._exit(-1)
B
barriery 已提交
472
        _LOGGER.info(log("succ init"))
473

B
barriery 已提交
474
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
475 476 477 478 479 480
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
            log_func=log)

B
barrierye 已提交
481 482
        while True:
            try:
B
barriery 已提交
483
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
484
            except ChannelStopError:
B
barriery 已提交
485
                _LOGGER.debug(log("channel stop."))
B
barriery 已提交
486
                self._finalize(is_thread_op)
B
barrierye 已提交
487
                break
488

B
barriery 已提交
489 490
            # parse channeldata batch
            try:
491 492 493
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
494
            except ChannelStopError:
B
barriery 已提交
495
                _LOGGER.debug(log("channel stop."))
496
                self._finalize(is_thread_op)
B
barriery 已提交
497
                break
498 499 500
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
501 502

            # preprecess
B
barriery 已提交
503
            profiler.record("prep#{}_0".format(op_info_prefix))
504
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
505
                    = self._run_preprocess(parsed_data_dict, log)
B
barriery 已提交
506
            profiler.record("prep#{}_1".format(op_info_prefix))
507 508
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
509
                    self._push_to_output_channels(
B
barriery 已提交
510 511
                        data=err_channeldata,
                        channels=output_channels,
512 513 514
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
515
                _LOGGER.debug(log("channel stop."))
516 517 518
                self._finalize(is_thread_op)
                break
            if len(parsed_data_dict) == 0:
519 520
                continue

B
barrierye 已提交
521
            # process
B
barriery 已提交
522
            profiler.record("midp#{}_0".format(op_info_prefix))
523
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
524
                    = self._run_process(preped_data_dict, log)
B
barriery 已提交
525
            profiler.record("midp#{}_1".format(op_info_prefix))
526 527
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
528
                    self._push_to_output_channels(
B
barriery 已提交
529 530
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
531 532
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
533
            except ChannelStopError:
B
barriery 已提交
534
                _LOGGER.debug(log("channel stop."))
535 536 537
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
538
                continue
539 540

            # postprocess
B
barriery 已提交
541
            profiler.record("postp#{}_0".format(op_info_prefix))
542 543
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
544
                            parsed_data_dict, midped_data_dict, log)
B
barriery 已提交
545
            profiler.record("postp#{}_1".format(op_info_prefix))
546 547
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
548
                    self._push_to_output_channels(
B
barriery 已提交
549 550
                        data=error_channeldata,
                        channels=output_channels,
B
barriery 已提交
551 552
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
553
            except ChannelStopError:
B
barriery 已提交
554
                _LOGGER.debug(log("channel stop."))
555 556 557
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
558
                continue
559 560

            # push data to channel (if run succ)
B
barrierye 已提交
561
            try:
B
barriery 已提交
562
                profile_str = profiler.gen_profile_str()
563
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
564 565
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
566
                    self._push_to_output_channels(
B
barriery 已提交
567 568 569
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
570 571
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
572
            except ChannelStopError:
B
barriery 已提交
573
                _LOGGER.debug(log("channel stop."))
574
                self._finalize(is_thread_op)
B
barrierye 已提交
575
                break
B
barriery 已提交
576

B
bug fix  
barriery 已提交
577
    def _initialize(self, is_thread_op, client_type, concurrency_idx):
B
barriery 已提交
578 579 580 581 582 583 584
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
                    self.client = self.init_client(
B
barriery 已提交
585 586
                        client_type, self._client_config,
                        self._server_endpoints, self._fetch_names)
B
barriery 已提交
587 588 589 590
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
591 592 593
        else:
            self.concurrency_idx = concurrency_idx
            # init client
B
barriery 已提交
594 595 596
            self.client = self.init_client(client_type, self._client_config,
                                           self._server_endpoints,
                                           self._fetch_names)
B
bug fix  
barriery 已提交
597 598
            # user defined
            self.init_op()
B
barriery 已提交
599

B
barriery 已提交
600 601 602 603 604
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
605 606 607 608 609 610 611 612
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
613 614 615 616 617

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
618 619 620
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
621
    def __init__(self):
B
barrierye 已提交
622
        # PipelineService.name = "@G"
B
barrierye 已提交
623
        super(RequestOp, self).__init__(name="@G", input_ops=[])
B
barrierye 已提交
624
        # init op
625
        try:
626
            self.init_op()
627
        except Exception as e:
B
barriery 已提交
628
            _LOGGER.error("Op(Request) init op failed: {}".format(e))
629
            os._exit(-1)
B
barrierye 已提交
630 631 632 633

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
634 635 636 637 638 639
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
640 641 642 643 644 645
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
646 647
    def __init__(self, input_ops):
        super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
B
barrierye 已提交
648
        # init op
649
        try:
650
            self.init_op()
651
        except Exception as e:
B
barriery 已提交
652
            _LOGGER.error("Op(ResponseOp) init op failed: {}".format(e))
653
            os._exit(-1)
B
barrierye 已提交
654 655 656 657 658 659 660 661 662

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
663
                np.set_printoptions(threshold=np.nan)
B
barrierye 已提交
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
686 687 688 689 690 691 692


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
693
            name=name, input_ops=None, concurrency=concurrency)
694 695 696 697 698
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
699 700 701 702 703 704 705 706
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

707 708 709 710 711 712
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
713 714
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
715
        self._outputs.append(channel)
D
dongdaxiang 已提交
716

717
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
718
             is_thread_op):
B
barrierye 已提交
719 720 721 722 723 724
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

725
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
726 727 728
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
729 730 731 732
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
733
                _LOGGER.debug(log("Channel stop."))
B
barrierye 已提交
734
                break
D
dongdaxiang 已提交
735

B
barrierye 已提交
736 737 738 739 740
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
B
barriery 已提交
741
                _LOGGER.debug(log("Channel stop."))
B
barrierye 已提交
742
                break