operator.py 43.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
B
barriery 已提交
16
import time
17 18 19 20 21 22
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
23
import os
B
barrierye 已提交
24
import sys
25
import collections
B
barrierye 已提交
26
import numpy as np
B
barrierye 已提交
27
from numpy import *
B
barrierye 已提交
28 29 30 31 32 33
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
34

B
barrierye 已提交
35
from .proto import pipeline_service_pb2
B
barrierye 已提交
36
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
37 38
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
39
from .util import NameGenerator
B
barriery 已提交
40
from .profiler import UnsafeTimeProfiler as TimeProfiler
B
barriery 已提交
41
from . import local_rpc_service_handler
42

43
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
44 45
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
46 47 48

class Op(object):
    def __init__(self,
B
barrierye 已提交
49
                 name=None,
D
dongdaxiang 已提交
50
                 input_ops=[],
B
barriery 已提交
51 52
                 server_endpoints=None,
                 fetch_list=None,
B
barrierye 已提交
53
                 client_config=None,
B
barriery 已提交
54 55 56 57
                 concurrency=None,
                 timeout=None,
                 retry=None,
                 batch_size=None,
58
                 auto_batching_timeout=None,
B
barriery 已提交
59
                 local_rpc_service_handler=None):
B
barriery 已提交
60
        # In __init__, all the parameters are just saved and Op is not initialized
B
barrierye 已提交
61
        if name is None:
B
barrierye 已提交
62
            name = _op_name_gen.next()
63
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
64
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
65
        self.set_input_ops(input_ops)
B
barrierye 已提交
66

B
barriery 已提交
67 68
        self._local_rpc_service_handler = local_rpc_service_handler
        self._server_endpoints = server_endpoints
B
barrierye 已提交
69
        self._fetch_names = fetch_list
B
barriery 已提交
70
        self._client_config = client_config
B
barriery 已提交
71
        self._timeout = timeout
72
        self._retry = max(1, retry)
B
barriery 已提交
73 74 75
        self._batch_size = batch_size
        self._auto_batching_timeout = auto_batching_timeout

76 77
        self._input = None
        self._outputs = []
B
barrierye 已提交
78

B
barriery 已提交
79 80 81 82 83 84 85 86 87
        self._server_use_profile = False
        self._tracer = None

        # only for thread op
        self._for_init_op_lock = threading.Lock()
        self._for_close_op_lock = threading.Lock()
        self._succ_init_op = False
        self._succ_close_op = False

B
barriery 已提交
88 89
    def init_from_dict(self, conf):
        # init op
B
barriery 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        if self.concurrency is None:
            self.concurrency = conf["concurrency"]
        if self._retry is None:
            self._retry = conf["retry"]
        if self._fetch_names is None:
            self._fetch_names = conf.get("fetch_list")
        if self._client_config is None:
            self._client_config = conf.get("client_config")

        if self._timeout is None:
            self._timeout = conf["timeout"]
        if self._timeout > 0:
            self._timeout = self._timeout / 1000.0
        else:
            self._timeout = -1

        if self._batch_size is None:
            self._batch_size = conf["batch_size"]
        if self._auto_batching_timeout is None:
            self._auto_batching_timeout = conf["auto_batching_timeout"]
        if self._auto_batching_timeout <= 0 or self._batch_size == 1:
            _LOGGER.warning(
                self._log(
                    "Because auto_batching_timeout <= 0 or batch_size == 1,"
                    " set auto_batching_timeout to None."))
            self._auto_batching_timeout = None
        else:
            self._auto_batching_timeout = self._auto_batching_timeout / 1000.0

        if self._server_endpoints is None:
            server_endpoints = conf.get("server_endpoints", [])
            if len(server_endpoints) != 0:
                # remote service
                self.with_serving = True
                self._server_endpoints = server_endpoints
125
            else:
B
barriery 已提交
126 127
                if self._local_rpc_service_handler is None:
                    local_service_conf = conf.get("local_service_conf")
