operator.py 43.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
B
barriery 已提交
16
import time
17 18 19 20 21 22
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
23
import os
B
barrierye 已提交
24
import sys
25
import collections
B
barrierye 已提交
26
import numpy as np
B
barrierye 已提交
27
from numpy import *
B
barrierye 已提交
28 29 30 31 32 33
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
34

B
barrierye 已提交
35
from .proto import pipeline_service_pb2
B
barrierye 已提交
36
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
37 38
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
39
from .util import NameGenerator
B
barriery 已提交
40
from .profiler import UnsafeTimeProfiler as TimeProfiler
W
wangjiawei04 已提交
41
from . import local_service_handler
42

43
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
44 45
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
46 47 48

class Op(object):
    def __init__(self,
B
barrierye 已提交
49
                 name=None,
D
dongdaxiang 已提交
50
                 input_ops=[],
B
barriery 已提交
51 52
                 server_endpoints=None,
                 fetch_list=None,
B
barrierye 已提交
53
                 client_config=None,
W
wangjiawei04 已提交
54
                 client_type=None,
B
barriery 已提交
55 56 57 58
                 concurrency=None,
                 timeout=None,
                 retry=None,
                 batch_size=None,
59
                 auto_batching_timeout=None,
W
wangjiawei04 已提交
60
                 local_service_handler=None):
B
barriery 已提交
61
        # In __init__, all the parameters are just saved and Op is not initialized
B
barrierye 已提交
62
        if name is None:
B
barrierye 已提交
63
            name = _op_name_gen.next()
64
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
65
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
66
        self.set_input_ops(input_ops)
B
barrierye 已提交
67

W
wangjiawei04 已提交
68
        self._local_service_handler = local_service_handler
B
barriery 已提交
69
        self._server_endpoints = server_endpoints
B
barrierye 已提交
70
        self._fetch_names = fetch_list
B
barriery 已提交
71
        self._client_config = client_config
W
wangjiawei04 已提交
72
        self.client_type = client_type
B
barriery 已提交
73
        self._timeout = timeout
74
        self._retry = max(1, retry)
B
barriery 已提交
75 76 77
        self._batch_size = batch_size
        self._auto_batching_timeout = auto_batching_timeout

78 79
        self._input = None
        self._outputs = []
B
barrierye 已提交
80

B
barriery 已提交
81 82 83 84 85 86 87 88 89
        self._server_use_profile = False
        self._tracer = None

        # only for thread op
        self._for_init_op_lock = threading.Lock()
        self._for_close_op_lock = threading.Lock()
        self._succ_init_op = False
        self._succ_close_op = False

B
barriery 已提交
90 91
    def init_from_dict(self, conf):
        # init op
B
barriery 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        if self.concurrency is None:
            self.concurrency = conf["concurrency"]
        if self._retry is None:
            self._retry = conf["retry"]
        if self._fetch_names is None:
            self._fetch_names = conf.get("fetch_list")
        if self._client_config is None:
            self._client_config = conf.get("client_config")

        if self._timeout is None:
            self._timeout = conf["timeout"]
        if self._timeout > 0:
            self._timeout = self._timeout / 1000.0
        else:
            self._timeout = -1

        if self._batch_size is None:
            self._batch_size = conf["batch_size"]
        if self._auto_batching_timeout is None:
            self._auto_batching_timeout = conf["auto_batching_timeout"]
        if self._auto_batching_timeout <= 0 or self._batch_size == 1:
            _LOGGER.warning(
                self._log(
                    "Because auto_batching_timeout <= 0 or batch_size == 1,"
                    " set auto_batching_timeout to None."))
            self._auto_batching_timeout = None
        else:
            self._auto_batching_timeout = self._auto_batching_timeout / 1000.0

        if self._server_endpoints is None:
            server_endpoints = conf.get("server_endpoints", [])
            if len(server_endpoints) != 0:
                # remote service
                self.with_serving = True
                self._server_endpoints = server_endpoints
127
            else:
W
wangjiawei04 已提交
128
                if self._local_service_handler is None:
B
barriery 已提交
129
                    local_service_conf = conf.get("local_service_conf")
B
barriery 已提交
130 131
                    _LOGGER.info("local_service_conf: {}".format(
                        local_service_conf))
B
barriery 已提交
132
                    model_config = local_service_conf.get("model_config")
W
wangjiawei04 已提交
133
                    self.client_type = local_service_conf.get("client_type")
B
barriery 已提交
134
                    _LOGGER.info("model_config: {}".format(model_config))
B
barriery 已提交
135 136 137 138 139
                    if model_config is None:
                        self.with_serving = False
                    else:
                        # local rpc service
                        self.with_serving = True
W
wangjiawei04 已提交
140 141 142
                        if self.client_type == "brpc" or self.client_type == "grpc":
                            service_handler = local_service_handler.LocalServiceHandler(
                                model_config=model_config,
W
wangjiawei04 已提交
143
                                client_type=self.client_type,
W
wangjiawei04 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"],
                                mem_optim=local_service_conf["mem_optim"],
                                ir_optim=local_service_conf["ir_optim"])
                            service_handler.prepare_server()  # get fetch_list
                            serivce_ports = service_handler.get_port_list()
                            self._server_endpoints = [
                                "127.0.0.1:{}".format(p) for p in serivce_ports
                            ]
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list(
                                )
                        elif self.client_type == "local_predictor":
