operator.py 43.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
B
barriery 已提交
16
import time
17 18 19 20 21 22
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
23
import os
B
barrierye 已提交
24
import sys
25
import collections
B
barrierye 已提交
26
import numpy as np
B
barrierye 已提交
27
from numpy import *
B
barrierye 已提交
28 29 30 31 32 33
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
34

B
barrierye 已提交
35
from .proto import pipeline_service_pb2
B
barrierye 已提交
36
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
37 38
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
39
from .util import NameGenerator
B
barriery 已提交
40
from .profiler import UnsafeTimeProfiler as TimeProfiler
W
wangjiawei04 已提交
41
from . import local_service_handler
42

43
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
44 45
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
46 47 48

class Op(object):
    def __init__(self,
B
barrierye 已提交
49
                 name=None,
D
dongdaxiang 已提交
50
                 input_ops=[],
B
barriery 已提交
51 52
                 server_endpoints=None,
                 fetch_list=None,
B
barrierye 已提交
53
                 client_config=None,
B
barriery 已提交
54 55 56 57
                 concurrency=None,
                 timeout=None,
                 retry=None,
                 batch_size=None,
58
                 auto_batching_timeout=None,
W
wangjiawei04 已提交
59
                 local_service_handler=None):
B
barriery 已提交
60
        # In __init__, all the parameters are just saved and Op is not initialized
B
barrierye 已提交
61
        if name is None:
B
barrierye 已提交
62
            name = _op_name_gen.next()
63
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
64
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
65
        self.set_input_ops(input_ops)
B
barrierye 已提交
66

W
wangjiawei04 已提交
67
        self._local_service_handler = local_service_handler
B
barriery 已提交
68
        self._server_endpoints = server_endpoints
B
barrierye 已提交
69
        self._fetch_names = fetch_list
B
barriery 已提交
70
        self._client_config = client_config
B
barriery 已提交
71
        self._timeout = timeout
72
        self._retry = max(1, retry)
B
barriery 已提交
73 74 75
        self._batch_size = batch_size
        self._auto_batching_timeout = auto_batching_timeout

76 77
        self._input = None
        self._outputs = []
B
barrierye 已提交
78

B
barriery 已提交
79 80 81 82 83 84 85 86 87
        self._server_use_profile = False
        self._tracer = None

        # only for thread op
        self._for_init_op_lock = threading.Lock()
        self._for_close_op_lock = threading.Lock()
        self._succ_init_op = False
        self._succ_close_op = False

B
barriery 已提交
88 89
    def init_from_dict(self, conf):
        # init op
B
barriery 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        if self.concurrency is None:
            self.concurrency = conf["concurrency"]
        if self._retry is None:
            self._retry = conf["retry"]
        if self._fetch_names is None:
            self._fetch_names = conf.get("fetch_list")
        if self._client_config is None:
            self._client_config = conf.get("client_config")

        if self._timeout is None:
            self._timeout = conf["timeout"]
        if self._timeout > 0:
            self._timeout = self._timeout / 1000.0
        else:
            self._timeout = -1

        if self._batch_size is None:
            self._batch_size = conf["batch_size"]
        if self._auto_batching_timeout is None:
            self._auto_batching_timeout = conf["auto_batching_timeout"]
        if self._auto_batching_timeout <= 0 or self._batch_size == 1:
            _LOGGER.warning(
                self._log(
                    "Because auto_batching_timeout <= 0 or batch_size == 1,"
                    " set auto_batching_timeout to None."))
            self._auto_batching_timeout = None
        else:
            self._auto_batching_timeout = self._auto_batching_timeout / 1000.0

        if self._server_endpoints is None:
            server_endpoints = conf.get("server_endpoints", [])
            if len(server_endpoints) != 0:
                # remote service
                self.with_serving = True
                self._server_endpoints = server_endpoints
125
            else:
W
wangjiawei04 已提交
126
                if self._local_service_handler is None:
B
barriery 已提交
127
                    local_service_conf = conf.get("local_service_conf")
B
barriery 已提交
128 129
                    _LOGGER.info("local_service_conf: {}".format(
                        local_service_conf))
B
barriery 已提交
130
                    model_config = local_service_conf.get("model_config")
W
wangjiawei04 已提交
131
                    self.client_type = local_service_conf.get("client_type")
B
barriery 已提交
132
                    _LOGGER.info("model_config: {}".format(model_config))
B
barriery 已提交
133 134 135 136 137
                    if model_config is None:
                        self.with_serving = False
                    else:
                        # local rpc service
                        self.with_serving = True
W
wangjiawei04 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
                        if self.client_type == "brpc" or self.client_type == "grpc":
                            service_handler = local_service_handler.LocalServiceHandler(
                                model_config=model_config,
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"],
                                mem_optim=local_service_conf["mem_optim"],
                                ir_optim=local_service_conf["ir_optim"])
                            service_handler.prepare_server()  # get fetch_list
                            serivce_ports = service_handler.get_port_list()
                            self._server_endpoints = [
                                "127.0.0.1:{}".format(p) for p in serivce_ports
                            ]
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list(
                                )
                        elif self.client_type == "local_predictor":
                            service_handler = local_service_handler.LocalPredictorServiceHandler(
                                model_config=model_config,
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"])
                            service_handler.prepare_server()  # get fetch_list
                            self.local_predictor = service_handler.get_client()
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list(
                                )
                        self._local_service_handler = service_handler
B
barriery 已提交
172
                else:
B
barriery 已提交
173
                    self.with_serving = True
W
wangjiawei04 已提交
174
                    self._local_service_handler.prepare_server(
B
barriery 已提交
175
                    )  # get fetch_list
W
wangjiawei04 已提交
176
                    serivce_ports = self._local_service_handler.get_port_list()
B
barriery 已提交
177 178 179
                    self._server_endpoints = [
                        "127.0.0.1:{}".format(p) for p in serivce_ports
                    ]
B
barriery 已提交
180
                    if self._client_config is None:
W
wangjiawei04 已提交
181
                        self._client_config = self._local_service_handler.get_client_config(
B
barriery 已提交
182
                        )
B
barriery 已提交
183
                    if self._fetch_names is None:
W
wangjiawei04 已提交
184
                        self._fetch_names = self._local_service_handler.get_fetch_list(
B
barriery 已提交
185
                        )
B
barriery 已提交
186 187
        else:
            self.with_serving = True
B
barriery 已提交
188

189 190 191 192 193 194 195 196 197 198 199
        if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
            _LOGGER.info(
                self._log("\n\tinput_ops: {},"
                          "\n\tserver_endpoints: {}"
                          "\n\tfetch_list: {}"
                          "\n\tclient_config: {}"
                          "\n\tconcurrency: {},"
                          "\n\ttimeout(s): {},"
                          "\n\tretry: {},"
                          "\n\tbatch_size: {},"
                          "\n\tauto_batching_timeout(s): {}".format(
B
barriery 已提交
200
                              ", ".join([op.name for op in self._input_ops
201 202 203 204
                                         ]), self._server_endpoints,
                              self._fetch_names, self._client_config,
                              self.concurrency, self._timeout, self._retry,
                              self._batch_size, self._auto_batching_timeout)))
B
barriery 已提交
205

206
    def launch_local_rpc_service(self):
W
wangjiawei04 已提交
207
        if self._local_service_handler is None:
B
barriery 已提交
208 209
            _LOGGER.warning(
                self._log("Failed to launch local rpc"
W
wangjiawei04 已提交
210
                          " service: local_service_handler is None."))
B
barriery 已提交
211
            return
W
wangjiawei04 已提交
212 213
        port = self._local_service_handler.get_port_list()
        self._local_service_handler.start_server()
B
barriery 已提交
214
        _LOGGER.info("Op({}) use local rpc service at port: {}"
215 216
                     .format(self.name, port))

B
barriery 已提交
217
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
218
        if self._batch_size != 1:
219 220
            _LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
                            .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
221 222
            self._batch_size = 1
        if self._auto_batching_timeout != None:
223
            _LOGGER.warning(
B
barriery 已提交
224 225
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
226
            self._auto_batching_timeout = None
B
barriery 已提交
227

B
barrierye 已提交
228
    def use_profiler(self, use_profile):
B
barrierye 已提交
229
        self._server_use_profile = use_profile
230

B
barriery 已提交
231 232 233
    def set_tracer(self, tracer):
        self._tracer = tracer

W
wangjiawei04 已提交
234
    def init_client(self, client_config, server_endpoints):
235
        if self.with_serving == False:
B
barriery 已提交
236
            _LOGGER.info("Op({}) has no client (and it also do not "
237
                         "run the process function)".format(self.name))
B
barrierye 已提交
238
            return None
W
wangjiawei04 已提交
239
        if self.client_type == 'brpc':
B
barrierye 已提交
240 241
            client = Client()
            client.load_client_config(client_config)
W
wangjiawei04 已提交
242
        elif self.client_type == 'grpc':
B
barrierye 已提交
243
            client = MultiLangClient()
W
wangjiawei04 已提交
244 245 246 247
        elif self.client_type == 'local_predictor':
            if self.local_predictor is None:
                raise ValueError("local predictor not yet created")
            client = self.local_predictor
248
        else:
B
barriery 已提交
249
            raise ValueError("Failed to init client: unknow client "
W
wangjiawei04 已提交
250 251 252
                             "type {}".format(self.client_type))
        if self.client_type != "local_predictor":
            client.connect(server_endpoints)
B
barrierye 已提交
253
        return client
254 255 256 257 258 259 260 261 262 263

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
264
                _LOGGER.critical(
B
barriery 已提交
265 266
                    self._log("Failed to set input_ops: input op "
                              "must be Op type, not {}".format(type(op))))
267
                os._exit(-1)
268
            self._input_ops.append(op)
D
dongdaxiang 已提交
269

270 271
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
272
            _LOGGER.critical(
B
barriery 已提交
273 274 275
                self._log("Failed to set input_channel: input "
                          "channel must be Channel type, not {}".format(
                              type(channel))))
276
            os._exit(-1)
277 278
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
279

280
    def clean_input_channel(self):
B
barrierye 已提交
281 282 283 284
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
285

286 287
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
288
            _LOGGER.critical(
B
barriery 已提交
289 290
                self._log("Failed to add output_channel: output channel "
                          "must be Channel type, not {}".format(type(channel))))
291
            os._exit(-1)
292 293
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
294

295
    def clean_output_channels(self):
B
barrierye 已提交
296 297 298 299 300
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
301
    def preprocess(self, input_dicts):
B
barrierye 已提交
302
        # multiple previous Op
B
barrierye 已提交
303
        if len(input_dicts) != 1:
304 305
            _LOGGER.critical(
                self._log(
B
barriery 已提交
306 307
                    "Failed to run preprocess: this Op has multiple previous "
                    "inputs. Please override this func."))
308
            os._exit(-1)
D
dongdaxiang 已提交
309

B
barrierye 已提交
310 311
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
312

W
wangjiawei04 已提交
313
    def process(self, feed_batch, fetch_names, typical_logid):
B
bug fix  
barriery 已提交
314
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
315
        if err != 0:
316
            _LOGGER.critical(
B
barriery 已提交
317 318
                self._log("Failed to run process: {}. Please override "
                          "preprocess func.".format(err_info)))
319
            os._exit(-1)
W
wangjiawei04 已提交
320 321 322 323 324 325 326 327 328 329 330 331
        if self.client_type == "local_predictor":
            call_result = self.client.predict(
                feed=feed_batch[0],
                fetch=fetch_names,
                batch=True,
                log_id=typical_logid)
        else:
            call_result = self.client.predict(
                feed=feed_batch,
                fetch=fetch_names,
                batch=True,
                log_id=typical_logid)
B
barriery 已提交
332 333 334 335
        if isinstance(self.client, MultiLangClient):
            if call_result is None or call_result["serving_status_code"] != 0:
                return None
            call_result.pop("serving_status_code")
336 337
        return call_result

W
wangjiawei04 已提交
338
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
339
        return fetch_dict
D
dongdaxiang 已提交
340

B
barrierye 已提交
341
    def _parse_channeldata(self, channeldata_dict):
342
        data_id, error_channeldata = None, None
B
barrierye 已提交
343
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
344 345 346 347
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
348
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
349 350 351 352 353 354

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
355
            if client_need_profile:
B
barrierye 已提交
356
                profile_set |= data.profile_data_set
B
barrierye 已提交
357
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
358
                profile_set)
B
barrierye 已提交
359 360 361 362 363

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
364
                                 profile_str=None,
B
barrierye 已提交
365
                                 client_need_profile=False,
B
barrierye 已提交
366
                                 profile_set=None):
