operator.py 29.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
from numpy import *
26

B
barrierye 已提交
27
from .proto import pipeline_service_pb2
B
barrierye 已提交
28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
29 30
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
31
from .util import NameGenerator
B
barrierye 已提交
32
from .profiler import TimeProfiler
33

W
wangjiawei04 已提交
34
_LOGGER = logging.getLogger()
B
barrierye 已提交
35 36
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
37 38 39

class Op(object):
    def __init__(self,
B
barrierye 已提交
40
                 name=None,
D
dongdaxiang 已提交
41 42
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
43 44
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
45 46
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
47 48
                 retry=1,
                 batch_size=1,
B
bug fix  
barriery 已提交
49
                 auto_batching_timeout=None):
B
barrierye 已提交
50
        if name is None:
B
barrierye 已提交
51
            name = _op_name_gen.next()
52
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
53
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
54
        self.set_input_ops(input_ops)
B
barrierye 已提交
55 56

        self._server_endpoints = server_endpoints
57
        self.with_serving = False
B
barrierye 已提交
58
        if len(self._server_endpoints) != 0:
59
            self.with_serving = True
B
barrierye 已提交
60 61 62
        self._client_config = client_config
        self._fetch_names = fetch_list

63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
67

B
barriery 已提交
68
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
69
        self._auto_batching_timeout = auto_batching_timeout
B
barriery 已提交
70 71 72
        if self._auto_batching_timeout is not None:
            if self._auto_batching_timeout <= 0 or self._batch_size == 1:
                self._auto_batching_timeout = None
B
barriery 已提交
73

B
barrierye 已提交
74
        self._server_use_profile = False
75

B
barrierye 已提交
76 77
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
78
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
79
        self._succ_init_op = False
B
barrierye 已提交
80
        self._succ_close_op = False
B
barrierye 已提交
81

B
barriery 已提交
82
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
83 84 85 86 87 88 89 90 91 92
        if self._batch_size != 1:
            _LOGGER.warn(
                    "Op({}) reset batch_size=1 (original: {})"
                    .format(self.name, self._batch_size))
            self._batch_size = 1
        if self._auto_batching_timeout != None:
            _LOGGER.warn(
                    "Op({}) reset auto_batching_timeout=1 (original: {})"
                    .format(self.name, self._auto_batching_timeout))
            self._auto_batching_timeout = None
B
barriery 已提交
93

B
barrierye 已提交
94
    def use_profiler(self, use_profile):
B
barrierye 已提交
95
        self._server_use_profile = use_profile
96 97 98 99 100 101

    def _profiler_record(self, string):
        if self._profiler is None:
            return
        self._profiler.record(string)

B
barrierye 已提交
102 103
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
104
        if self.with_serving == False:
B
barriery 已提交
105
            _LOGGER.info("Op({}) no client".format(self.name))
B
barrierye 已提交
106
            return None
B
barriery 已提交
107 108
        _LOGGER.info("Op({}) service endpoints: {}".format(self.name, server_endpoints))
        _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
109
        if client_type == 'brpc':
B
barriery 已提交
110
            _LOGGER.debug("Op({}) client_config: {}".format(self.name, client_config))
B
barrierye 已提交
111 112
            client = Client()
            client.load_client_config(client_config)
113
        elif client_type == 'grpc':
B
barrierye 已提交
114
            client = MultiLangClient()
115 116
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
117
        client.connect(server_endpoints)
118
        self._fetch_names = fetch_names
B
barrierye 已提交
119
        return client
120 121 122 123 124 125 126 127 128 129 130 131 132 133

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
134

135 136 137 138 139 140 141
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
142

143
    def clean_input_channel(self):
B
barrierye 已提交
144 145 146 147
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
148

149 150 151 152 153 154 155
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
156

157
    def clean_output_channels(self):
B
barrierye 已提交
158 159 160 161 162
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
163
    def preprocess(self, input_dicts):
B
barrierye 已提交
164
        # multiple previous Op
B
barrierye 已提交
165
        if len(input_dicts) != 1:
166
            raise NotImplementedError(
B
barrierye 已提交
167
                'this Op has multiple previous inputs. Please override this func.'
168
            )
D
dongdaxiang 已提交
169

B
barrierye 已提交
170 171
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
172

B
bug fix  
barriery 已提交
173 174
    def process(self, feed_batch):
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
175 176 177
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
178
        call_result = self.client.predict(
B
bug fix  
barriery 已提交
179
            feed=feed_batch, fetch=self._fetch_names)
