operator.py 28.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
from numpy import *
26

B
barrierye 已提交
27
from .proto import pipeline_service_pb2
B
barrierye 已提交
28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
29 30
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
31
from .util import NameGenerator
B
barrierye 已提交
32
from .profiler import TimeProfiler
33

W
wangjiawei04 已提交
34
_LOGGER = logging.getLogger()
B
barrierye 已提交
35 36
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
37 38 39

class Op(object):
    def __init__(self,
B
barrierye 已提交
40
                 name=None,
D
dongdaxiang 已提交
41 42
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
43 44
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
45 46
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
47 48
                 retry=1,
                 batch_size=1,
B
bug fix  
barriery 已提交
49
                 auto_batching_timeout=None):
B
barrierye 已提交
50
        if name is None:
B
barrierye 已提交
51
            name = _op_name_gen.next()
52
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
53
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
54
        self.set_input_ops(input_ops)
B
barrierye 已提交
55 56

        self._server_endpoints = server_endpoints
57
        self.with_serving = False
B
barrierye 已提交
58
        if len(self._server_endpoints) != 0:
59
            self.with_serving = True
B
barrierye 已提交
60 61 62
        self._client_config = client_config
        self._fetch_names = fetch_list

63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
67

B
barriery 已提交
68
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
69
        self._auto_batching_timeout = auto_batching_timeout
B
barriery 已提交
70 71 72
        if self._auto_batching_timeout is not None:
            if self._auto_batching_timeout <= 0 or self._batch_size == 1:
                self._auto_batching_timeout = None
B
barriery 已提交
73

B
barrierye 已提交
74
        self._server_use_profile = False
75

B
barrierye 已提交
76 77
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
78
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
79
        self._succ_init_op = False
B
barrierye 已提交
80
        self._succ_close_op = False
B
barrierye 已提交
81

B
barriery 已提交
82 83 84 85
    def use_default_auto_batching_config(self):
        self._batch_size = 1
        self._auto_batching_timeout = None

B
barrierye 已提交
86
    def use_profiler(self, use_profile):
B
barrierye 已提交
87
        self._server_use_profile = use_profile
88 89 90 91 92 93

    def _profiler_record(self, string):
        if self._profiler is None:
            return
        self._profiler.record(string)

B
barrierye 已提交
94 95
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
96
        if self.with_serving == False:
B
barriery 已提交
97
            _LOGGER.info("Op({}) no client".format(self.name))
B
barrierye 已提交
98
            return None
B
barriery 已提交
99 100
        _LOGGER.info("Op({}) service endpoints: {}".format(self.name, server_endpoints))
        _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
101
        if client_type == 'brpc':
B
barriery 已提交
102
            _LOGGER.debug("Op({}) client_config: {}".format(self.name, client_config))
B
barrierye 已提交
103 104
            client = Client()
            client.load_client_config(client_config)
105
        elif client_type == 'grpc':
B
barrierye 已提交
106
            client = MultiLangClient()
107 108
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
109
        client.connect(server_endpoints)
110
        self._fetch_names = fetch_names
B
barrierye 已提交
111
        return client
112 113 114 115 116 117 118 119 120 121 122 123 124 125

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
126

127 128 129 130 131 132 133
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
134

135
    def clean_input_channel(self):
B
barrierye 已提交
136 137 138 139
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
140

141 142 143 144 145 146 147
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
148

149
    def clean_output_channels(self):
B
barrierye 已提交
150 151 152 153 154
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
155
    def preprocess(self, input_dicts):
B
barrierye 已提交
156
        # multiple previous Op
B
barrierye 已提交
157
        if len(input_dicts) != 1:
158
            raise NotImplementedError(
B
barrierye 已提交
159
                'this Op has multiple previous inputs. Please override this func.'
160
            )
D
dongdaxiang 已提交
161

B
barrierye 已提交
162 163
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
164

B
bug fix  
barriery 已提交
165 166
    def process(self, feed_batch):
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
167 168 169
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
170
        call_result = self.client.predict(
B
bug fix  
barriery 已提交
171
            feed=feed_batch, fetch=self._fetch_names)
172 173
        return call_result

W
wangjiawei04 已提交
174
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
175
        return fetch_dict
D
dongdaxiang 已提交
176