B
barriery 已提交
128 129
                    _LOGGER.info("local_service_conf: {}".format(
                        local_service_conf))
B
barriery 已提交
130
                    model_config = local_service_conf.get("model_config")
W
wangjiawei04 已提交
131
                    self.client_type = local_service_conf.get("client_type")
B
barriery 已提交
132
                    _LOGGER.info("model_config: {}".format(model_config))
B
barriery 已提交
133 134 135 136 137
                    if model_config is None:
                        self.with_serving = False
                    else:
                        # local rpc service
                        self.with_serving = True
W
wangjiawei04 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
                        if self.client_type == "brpc" or self.client_type == "grpc":
                            service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
                                model_config=model_config,
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"],
                                mem_optim=local_service_conf["mem_optim"],
                                ir_optim=local_service_conf["ir_optim"])
                            service_handler.prepare_server()  # get fetch_list
                            serivce_ports = service_handler.get_port_list()
                            self._server_endpoints = [
                                "127.0.0.1:{}".format(p) for p in serivce_ports
                            ]
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list()
                        elif self.client_type == "local_predictor":
                            service_handler = local_rpc_service_handler.LocalPredictorServiceHandler(
                                model_config=model_config,
                                workdir=local_service_conf["workdir"],
                                thread_num=local_service_conf["thread_num"],
                                devices=local_service_conf["devices"])
                            service_handler.prepare_server()  # get fetch_list
                            self.local_predictor = service_handler.get_client()
                            if self._client_config is None:
                                self._client_config = service_handler.get_client_config(
                                )
                            if self._fetch_names is None:
                                self._fetch_names = service_handler.get_fetch_list()
B
barriery 已提交
169 170
                        self._local_rpc_service_handler = service_handler
                else:
B
barriery 已提交
171
                    self.with_serving = True
B
barriery 已提交
172 173 174 175 176 177 178
                    self._local_rpc_service_handler.prepare_server(
                    )  # get fetch_list
                    serivce_ports = self._local_rpc_service_handler.get_port_list(
                    )
                    self._server_endpoints = [
                        "127.0.0.1:{}".format(p) for p in serivce_ports
                    ]
B
barriery 已提交
179 180
                    if self._client_config is None:
                        self._client_config = self._local_rpc_service_handler.get_client_config(
B
barriery 已提交
181
                        )
B
barriery 已提交
182 183
                    if self._fetch_names is None:
                        self._fetch_names = self._local_rpc_service_handler.get_fetch_list(
B
barriery 已提交
184
                        )
B
barriery 已提交
185 186
        else:
            self.with_serving = True
B
barriery 已提交
187

188 189 190 191 192 193 194 195 196 197 198
        if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
            _LOGGER.info(
                self._log("\n\tinput_ops: {},"
                          "\n\tserver_endpoints: {}"
                          "\n\tfetch_list: {}"
                          "\n\tclient_config: {}"
                          "\n\tconcurrency: {},"
                          "\n\ttimeout(s): {},"
                          "\n\tretry: {},"
                          "\n\tbatch_size: {},"
                          "\n\tauto_batching_timeout(s): {}".format(
B
barriery 已提交
199
                              ", ".join([op.name for op in self._input_ops
200 201 202 203
                                         ]), self._server_endpoints,
                              self._fetch_names, self._client_config,
                              self.concurrency, self._timeout, self._retry,
                              self._batch_size, self._auto_batching_timeout)))
B
barriery 已提交
204

205
    def launch_local_rpc_service(self):
B
barriery 已提交
206 207 208 209 210 211 212 213
        if self._local_rpc_service_handler is None:
            _LOGGER.warning(
                self._log("Failed to launch local rpc"
                          " service: local_rpc_service_handler is None."))
            return
        port = self._local_rpc_service_handler.get_port_list()
        self._local_rpc_service_handler.start_server()
        _LOGGER.info("Op({}) use local rpc service at port: {}"
214 215
                     .format(self.name, port))

B
barriery 已提交
216
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
217
        if self._batch_size != 1:
218 219
            _LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
                            .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
220 221
            self._batch_size = 1
        if self._auto_batching_timeout != None:
222
            _LOGGER.warning(
B
barriery 已提交
223 224
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
225
            self._auto_batching_timeout = None
B
barriery 已提交
226

B
barrierye 已提交
227
    def use_profiler(self, use_profile):
B
barrierye 已提交
228
        self._server_use_profile = use_profile
229

B
barriery 已提交
230 231 232
    def set_tracer(self, tracer):
        self._tracer = tracer

W
wangjiawei04 已提交
233
    def init_client(self, client_config, server_endpoints,
B
barrierye 已提交
234
                    fetch_names):
W
wangjiawei04 已提交
235
        print("init client", fetch_names)
236
        if self.with_serving == False:
B
barriery 已提交
237
            _LOGGER.info("Op({}) has no client (and it also do not "
238
                         "run the process function)".format(self.name))
B
barrierye 已提交
239
            return None
W
wangjiawei04 已提交
240
        if self.client_type == 'brpc':
B
barrierye 已提交
241 242
            client = Client()
            client.load_client_config(client_config)
W
wangjiawei04 已提交
243
        elif self.client_type == 'grpc':
B
barrierye 已提交
244
            client = MultiLangClient()
W
wangjiawei04 已提交
245 246 247 248
        elif self.client_type == 'local_predictor':
            if self.local_predictor is None:
                raise ValueError("local predictor not yet created")
            client = self.local_predictor
249
        else:
B
barriery 已提交
250
            raise ValueError("Failed to init client: unknow client "
W
wangjiawei04 已提交
251 252 253
                             "type {}".format(self.client_type))
        if self.client_type != "local_predictor":
            client.connect(server_endpoints)
254
        self._fetch_names = fetch_names
B
barrierye 已提交
255
        return client
256 257 258 259 260 261 262 263 264 265

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
266
                _LOGGER.critical(
B
barriery 已提交
267 268
                    self._log("Failed to set input_ops: input op "
                              "must be Op type, not {}".format(type(op))))
269
                os._exit(-1)
270
            self._input_ops.append(op)
D
dongdaxiang 已提交
271

272 273
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
274
            _LOGGER.critical(
B
barriery 已提交
275 276 277
                self._log("Failed to set input_channel: input "
                          "channel must be Channel type, not {}".format(
                              type(channel))))
278
            os._exit(-1)
279 280
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
281

282
    def clean_input_channel(self):
B
barrierye 已提交
283 284 285 286
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
287

288 289
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
290
            _LOGGER.critical(
B
barriery 已提交
291 292
                self._log("Failed to add output_channel: output channel "
                          "must be Channel type, not {}".format(type(channel))))
293
            os._exit(-1)
294 295
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
296

297
    def clean_output_channels(self):
B
barrierye 已提交
298 299 300 301 302
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
303
    def preprocess(self, input_dicts):
B
barrierye 已提交
304
        # multiple previous Op
B
barrierye 已提交
305
        if len(input_dicts) != 1:
306 307
            _LOGGER.critical(
                self._log(
B
barriery 已提交
308 309
                    "Failed to run preprocess: this Op has multiple previous "
                    "inputs. Please override this func."))
310
            os._exit(-1)
D
dongdaxiang 已提交
311

B
barrierye 已提交
312 313
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
314

B
barriery 已提交
315
    def process(self, feed_batch, typical_logid):
W
wangjiawei04 已提交
316
        print("now we start process")
B
bug fix  
barriery 已提交
317
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
318
        if err != 0:
319
            _LOGGER.critical(
B
barriery 已提交
320 321
                self._log("Failed to run process: {}. Please override "
                          "preprocess func.".format(err_info)))
322
            os._exit(-1)
W
wangjiawei04 已提交
323 324 325 326 327 328
        if self.client_type == "local_predictor":
            call_result = self.client.predict(feed=feed_batch[0], fetch=self._fetch_names, log_id=typical_logid)
        else:
            call_result = self.client.predict(
                feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
        print("now we end predict")
B
barriery 已提交
329 330 331 332
        if isinstance(self.client, MultiLangClient):
            if call_result is None or call_result["serving_status_code"] != 0:
                return None
            call_result.pop("serving_status_code")
333 334
        return call_result

W
wangjiawei04 已提交
335
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
336
        return fetch_dict
D
dongdaxiang 已提交
337

B
barrierye 已提交
338
    def _parse_channeldata(self, channeldata_dict):
339
        data_id, error_channeldata = None, None
B
barrierye 已提交
340
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
341 342 343 344
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
345
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
346 347 348 349 350 351

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
352
            if client_need_profile:
B
barrierye 已提交
353
                profile_set |= data.profile_data_set
B
barrierye 已提交
354
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
355
                profile_set)
B
barrierye 已提交
356 357 358 359 360

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
361
                                 profile_str=None,
B
barrierye 已提交
362
                                 client_need_profile=False,
B
barrierye 已提交
363
                                 profile_set=None):