180 181
        return call_result

W
wangjiawei04 已提交
182
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
183
        return fetch_dict
D
dongdaxiang 已提交
184

B
barrierye 已提交
185
    def _parse_channeldata(self, channeldata_dict):
186
        data_id, error_channeldata = None, None
B
barrierye 已提交
187
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
188 189 190 191
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
192
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
193 194 195 196 197 198

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
199
            if client_need_profile:
B
barrierye 已提交
200
                profile_set |= data.profile_data_set
B
barrierye 已提交
201
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
202
                profile_set)
B
barrierye 已提交
203 204 205 206 207 208

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
                                 client_need_profile=False,
B
barrierye 已提交
209
                                 profile_set=None):
210 211
        if name is None:
            name = self.name
B
barrierye 已提交
212
        self._add_profile_into_channeldata(data, client_need_profile,
B
barrierye 已提交
213
                                           profile_set)
214 215 216
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
217
    def _add_profile_into_channeldata(self, data, client_need_profile,
B
barrierye 已提交
218
                                      profile_set):
B
barrierye 已提交
219 220 221 222
        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

B
barrierye 已提交
223 224 225
        if client_need_profile and profile_set is not None:
            profile_set.add(profile_str)
            data.add_profile(profile_set)
B
barrierye 已提交
226

B
barrierye 已提交
227
    def start_with_process(self, client_type):
228
        proces = []
B
barrierye 已提交
229
        for concurrency_idx in range(self.concurrency):
230 231
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
232
                args=(concurrency_idx, self._get_input_channel(),
233
                      self._get_output_channels(), client_type, False))
234 235 236 237
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
238
    def start_with_thread(self, client_type):
239
        threads = []
B
barrierye 已提交
240
        for concurrency_idx in range(self.concurrency):
241 242
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
243
                args=(concurrency_idx, self._get_input_channel(),
244
                      self._get_output_channels(), client_type, True))
245 246 247 248
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
249
    def init_op(self):
B
barrierye 已提交
250 251
        pass

252
    def _run_preprocess(self, parsed_data_dict, log_func):
B
barriery 已提交
253
        _LOGGER.debug(log_func("try to run preprocess"))
254 255 256 257 258 259 260 261
        preped_data_dict = {}
        err_channeldata_dict = {}
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except NotImplementedError as e:
                # preprocess function not implemented
B
barriery 已提交
262 263 264
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
265 266 267 268 269 270
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                    error_info=error_info,
                    data_id=data_id)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
271 272 273
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
274 275 276 277 278 279
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
280 281 282
                error_info = log_func(
                        "preprocess data[{}] failed: {}".format(
                            data_id, e))
283 284 285 286 287 288 289 290 291
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
292
        _LOGGER.debug(log_func("succ run preprocess"))
293 294 295
        return preped_data_dict, err_channeldata_dict

    def _run_process(self, preped_data_dict, log_func):
B
barriery 已提交
296
        _LOGGER.debug(log_func("try to run process"))
297 298
        midped_data_dict = {}
        err_channeldata_dict = {}
299
        if self.with_serving:
300
            data_ids = preped_data_dict.keys()
B
bug fix  
barriery 已提交
301 302
            feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
            midped_batch = None
303 304 305
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
bug fix  
barriery 已提交
306
                    midped_batch = self.process(feed_batch)
307 308
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
309
                    error_info = log_func("process batch failed: {}".format(e))
310 311 312 313
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
314
                        midped_batch = func_timeout.func_timeout(
B
bug fix  
barriery 已提交
315
                            self._timeout, self.process, args=(feed_batch, ))
316 317 318 319 320 321 322
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
B
barriery 已提交
323 324
                                log_func("timeout, retry({}/{})"
                                    .format(i + 1, self._retry)))
325 326
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
327
                        error_info = log_func("process batch failed: {}".format(e))
328 329 330 331 332
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
333 334 335 336 337 338
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ecode,
                            error_info=error_info,
                            data_id=data_id)
            elif midped_batch is None:
339
                # op client return None
B
barriery 已提交
340 341 342
                error_info=log_func(
                        "predict failed. pls check the server side.")
                _LOGGER.error(error_info)
343 344 345
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ChannelDataEcode.CLIENT_ERROR.value,
B
barriery 已提交
346
                            error_info=error_info,
