operator.py 29.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
from numpy import *
26

B
barrierye 已提交
27
from .proto import pipeline_service_pb2
B
barrierye 已提交
28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
29 30
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
31
from .util import NameGenerator
B
barriery 已提交
32
from .profiler import UnsafeTimeProfiler as TimeProfiler
33

W
wangjiawei04 已提交
34
_LOGGER = logging.getLogger()
B
barrierye 已提交
35 36
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
37 38 39

class Op(object):
    def __init__(self,
B
barrierye 已提交
40
                 name=None,
D
dongdaxiang 已提交
41 42
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
43 44
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
45 46
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
47 48
                 retry=1,
                 batch_size=1,
B
bug fix  
barriery 已提交
49
                 auto_batching_timeout=None):
B
barrierye 已提交
50
        if name is None:
B
barrierye 已提交
51
            name = _op_name_gen.next()
52
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
53
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
54
        self.set_input_ops(input_ops)
B
barrierye 已提交
55 56

        self._server_endpoints = server_endpoints
57
        self.with_serving = False
B
barrierye 已提交
58
        if len(self._server_endpoints) != 0:
59
            self.with_serving = True
B
barrierye 已提交
60 61 62
        self._client_config = client_config
        self._fetch_names = fetch_list

63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
67

B
barriery 已提交
68
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
69
        self._auto_batching_timeout = auto_batching_timeout
B
barriery 已提交
70 71 72
        if self._auto_batching_timeout is not None:
            if self._auto_batching_timeout <= 0 or self._batch_size == 1:
                self._auto_batching_timeout = None
B
barriery 已提交
73

B
barrierye 已提交
74
        self._server_use_profile = False
75

B
barrierye 已提交
76 77
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
78
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
79
        self._succ_init_op = False
B
barrierye 已提交
80
        self._succ_close_op = False
B
barrierye 已提交
81

B
barriery 已提交
82
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
83
        if self._batch_size != 1:
B
barriery 已提交
84 85
            _LOGGER.warn("Op({}) reset batch_size=1 (original: {})"
                         .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
86 87
            self._batch_size = 1
        if self._auto_batching_timeout != None:
B
barriery 已提交
88 89
            _LOGGER.warn("Op({}) reset auto_batching_timeout=1 (original: {})"
                         .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
90
            self._auto_batching_timeout = None
B
barriery 已提交
91

B
barrierye 已提交
92
    def use_profiler(self, use_profile):
B
barrierye 已提交
93
        self._server_use_profile = use_profile
94

B
barrierye 已提交
95 96
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
97
        if self.with_serving == False:
B
barriery 已提交
98
            _LOGGER.info("Op({}) no client".format(self.name))
B
barrierye 已提交
99
            return None
B
barriery 已提交
100 101
        _LOGGER.info("Op({}) service endpoints: {}".format(self.name,
                                                           server_endpoints))
B
barriery 已提交
102
        _LOGGER.debug("Op({}) fetch_names: {}".format(self.name, fetch_names))
103
        if client_type == 'brpc':
B
barriery 已提交
104 105
            _LOGGER.debug("Op({}) client_config: {}".format(self.name,
                                                            client_config))
B
barrierye 已提交
106 107
            client = Client()
            client.load_client_config(client_config)
108
        elif client_type == 'grpc':
B
barrierye 已提交
109
            client = MultiLangClient()
110 111
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
112
        client.connect(server_endpoints)
113
        self._fetch_names = fetch_names
B
barrierye 已提交
114
        return client
115 116 117 118 119 120 121 122 123 124 125 126 127 128

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
129

130 131 132 133 134 135 136
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
137

138
    def clean_input_channel(self):
B
barrierye 已提交
139 140 141 142
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
143

144 145 146 147 148 149 150
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
151

152
    def clean_output_channels(self):
B
barrierye 已提交
153 154 155 156 157
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
158
    def preprocess(self, input_dicts):
B
barrierye 已提交
159
        # multiple previous Op
B
barrierye 已提交
160
        if len(input_dicts) != 1:
161
            raise NotImplementedError(
B
barrierye 已提交
162
                'this Op has multiple previous inputs. Please override this func.'
163
            )
D
dongdaxiang 已提交
164

B
barrierye 已提交
165 166
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
167

B
bug fix  
barriery 已提交
168 169
    def process(self, feed_batch):
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
170 171 172
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
173
        call_result = self.client.predict(
B
bug fix  
barriery 已提交
174
            feed=feed_batch, fetch=self._fetch_names)
175 176
        return call_result

W
wangjiawei04 已提交
177
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
178
        return fetch_dict
D
dongdaxiang 已提交
179

B
barrierye 已提交
180
    def _parse_channeldata(self, channeldata_dict):
181
        data_id, error_channeldata = None, None
B
barrierye 已提交
182
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
183 184 185 186
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
187
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
188 189 190 191 192 193

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
194
            if client_need_profile:
B
barrierye 已提交
195
                profile_set |= data.profile_data_set
B
barrierye 已提交
196
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
197
                profile_set)
B
barrierye 已提交
198 199 200 201 202

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
203
                                 profile_str=None,
B
barrierye 已提交
204
                                 client_need_profile=False,
B
barrierye 已提交
205
                                 profile_set=None):