367 368
        if name is None:
            name = self.name
B
barrierye 已提交
369

B
barriery 已提交
370
        # add profile into channeldata
B
barrierye 已提交
371
        if client_need_profile and profile_set is not None:
B
barriery 已提交
372 373
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
374
            data.add_profile(profile_set)
B
barrierye 已提交
375

B
barriery 已提交
376 377 378
        for channel in channels:
            channel.push(data, name)

W
wangjiawei04 已提交
379
    def start_with_process(self):
B
barriery 已提交
380 381 382
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
W
wangjiawei04 已提交
383
        process = []
B
barrierye 已提交
384
        for concurrency_idx in range(self.concurrency):
385 386
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
387
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
388
                      self._get_output_channels(), False, trace_buffer))
B
barriery 已提交
389
            p.daemon = True
390
            p.start()
W
wangjiawei04 已提交
391 392
            process.append(p)
        return process
393

W
wangjiawei04 已提交
394
    def start_with_thread(self):
B
barriery 已提交
395 396 397
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
398
        threads = []
B
barrierye 已提交
399
        for concurrency_idx in range(self.concurrency):
400 401
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
402
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
403
                      self._get_output_channels(), True, trace_buffer))
B
barriery 已提交
404 405 406
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
407 408 409 410
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
411
    def init_op(self):