347 348 349 350 351 352 353
                            data_id=data_id)
            else:
                # transform np format to dict format
                for idx, data_id in enumerate(data_ids):
                    midped_data_dict[data_id] = {
                        k: v[idx] for k, v in midped_batch.items()
                    }
354
        else:
355
            midped_data_dict = preped_data_dict
B
barriery 已提交
356
        _LOGGER.debug(log_func("succ run process"))
357 358 359
        return midped_data_dict, err_channeldata_dict

    def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
B
barriery 已提交
360
        _LOGGER.debug(log_func("try to run postprocess"))
361 362
        postped_data_dict = {}
        err_channeldata_dict = {}
B
bug fix  
barriery 已提交
363
        for data_id, midped_data in midped_data_dict.items():
364 365 366 367 368
            postped_data, err_channeldata = None, None
            try:
                postped_data = self.postprocess(
                        parsed_data_dict[data_id], midped_data)
            except Exception as e:
B
barriery 已提交
369 370
                error_info = log_func("postprocess data[{}] failed: {}"
                        .format(data_id, e))
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
                _LOGGER.error(error_info)
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
                    error_info = log_func("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
404
        _LOGGER.debug(log_func("succ run postprocess"))
405
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
406
    
B
barriery 已提交
407 408
    def _auto_batching_generator(self, input_channel, op_name,
            batch_size, timeout, log_func):
B
barriery 已提交
409 410
        while True:
            batch = []
B
barriery 已提交
411 412 413 414
            _LOGGER.debug(
                    log_func(
                        "Auto-batching expect size: {}; timeout: {}".format(
                            batch_size, timeout)))
B
barriery 已提交
415 416 417 418 419 420 421 422 423 424
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
425
                                _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
426 427 428 429 430 431
                                break
                            channeldata_dict = input_channel.front(op_name, timeout)
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
432
                        _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
433
                        break
B
barriery 已提交
434
            _LOGGER.debug(log_func("Auto-batching actual size: {}".format(len(batch))))
B
barriery 已提交
435
            yield batch
436

437 438 439 440
    def _parse_channeldata_batch(self, batch, output_channels):
        parsed_data_dict = {}
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
441
        for channeldata_dict in batch:
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
                self._push_to_output_channels(
                        error_channeldata, output_channels)

        return parsed_data_dict, need_profile_dict, profile_dict
   
B
bug fix  
barriery 已提交
457
    def _run(self, concurrency_idx, input_channel, output_channels,
458
           client_type, is_thread_op):
B
barrierye 已提交
459 460 461 462 463
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)
            return log_func

464
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
465
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
466
        tid = threading.current_thread().ident
B
barrierye 已提交
467

B
barrierye 已提交
468
        # init op
B
barrierye 已提交
469
        try:
B
bug fix  
barriery 已提交
470
            self._initialize(is_thread_op, client_type, concurrency_idx)
B
barrierye 已提交
471
        except Exception as e:
B
barriery 已提交
472
            _LOGGER.error(log("init op failed: {}".format(e)))
B
barrierye 已提交
473
            os._exit(-1)
B
barriery 已提交
474
        _LOGGER.info(log("succ init"))
475

B
barriery 已提交
476 477 478 479
        batch_generator = self._auto_batching_generator(
                input_channel=input_channel, 
                op_name=self.name,
                batch_size=self._batch_size,
B
barriery 已提交
480 481
                timeout=self._auto_batching_timeout,
                log_func=log)
B
barriery 已提交
482
        
B
barrierye 已提交
483 484
        while True:
            try:
B
barriery 已提交
485
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
486
            except ChannelStopError:
B
barriery 已提交
487
                _LOGGER.debug(log("channel stop."))
B
barriery 已提交
488
                self._finalize(is_thread_op)
B
barrierye 已提交
489
                break
490

B
barriery 已提交
491 492
            # parse channeldata batch
            try:
493 494 495
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
496
            except ChannelStopError:
B
barriery 已提交
497
                _LOGGER.debug(log("channel stop."))
498
                self._finalize(is_thread_op)
B
barriery 已提交
499
                break
500 501 502
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
503 504

            # preprecess
B
barrierye 已提交
505
            self._profiler_record("prep#{}_0".format(op_info_prefix))
506
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
507
                    = self._run_preprocess(parsed_data_dict, log)
B
barrierye 已提交
508
            self._profiler_record("prep#{}_1".format(op_info_prefix))