206 207
        if name is None:
            name = self.name
B
barrierye 已提交
208

B
barriery 已提交
209
        # add profile into channeldata
B
barrierye 已提交
210
        if client_need_profile and profile_set is not None:
B
barriery 已提交
211 212
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
213
            data.add_profile(profile_set)
B
barrierye 已提交
214

B
barriery 已提交
215 216 217
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
218
    def start_with_process(self, client_type):
219
        proces = []
B
barrierye 已提交
220
        for concurrency_idx in range(self.concurrency):
221 222
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
223
                args=(concurrency_idx, self._get_input_channel(),
224
                      self._get_output_channels(), client_type, False))
225 226 227 228
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
229
    def start_with_thread(self, client_type):
230
        threads = []
B
barrierye 已提交
231
        for concurrency_idx in range(self.concurrency):
232 233
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
234
                args=(concurrency_idx, self._get_input_channel(),
235
                      self._get_output_channels(), client_type, True))
236 237 238 239
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
240
    def init_op(self):
B
barrierye 已提交
241 242
        pass

243
    def _run_preprocess(self, parsed_data_dict, log_func):
B
barriery 已提交
244
        _LOGGER.debug(log_func("try to run preprocess"))
245 246 247 248 249 250 251 252
        preped_data_dict = {}
        err_channeldata_dict = {}
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except NotImplementedError as e:
                # preprocess function not implemented
B
barriery 已提交
253 254
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
255 256 257 258 259 260
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                    error_info=error_info,
                    data_id=data_id)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
261 262
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
263 264 265 266 267 268
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
269 270
                error_info = log_func("preprocess data[{}] failed: {}".format(
                    data_id, e))
271 272 273 274 275 276 277 278 279
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
280
        _LOGGER.debug(log_func("succ run preprocess"))
281 282 283
        return preped_data_dict, err_channeldata_dict

    def _run_process(self, preped_data_dict, log_func):
B
barriery 已提交
284
        _LOGGER.debug(log_func("try to run process"))
285 286
        midped_data_dict = {}
        err_channeldata_dict = {}
287
        if self.with_serving:
288
            data_ids = preped_data_dict.keys()
B
bug fix  
barriery 已提交
289 290
            feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
            midped_batch = None
291 292 293
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
bug fix  
barriery 已提交
294
                    midped_batch = self.process(feed_batch)
295 296
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
297
                    error_info = log_func("process batch failed: {}".format(e))
298 299 300 301
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
302
                        midped_batch = func_timeout.func_timeout(
B
bug fix  
barriery 已提交
303
                            self._timeout, self.process, args=(feed_batch, ))
304 305 306 307 308 309 310
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
B
barriery 已提交
311
                                log_func("timeout, retry({}/{})"
B
barriery 已提交
312
                                         .format(i + 1, self._retry)))
313 314
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
315 316
                        error_info = log_func("process batch failed: {}".format(
                            e))
317 318 319 320 321
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
322 323
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
324
                        ecode=ecode, error_info=error_info, data_id=data_id)
325
            elif midped_batch is None:
326
                # op client return None
B
barriery 已提交
327 328
                error_info = log_func(
                    "predict failed. pls check the server side.")
B
barriery 已提交
329
                _LOGGER.error(error_info)
330 331
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
332 333 334
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
335 336 337 338
            else:
                # transform np format to dict format
                for idx, data_id in enumerate(data_ids):
                    midped_data_dict[data_id] = {
B
barriery 已提交
339 340
                        k: v[idx]
                        for k, v in midped_batch.items()
341
                    }