364 365
        if name is None:
            name = self.name
B
barrierye 已提交
366

B
barriery 已提交
367
        # add profile into channeldata
B
barrierye 已提交
368
        if client_need_profile and profile_set is not None:
B
barriery 已提交
369 370
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
371
            data.add_profile(profile_set)
B
barrierye 已提交
372

B
barriery 已提交
373 374 375
        for channel in channels:
            channel.push(data, name)

W
wangjiawei04 已提交
376
    def start_with_process(self):
B
barriery 已提交
377 378 379
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
W
wangjiawei04 已提交
380
        process= []
B
barrierye 已提交
381
        for concurrency_idx in range(self.concurrency):
382 383
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
384
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
385
                      self._get_output_channels(), False,
B
barriery 已提交
386
                      trace_buffer))
B
barriery 已提交
387
            p.daemon = True
388
            p.start()
W
wangjiawei04 已提交
389 390
            process.append(p)
        return process
391

W
wangjiawei04 已提交
392
    def start_with_thread(self):
B
barriery 已提交
393 394 395
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
396
        threads = []
B
barrierye 已提交
397
        for concurrency_idx in range(self.concurrency):
398 399
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
400
                args=(concurrency_idx, self._get_input_channel(),
W
wangjiawei04 已提交
401
                      self._get_output_channels(), True,
B
barriery 已提交
402
                      trace_buffer))
B
barriery 已提交
403 404 405
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
406 407 408 409
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
410
    def init_op(self):
B
barrierye 已提交
411 412
        pass