B
barrierye 已提交
412 413
        pass

B
barriery 已提交
414 415
    def _run_preprocess(self, parsed_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
416 417
        preped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
418 419 420 421 422 423
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
424 425 426
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
427 428 429 430 431
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
432 433 434
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
435 436 437 438 439 440 441 442
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
443
        _LOGGER.debug("{} Succ preprocess".format(op_info_prefix))
444 445
        return preped_data_dict, err_channeldata_dict

B
barriery 已提交
446 447
    def _run_process(self, preped_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running process".format(op_info_prefix))
448 449
        midped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
450
        if self.with_serving:
451
            data_ids = preped_data_dict.keys()
B
barriery 已提交
452 453 454 455
            typical_logid = data_ids[0]
            if len(data_ids) != 1:
                for data_id in data_ids:
                    _LOGGER.info(
456 457 458 459
                        "(logid={}) {} During access to PaddleServingService,"
                        " we selected logid={} (from batch: {}) as a "
                        "representative for logging.".format(
                            data_id, op_info_prefix, typical_logid, data_ids))
B
barrierye 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478

            # combine samples to batch
            one_input = preped_data_dict[data_ids[0]]
            feed_batch = []
            input_offset = None
            if isinstance(one_input, dict):
                # sample input
                feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
                input_offset = list(range(len(data_ids) + 1))
            elif isinstance(one_input, list):
                # batch input
                input_offset = [0]
                for data_id in data_ids:
                    batch_input = preped_data_dict[data_id]
                    offset = input_offset[-1] + len(batch_input)
                    feed_batch += batch_input
                    input_offset.append(offset)
            else:
                _LOGGER.critical(
B
barriery 已提交
479 480 481
                    "{} Failed to process: expect input type is dict(sample"
                    " input) or list(batch input), but get {}".format(
                        op_info_prefix, type(one_input)))
B
barrierye 已提交
482 483
                os._exit(-1)

B
bug fix  
barriery 已提交
484
            midped_batch = None
485 486 487
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barriery 已提交
488
                    midped_batch = self.process(feed_batch, typical_logid)
489 490
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
491 492
                    error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                        typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
493
                    _LOGGER.error(error_info, exc_info=True)
494 495 496
            else:
                for i in range(self._retry):
                    try:
497
                        midped_batch = func_timeout.func_timeout(
B
barriery 已提交
498 499 500
                            self._timeout,
                            self.process,
                            args=(feed_batch, typical_logid))
501 502 503
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
B
barriery 已提交
504
                            error_info = "(logid={}) {} Failed to process(batch: {}): " \
B
barriery 已提交
505
                                    "exceeded retry count.".format(
B
barriery 已提交
506
                                            typical_logid, op_info_prefix, data_ids)
507 508
                            _LOGGER.error(error_info)
                        else:
509
                            _LOGGER.warning(
B
barriery 已提交
510 511 512 513
                                "(logid={}) {} Failed to process(batch: {}): timeout,"
                                " and retrying({}/{})...".format(
                                    typical_logid, op_info_prefix, data_ids, i +
                                    1, self._retry))
514 515
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
516 517
                        error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                            typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
518
                        _LOGGER.error(error_info, exc_info=True)
519 520 521 522
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
523 524
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
525
                        ecode=ecode, error_info=error_info, data_id=data_id)
526
            elif midped_batch is None:
527
                # op client return None
B
barriery 已提交
528 529 530 531
                error_info = "(logid={}) {} Failed to predict, please check if " \
                        "PaddleServingService is working properly.".format(
                                typical_logid, op_info_prefix)
                _LOGGER.error(error_info)
532 533
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
534 535 536
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
537 538
            else:
                # transform np format to dict format
B
barrierye 已提交
539 540 541 542 543 544 545 546 547 548
                var_names = midped_batch.keys()
                lod_var_names = set()
                lod_offset_names = set()
                for name in var_names:
                    lod_offset_name = "{}.lod".format(name)
                    if lod_offset_name in var_names:
                        _LOGGER.debug("(logid={}) {} {} is LodTensor".format(
                            typical_logid, op_info_prefix, name))
                        lod_var_names.add(name)
                        lod_offset_names.add(lod_offset_name)
B
barriery 已提交
549

550
                for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
551
                    midped_data_dict[data_id] = {}
B
barriery 已提交
552

B
barrierye 已提交
553 554 555 556 557 558
                for name, value in midped_batch.items():
                    if name in lod_offset_names:
                        continue
                    if name in lod_var_names:
                        # lodtensor
                        lod_offset_name = "{}.lod".format(name)
B
barrierye 已提交
559
                        lod_offset = midped_batch[lod_offset_name]
B
barrierye 已提交
560
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
561 562 563 564
                            data_offset_left = input_offset[idx]
                            data_offset_right = input_offset[idx + 1]
                            lod_offset_left = lod_offset[data_offset_left]
                            lod_offset_right = lod_offset[data_offset_right]
B
barriery 已提交
565 566
                            midped_data_dict[data_id][name] = value[
                                lod_offset_left:lod_offset_right]
B
barrierye 已提交
567 568
                            midped_data_dict[data_id][lod_offset_name] = \
                                    lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
B
barrierye 已提交
569
                    else:
B
barrierye 已提交
570
                        # normal tensor
B
barrierye 已提交
571
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
572 573 574
                            left = input_offset[idx]
                            right = input_offset[idx + 1]
                            midped_data_dict[data_id][name] = value[left:right]
575
        else:
576
            midped_data_dict = preped_data_dict
B
barriery 已提交
577
        _LOGGER.debug("{} Succ process".format(op_info_prefix))
578 579
        return midped_data_dict, err_channeldata_dict

B
barriery 已提交
580 581 582
    def _run_postprocess(self, parsed_data_dict, midped_data_dict,
                         op_info_prefix):
        _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
583 584
        postped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
B
bug fix  
barriery 已提交
585
        for data_id, midped_data in midped_data_dict.items():
586 587
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
588 589
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
590
            except Exception as e:
B
barriery 已提交
591 592 593
                error_info = "(logid={}) {} Failed to postprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
594 595 596 597 598 599 600 601 602
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
B
barriery 已提交
603 604 605 606 607
                    error_info = "(logid={}) {} Failed to postprocess: " \
                            "output of postprocess funticon must be " \
                            "dict type, but get {}".format(
                                data_id, op_info_prefix,
                                type(postped_data))
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
629
        _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
630
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
631 632

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
B
barriery 已提交
633
                                 timeout, op_info_prefix):