B
barrierye 已提交
177
    def _parse_channeldata(self, channeldata_dict):
178
        data_id, error_channeldata = None, None
B
barrierye 已提交
179
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
180 181 182 183
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
184
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
185 186 187 188 189 190

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
191
            if client_need_profile:
B
barrierye 已提交
192
                profile_set |= data.profile_data_set
B
barrierye 已提交
193
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
194
                profile_set)
B
barrierye 已提交
195 196 197 198 199 200

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
                                 client_need_profile=False,
B
barrierye 已提交
201
                                 profile_set=None):
202 203
        if name is None:
            name = self.name
B
barrierye 已提交
204
        self._add_profile_into_channeldata(data, client_need_profile,
B
barrierye 已提交
205
                                           profile_set)
206 207 208
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
209
    def _add_profile_into_channeldata(self, data, client_need_profile,
B
barrierye 已提交
210
                                      profile_set):
B
barrierye 已提交
211 212 213 214
        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

B
barrierye 已提交
215 216 217
        if client_need_profile and profile_set is not None:
            profile_set.add(profile_str)
            data.add_profile(profile_set)
B
barrierye 已提交
218

B
barrierye 已提交
219
    def start_with_process(self, client_type):
220
        proces = []
B
barrierye 已提交
221
        for concurrency_idx in range(self.concurrency):
222 223
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
224
                args=(concurrency_idx, self._get_input_channel(),
225
                      self._get_output_channels(), client_type, False))
226 227 228 229
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
230
    def start_with_thread(self, client_type):
231
        threads = []
B
barrierye 已提交
232
        for concurrency_idx in range(self.concurrency):
233 234
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
235
                args=(concurrency_idx, self._get_input_channel(),
236
                      self._get_output_channels(), client_type, True))
237 238 239 240
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
241
    def init_op(self):
B
barrierye 已提交
242 243
        pass

244
    def _run_preprocess(self, parsed_data_dict, log_func):
B
barriery 已提交
245
        _LOGGER.debug(log_func("try to run preprocess"))
246 247 248 249 250 251 252 253
        preped_data_dict = {}
        err_channeldata_dict = {}
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except NotImplementedError as e:
                # preprocess function not implemented
B
barriery 已提交
254 255 256
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
257 258 259 260 261 262
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                    error_info=error_info,
                    data_id=data_id)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
263 264 265
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
266 267 268 269 270 271
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
272 273 274
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
275 276 277 278 279 280 281 282 283
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
284
        _LOGGER.debug(log_func("succ run preprocess"))
285 286 287
        return preped_data_dict, err_channeldata_dict

    def _run_process(self, preped_data_dict, log_func):
B
barriery 已提交
288
        _LOGGER.debug(log_func("try to run process"))
289 290
        midped_data_dict = {}
        err_channeldata_dict = {}
291
        if self.with_serving:
292
            data_ids = preped_data_dict.keys()
B
bug fix  
barriery 已提交
293 294
            feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
            midped_batch = None
295 296 297
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
bug fix  
barriery 已提交
298
                    midped_batch = self.process(feed_batch)
299 300
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
301
                    error_info = log_func("process batch failed: {}".format(e))
302 303 304 305
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
306
                        midped_batch = func_timeout.func_timeout(
B
bug fix  
barriery 已提交
307
                            self._timeout, self.process, args=(feed_batch, ))
308 309 310 311 312 313 314
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
B
barriery 已提交
315 316
                                log_func("timeout, retry({}/{})"
                                    .format(i + 1, self._retry)))
317 318
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
319
                        error_info = log_func("process batch failed: {}".format(e))
320 321 322 323 324
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
325 326 327 328 329 330
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ecode,
                            error_info=error_info,
                            data_id=data_id)
            elif midped_batch is None:
331
                # op client return None
B
barriery 已提交
332 333 334
                error_info=log_func(
                        "predict failed. pls check the server side.")
                _LOGGER.error(error_info)
335 336 337
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ChannelDataEcode.CLIENT_ERROR.value,
B
barriery 已提交
338
                            error_info=error_info,
339 340 341 342 343 344 345
                            data_id=data_id)
            else:
                # transform np format to dict format
                for idx, data_id in enumerate(data_ids):
                    midped_data_dict[data_id] = {
                        k: v[idx] for k, v in midped_batch.items()
                    }