509 510
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
511
                    self._push_to_output_channels(
512
                        err_channeldata,
B
barrierye 已提交
513
                        output_channels,
514 515 516
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
517
                _LOGGER.debug(log("channel stop."))
518 519 520
                self._finalize(is_thread_op)
                break
            if len(parsed_data_dict) == 0:
521 522
                continue

B
barrierye 已提交
523
            # process
B
barrierye 已提交
524
            self._profiler_record("midp#{}_0".format(op_info_prefix))
525
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
526
                    = self._run_process(preped_data_dict, log)
B
barrierye 已提交
527
            self._profiler_record("midp#{}_1".format(op_info_prefix))
528 529
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
530
                    self._push_to_output_channels(
531 532 533 534 535
                            err_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
536
                _LOGGER.debug(log("channel stop."))
537 538 539
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
540
                continue
541 542

            # postprocess
B
barrierye 已提交
543
            self._profiler_record("postp#{}_0".format(op_info_prefix))
544 545
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
546
                            parsed_data_dict, midped_data_dict, log)
B
barrierye 已提交
547
            self._profiler_record("postp#{}_1".format(op_info_prefix))
548 549
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
550
                    self._push_to_output_channels(
551 552 553 554 555
                            error_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
556
                _LOGGER.debug(log("channel stop."))
557 558 559
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
560
                continue
561 562

            # push data to channel (if run succ)
B
barrierye 已提交
563
            try:
564 565 566 567 568 569
                for data_id, postped_data in postped_data_dict.items():
                    self._push_to_output_channels(
                            postped_data,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
B
barrierye 已提交
570
            except ChannelStopError:
B
barriery 已提交
571
                _LOGGER.debug(log("channel stop."))
572
                self._finalize(is_thread_op)
B
barrierye 已提交
573
                break
B
barriery 已提交
574

B
bug fix  
barriery 已提交
575
    def _initialize(self, is_thread_op, client_type, concurrency_idx):
B
barriery 已提交
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init profiler
                    self._profiler = TimeProfiler()
                    self._profiler.enable(True)
                    # init client
                    self.client = self.init_client(
                            client_type, self._client_config,
                            self._server_endpoints, self._fetch_names)
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
592 593 594 595 596 597 598 599 600 601 602 603
        else:
            self.concurrency_idx = concurrency_idx
            # init profiler
            self._profiler = TimeProfiler()
            self._profiler.enable(True)
            # init client
            self.client = self.init_client(
                    client_type, self._client_config,
                    self._server_endpoints,
                    self._fetch_names)
            # user defined
            self.init_op()
B
barriery 已提交
604 605 606 607 608 609 610 611 612
 
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
613 614 615 616 617

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
618 619 620
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
621
    def __init__(self):
B
barrierye 已提交
622
        # PipelineService.name = "@G"
B
barrierye 已提交
623
        super(RequestOp, self).__init__(name="@G", input_ops=[])
B
barrierye 已提交
624
        # init op
625
        try:
626
            self.init_op()
627
        except Exception as e:
B
bug fix  
barrierye 已提交
628
            _LOGGER.error(e)
629
            os._exit(-1)
B
barrierye 已提交
630 631 632 633

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
634 635 636 637 638 639
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
640 641 642 643 644 645
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
646 647
    def __init__(self, input_ops):
        super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
B
barrierye 已提交
648
        # init op
649
        try:
650
            self.init_op()
651
        except Exception as e:
B
bug fix  
barrierye 已提交
652
            _LOGGER.error(e)
653
            os._exit(-1)
B
barrierye 已提交
654 655 656 657 658 659 660 661 662

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
663
                np.set_printoptions(threshold=np.nan)
B
barrierye 已提交
664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
686 687 688 689 690 691 692


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
693
            name=name, input_ops=None, concurrency=concurrency)
694 695 696 697 698
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
699 700 701 702 703 704 705 706
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

707 708 709 710 711 712
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
713 714
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
715
        self._outputs.append(channel)
D
dongdaxiang 已提交
716

717
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
718
             is_thread_op):
B
barrierye 已提交
719 720 721 722 723 724
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

725
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
726 727 728
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
729 730 731 732
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
B
barrierye 已提交
733
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
734
                break
D
dongdaxiang 已提交
735

B
barrierye 已提交
736 737 738 739 740
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
B
barrierye 已提交
741
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
742
                break