W
wangjiawei04 已提交
161
                            service_handler = local_service_handler.LocalServiceHandler(
W
wangjiawei04 已提交
162
                                model_config=model_config,
W
wangjiawei04 已提交
163
                                client_type=self.client_type,
W
wangjiawei04 已提交
164 165 166
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"])
W
wangjiawei04 已提交
167
                            #service_handler.prepare_server()  # get fetch_list
W
wangjiawei04 已提交
168 169 170 171 172 173 174 175
                            self.local_predictor = service_handler.get_client()
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list(
                                )
                        self._local_service_handler = service_handler
B
barriery 已提交
176
                else:
B
barriery 已提交
177
                    self.with_serving = True
W
wangjiawei04 已提交
178
                    self._local_service_handler.prepare_server(
B
barriery 已提交
179
                    )  # get fetch_list
W
wangjiawei04 已提交
180
                    serivce_ports = self._local_service_handler.get_port_list()
B
barriery 已提交
181 182 183
                    self._server_endpoints = [
                        "127.0.0.1:{}".format(p) for p in serivce_ports
                    ]
B
barriery 已提交
184
                    if self._client_config is None:
W
wangjiawei04 已提交
185
                        self._client_config = self._local_service_handler.get_client_config(
B
barriery 已提交
186
                        )
B
barriery 已提交
187
                    if self._fetch_names is None:
W
wangjiawei04 已提交
188
                        self._fetch_names = self._local_service_handler.get_fetch_list(
B
barriery 已提交
189
                        )
B
barriery 已提交
190 191
        else:
            self.with_serving = True
B
barriery 已提交
192

193 194 195 196 197 198 199 200 201 202 203
        if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
            _LOGGER.info(
                self._log("\n\tinput_ops: {},"
                          "\n\tserver_endpoints: {}"
                          "\n\tfetch_list: {}"
                          "\n\tclient_config: {}"
                          "\n\tconcurrency: {},"
                          "\n\ttimeout(s): {},"
                          "\n\tretry: {},"
                          "\n\tbatch_size: {},"
                          "\n\tauto_batching_timeout(s): {}".format(
B
barriery 已提交
204
                              ", ".join([op.name for op in self._input_ops
205 206 207 208
                                         ]), self._server_endpoints,
                              self._fetch_names, self._client_config,
                              self.concurrency, self._timeout, self._retry,
                              self._batch_size, self._auto_batching_timeout)))
B
barriery 已提交
209

210
    def launch_local_rpc_service(self):
W
wangjiawei04 已提交
211
        if self._local_service_handler is None:
B
barriery 已提交
212 213
            _LOGGER.warning(
                self._log("Failed to launch local rpc"
W
wangjiawei04 已提交
214
                          " service: local_service_handler is None."))
B
barriery 已提交
215
            return
W
wangjiawei04 已提交
216
        port = self._local_service_handler.get_port_list()
W
wangjiawei04 已提交
217 218 219
        #if self._local_service_handler.client_type == "local_predictor":
        #    _LOGGER.info("Op({}) use local predictor.")
        #    return
W
wangjiawei04 已提交
220
        self._local_service_handler.start_server()
B
barriery 已提交
221
        _LOGGER.info("Op({}) use local rpc service at port: {}"
222 223
                     .format(self.name, port))

B
barriery 已提交
224
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
225
        if self._batch_size != 1:
226 227
            _LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
                            .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
228 229
            self._batch_size = 1
        if self._auto_batching_timeout != None:
230
            _LOGGER.warning(
B
barriery 已提交
231 232
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
233
            self._auto_batching_timeout = None
B
barriery 已提交
234

B
barrierye 已提交
235
    def use_profiler(self, use_profile):
B
barrierye 已提交
236
        self._server_use_profile = use_profile
237

B
barriery 已提交
238 239 240
    def set_tracer(self, tracer):
        self._tracer = tracer

W
wangjiawei04 已提交
241
    def init_client(self, client_config, server_endpoints):
242
        if self.with_serving == False:
B
barriery 已提交
243
            _LOGGER.info("Op({}) has no client (and it also do not "
244
                         "run the process function)".format(self.name))
B
barrierye 已提交
245
            return None
W
wangjiawei04 已提交
246
        if self.client_type == 'brpc':
B
barrierye 已提交
247 248
            client = Client()
            client.load_client_config(client_config)
W
wangjiawei04 已提交
249
        elif self.client_type == 'grpc':
B
barrierye 已提交
250
            client = MultiLangClient()
W
wangjiawei04 已提交
251 252 253 254
        elif self.client_type == 'local_predictor':
            if self.local_predictor is None:
                raise ValueError("local predictor not yet created")
            client = self.local_predictor
255
        else:
B
barriery 已提交
256
            raise ValueError("Failed to init client: unknow client "
W
wangjiawei04 已提交
257
                             "type {}".format(self.client_type))
W
wangjiawei04 已提交
258 259 260
        if self._fetch_names is None:
            self._fetch_names = client.fetch_names_
            _LOGGER.info("Op({}) has no fetch name set. So fetch all vars")
W
wangjiawei04 已提交
261 262
        if self.client_type != "local_predictor":
            client.connect(server_endpoints)
B
barrierye 已提交
263
        return client
264 265 266 267 268 269 270 271 272 273

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
274
                _LOGGER.critical(
B
barriery 已提交
275 276
                    self._log("Failed to set input_ops: input op "
                              "must be Op type, not {}".format(type(op))))
277
                os._exit(-1)
278
            self._input_ops.append(op)
D
dongdaxiang 已提交
279

280 281
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
282
            _LOGGER.critical(
B
barriery 已提交
283 284 285
                self._log("Failed to set input_channel: input "
                          "channel must be Channel type, not {}".format(
                              type(channel))))
286
            os._exit(-1)
287 288
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
289

290
    def clean_input_channel(self):
B
barrierye 已提交
291 292 293 294
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
295

296 297
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
298
            _LOGGER.critical(
B
barriery 已提交
299 300
                self._log("Failed to add output_channel: output channel "
                          "must be Channel type, not {}".format(type(channel))))
301
            os._exit(-1)
302 303
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
304

305
    def clean_output_channels(self):
B
barrierye 已提交
306 307 308 309 310
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
311
    def preprocess(self, input_dicts):
B
barrierye 已提交
312
        # multiple previous Op
B
barrierye 已提交
313
        if len(input_dicts) != 1:
314 315
            _LOGGER.critical(
                self._log(
B
barriery 已提交
316 317
                    "Failed to run preprocess: this Op has multiple previous "
                    "inputs. Please override this func."))
318
            os._exit(-1)
D
dongdaxiang 已提交
319

B
barrierye 已提交
320 321
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
322

W
wangjiawei04 已提交
323
    def process(self, feed_batch, typical_logid):
B
bug fix  
barriery 已提交
324
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
325
        if err != 0:
326
            _LOGGER.critical(
B
barriery 已提交
327 328
                self._log("Failed to run process: {}. Please override "
                          "preprocess func.".format(err_info)))
329
            os._exit(-1)
W
wangjiawei04 已提交
330 331 332
        if self.client_type == "local_predictor":
            call_result = self.client.predict(
                feed=feed_batch[0],
W
wangjiawei04 已提交
333
                fetch=self._fetch_names,
W
wangjiawei04 已提交
334 335 336 337 338
                batch=True,
                log_id=typical_logid)
        else:
            call_result = self.client.predict(
                feed=feed_batch,
W
wangjiawei04 已提交
339
                fetch=self._fetch_names,
W
wangjiawei04 已提交
340 341
                batch=True,
                log_id=typical_logid)
B
barriery 已提交
342 343 344 345
        if isinstance(self.client, MultiLangClient):
            if call_result is None or call_result["serving_status_code"] != 0:
                return None
            call_result.pop("serving_status_code")
346 347
        return call_result

W
wangjiawei04 已提交
348
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
349
        return fetch_dict
D
dongdaxiang 已提交
350

B
barrierye 已提交
351
    def _parse_channeldata(self, channeldata_dict):
352
        data_id, error_channeldata = None, None
B
barrierye 已提交
353
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
354 355 356 357
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
358
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
359 360 361 362 363 364

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
365
            if client_need_profile:
B
barrierye 已提交
366
                profile_set |= data.profile_data_set
B
barrierye 已提交
367
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
368
                profile_set)
B
barrierye 已提交
369 370 371 372 373

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
374
                                 profile_str=None,
B
barrierye 已提交
375
                                 client_need_profile=False,
B
barrierye 已提交
376
                                 profile_set=None):