346
        else:
347
            midped_data_dict = preped_data_dict
B
barriery 已提交
348
        _LOGGER.debug(log_func("succ run process"))
349 350 351
        return midped_data_dict, err_channeldata_dict

    def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
B
barriery 已提交
352
        _LOGGER.debug(log_func("try to run postprocess"))
353 354
        postped_data_dict = {}
        err_channeldata_dict = {}
B
bug fix  
barriery 已提交
355
        for data_id, midped_data in midped_data_dict.items():
356 357 358 359 360
            postped_data, err_channeldata = None, None
            try:
                postped_data = self.postprocess(
                        parsed_data_dict[data_id], midped_data)
            except Exception as e:
B
barriery 已提交
361 362
                error_info = log_func("postprocess data[{}] failed: {}"
                        .format(data_id, e))
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
                _LOGGER.error(error_info)
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
                    error_info = log_func("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
396
        _LOGGER.debug(log_func("succ run postprocess"))
397
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
398
    
B
barriery 已提交
399 400
    def _auto_batching_generator(self, input_channel, op_name,
            batch_size, timeout, log_func):
B
barriery 已提交
401 402
        while True:
            batch = []
B
barriery 已提交
403 404 405 406
            _LOGGER.debug(
                    log_func(
                        "Auto-batching expect size: {}; timeout: {}".format(
                            batch_size, timeout)))
B
barriery 已提交
407 408 409 410 411 412 413 414 415 416
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
417
                                _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
418 419 420 421 422 423
                                break
                            channeldata_dict = input_channel.front(op_name, timeout)
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
424
                        _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
425
                        break
B
barriery 已提交
426
            _LOGGER.debug(log_func("Auto-batching actual size: {}".format(len(batch))))
B
barriery 已提交
427
            yield batch
428

429 430 431 432
    def _parse_channeldata_batch(self, batch, output_channels):
        parsed_data_dict = {}
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
433
        for channeldata_dict in batch:
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
                self._push_to_output_channels(
                        error_channeldata, output_channels)

        return parsed_data_dict, need_profile_dict, profile_dict
   
B
bug fix  
barriery 已提交
449
    def _run(self, concurrency_idx, input_channel, output_channels,
450
           client_type, is_thread_op):
B
barrierye 已提交
451 452 453 454 455
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)
            return log_func

456
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
457
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
458
        tid = threading.current_thread().ident
B
barrierye 已提交
459

B
barrierye 已提交
460
        # init op
B
barrierye 已提交
461
        try:
B
bug fix  
barriery 已提交
462
            self._initialize(is_thread_op, client_type)
B
barrierye 已提交
463
        except Exception as e:
B
barriery 已提交
464
            _LOGGER.error(log("init op failed: {}".format(e)))
B
barrierye 已提交
465
            os._exit(-1)
B
barriery 已提交
466
        _LOGGER.info(log("succ init"))
467

B
barriery 已提交
468 469 470 471
        batch_generator = self._auto_batching_generator(
                input_channel=input_channel, 
                op_name=self.name,
                batch_size=self._batch_size,
B
barriery 已提交
472 473
                timeout=self._auto_batching_timeout,
                log_func=log)
B
barriery 已提交
474
        
B
barrierye 已提交
475 476
        while True:
            try:
B
barriery 已提交
477
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
478
            except ChannelStopError:
B
barriery 已提交
479
                _LOGGER.debug(log("channel stop."))
B
barriery 已提交
480
                self._finalize(is_thread_op)
B
barrierye 已提交
481
                break
482

B
barriery 已提交
483 484
            # parse channeldata batch
            try:
485 486 487
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
488
            except ChannelStopError:
B
barriery 已提交
489
                _LOGGER.debug(log("channel stop."))
490
                self._finalize(is_thread_op)
B
barriery 已提交
491
                break
492 493 494
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
495 496

            # preprecess
B
barrierye 已提交
497
            self._profiler_record("prep#{}_0".format(op_info_prefix))
498
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
499
                    = self._run_preprocess(parsed_data_dict, log)
B
barrierye 已提交
500
            self._profiler_record("prep#{}_1".format(op_info_prefix))