B
barriery 已提交
634 635 636 637 638 639 640 641 642 643 644 645
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
646 647
                                _LOGGER.debug("{} Failed to generate batch: "
                                              "timeout".format(op_info_prefix))
B
barriery 已提交
648
                                break
B
barriery 已提交
649 650
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
651 652 653 654
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
655 656
                        _LOGGER.debug("{} Failed to generate batch: "
                                      "timeout".format(op_info_prefix))
B
barriery 已提交
657
                        break
B
barriery 已提交
658 659
            _LOGGER.debug("{} Got actual batch_size: {}".format(op_info_prefix,
                                                                len(batch)))
B
barriery 已提交
660
            yield batch
661

662
    def _parse_channeldata_batch(self, batch, output_channels):
663
        parsed_data_dict = collections.OrderedDict()
664 665
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
666
        for channeldata_dict in batch:
667 668 669 670 671 672 673 674 675 676
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
677 678
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
679 680

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
681

W
wangjiawei04 已提交
682
    def _run(self, concurrency_idx, input_channel, output_channels,
B
barriery 已提交
683
             is_thread_op, trace_buffer):
684
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
685
        tid = threading.current_thread().ident
B
barrierye 已提交
686

B
barrierye 已提交
687
        # init op
B
barriery 已提交
688
        profiler = None