377 378
        if name is None:
            name = self.name
B
barrierye 已提交
379

B
barriery 已提交
380
        # add profile into channeldata
B
barrierye 已提交
381
        if client_need_profile and profile_set is not None:
B
barriery 已提交
382 383
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
384
            data.add_profile(profile_set)
B
barrierye 已提交
385

B
barriery 已提交
386 387 388
        for channel in channels:
            channel.push(data, name)

W
wangjiawei04 已提交
389
    def start_with_process(self):
B
barriery 已提交
390 391 392
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
W
wangjiawei04 已提交
393
        process = []
B
barrierye 已提交
394
        for concurrency_idx in range(self.concurrency):
395 396
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
397
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
398
                      self._get_output_channels(), False, trace_buffer))
B
barriery 已提交
399
            p.daemon = True
400
            p.start()
W
wangjiawei04 已提交
401 402
            process.append(p)
        return process
403

W
wangjiawei04 已提交
404
    def start_with_thread(self):
B
barriery 已提交
405 406 407
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
408
        threads = []
B
barrierye 已提交
409
        for concurrency_idx in range(self.concurrency):
410 411
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
412
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
413
                      self._get_output_channels(), True, trace_buffer))
B
barriery 已提交
414 415 416
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
417 418 419 420
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
421
    def init_op(self):
B
barrierye 已提交
422 423
        pass