B
barriery 已提交
413 414
    def _run_preprocess(self, parsed_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
415 416
        preped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
417 418 419 420 421 422
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
423 424 425
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
426 427 428 429 430
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
431 432 433
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
434 435 436 437 438 439 440 441
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
442
        _LOGGER.debug("{} Succ preprocess".format(op_info_prefix))
443 444
        return preped_data_dict, err_channeldata_dict

B
barriery 已提交
445 446
    def _run_process(self, preped_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running process".format(op_info_prefix))
447 448
        midped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
449
        if self.with_serving:
450
            data_ids = preped_data_dict.keys()
B
barriery 已提交
451 452 453 454
            typical_logid = data_ids[0]
            if len(data_ids) != 1:
                for data_id in data_ids:
                    _LOGGER.info(
455 456 457 458
                        "(logid={}) {} During access to PaddleServingService,"
                        " we selected logid={} (from batch: {}) as a "
                        "representative for logging.".format(
                            data_id, op_info_prefix, typical_logid, data_ids))
B
barrierye 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477

            # combine samples to batch
            one_input = preped_data_dict[data_ids[0]]
            feed_batch = []
            input_offset = None
            if isinstance(one_input, dict):
                # sample input
                feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
                input_offset = list(range(len(data_ids) + 1))
            elif isinstance(one_input, list):
                # batch input
                input_offset = [0]
                for data_id in data_ids:
                    batch_input = preped_data_dict[data_id]
                    offset = input_offset[-1] + len(batch_input)
                    feed_batch += batch_input
                    input_offset.append(offset)
            else:
                _LOGGER.critical(
B
barriery 已提交
478 479 480
                    "{} Failed to process: expect input type is dict(sample"
                    " input) or list(batch input), but get {}".format(
                        op_info_prefix, type(one_input)))
B
barrierye 已提交
481 482
                os._exit(-1)

B
bug fix  
barriery 已提交
483
            midped_batch = None
484 485 486
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barriery 已提交
487
                    midped_batch = self.process(feed_batch, typical_logid)
488 489
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
490 491
                    error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                        typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
492
                    _LOGGER.error(error_info, exc_info=True)
493 494 495
            else:
                for i in range(self._retry):
                    try:
496
                        midped_batch = func_timeout.func_timeout(
B
barriery 已提交
497 498 499
                            self._timeout,
                            self.process,
                            args=(feed_batch, typical_logid))
500 501 502
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
B
barriery 已提交
503
                            error_info = "(logid={}) {} Failed to process(batch: {}): " \
B
barriery 已提交
504
                                    "exceeded retry count.".format(
B
barriery 已提交
505
                                            typical_logid, op_info_prefix, data_ids)
506 507
                            _LOGGER.error(error_info)
                        else:
508
                            _LOGGER.warning(
B
barriery 已提交
509 510 511 512
                                "(logid={}) {} Failed to process(batch: {}): timeout,"
                                " and retrying({}/{})...".format(
                                    typical_logid, op_info_prefix, data_ids, i +
                                    1, self._retry))
513 514
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
515 516
                        error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                            typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
517
                        _LOGGER.error(error_info, exc_info=True)
518 519 520 521
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
522 523
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
524
                        ecode=ecode, error_info=error_info, data_id=data_id)
525
            elif midped_batch is None:
526
                # op client return None
B
barriery 已提交
527 528 529 530
                error_info = "(logid={}) {} Failed to predict, please check if " \
                        "PaddleServingService is working properly.".format(
                                typical_logid, op_info_prefix)
                _LOGGER.error(error_info)
531 532
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
533 534 535
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
536 537
            else:
                # transform np format to dict format
B
barrierye 已提交
538 539 540 541 542 543 544 545 546 547
                var_names = midped_batch.keys()
                lod_var_names = set()
                lod_offset_names = set()
                for name in var_names:
                    lod_offset_name = "{}.lod".format(name)
                    if lod_offset_name in var_names:
                        _LOGGER.debug("(logid={}) {} {} is LodTensor".format(
                            typical_logid, op_info_prefix, name))
                        lod_var_names.add(name)
                        lod_offset_names.add(lod_offset_name)
B
barriery 已提交
548

549
                for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
550
                    midped_data_dict[data_id] = {}
B
barriery 已提交
551

B
barrierye 已提交
552 553 554 555 556 557
                for name, value in midped_batch.items():
                    if name in lod_offset_names:
                        continue
                    if name in lod_var_names:
                        # lodtensor
                        lod_offset_name = "{}.lod".format(name)
B
barrierye 已提交
558
                        lod_offset = midped_batch[lod_offset_name]
B
barrierye 已提交
559
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
560 561 562 563
                            data_offset_left = input_offset[idx]
                            data_offset_right = input_offset[idx + 1]
                            lod_offset_left = lod_offset[data_offset_left]
                            lod_offset_right = lod_offset[data_offset_right]
B
barriery 已提交
564 565
                            midped_data_dict[data_id][name] = value[
                                lod_offset_left:lod_offset_right]
W
wangjiawei04 已提交
566
                            print(lod_offset[data_offset_left:data_offset_right + 1], lod_offset[data_offset_left])
B
barrierye 已提交
567 568
                            midped_data_dict[data_id][lod_offset_name] = \
                                    lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
B
barrierye 已提交
569
                    else:
B
barrierye 已提交
570
                        # normal tensor
B
barrierye 已提交
571
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
572 573 574
                            left = input_offset[idx]
                            right = input_offset[idx + 1]
                            midped_data_dict[data_id][name] = value[left:right]
575
        else:
576
            midped_data_dict = preped_data_dict
B
barriery 已提交
577
        _LOGGER.debug("{} Succ process".format(op_info_prefix))
578 579
        return midped_data_dict, err_channeldata_dict

B
barriery 已提交
580 581 582
    def _run_postprocess(self, parsed_data_dict, midped_data_dict,
                         op_info_prefix):
        _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
583 584
        postped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
B
bug fix  
barriery 已提交
585
        for data_id, midped_data in midped_data_dict.items():
586 587
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
588 589
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
590
            except Exception as e:
B
barriery 已提交
591 592 593
                error_info = "(logid={}) {} Failed to postprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
594 595 596 597 598 599 600 601 602
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
B
barriery 已提交
603 604 605 606 607
                    error_info = "(logid={}) {} Failed to postprocess: " \
                            "output of postprocess funticon must be " \
                            "dict type, but get {}".format(
                                data_id, op_info_prefix,
                                type(postped_data))
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
629
        _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
630
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
631 632

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
B
barriery 已提交
633
                                 timeout, op_info_prefix):