501 502
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
503
                    self._push_to_output_channels(
504
                        err_channeldata,
B
barrierye 已提交
505
                        output_channels,
506 507 508
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
509
                _LOGGER.debug(log("channel stop."))
510 511 512
                self._finalize(is_thread_op)
                break
            if len(parsed_data_dict) == 0:
513 514
                continue

B
barrierye 已提交
515
            # process
B
barrierye 已提交
516
            self._profiler_record("midp#{}_0".format(op_info_prefix))
517
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
518
                    = self._run_process(preped_data_dict, log)
B
barrierye 已提交
519
            self._profiler_record("midp#{}_1".format(op_info_prefix))
520 521
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
522
                    self._push_to_output_channels(
523 524 525 526 527
                            err_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
528
                _LOGGER.debug(log("channel stop."))
529 530 531
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
532
                continue
533 534

            # postprocess
B
barrierye 已提交
535
            self._profiler_record("postp#{}_0".format(op_info_prefix))
536 537
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
538
                            parsed_data_dict, midped_data_dict, log)
B
barrierye 已提交
539
            self._profiler_record("postp#{}_1".format(op_info_prefix))
540 541
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
542
                    self._push_to_output_channels(
543 544 545 546 547
                            error_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
548
                _LOGGER.debug(log("channel stop."))
549 550 551
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
552
                continue
553 554

            # push data to channel (if run succ)
B
barrierye 已提交
555
            try:
556 557 558 559 560 561
                for data_id, postped_data in postped_data_dict.items():
                    self._push_to_output_channels(
                            postped_data,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
B
barrierye 已提交
562
            except ChannelStopError:
B
barriery 已提交
563
                _LOGGER.debug(log("channel stop."))
564
                self._finalize(is_thread_op)
B
barrierye 已提交
565
                break
B
barriery 已提交
566

B
bug fix  
barriery 已提交
567
    def _initialize(self, is_thread_op, client_type):
B
barriery 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init profiler
                    self._profiler = TimeProfiler()
                    self._profiler.enable(True)
                    # init client
                    self.client = self.init_client(
                            client_type, self._client_config,
                            self._server_endpoints, self._fetch_names)
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
584 585 586 587 588 589 590 591 592 593 594 595
        else:
            self.concurrency_idx = concurrency_idx
            # init profiler
            self._profiler = TimeProfiler()
            self._profiler.enable(True)
            # init client
            self.client = self.init_client(
                    client_type, self._client_config,
                    self._server_endpoints,
                    self._fetch_names)
            # user defined
            self.init_op()
B
barriery 已提交
596 597 598 599 600 601 602 603 604
 
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
605 606 607 608 609

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
610 611 612
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
613
    def __init__(self):
B
barrierye 已提交
614
        # PipelineService.name = "@G"
B
barrierye 已提交
615
        super(RequestOp, self).__init__(name="@G", input_ops=[])
B
barrierye 已提交
616
        # init op
617
        try:
618
            self.init_op()
619
        except Exception as e:
B
bug fix  
barrierye 已提交
620
            _LOGGER.error(e)
621
            os._exit(-1)
B
barrierye 已提交
622 623 624 625

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
626 627 628 629 630 631
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
632 633 634 635 636 637
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
638 639
    def __init__(self, input_ops):
        super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
B
barrierye 已提交
640
        # init op
641
        try:
642
            self.init_op()
643
        except Exception as e:
B
bug fix  
barrierye 已提交
644
            _LOGGER.error(e)
645
            os._exit(-1)
B
barrierye 已提交
646 647 648 649 650 651 652 653 654

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
655
                np.set_printoptions(threshold=np.nan)
B
barrierye 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
678 679 680 681 682 683 684


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
685
            name=name, input_ops=None, concurrency=concurrency)
686 687 688 689 690
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
691 692 693 694 695 696 697 698
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

699 700 701 702 703 704
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
705 706
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
707
        self._outputs.append(channel)
D
dongdaxiang 已提交
708

709
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
710
             is_thread_op):
B
barrierye 已提交
711 712 713 714 715 716
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

717
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
718 719 720
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
721 722 723 724
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
B
barrierye 已提交
725
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
726
                break
D
dongdaxiang 已提交
727

B
barrierye 已提交
728 729 730 731 732
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
B
barrierye 已提交
733
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
734
                break