B
barriery 已提交
424 425
    def _run_preprocess(self, parsed_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
426 427
        preped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
428 429 430 431 432 433
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
434 435 436
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
437 438 439 440 441
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
442 443 444
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
445 446 447 448 449 450 451 452
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
453
        _LOGGER.debug("{} Succ preprocess".format(op_info_prefix))
454 455
        return preped_data_dict, err_channeldata_dict

B
barriery 已提交
456 457
    def _run_process(self, preped_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running process".format(op_info_prefix))
458 459
        midped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
460
        if self.with_serving:
461
            data_ids = preped_data_dict.keys()
B
barriery 已提交
462 463 464 465
            typical_logid = data_ids[0]
            if len(data_ids) != 1:
                for data_id in data_ids:
                    _LOGGER.info(
466 467 468 469
                        "(logid={}) {} During access to PaddleServingService,"
                        " we selected logid={} (from batch: {}) as a "
                        "representative for logging.".format(
                            data_id, op_info_prefix, typical_logid, data_ids))
B
barrierye 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488

            # combine samples to batch
            one_input = preped_data_dict[data_ids[0]]
            feed_batch = []
            input_offset = None
            if isinstance(one_input, dict):
                # sample input
                feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
                input_offset = list(range(len(data_ids) + 1))
            elif isinstance(one_input, list):
                # batch input
                input_offset = [0]
                for data_id in data_ids:
                    batch_input = preped_data_dict[data_id]
                    offset = input_offset[-1] + len(batch_input)
                    feed_batch += batch_input
                    input_offset.append(offset)
            else:
                _LOGGER.critical(
B
barriery 已提交
489 490 491
                    "{} Failed to process: expect input type is dict(sample"
                    " input) or list(batch input), but get {}".format(
                        op_info_prefix, type(one_input)))
B
barrierye 已提交
492 493
                os._exit(-1)

B
bug fix  
barriery 已提交
494
            midped_batch = None
495 496 497
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barriery 已提交
498
                    midped_batch = self.process(feed_batch, typical_logid)
499 500
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
501 502
                    error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                        typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
503
                    _LOGGER.error(error_info, exc_info=True)
504 505 506
            else:
                for i in range(self._retry):
                    try:
507
                        midped_batch = func_timeout.func_timeout(
B
barriery 已提交
508 509 510
                            self._timeout,
                            self.process,
                            args=(feed_batch, typical_logid))
511 512 513
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
B
barriery 已提交
514
                            error_info = "(logid={}) {} Failed to process(batch: {}): " \
B
barriery 已提交
515
                                    "exceeded retry count.".format(
B
barriery 已提交
516
                                            typical_logid, op_info_prefix, data_ids)
517 518
                            _LOGGER.error(error_info)
                        else:
519
                            _LOGGER.warning(
B
barriery 已提交
520 521 522 523
                                "(logid={}) {} Failed to process(batch: {}): timeout,"
                                " and retrying({}/{})...".format(
                                    typical_logid, op_info_prefix, data_ids, i +
                                    1, self._retry))
524 525
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
526 527
                        error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                            typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
528
                        _LOGGER.error(error_info, exc_info=True)
529 530 531 532
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
533 534
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
535
                        ecode=ecode, error_info=error_info, data_id=data_id)
536
            elif midped_batch is None:
537
                # op client return None
B
barriery 已提交
538 539 540 541
                error_info = "(logid={}) {} Failed to predict, please check if " \
                        "PaddleServingService is working properly.".format(
                                typical_logid, op_info_prefix)
                _LOGGER.error(error_info)
542 543
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
544 545 546
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
547 548
            else:
                # transform np format to dict format
B
barrierye 已提交
549 550 551 552 553 554 555 556 557 558
                var_names = midped_batch.keys()
                lod_var_names = set()
                lod_offset_names = set()
                for name in var_names:
                    lod_offset_name = "{}.lod".format(name)
                    if lod_offset_name in var_names:
                        _LOGGER.debug("(logid={}) {} {} is LodTensor".format(
                            typical_logid, op_info_prefix, name))
                        lod_var_names.add(name)
                        lod_offset_names.add(lod_offset_name)
B
barriery 已提交
559

560
                for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
561
                    midped_data_dict[data_id] = {}
B
barriery 已提交
562

B
barrierye 已提交
563 564 565 566 567 568
                for name, value in midped_batch.items():
                    if name in lod_offset_names:
                        continue
                    if name in lod_var_names:
                        # lodtensor
                        lod_offset_name = "{}.lod".format(name)
B
barrierye 已提交
569
                        lod_offset = midped_batch[lod_offset_name]
B
barrierye 已提交
570
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
571 572 573 574
                            data_offset_left = input_offset[idx]
                            data_offset_right = input_offset[idx + 1]
                            lod_offset_left = lod_offset[data_offset_left]
                            lod_offset_right = lod_offset[data_offset_right]
B
barriery 已提交
575 576
                            midped_data_dict[data_id][name] = value[
                                lod_offset_left:lod_offset_right]
B
barrierye 已提交
577 578
                            midped_data_dict[data_id][lod_offset_name] = \
                                    lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
B
barrierye 已提交
579
                    else:
B
barrierye 已提交
580
                        # normal tensor
B
barrierye 已提交
581
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
582 583 584
                            left = input_offset[idx]
                            right = input_offset[idx + 1]
                            midped_data_dict[data_id][name] = value[left:right]
585
        else:
586
            midped_data_dict = preped_data_dict
B
barriery 已提交
587
        _LOGGER.debug("{} Succ process".format(op_info_prefix))
588 589
        return midped_data_dict, err_channeldata_dict

B
barriery 已提交
590 591 592
    def _run_postprocess(self, parsed_data_dict, midped_data_dict,
                         op_info_prefix):
        _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
593 594
        postped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
B
bug fix  
barriery 已提交
595
        for data_id, midped_data in midped_data_dict.items():
596 597
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
598 599
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
600
            except Exception as e:
B
barriery 已提交
601 602 603
                error_info = "(logid={}) {} Failed to postprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
604 605 606 607 608 609 610 611 612
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
B
barriery 已提交
613 614 615 616 617
                    error_info = "(logid={}) {} Failed to postprocess: " \
                            "output of postprocess funticon must be " \
                            "dict type, but get {}".format(
                                data_id, op_info_prefix,
                                type(postped_data))
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
639
        _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
640
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
641 642

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
B
barriery 已提交
643
                                 timeout, op_info_prefix):