B
barriery 已提交
634 635 636 637 638 639 640 641 642 643 644 645
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
646 647
                                _LOGGER.debug("{} Failed to generate batch: "
                                              "timeout".format(op_info_prefix))
B
barriery 已提交
648
                                break
B
barriery 已提交
649 650
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
651 652 653 654
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
655 656
                        _LOGGER.debug("{} Failed to generate batch: "
                                      "timeout".format(op_info_prefix))
B
barriery 已提交
657
                        break
B
barriery 已提交
658 659
            _LOGGER.debug("{} Got actual batch_size: {}".format(op_info_prefix,
                                                                len(batch)))
B
barriery 已提交
660
            yield batch
661

662
    def _parse_channeldata_batch(self, batch, output_channels):
663
        parsed_data_dict = collections.OrderedDict()
664 665
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
666
        for channeldata_dict in batch:
667 668 669 670 671 672 673 674 675 676
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
677 678
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
679 680

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
681

W
wangjiawei04 已提交
682
    def _run(self, concurrency_idx, input_channel, output_channels,
B
barriery 已提交
683
             is_thread_op, trace_buffer):
684
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
685
        tid = threading.current_thread().ident
B
barrierye 已提交
686

B
barrierye 已提交
687
        # init op
B
barriery 已提交
688
        profiler = None
B
barrierye 已提交
689
        try:
W
wangjiawei04 已提交
690
            profiler = self._initialize(is_thread_op,
B
barriery 已提交
691
                                        concurrency_idx)
B
barrierye 已提交
692
        except Exception as e:
B
barriery 已提交
693 694 695
            _LOGGER.critical(
                "{} Failed to init op: {}".format(op_info_prefix, e),
                exc_info=True)
B
barrierye 已提交
696
            os._exit(-1)
B
barriery 已提交
697
        _LOGGER.info("{} Succ init".format(op_info_prefix))
698

B
barriery 已提交
699
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
700 701 702 703
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
B
barriery 已提交
704
            op_info_prefix=op_info_prefix)
B
barriery 已提交
705

B
barriery 已提交
706
        start, end = None, None
B
barrierye 已提交
707
        trace_que = collections.deque()
B
barrierye 已提交
708
        while True:
B
barriery 已提交
709
            start = int(round(_time() * 1000000))
B
barrierye 已提交
710
            try:
B
barriery 已提交
711
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
712
            except ChannelStopError:
B
barriery 已提交
713
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
B
barriery 已提交
714
                self._finalize(is_thread_op)
B
barrierye 已提交
715
                break
B
barriery 已提交
716
            end = int(round(_time() * 1000000))
B
barrierye 已提交
717
            in_time = end - start
718

B
barriery 已提交
719 720
            # parse channeldata batch
            try:
721 722 723
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
724
            except ChannelStopError:
B
barriery 已提交
725
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
726
                self._finalize(is_thread_op)
B
barriery 已提交
727
                break
728 729 730
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
731 732

            # preprecess
B
barriery 已提交
733
            start = profiler.record("prep#{}_0".format(op_info_prefix))
734
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
735
                    = self._run_preprocess(parsed_data_dict, op_info_prefix)
B
barriery 已提交
736
            end = profiler.record("prep#{}_1".format(op_info_prefix))
B
barrierye 已提交
737
            prep_time = end - start
738 739
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
740
                    self._push_to_output_channels(
B
barriery 已提交
741 742
                        data=err_channeldata,
                        channels=output_channels,
743 744 745
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
746
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
747 748
                self._finalize(is_thread_op)
                break
B
bug fix  
barrierye 已提交
749
            if len(preped_data_dict) == 0:
750 751
                continue

B
barrierye 已提交
752
            # process
B
barriery 已提交
753
            start = profiler.record("midp#{}_0".format(op_info_prefix))
754
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
755
                    = self._run_process(preped_data_dict, op_info_prefix)
B
barriery 已提交
756
            end = profiler.record("midp#{}_1".format(op_info_prefix))
B
barrierye 已提交
757
            midp_time = end - start
758 759
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
760
                    self._push_to_output_channels(
B
barriery 已提交
761 762
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
763 764
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
765
            except ChannelStopError:
B
barriery 已提交
766
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
767 768 769
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
770
                continue
771 772

            # postprocess
B
barriery 已提交
773
            start = profiler.record("postp#{}_0".format(op_info_prefix))
774 775
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
776
                            parsed_data_dict, midped_data_dict, op_info_prefix)
B
barriery 已提交
777
            end = profiler.record("postp#{}_1".format(op_info_prefix))
B
barrierye 已提交
778
            postp_time = end - start
779 780
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
781
                    self._push_to_output_channels(
B
bug fix  
barrierye 已提交
782
                        data=err_channeldata,
B
barriery 已提交
783
                        channels=output_channels,
B
barriery 已提交
784 785
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
786
            except ChannelStopError:
B
barriery 已提交
787
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
788 789 790
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
791
                continue
792 793

            # push data to channel (if run succ)
B
barriery 已提交
794
            start = int(round(_time() * 1000000))
B
barrierye 已提交
795
            try:
B
barriery 已提交
796
                profile_str = profiler.gen_profile_str()
797
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
798 799
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
800
                    self._push_to_output_channels(
B
barriery 已提交
801 802 803
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
804 805
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
806
            except ChannelStopError:
B
barriery 已提交
807
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
808
                self._finalize(is_thread_op)
B
barrierye 已提交
809
                break
B
barriery 已提交
810
            end = int(round(_time() * 1000000))
B
barrierye 已提交
811
            out_time = end - start
B
barriery 已提交
812
            if trace_buffer is not None:
B
barrierye 已提交
813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829
                trace_que.append({
                    "name": self.name,
                    "actions": {
                        "in": in_time,
                        "prep": prep_time,
                        "midp": midp_time,
                        "postp": postp_time,
                        "out": out_time,
                    }
                })
                while trace_que:
                    info = trace_que[0]
                    try:
                        trace_buffer.put_nowait(info)
                        trace_que.popleft()
                    except Queue.Full:
                        break
B
barriery 已提交
830

W
wangjiawei04 已提交
831
    def _initialize(self, is_thread_op, concurrency_idx):
B
barriery 已提交
832 833 834 835 836 837 838
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
                    self.client = self.init_client(
W
wangjiawei04 已提交
839
                        self._client_config,
B
barriery 已提交
840
                        self._server_endpoints, self._fetch_names)
B
barriery 已提交
841 842 843 844
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
845 846 847
        else:
            self.concurrency_idx = concurrency_idx
            # init client
W
wangjiawei04 已提交
848
            self.client = self.init_client(self._client_config,
B
barriery 已提交
849 850
                                           self._server_endpoints,
                                           self._fetch_names)
B
bug fix  
barriery 已提交
851 852
            # user defined
            self.init_op()
B
barriery 已提交
853

B
barriery 已提交
854 855 856 857 858
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
859 860 861 862 863 864 865 866
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
867 868 869 870 871

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
872 873 874
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
875
    def __init__(self):
B
barriery 已提交
876 877
        # PipelineService.name = "@DAGExecutor"
        super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
B
barrierye 已提交
878
        # init op
879
        try:
880
            self.init_op()
881
        except Exception as e:
B
barriery 已提交
882
            _LOGGER.critical("Op(Request) Failed to init: {}".format(e))
883
            os._exit(-1)
B
barrierye 已提交
884 885 886 887

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
888 889
            data = request.value[idx]
            try:
B
barriery 已提交
890 891 892
                evaled_data = eval(data)
                if isinstance(evaled_data, np.ndarray):
                    data = evaled_data
B
barrierye 已提交
893 894 895
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
896 897 898 899 900 901
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
902
    def __init__(self, input_ops):
B
barriery 已提交
903 904
        super(ResponseOp, self).__init__(
            name="@DAGExecutor", input_ops=input_ops)
B
barrierye 已提交
905
        # init op
906
        try:
907
            self.init_op()
908
        except Exception as e:
B
barriery 已提交
909 910
            _LOGGER.critical("Op(ResponseOp) Failed to init: {}".format(
                e, exc_info=True))
911
            os._exit(-1)
B
barrierye 已提交
912 913 914 915 916 917 918 919 920

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
921
                np.set_printoptions(threshold=sys.maxsize)
B
barrierye 已提交
922 923 924 925 926 927 928 929 930 931 932
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
B
barriery 已提交
933 934 935
                        _LOGGER.error("(logid={}) Failed to pack RPC "
                                      "response package: {}".format(
                                          channeldata.id, resp.error_info))
B
barrierye 已提交
936 937 938 939 940 941
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
B
barriery 已提交
942 943 944 945
                    "error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error("(logid={}) Failed to pack RPC response"
                              " package: {}".format(channeldata.id,
                                                    resp.error_info))
B
barrierye 已提交
946 947 948
        else:
            resp.error_info = channeldata.error_info
        return resp
949 950 951 952 953 954 955


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
956
            name=name, input_ops=None, concurrency=concurrency)
957 958 959 960 961
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
962
    def _actual_pred_op_names(self, op):
B
barriery 已提交
963
        # can use disjoint-set, but it's not necessary
B
barrierye 已提交
964 965 966 967 968 969 970
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

971 972
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
973
            _LOGGER.critical(
B
barriery 已提交
974 975 976
                self._log("Failed to add output_channel: output_channel"
                          " must be Channel type, not {}".format(
                              type(channel))))
977
            os._exit(-1)
978
        for op in self._virtual_pred_ops:
B
barrierye 已提交
979 980
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
981
        self._outputs.append(channel)
D
dongdaxiang 已提交
982

983
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
984
             is_thread_op):
985
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
986 987 988
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

989 990 991 992 993 994 995
        batch_generator = self._auto_batching_generator(
            input_channel=input_channel,
            op_name=self.name,
            batch_size=1,
            timeout=None,
            log_func=log)

B
barrierye 已提交
996 997
        while True:
            try:
998
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
999
            except ChannelStopError:
B
barriery 已提交
1000
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
1001
                self._finalize(is_thread_op)
B
barrierye 已提交
1002
                break
D
dongdaxiang 已提交
1003

B
barrierye 已提交
1004
            try:
1005 1006 1007 1008
                for channeldata_dict in channeldata_dict_batch:
                    for name, data in channeldata_dict.items():
                        self._push_to_output_channels(
                            data, channels=output_channels, name=name)
B
barrierye 已提交
1009
            except ChannelStopError:
B
barriery 已提交
1010
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
1011
                self._finalize(is_thread_op)
B
barrierye 已提交
1012
                break