operator.py 27.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
import numpy as np
B
barrierye 已提交
25
from numpy import *
26

B
barrierye 已提交
27
from .proto import pipeline_service_pb2
B
barrierye 已提交
28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
29 30
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
31
from .util import NameGenerator
B
barrierye 已提交
32
from .profiler import TimeProfiler
33

W
wangjiawei04 已提交
34
_LOGGER = logging.getLogger()
B
barrierye 已提交
35 36
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
37 38 39

class Op(object):
    def __init__(self,
B
barrierye 已提交
40
                 name=None,
D
dongdaxiang 已提交
41 42
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
43 44
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
45 46
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
47 48
                 retry=1,
                 batch_size=1,
B
bug fix  
barriery 已提交
49
                 auto_batching_timeout=None):
B
barrierye 已提交
50
        if name is None:
B
barrierye 已提交
51
            name = _op_name_gen.next()
52
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
53
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
54
        self.set_input_ops(input_ops)
B
barrierye 已提交
55 56

        self._server_endpoints = server_endpoints
57
        self.with_serving = False
B
barrierye 已提交
58
        if len(self._server_endpoints) != 0:
59
            self.with_serving = True
B
barrierye 已提交
60 61 62
        self._client_config = client_config
        self._fetch_names = fetch_list

63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
67

B
barriery 已提交
68
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
69 70 71
        self._auto_batching_timeout = auto_batching_timeout
        if self._auto_batching_timeout is not None and self._auto_batching_timeout <= 0:
            self._auto_batching_timeout = None
B
barriery 已提交
72

B
barrierye 已提交
73
        self._server_use_profile = False
74

B
barrierye 已提交
75 76
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
77
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
78
        self._succ_init_op = False
B
barrierye 已提交
79
        self._succ_close_op = False
B
barrierye 已提交
80

B
barrierye 已提交
81
    def use_profiler(self, use_profile):
B
barrierye 已提交
82
        self._server_use_profile = use_profile
83 84 85 86 87 88

    def _profiler_record(self, string):
        if self._profiler is None:
            return
        self._profiler.record(string)

B
barrierye 已提交
89 90
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
91
        if self.with_serving == False:
B
barrierye 已提交
92
            _LOGGER.debug("{} no client".format(self.name))
B
barrierye 已提交
93
            return None
B
barrierye 已提交
94 95
        _LOGGER.debug("{} client_config: {}".format(self.name, client_config))
        _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
96
        if client_type == 'brpc':
B
barrierye 已提交
97 98
            client = Client()
            client.load_client_config(client_config)
99
        elif client_type == 'grpc':
B
barrierye 已提交
100
            client = MultiLangClient()
101 102
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
103
        client.connect(server_endpoints)
104
        self._fetch_names = fetch_names
B
barrierye 已提交
105
        return client
106 107 108 109 110 111 112 113 114 115 116 117 118 119

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
120

121 122 123 124 125 126 127
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
128

129
    def clean_input_channel(self):
B
barrierye 已提交
130 131 132 133
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
134

135 136 137 138 139 140 141
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
142

143
    def clean_output_channels(self):
B
barrierye 已提交
144 145 146 147 148
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
149
    def preprocess(self, input_dicts):
B
barrierye 已提交
150
        # multiple previous Op
B
barrierye 已提交
151
        if len(input_dicts) != 1:
152
            raise NotImplementedError(
B
barrierye 已提交
153
                'this Op has multiple previous inputs. Please override this func.'
154
            )
D
dongdaxiang 已提交
155

B
barrierye 已提交
156 157
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
158

B
bug fix  
barriery 已提交
159 160
    def process(self, feed_batch):
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
161 162 163
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
164
        call_result = self.client.predict(
B
bug fix  
barriery 已提交
165
            feed=feed_batch, fetch=self._fetch_names)
B
barrierye 已提交
166
        _LOGGER.debug(self._log("get call_result"))
167 168
        return call_result

W
wangjiawei04 已提交
169
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
170
        return fetch_dict