B
barriery 已提交
644 645 646 647 648 649 650 651 652 653 654 655
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
656 657
                                _LOGGER.debug("{} Failed to generate batch: "
                                              "timeout".format(op_info_prefix))
B
barriery 已提交
658
                                break
B
barriery 已提交
659 660
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
661 662 663 664
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
665 666
                        _LOGGER.debug("{} Failed to generate batch: "
                                      "timeout".format(op_info_prefix))
B
barriery 已提交
667
                        break
B
barriery 已提交
668 669
            _LOGGER.debug("{} Got actual batch_size: {}".format(op_info_prefix,
                                                                len(batch)))
B
barriery 已提交
670
            yield batch
671

672
    def _parse_channeldata_batch(self, batch, output_channels):
673
        parsed_data_dict = collections.OrderedDict()
674 675
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
676
        for channeldata_dict in batch:
677 678 679 680 681 682 683 684 685 686
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
687 688
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
689 690

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
691

W
wangjiawei04 已提交
692
    def _run(self, concurrency_idx, input_channel, output_channels,
B
barriery 已提交
693
             is_thread_op, trace_buffer):
694
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
695
        tid = threading.current_thread().ident
B
barrierye 已提交
696

B
barrierye 已提交
697
        # init op
B
barriery 已提交
698
        profiler = None
B
barrierye 已提交
699
        try:
W
wangjiawei04 已提交
700
            profiler = self._initialize(is_thread_op, concurrency_idx)
B
barrierye 已提交
701
        except Exception as e:
B
barriery 已提交
702 703 704
            _LOGGER.critical(
                "{} Failed to init op: {}".format(op_info_prefix, e),
                exc_info=True)
B
barrierye 已提交
705
            os._exit(-1)
B
barriery 已提交
706
        _LOGGER.info("{} Succ init".format(op_info_prefix))
707

B
barriery 已提交
708
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
709 710 711 712
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
B
barriery 已提交
713
            op_info_prefix=op_info_prefix)
B
barriery 已提交
714

B
barriery 已提交
715
        start, end = None, None
B
barrierye 已提交
716
        trace_que = collections.deque()
B
barrierye 已提交
717
        while True:
B
barriery 已提交
718
            start = int(round(_time() * 1000000))
B
barrierye 已提交
719
            try:
B
barriery 已提交
720
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
721
            except ChannelStopError:
B
barriery 已提交
722
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
B
barriery 已提交
723
                self._finalize(is_thread_op)
B
barrierye 已提交
724
                break