B
barrierye 已提交
689
        try:
W
wangjiawei04 已提交
690
            profiler = self._initialize(is_thread_op, concurrency_idx)
B
barrierye 已提交
691
        except Exception as e:
B
barriery 已提交
692 693 694
            _LOGGER.critical(
                "{} Failed to init op: {}".format(op_info_prefix, e),
                exc_info=True)
B
barrierye 已提交
695
            os._exit(-1)
B
barriery 已提交
696
        _LOGGER.info("{} Succ init".format(op_info_prefix))
697

B
barriery 已提交
698
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
699 700 701 702
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
B
barriery 已提交
703
            op_info_prefix=op_info_prefix)
B
barriery 已提交
704

B
barriery 已提交
705
        start, end = None, None
B
barrierye 已提交
706
        trace_que = collections.deque()
B
barrierye 已提交
707
        while True:
B
barriery 已提交
708
            start = int(round(_time() * 1000000))
B
barrierye 已提交
709
            try:
B
barriery 已提交
710
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
711
            except ChannelStopError:
B
barriery 已提交
712
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
B
barriery 已提交
713
                self._finalize(is_thread_op)
B
barrierye 已提交
714
                break
B
barriery 已提交
715
            end = int(round(_time() * 1000000))
B
barrierye 已提交
716
            in_time = end - start
717

B
barriery 已提交
718 719
            # parse channeldata batch
            try:
720 721 722
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
723
            except ChannelStopError:
B
barriery 已提交
724
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
725
                self._finalize(is_thread_op)
B
barriery 已提交
726
                break
727 728 729
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
730 731

            # preprecess
B
barriery 已提交
732
            start = profiler.record("prep#{}_0".format(op_info_prefix))
733
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
734
                    = self._run_preprocess(parsed_data_dict, op_info_prefix)
B
barriery 已提交
735
            end = profiler.record("prep#{}_1".format(op_info_prefix))
B
barrierye 已提交
736
            prep_time = end - start
737 738
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
739
                    self._push_to_output_channels(
B
barriery 已提交
740 741
                        data=err_channeldata,
                        channels=output_channels,
742 743 744
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
745
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
746 747
                self._finalize(is_thread_op)
                break
B
bug fix  
barrierye 已提交
748
            if len(preped_data_dict) == 0:
749 750
                continue

B
barrierye 已提交
751
            # process
B
barriery 已提交
752
            start = profiler.record("midp#{}_0".format(op_info_prefix))
753
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
754
                    = self._run_process(preped_data_dict, op_info_prefix)
B
barriery 已提交
755
            end = profiler.record("midp#{}_1".format(op_info_prefix))
B
barrierye 已提交
756
            midp_time = end - start
757 758
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
759
                    self._push_to_output_channels(
B
barriery 已提交
760 761
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
762 763
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
764
            except ChannelStopError:
B
barriery 已提交
765
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
766 767 768
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
769
                continue
770 771

            # postprocess
B
barriery 已提交
772
            start = profiler.record("postp#{}_0".format(op_info_prefix))
773 774
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
775
                            parsed_data_dict, midped_data_dict, op_info_prefix)
B
barriery 已提交
776
            end = profiler.record("postp#{}_1".format(op_info_prefix))
B
barrierye 已提交
777
            postp_time = end - start
778 779
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
780
                    self._push_to_output_channels(
B
bug fix  
barrierye 已提交
781
                        data=err_channeldata,
B
barriery 已提交
782
                        channels=output_channels,
B
barriery 已提交
783 784
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
785
            except ChannelStopError:
B
barriery 已提交
786
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
787 788 789
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
790
                continue
791 792

            # push data to channel (if run succ)
B
barriery 已提交
793
            start = int(round(_time() * 1000000))
B
barrierye 已提交
794
            try:
B
barriery 已提交
795
                profile_str = profiler.gen_profile_str()
796
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
797 798
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
799
                    self._push_to_output_channels(
B
barriery 已提交
800 801 802
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
803 804
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
805
            except ChannelStopError:
B
barriery 已提交
806
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
807
                self._finalize(is_thread_op)
B
barrierye 已提交
808
                break
B
barriery 已提交
809
            end = int(round(_time() * 1000000))
B
barrierye 已提交
810
            out_time = end - start
B
barriery 已提交
811
            if trace_buffer is not None:
B
barrierye 已提交
812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828
                trace_que.append({
                    "name": self.name,
                    "actions": {
                        "in": in_time,
                        "prep": prep_time,
                        "midp": midp_time,
                        "postp": postp_time,
                        "out": out_time,
                    }
                })
                while trace_que:
                    info = trace_que[0]
                    try:
                        trace_buffer.put_nowait(info)
                        trace_que.popleft()
                    except Queue.Full:
                        break
B
barriery 已提交
829

W
wangjiawei04 已提交
830
    def _initialize(self, is_thread_op, concurrency_idx):
B
barriery 已提交
831 832 833 834 835 836
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
W
wangjiawei04 已提交
837
                    self.client = self.init_client(self._client_config,
W
wangjiawei04 已提交
838
                                                   self._server_endpoints)
B
barriery 已提交
839 840 841 842
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
843 844 845
        else:
            self.concurrency_idx = concurrency_idx
            # init client
W
wangjiawei04 已提交
846
            self.client = self.init_client(
W
wangjiawei04 已提交
847
                self._client_config, self._server_endpoints)
B
bug fix  
barriery 已提交
848 849
            # user defined
            self.init_op()
B
barriery 已提交
850

B
barriery 已提交
851 852 853 854 855
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
856 857 858 859 860 861 862 863
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
864 865 866 867 868

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
869 870 871
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
872
    def __init__(self):
B
barriery 已提交
873 874
        # PipelineService.name = "@DAGExecutor"
        super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
B
barrierye 已提交
875
        # init op
876
        try:
877
            self.init_op()
878
        except Exception as e:
B
barriery 已提交
879
            _LOGGER.critical("Op(Request) Failed to init: {}".format(e))
880
            os._exit(-1)
B
barrierye 已提交
881 882 883 884

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
885 886
            data = request.value[idx]
            try:
B
barriery 已提交
887 888 889
                evaled_data = eval(data)
                if isinstance(evaled_data, np.ndarray):
                    data = evaled_data
B
barrierye 已提交
890 891 892
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
893 894 895 896 897 898
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
899
    def __init__(self, input_ops):
B
barriery 已提交
900 901
        super(ResponseOp, self).__init__(
            name="@DAGExecutor", input_ops=input_ops)
B
barrierye 已提交
902
        # init op
903
        try:
904
            self.init_op()
905
        except Exception as e:
B
barriery 已提交
906 907
            _LOGGER.critical("Op(ResponseOp) Failed to init: {}".format(
                e, exc_info=True))
908
            os._exit(-1)
B
barrierye 已提交
909 910 911 912 913 914 915 916 917

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
918
                np.set_printoptions(threshold=sys.maxsize)
B
barrierye 已提交
919 920 921 922 923 924 925 926 927 928 929
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
B
barriery 已提交
930 931 932
                        _LOGGER.error("(logid={}) Failed to pack RPC "
                                      "response package: {}".format(
                                          channeldata.id, resp.error_info))
B
barrierye 已提交
933 934 935 936 937 938
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
B
barriery 已提交
939 940 941 942
                    "error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error("(logid={}) Failed to pack RPC response"
                              " package: {}".format(channeldata.id,
                                                    resp.error_info))
B
barrierye 已提交
943 944 945
        else:
            resp.error_info = channeldata.error_info
        return resp
946 947 948 949 950 951 952


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
953
            name=name, input_ops=None, concurrency=concurrency)
954 955 956 957 958
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
959
    def _actual_pred_op_names(self, op):
B
barriery 已提交
960
        # can use disjoint-set, but it's not necessary
B
barrierye 已提交
961 962 963 964 965 966 967
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

968 969
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
970
            _LOGGER.critical(
B
barriery 已提交
971 972 973
                self._log("Failed to add output_channel: output_channel"
                          " must be Channel type, not {}".format(
                              type(channel))))
974
            os._exit(-1)
975
        for op in self._virtual_pred_ops:
B
barrierye 已提交
976 977
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
978
        self._outputs.append(channel)
D
dongdaxiang 已提交
979

980
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
981
             is_thread_op):
982
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
983 984 985
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

986 987 988 989 990 991 992
        batch_generator = self._auto_batching_generator(
            input_channel=input_channel,
            op_name=self.name,
            batch_size=1,
            timeout=None,
            log_func=log)

B
barrierye 已提交
993 994
        while True:
            try:
995
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
996
            except ChannelStopError:
B
barriery 已提交
997
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
998
                self._finalize(is_thread_op)
B
barrierye 已提交
999
                break
D
dongdaxiang 已提交
1000

B
barrierye 已提交
1001
            try:
1002 1003 1004 1005
                for channeldata_dict in channeldata_dict_batch:
                    for name, data in channeldata_dict.items():
                        self._push_to_output_channels(
                            data, channels=output_channels, name=name)
B
barrierye 已提交
1006
            except ChannelStopError:
B
barriery 已提交
1007
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
1008
                self._finalize(is_thread_op)
B
barrierye 已提交
1009
                break