342
        else:
343
            midped_data_dict = preped_data_dict
B
barriery 已提交
344
        _LOGGER.debug(log_func("succ run process"))
345 346 347
        return midped_data_dict, err_channeldata_dict

    def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
B
barriery 已提交
348
        _LOGGER.debug(log_func("try to run postprocess"))
349 350
        postped_data_dict = {}
        err_channeldata_dict = {}
B
bug fix  
barriery 已提交
351
        for data_id, midped_data in midped_data_dict.items():
352 353
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
354 355
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
356
            except Exception as e:
B
barriery 已提交
357
                error_info = log_func("postprocess data[{}] failed: {}"
B
barriery 已提交
358
                                      .format(data_id, e))
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
                _LOGGER.error(error_info)
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
                    error_info = log_func("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
392
        _LOGGER.debug(log_func("succ run postprocess"))
393
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
394 395 396

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
                                 timeout, log_func):
B
barriery 已提交
397 398
        while True:
            batch = []
B
barriery 已提交
399
            _LOGGER.debug(
B
barriery 已提交
400 401
                log_func("Auto-batching expect size: {}; timeout: {}".format(
                    batch_size, timeout)))
B
barriery 已提交
402 403 404 405 406 407 408 409 410 411
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
412
                                _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
413
                                break
B
barriery 已提交
414 415
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
416 417 418 419
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
420
                        _LOGGER.debug(log_func("Auto-batching timeout"))
B
barriery 已提交
421
                        break
B
barriery 已提交
422 423
            _LOGGER.debug(
                log_func("Auto-batching actual size: {}".format(len(batch))))
B
barriery 已提交
424
            yield batch
425

426 427 428 429
    def _parse_channeldata_batch(self, batch, output_channels):
        parsed_data_dict = {}
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
430
        for channeldata_dict in batch:
431 432 433 434 435 436 437 438 439 440
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
441 442
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
443 444

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
445 446 447

    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
             is_thread_op):
B
barrierye 已提交
448 449 450
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)
B
barriery 已提交
451

B
barrierye 已提交
452 453
            return log_func

454
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
455
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
456
        tid = threading.current_thread().ident
B
barrierye 已提交
457

B
barrierye 已提交
458
        # init op
B
barriery 已提交
459
        profiler = None
B
barrierye 已提交
460
        try:
B
barriery 已提交
461 462
            profiler = self._initialize(is_thread_op, client_type,
                                        concurrency_idx)
B
barrierye 已提交
463
        except Exception as e:
B
barriery 已提交
464
            _LOGGER.error(log("init op failed: {}".format(e)))
B
barrierye 已提交
465
            os._exit(-1)
B
barriery 已提交
466
        _LOGGER.info(log("succ init"))
467

B
barriery 已提交
468
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
469 470 471 472 473 474
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
            log_func=log)

B
barrierye 已提交
475 476
        while True:
            try:
B
barriery 已提交
477
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
478
            except ChannelStopError:
B
barriery 已提交
479
                _LOGGER.debug(log("channel stop."))
B
barriery 已提交
480
                self._finalize(is_thread_op)
B
barrierye 已提交
481
                break
482

B
barriery 已提交
483 484
            # parse channeldata batch
            try:
485 486 487
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
488
            except ChannelStopError:
B
barriery 已提交
489
                _LOGGER.debug(log("channel stop."))
490
                self._finalize(is_thread_op)
B
barriery 已提交
491
                break
492 493 494
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
495 496

            # preprecess
B
barriery 已提交
497
            profiler.record("prep#{}_0".format(op_info_prefix))
498
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
499
                    = self._run_preprocess(parsed_data_dict, log)
B
barriery 已提交
500
            profiler.record("prep#{}_1".format(op_info_prefix))