D
dongdaxiang 已提交
171

B
barrierye 已提交
172
    def _parse_channeldata(self, channeldata_dict):
173
        data_id, error_channeldata = None, None
B
barrierye 已提交
174
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
175 176 177 178
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
179
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
180 181 182 183 184 185

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
186
            if client_need_profile:
B
barrierye 已提交
187
                profile_set |= data.profile_data_set
B
barrierye 已提交
188
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
189
                profile_set)
B
barrierye 已提交
190 191 192 193 194 195

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
                                 client_need_profile=False,
B
barrierye 已提交
196
                                 profile_set=None):
197 198
        if name is None:
            name = self.name
B
barrierye 已提交
199
        self._add_profile_into_channeldata(data, client_need_profile,
B
barrierye 已提交
200
                                           profile_set)
201 202 203
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
204
    def _add_profile_into_channeldata(self, data, client_need_profile,
B
barrierye 已提交
205
                                      profile_set):
B
barrierye 已提交
206 207 208 209
        profile_str = self._profiler.gen_profile_str()
        if self._server_use_profile:
            sys.stderr.write(profile_str)

B
barrierye 已提交
210 211 212
        if client_need_profile and profile_set is not None:
            profile_set.add(profile_str)
            data.add_profile(profile_set)
B
barrierye 已提交
213

B
barrierye 已提交
214
    def start_with_process(self, client_type):
215
        proces = []
B
barrierye 已提交
216
        for concurrency_idx in range(self.concurrency):
217 218
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
219
                args=(concurrency_idx, self._get_input_channel(),
220
                      self._get_output_channels(), client_type, False))
221 222 223 224
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
225
    def start_with_thread(self, client_type):
226
        threads = []
B
barrierye 已提交
227
        for concurrency_idx in range(self.concurrency):
228 229
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
230
                args=(concurrency_idx, self._get_input_channel(),
231
                      self._get_output_channels(), client_type, True))
232 233 234 235
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
236
    def init_op(self):
B
barrierye 已提交
237 238
        pass

239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    def _run_preprocess(self, parsed_data_dict, log_func):
        preped_data_dict = {}
        err_channeldata_dict = {}
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except NotImplementedError as e:
                # preprocess function not implemented
                error_info = log_func(e)
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                    error_info=error_info,
                    data_id=data_id)
            except TypeError as e:
                # Error type in channeldata.datatype
                error_info = log_func(e)
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
                error_info = log_func(e)
                _LOGGER.error(error_info)
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
        return preped_data_dict, err_channeldata_dict

    def _run_process(self, preped_data_dict, log_func):
        midped_data_dict = {}
        err_channeldata_dict = {}
278
        if self.with_serving:
279
            data_ids = preped_data_dict.keys()
B
bug fix  
barriery 已提交
280 281
            feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
            midped_batch = None
282 283 284
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
bug fix  
barriery 已提交
285
                    midped_batch = self.process(feed_batch)
286 287 288 289 290 291 292
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log_func(e)
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
293
                        midped_batch = func_timeout.func_timeout(
B
bug fix  
barriery 已提交
294
                            self._timeout, self.process, args=(feed_batch, ))
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
                                log_func("timeout, retry({})".format(i + 1)))
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log_func(e)
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
311 312 313 314 315 316
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ecode,
                            error_info=error_info,
                            data_id=data_id)
            elif midped_batch is None:
317
                # op client return None
318 319 320 321 322 323 324 325 326 327 328 329
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
                            ecode=ChannelDataEcode.CLIENT_ERROR.value,
                            error_info=log_func(
                                "predict failed. pls check the server side."),
                            data_id=data_id)
            else:
                # transform np format to dict format
                for idx, data_id in enumerate(data_ids):
                    midped_data_dict[data_id] = {
                        k: v[idx] for k, v in midped_batch.items()
                    }
330
        else:
331 332 333 334 335 336
            midped_data_dict = preped_data_dict
        return midped_data_dict, err_channeldata_dict

    def _run_postprocess(self, parsed_data_dict, midped_data_dict, log_func):
        postped_data_dict = {}
        err_channeldata_dict = {}