B
barriery 已提交
725
            end = int(round(_time() * 1000000))
B
barrierye 已提交
726
            in_time = end - start
727

B
barriery 已提交
728 729
            # parse channeldata batch
            try:
730 731 732
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
733
            except ChannelStopError:
B
barriery 已提交
734
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
735
                self._finalize(is_thread_op)
B
barriery 已提交
736
                break
737 738 739
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
740 741

            # preprecess
B
barriery 已提交
742
            start = profiler.record("prep#{}_0".format(op_info_prefix))
743
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
744
                    = self._run_preprocess(parsed_data_dict, op_info_prefix)
B
barriery 已提交
745
            end = profiler.record("prep#{}_1".format(op_info_prefix))
B
barrierye 已提交
746
            prep_time = end - start
747 748
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
749
                    self._push_to_output_channels(
B
barriery 已提交
750 751
                        data=err_channeldata,
                        channels=output_channels,
752 753 754
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
755
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
756 757
                self._finalize(is_thread_op)
                break
B
bug fix  
barrierye 已提交
758
            if len(preped_data_dict) == 0:
759 760
                continue

B
barrierye 已提交
761
            # process
B
barriery 已提交
762
            start = profiler.record("midp#{}_0".format(op_info_prefix))
763
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
764
                    = self._run_process(preped_data_dict, op_info_prefix)
B
barriery 已提交
765
            end = profiler.record("midp#{}_1".format(op_info_prefix))
B
barrierye 已提交
766
            midp_time = end - start
767 768
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
769
                    self._push_to_output_channels(
B
barriery 已提交
770 771
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
772 773
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
774
            except ChannelStopError:
B
barriery 已提交
775
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
776 777 778
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
779
                continue
780 781

            # postprocess
B
barriery 已提交
782
            start = profiler.record("postp#{}_0".format(op_info_prefix))
783 784
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
785
                            parsed_data_dict, midped_data_dict, op_info_prefix)
B
barriery 已提交
786
            end = profiler.record("postp#{}_1".format(op_info_prefix))
B
barrierye 已提交
787
            postp_time = end - start
788 789
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
790
                    self._push_to_output_channels(
B
bug fix  
barrierye 已提交
791
                        data=err_channeldata,
B
barriery 已提交
792
                        channels=output_channels,
B
barriery 已提交
793 794
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
795
            except ChannelStopError:
B
barriery 已提交
796
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
797 798 799
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
800
                continue
801 802

            # push data to channel (if run succ)
B
barriery 已提交
803
            start = int(round(_time() * 1000000))
B
barrierye 已提交
804
            try:
B
barriery 已提交
805
                profile_str = profiler.gen_profile_str()
806
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
807 808
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
809
                    self._push_to_output_channels(
B
barriery 已提交
810 811 812
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
813 814
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
815
            except ChannelStopError:
B
barriery 已提交
816
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
817
                self._finalize(is_thread_op)
B
barrierye 已提交
818
                break
B
barriery 已提交
819
            end = int(round(_time() * 1000000))
B
barrierye 已提交
820
            out_time = end - start
B
barriery 已提交
821
            if trace_buffer is not None:
B
barrierye 已提交
822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
                trace_que.append({
                    "name": self.name,
                    "actions": {
                        "in": in_time,
                        "prep": prep_time,
                        "midp": midp_time,
                        "postp": postp_time,
                        "out": out_time,
                    }
                })
                while trace_que:
                    info = trace_que[0]
                    try:
                        trace_buffer.put_nowait(info)
                        trace_que.popleft()
                    except Queue.Full:
                        break
B
barriery 已提交
839

W
wangjiawei04 已提交
840
    def _initialize(self, is_thread_op, concurrency_idx):
B
barriery 已提交
841 842 843 844 845 846
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
W
wangjiawei04 已提交
847
                    self.client = self.init_client(self._client_config,
W
wangjiawei04 已提交
848
                                                   self._server_endpoints)
B
barriery 已提交
849 850 851 852
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
853 854 855
        else:
            self.concurrency_idx = concurrency_idx
            # init client
W
wangjiawei04 已提交
856 857
            self.client = self.init_client(self._client_config,
                                           self._server_endpoints)
B
bug fix  
barriery 已提交
858 859
            # user defined
            self.init_op()
B
barriery 已提交
860

B
barriery 已提交
861 862 863 864 865
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
866 867 868 869 870 871 872 873
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
874 875 876 877 878

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
879 880 881
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
882
    def __init__(self):
B
barriery 已提交
883 884
        # PipelineService.name = "@DAGExecutor"
        super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
B
barrierye 已提交
885
        # init op
886
        try:
887
            self.init_op()
888
        except Exception as e:
B
barriery 已提交
889
            _LOGGER.critical("Op(Request) Failed to init: {}".format(e))
890
            os._exit(-1)
B
barrierye 已提交
891 892 893 894

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
895 896
            data = request.value[idx]
            try:
B
barriery 已提交
897 898 899
                evaled_data = eval(data)
                if isinstance(evaled_data, np.ndarray):
                    data = evaled_data
B
barrierye 已提交
900 901 902
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
903 904 905 906 907 908
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
909
    def __init__(self, input_ops):
B
barriery 已提交
910 911
        super(ResponseOp, self).__init__(
            name="@DAGExecutor", input_ops=input_ops)
B
barrierye 已提交
912
        # init op
913
        try:
914
            self.init_op()
915
        except Exception as e:
B
barriery 已提交
916 917
            _LOGGER.critical("Op(ResponseOp) Failed to init: {}".format(
                e, exc_info=True))
918
            os._exit(-1)
B
barrierye 已提交
919 920 921 922 923 924 925 926 927

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
928
                np.set_printoptions(threshold=sys.maxsize)
B
barrierye 已提交
929 930 931 932 933 934 935 936 937 938 939
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
B
barriery 已提交
940 941 942
                        _LOGGER.error("(logid={}) Failed to pack RPC "
                                      "response package: {}".format(
                                          channeldata.id, resp.error_info))
B
barrierye 已提交
943 944 945 946 947 948
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
B
barriery 已提交
949 950 951 952
                    "error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error("(logid={}) Failed to pack RPC response"
                              " package: {}".format(channeldata.id,
                                                    resp.error_info))
B
barrierye 已提交
953 954 955
        else:
            resp.error_info = channeldata.error_info
        return resp
956 957 958 959 960 961 962


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
963
            name=name, input_ops=None, concurrency=concurrency)
964 965 966 967 968
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
969
    def _actual_pred_op_names(self, op):
B
barriery 已提交
970
        # can use disjoint-set, but it's not necessary
B
barrierye 已提交
971 972 973 974 975 976 977
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

978 979
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
980
            _LOGGER.critical(
B
barriery 已提交
981 982 983
                self._log("Failed to add output_channel: output_channel"
                          " must be Channel type, not {}".format(
                              type(channel))))
984
            os._exit(-1)
985
        for op in self._virtual_pred_ops:
B
barrierye 已提交
986 987
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
988
        self._outputs.append(channel)
D
dongdaxiang 已提交
989

990
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
991
             is_thread_op):
992
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
993 994 995
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

996 997 998 999 1000 1001 1002
        batch_generator = self._auto_batching_generator(
            input_channel=input_channel,
            op_name=self.name,
            batch_size=1,
            timeout=None,
            log_func=log)

B
barrierye 已提交
1003 1004
        while True:
            try:
1005
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
1006
            except ChannelStopError:
B
barriery 已提交
1007
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
1008
                self._finalize(is_thread_op)
B
barrierye 已提交
1009
                break
D
dongdaxiang 已提交
1010

B
barrierye 已提交
1011
            try:
1012 1013 1014 1015
                for channeldata_dict in channeldata_dict_batch:
                    for name, data in channeldata_dict.items():
                        self._push_to_output_channels(
                            data, channels=output_channels, name=name)
B
barrierye 已提交
1016
            except ChannelStopError:
B
barriery 已提交
1017
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
1018
                self._finalize(is_thread_op)
B
barrierye 已提交
1019
                break