501 502
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
503
                    self._push_to_output_channels(
B
barriery 已提交
504 505
                        data=err_channeldata,
                        channels=output_channels,
506 507 508
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
509
                _LOGGER.debug(log("channel stop."))
510 511 512
                self._finalize(is_thread_op)
                break
            if len(parsed_data_dict) == 0:
513 514
                continue

B
barrierye 已提交
515
            # process
B
barriery 已提交
516
            profiler.record("midp#{}_0".format(op_info_prefix))
517
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
518
                    = self._run_process(preped_data_dict, log)
B
barriery 已提交
519
            profiler.record("midp#{}_1".format(op_info_prefix))
520 521
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
522
                    self._push_to_output_channels(
B
barriery 已提交
523 524
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
525 526
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
527
            except ChannelStopError:
B
barriery 已提交
528
                _LOGGER.debug(log("channel stop."))
529 530 531
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
532
                continue
533 534

            # postprocess
B
barriery 已提交
535
            profiler.record("postp#{}_0".format(op_info_prefix))
536 537
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
538
                            parsed_data_dict, midped_data_dict, log)
B
barriery 已提交
539
            profiler.record("postp#{}_1".format(op_info_prefix))
540 541
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
542
                    self._push_to_output_channels(
B
barriery 已提交
543 544
                        data=error_channeldata,
                        channels=output_channels,
B
barriery 已提交
545 546
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
547
            except ChannelStopError:
B
barriery 已提交
548
                _LOGGER.debug(log("channel stop."))
549 550 551
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
552
                continue
553 554

            # push data to channel (if run succ)
B
barrierye 已提交
555
            try:
B
barriery 已提交
556
                profile_str = profiler.gen_profile_str()
557
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
558 559
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
560
                    self._push_to_output_channels(
B
barriery 已提交
561 562 563
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
564 565
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
566
            except ChannelStopError:
B
barriery 已提交
567
                _LOGGER.debug(log("channel stop."))
568
                self._finalize(is_thread_op)
B
barrierye 已提交
569
                break
B
barriery 已提交
570

B
bug fix  
barriery 已提交
571
    def _initialize(self, is_thread_op, client_type, concurrency_idx):
B
barriery 已提交
572 573 574 575 576 577 578
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
                    self.client = self.init_client(
B
barriery 已提交
579 580
                        client_type, self._client_config,
                        self._server_endpoints, self._fetch_names)
B
barriery 已提交
581 582 583 584
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
585 586 587
        else:
            self.concurrency_idx = concurrency_idx
            # init client
B
barriery 已提交
588 589 590
            self.client = self.init_client(client_type, self._client_config,
                                           self._server_endpoints,
                                           self._fetch_names)
B
bug fix  
barriery 已提交
591 592
            # user defined
            self.init_op()
B
barriery 已提交
593

B
barriery 已提交
594 595 596 597 598
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
599 600 601 602 603 604 605 606
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
607 608 609 610 611

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
612 613 614
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
615
    def __init__(self):
B
barrierye 已提交
616
        # PipelineService.name = "@G"
B
barrierye 已提交
617
        super(RequestOp, self).__init__(name="@G", input_ops=[])
B
barrierye 已提交
618
        # init op
619
        try:
620
            self.init_op()
621
        except Exception as e:
B
barriery 已提交
622
            _LOGGER.error("Op(Request) init op failed: {}".format(e))
623
            os._exit(-1)
B
barrierye 已提交
624 625 626 627

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
628 629 630 631 632 633
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
634 635 636 637 638 639
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
640 641
    def __init__(self, input_ops):
        super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
B
barrierye 已提交
642
        # init op
643
        try:
644
            self.init_op()
645
        except Exception as e:
B
barriery 已提交
646
            _LOGGER.error("Op(ResponseOp) init op failed: {}".format(e))
647
            os._exit(-1)
B
barrierye 已提交
648 649 650 651 652 653 654 655 656

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
657
                np.set_printoptions(threshold=np.nan)
B
barrierye 已提交
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
680 681 682 683 684 685 686


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
687
            name=name, input_ops=None, concurrency=concurrency)
688 689 690 691 692
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
693 694 695 696 697 698 699 700
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

701 702 703 704 705 706
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
707 708
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
709
        self._outputs.append(channel)
D
dongdaxiang 已提交
710

711
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
712
             is_thread_op):
B
barrierye 已提交
713 714 715 716 717 718
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

719
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
720 721 722
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
723 724 725 726
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
B
barriery 已提交
727
                _LOGGER.debug(log("Channel stop."))
B
barrierye 已提交
728
                break
D
dongdaxiang 已提交
729

B
barrierye 已提交
730 731 732 733 734
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
B
barriery 已提交
735
                _LOGGER.debug(log("Channel stop."))
B
barrierye 已提交
736
                break