B
bug fix  
barriery 已提交
337
        for data_id, midped_data in midped_data_dict.items():
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
            postped_data, err_channeldata = None, None
            try:
                postped_data = self.postprocess(
                        parsed_data_dict[data_id], midped_data)
            except Exception as e:
                error_info = log_func(e)
                _LOGGER.error(error_info)
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
                    error_info = log_func("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
    
    def _auto_batching_generator(self, input_channel, op_name, batch_size, timeout):
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
                                _LOGGER.info(log("auto-batching timeout"))
                                break
                            channeldata_dict = input_channel.front(op_name, timeout)
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
                        _LOGGER.info(log("auto-batching timeout"))
                        break
            yield batch
402

403 404 405 406
    def _parse_channeldata_batch(self, batch, output_channels):
        parsed_data_dict = {}
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
407
        for channeldata_dict in batch:
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
                self._push_to_output_channels(
                        error_channeldata, output_channels)

        return parsed_data_dict, need_profile_dict, profile_dict
   
B
bug fix  
barriery 已提交
423
    def _run(self, concurrency_idx, input_channel, output_channels,
424
           client_type, is_thread_op):
B
barrierye 已提交
425 426 427 428 429
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)
            return log_func

430
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
431
        log = get_log_func(op_info_prefix)
B
bug fix  
barriery 已提交
432 433 434
        preplog = get_log_func(op_info_prefix + "(prep)")
        midplog = get_log_func(op_info_prefix + "(midp)")
        postplog = get_log_func(op_info_prefix + "(postp)")
B
barrierye 已提交
435
        tid = threading.current_thread().ident
B
barrierye 已提交
436

B
barrierye 已提交
437
        # init op
B
barrierye 已提交
438
        try:
B
bug fix  
barriery 已提交
439
            self._initialize(is_thread_op, client_type)
B
barrierye 已提交
440 441 442
        except Exception as e:
            _LOGGER.error(log(e))
            os._exit(-1)
443

B
barriery 已提交
444 445 446 447 448 449
        batch_generator = self._auto_batching_generator(
                input_channel=input_channel, 
                op_name=self.name,
                batch_size=self._batch_size,
                timeout=self._auto_batching_timeout)
        
B
barrierye 已提交
450 451
        while True:
            try:
B
barriery 已提交
452
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
453
            except ChannelStopError:
B
barrierye 已提交
454
                _LOGGER.debug(log("stop."))
B
barriery 已提交
455
                self._finalize(is_thread_op)
B
barrierye 已提交
456
                break
457

B
barriery 已提交
458 459
            # parse channeldata batch
            try:
460 461 462
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
463 464
            except ChannelStopError:
                _LOGGER.debug(log("stop."))
465
                self._finalize(is_thread_op)
B
barriery 已提交
466
                break
467 468 469
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
470 471

            # preprecess
B
barrierye 已提交
472
            self._profiler_record("prep#{}_0".format(op_info_prefix))
473
            preped_data_dict, err_channeldata_dict \
B
bug fix  
barriery 已提交
474
                    = self._run_preprocess(parsed_data_dict, preplog)
B
barrierye 已提交
475
            self._profiler_record("prep#{}_1".format(op_info_prefix))
476 477
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
478
                    self._push_to_output_channels(
479
                        err_channeldata,
B
barrierye 已提交
480
                        output_channels,
481 482 483 484 485 486 487
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
                _LOGGER.debug(log("stop."))
                self._finalize(is_thread_op)
                break
            if len(parsed_data_dict) == 0:
488 489
                continue

B
barrierye 已提交
490
            # process
B
barrierye 已提交
491
            self._profiler_record("midp#{}_0".format(op_info_prefix))
492
            midped_data_dict, err_channeldata_dict \
B
bug fix  
barriery 已提交
493
                    = self._run_process(preped_data_dict, midplog)
B
barrierye 已提交
494
            self._profiler_record("midp#{}_1".format(op_info_prefix))
495 496
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
497
                    self._push_to_output_channels(
498 499 500 501 502 503 504 505 506
                            err_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
                _LOGGER.debug(log("stop."))
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
507
                continue
508 509

            # postprocess
B
barrierye 已提交
510
            self._profiler_record("postp#{}_0".format(op_info_prefix))
511 512
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
bug fix  
barriery 已提交
513
                            parsed_data_dict, midped_data_dict, postplog)
B
barrierye 已提交
514
            self._profiler_record("postp#{}_1".format(op_info_prefix))
515 516
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
517
                    self._push_to_output_channels(
518 519 520 521 522 523 524 525 526
                            error_channeldata,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
            except ChannelStopError:
                _LOGGER.debug(log("stop."))
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
527
                continue
528 529

            # push data to channel (if run succ)
B
barrierye 已提交
530
            try:
531 532 533 534 535 536
                for data_id, postped_data in postped_data_dict.items():
                    self._push_to_output_channels(
                            postped_data,
                            output_channels,
                            client_need_profile=need_profile_dict[data_id],
                            profile_set=profile_dict[data_id])
B
barrierye 已提交
537
            except ChannelStopError:
B
barrierye 已提交
538
                _LOGGER.debug(log("stop."))
539
                self._finalize(is_thread_op)
B
barrierye 已提交
540
                break
B
barriery 已提交
541

B
bug fix  
barriery 已提交
542
    def _initialize(self, is_thread_op, client_type):
B
barriery 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init profiler
                    self._profiler = TimeProfiler()
                    self._profiler.enable(True)
                    # init client
                    self.client = self.init_client(
                            client_type, self._client_config,
                            self._server_endpoints, self._fetch_names)
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
559 560 561 562 563 564 565 566 567 568 569 570
        else:
            self.concurrency_idx = concurrency_idx
            # init profiler
            self._profiler = TimeProfiler()
            self._profiler.enable(True)
            # init client
            self.client = self.init_client(
                    client_type, self._client_config,
                    self._server_endpoints,
                    self._fetch_names)
            # user defined
            self.init_op()
B
barriery 已提交
571 572 573 574 575 576 577 578 579
 
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
580 581 582 583 584

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
585 586 587
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
588
    def __init__(self):
B
barrierye 已提交
589
        # PipelineService.name = "@G"
B
barrierye 已提交
590
        super(RequestOp, self).__init__(name="@G", input_ops=[])
B
barrierye 已提交
591
        # init op
592
        try:
593
            self.init_op()
594
        except Exception as e:
B
bug fix  
barrierye 已提交
595
            _LOGGER.error(e)
596
            os._exit(-1)
B
barrierye 已提交
597 598 599 600

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
601 602 603 604 605 606
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
607 608 609 610 611 612
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
613 614
    def __init__(self, input_ops):
        super(ResponseOp, self).__init__(name="@R", input_ops=input_ops)
B
barrierye 已提交
615
        # init op
616
        try:
617
            self.init_op()
618
        except Exception as e:
B
bug fix  
barrierye 已提交
619
            _LOGGER.error(e)
620
            os._exit(-1)
B
barrierye 已提交
621 622 623 624 625 626 627 628 629

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
630
                np.set_printoptions(threshold=np.nan)
B
barrierye 已提交
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
653 654 655 656 657 658 659


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
660
            name=name, input_ops=None, concurrency=concurrency)
661 662 663 664 665
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
666 667 668 669 670 671 672 673
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

674 675 676 677 678 679
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
680 681
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
682
        self._outputs.append(channel)
D
dongdaxiang 已提交
683

684
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
685
             is_thread_op):
B
barrierye 已提交
686 687 688 689 690 691
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

692
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
693 694 695
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
696 697 698 699
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
B
barrierye 已提交
700
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
701
                break
D
dongdaxiang 已提交
702

B
barrierye 已提交
703 704 705 706 707
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
B
barrierye 已提交
708
                _LOGGER.debug(log("stop."))
B
barrierye 已提交
709
                break