operator.py 41.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
B
barriery 已提交
16
import time
17 18 19 20 21 22
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
23
import os
B
barrierye 已提交
24
import sys
25
import collections
B
barrierye 已提交
26
import numpy as np
B
barrierye 已提交
27
from numpy import *
B
barrierye 已提交
28 29 30 31 32 33
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
34

B
barrierye 已提交
35
from .proto import pipeline_service_pb2
B
barrierye 已提交
36
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
37 38
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
39
from .util import NameGenerator
B
barriery 已提交
40
from .profiler import UnsafeTimeProfiler as TimeProfiler
B
barriery 已提交
41
from . import local_rpc_service_handler
42

43
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
44 45
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
46 47 48

class Op(object):
    def __init__(self,
B
barrierye 已提交
49
                 name=None,
D
dongdaxiang 已提交
50
                 input_ops=[],
B
barriery 已提交
51 52
                 server_endpoints=None,
                 fetch_list=None,
B
barrierye 已提交
53
                 client_config=None,
B
barriery 已提交
54 55 56 57
                 concurrency=None,
                 timeout=None,
                 retry=None,
                 batch_size=None,
58
                 auto_batching_timeout=None,
B
barriery 已提交
59
                 local_rpc_service_handler=None):
B
barriery 已提交
60
        # In __init__, all the parameters are just saved and Op is not initialized
B
barrierye 已提交
61
        if name is None:
B
barrierye 已提交
62
            name = _op_name_gen.next()
63
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
64
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
65
        self.set_input_ops(input_ops)
B
barrierye 已提交
66

B
barriery 已提交
67 68
        self._local_rpc_service_handler = local_rpc_service_handler
        self._server_endpoints = server_endpoints
B
barrierye 已提交
69
        self._fetch_names = fetch_list
B
barriery 已提交
70
        self._client_config = client_config
B
barriery 已提交
71
        self._timeout = timeout
72
        self._retry = max(1, retry)
B
barriery 已提交
73 74 75
        self._batch_size = batch_size
        self._auto_batching_timeout = auto_batching_timeout

76 77
        self._input = None
        self._outputs = []
B
barrierye 已提交
78

B
barriery 已提交
79 80 81 82 83 84 85 86 87
        self._server_use_profile = False
        self._tracer = None

        # only for thread op
        self._for_init_op_lock = threading.Lock()
        self._for_close_op_lock = threading.Lock()
        self._succ_init_op = False
        self._succ_close_op = False

B
barriery 已提交
88 89
    def init_from_dict(self, conf):
        # init op
B
barriery 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        if self.concurrency is None:
            self.concurrency = conf["concurrency"]
        if self._retry is None:
            self._retry = conf["retry"]
        if self._fetch_names is None:
            self._fetch_names = conf.get("fetch_list")
        if self._client_config is None:
            self._client_config = conf.get("client_config")

        if self._timeout is None:
            self._timeout = conf["timeout"]
        if self._timeout > 0:
            self._timeout = self._timeout / 1000.0
        else:
            self._timeout = -1

        if self._batch_size is None:
            self._batch_size = conf["batch_size"]
        if self._auto_batching_timeout is None:
            self._auto_batching_timeout = conf["auto_batching_timeout"]
        if self._auto_batching_timeout <= 0 or self._batch_size == 1:
            _LOGGER.warning(
                self._log(
                    "Because auto_batching_timeout <= 0 or batch_size == 1,"
                    " set auto_batching_timeout to None."))
            self._auto_batching_timeout = None
        else:
            self._auto_batching_timeout = self._auto_batching_timeout / 1000.0

        if self._server_endpoints is None:
            server_endpoints = conf.get("server_endpoints", [])
            if len(server_endpoints) != 0:
                # remote service
                self.with_serving = True
                self._server_endpoints = server_endpoints
125
            else:
B
barriery 已提交
126 127
                if self._local_rpc_service_handler is None:
                    local_service_conf = conf.get("local_service_conf")
B
barriery 已提交
128 129
                    _LOGGER.info("local_service_conf: {}".format(
                        local_service_conf))
B
barriery 已提交
130
                    model_config = local_service_conf.get("model_config")
B
barriery 已提交
131
                    _LOGGER.info("model_config: {}".format(model_config))
B
barriery 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
                    if model_config is None:
                        self.with_serving = False
                    else:
                        # local rpc service
                        self.with_serving = True
                        service_handler = local_rpc_service_handler.LocalRpcServiceHandler(
                            model_config=model_config,
                            workdir=local_service_conf["workdir"],
                            thread_num=local_service_conf["thread_num"],
                            devices=local_service_conf["devices"],
                            mem_optim=local_service_conf["mem_optim"],
                            ir_optim=local_service_conf["ir_optim"])
                        service_handler.prepare_server()  # get fetch_list
                        serivce_ports = service_handler.get_port_list()
                        self._server_endpoints = [
                            "127.0.0.1:{}".format(p) for p in serivce_ports
                        ]
B
barriery 已提交
149 150 151 152 153
                        if self._client_config is None:
                            self._client_config = service_handler.get_client_config(
                            )
                        if self._fetch_names is None:
                            self._fetch_names = service_handler.get_fetch_list()
B
barriery 已提交
154 155
                        self._local_rpc_service_handler = service_handler
                else:
B
barriery 已提交
156
                    self.with_serving = True
B
barriery 已提交
157 158 159 160 161 162 163
                    self._local_rpc_service_handler.prepare_server(
                    )  # get fetch_list
                    serivce_ports = self._local_rpc_service_handler.get_port_list(
                    )
                    self._server_endpoints = [
                        "127.0.0.1:{}".format(p) for p in serivce_ports
                    ]
B
barriery 已提交
164 165
                    if self._client_config is None:
                        self._client_config = self._local_rpc_service_handler.get_client_config(
B
barriery 已提交
166
                        )
B
barriery 已提交
167 168
                    if self._fetch_names is None:
                        self._fetch_names = self._local_rpc_service_handler.get_fetch_list(
B
barriery 已提交
169
                        )
B
barriery 已提交
170 171
        else:
            self.with_serving = True
B
barriery 已提交
172

173 174 175 176 177 178 179 180 181 182 183
        if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
            _LOGGER.info(
                self._log("\n\tinput_ops: {},"
                          "\n\tserver_endpoints: {}"
                          "\n\tfetch_list: {}"
                          "\n\tclient_config: {}"
                          "\n\tconcurrency: {},"
                          "\n\ttimeout(s): {},"
                          "\n\tretry: {},"
                          "\n\tbatch_size: {},"
                          "\n\tauto_batching_timeout(s): {}".format(
B
barriery 已提交
184
                              ", ".join([op.name for op in self._input_ops
185 186 187 188
                                         ]), self._server_endpoints,
                              self._fetch_names, self._client_config,
                              self.concurrency, self._timeout, self._retry,
                              self._batch_size, self._auto_batching_timeout)))
B
barriery 已提交
189

190
    def launch_local_rpc_service(self):
B
barriery 已提交
191 192 193 194 195 196 197 198
        if self._local_rpc_service_handler is None:
            _LOGGER.warning(
                self._log("Failed to launch local rpc"
                          " service: local_rpc_service_handler is None."))
            return
        port = self._local_rpc_service_handler.get_port_list()
        self._local_rpc_service_handler.start_server()
        _LOGGER.info("Op({}) use local rpc service at port: {}"
199 200
                     .format(self.name, port))

B
barriery 已提交
201
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
202
        if self._batch_size != 1:
203 204
            _LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
                            .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
205 206
            self._batch_size = 1
        if self._auto_batching_timeout != None:
207
            _LOGGER.warning(
B
barriery 已提交
208 209
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
210
            self._auto_batching_timeout = None
B
barriery 已提交
211

B
barrierye 已提交
212
    def use_profiler(self, use_profile):
B
barrierye 已提交
213
        self._server_use_profile = use_profile
214

B
barriery 已提交
215 216 217
    def set_tracer(self, tracer):
        self._tracer = tracer

B
barrierye 已提交
218 219
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
220
        if self.with_serving == False:
B
barriery 已提交
221
            _LOGGER.info("Op({}) has no client (and it also do not "
222
                         "run the process function)".format(self.name))
B
barrierye 已提交
223
            return None
224
        if client_type == 'brpc':
B
barrierye 已提交
225 226
            client = Client()
            client.load_client_config(client_config)
227
        elif client_type == 'grpc':
B
barrierye 已提交
228
            client = MultiLangClient()
229
        else:
B
barriery 已提交
230 231
            raise ValueError("Failed to init client: unknow client "
                             "type {}".format(client_type))
B
barrierye 已提交
232
        client.connect(server_endpoints)
233
        self._fetch_names = fetch_names
B
barrierye 已提交
234
        return client
235 236 237 238 239 240 241 242 243 244

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
245
                _LOGGER.critical(
B
barriery 已提交
246 247
                    self._log("Failed to set input_ops: input op "
                              "must be Op type, not {}".format(type(op))))
248
                os._exit(-1)
249
            self._input_ops.append(op)
D
dongdaxiang 已提交
250

251 252
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
253
            _LOGGER.critical(
B
barriery 已提交
254 255 256
                self._log("Failed to set input_channel: input "
                          "channel must be Channel type, not {}".format(
                              type(channel))))
257
            os._exit(-1)
258 259
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
260

261
    def clean_input_channel(self):
B
barrierye 已提交
262 263 264 265
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
266

267 268
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
269
            _LOGGER.critical(
B
barriery 已提交
270 271
                self._log("Failed to add output_channel: output channel "
                          "must be Channel type, not {}".format(type(channel))))
272
            os._exit(-1)
273 274
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
275

276
    def clean_output_channels(self):
B
barrierye 已提交
277 278 279 280 281
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
282
    def preprocess(self, input_dicts):
B
barrierye 已提交
283
        # multiple previous Op
B
barrierye 已提交
284
        if len(input_dicts) != 1:
285 286
            _LOGGER.critical(
                self._log(
B
barriery 已提交
287 288
                    "Failed to run preprocess: this Op has multiple previous "
                    "inputs. Please override this func."))
289
            os._exit(-1)
D
dongdaxiang 已提交
290

B
barrierye 已提交
291 292
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
293

B
barriery 已提交
294
    def process(self, feed_batch, typical_logid):
B
bug fix  
barriery 已提交
295
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
296
        if err != 0:
297
            _LOGGER.critical(
B
barriery 已提交
298 299
                self._log("Failed to run process: {}. Please override "
                          "preprocess func.".format(err_info)))
300
            os._exit(-1)
B
barrierye 已提交
301
        call_result = self.client.predict(
B
barriery 已提交
302
            feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
B
barriery 已提交
303 304 305 306
        if isinstance(self.client, MultiLangClient):
            if call_result is None or call_result["serving_status_code"] != 0:
                return None
            call_result.pop("serving_status_code")
307 308
        return call_result

W
wangjiawei04 已提交
309
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
310
        return fetch_dict
D
dongdaxiang 已提交
311

B
barrierye 已提交
312
    def _parse_channeldata(self, channeldata_dict):
313
        data_id, error_channeldata = None, None
B
barrierye 已提交
314
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
315 316 317 318
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
319
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
320 321 322 323 324 325

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
326
            if client_need_profile:
B
barrierye 已提交
327
                profile_set |= data.profile_data_set
B
barrierye 已提交
328
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
329
                profile_set)
B
barrierye 已提交
330 331 332 333 334

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
335
                                 profile_str=None,
B
barrierye 已提交
336
                                 client_need_profile=False,
B
barrierye 已提交
337
                                 profile_set=None):
338 339
        if name is None:
            name = self.name
B
barrierye 已提交
340

B
barriery 已提交
341
        # add profile into channeldata
B
barrierye 已提交
342
        if client_need_profile and profile_set is not None:
B
barriery 已提交
343 344
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
345
            data.add_profile(profile_set)
B
barrierye 已提交
346

B
barriery 已提交
347 348 349
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
350
    def start_with_process(self, client_type):
B
barriery 已提交
351 352 353
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
354
        proces = []
B
barrierye 已提交
355
        for concurrency_idx in range(self.concurrency):
356 357
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
358
                args=(concurrency_idx, self._get_input_channel(),
B
barriery 已提交
359
                      self._get_output_channels(), client_type, False,
B
barriery 已提交
360
                      trace_buffer))
B
barriery 已提交
361
            p.daemon = True
362 363 364 365
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
366
    def start_with_thread(self, client_type):
B
barriery 已提交
367 368 369
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
370
        threads = []
B
barrierye 已提交
371
        for concurrency_idx in range(self.concurrency):
372 373
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
374
                args=(concurrency_idx, self._get_input_channel(),
B
barriery 已提交
375
                      self._get_output_channels(), client_type, True,
B
barriery 已提交
376
                      trace_buffer))
B
barriery 已提交
377 378 379
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
380 381 382 383
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
384
    def init_op(self):
B
barrierye 已提交
385 386
        pass

B
barriery 已提交
387 388
    def _run_preprocess(self, parsed_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
389 390
        preped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
391 392 393 394 395 396
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
397 398 399
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
400 401 402 403 404
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
405 406 407
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
408 409 410 411 412 413 414 415
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
416
        _LOGGER.debug("{} Succ preprocess".format(op_info_prefix))
417 418
        return preped_data_dict, err_channeldata_dict

B
barriery 已提交
419 420
    def _run_process(self, preped_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running process".format(op_info_prefix))
421 422
        midped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
423
        if self.with_serving:
424
            data_ids = preped_data_dict.keys()
B
barriery 已提交
425 426 427 428
            typical_logid = data_ids[0]
            if len(data_ids) != 1:
                for data_id in data_ids:
                    _LOGGER.info(
429 430 431 432
                        "(logid={}) {} During access to PaddleServingService,"
                        " we selected logid={} (from batch: {}) as a "
                        "representative for logging.".format(
                            data_id, op_info_prefix, typical_logid, data_ids))
B
barrierye 已提交
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451

            # combine samples to batch
            one_input = preped_data_dict[data_ids[0]]
            feed_batch = []
            input_offset = None
            if isinstance(one_input, dict):
                # sample input
                feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
                input_offset = list(range(len(data_ids) + 1))
            elif isinstance(one_input, list):
                # batch input
                input_offset = [0]
                for data_id in data_ids:
                    batch_input = preped_data_dict[data_id]
                    offset = input_offset[-1] + len(batch_input)
                    feed_batch += batch_input
                    input_offset.append(offset)
            else:
                _LOGGER.critical(
B
barriery 已提交
452 453 454
                    "{} Failed to process: expect input type is dict(sample"
                    " input) or list(batch input), but get {}".format(
                        op_info_prefix, type(one_input)))
B
barrierye 已提交
455 456
                os._exit(-1)

B
bug fix  
barriery 已提交
457
            midped_batch = None
458 459 460
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barriery 已提交
461
                    midped_batch = self.process(feed_batch, typical_logid)
462 463
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
464 465
                    error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                        typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
466
                    _LOGGER.error(error_info, exc_info=True)
467 468 469
            else:
                for i in range(self._retry):
                    try:
470
                        midped_batch = func_timeout.func_timeout(
B
barriery 已提交
471 472 473
                            self._timeout,
                            self.process,
                            args=(feed_batch, typical_logid))
474 475 476
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
B
barriery 已提交
477
                            error_info = "(logid={}) {} Failed to process(batch: {}): " \
B
barriery 已提交
478
                                    "exceeded retry count.".format(
B
barriery 已提交
479
                                            typical_logid, op_info_prefix, data_ids)
480 481
                            _LOGGER.error(error_info)
                        else:
482
                            _LOGGER.warning(
B
barriery 已提交
483 484 485 486
                                "(logid={}) {} Failed to process(batch: {}): timeout,"
                                " and retrying({}/{})...".format(
                                    typical_logid, op_info_prefix, data_ids, i +
                                    1, self._retry))
487 488
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
489 490
                        error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                            typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
491
                        _LOGGER.error(error_info, exc_info=True)
492 493 494 495
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
496 497
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
498
                        ecode=ecode, error_info=error_info, data_id=data_id)
499
            elif midped_batch is None:
500
                # op client return None
B
barriery 已提交
501 502 503 504
                error_info = "(logid={}) {} Failed to predict, please check if " \
                        "PaddleServingService is working properly.".format(
                                typical_logid, op_info_prefix)
                _LOGGER.error(error_info)
505 506
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
507 508 509
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
510 511
            else:
                # transform np format to dict format
B
barrierye 已提交
512 513 514 515 516 517 518 519 520 521
                var_names = midped_batch.keys()
                lod_var_names = set()
                lod_offset_names = set()
                for name in var_names:
                    lod_offset_name = "{}.lod".format(name)
                    if lod_offset_name in var_names:
                        _LOGGER.debug("(logid={}) {} {} is LodTensor".format(
                            typical_logid, op_info_prefix, name))
                        lod_var_names.add(name)
                        lod_offset_names.add(lod_offset_name)
B
barriery 已提交
522

523
                for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
524
                    midped_data_dict[data_id] = {}
B
barriery 已提交
525

B
barrierye 已提交
526 527 528 529 530 531
                for name, value in midped_batch.items():
                    if name in lod_offset_names:
                        continue
                    if name in lod_var_names:
                        # lodtensor
                        lod_offset_name = "{}.lod".format(name)
B
barrierye 已提交
532
                        lod_offset = midped_batch[lod_offset_name]
B
barrierye 已提交
533
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
534 535 536 537
                            data_offset_left = input_offset[idx]
                            data_offset_right = input_offset[idx + 1]
                            lod_offset_left = lod_offset[data_offset_left]
                            lod_offset_right = lod_offset[data_offset_right]
B
barriery 已提交
538 539
                            midped_data_dict[data_id][name] = value[
                                lod_offset_left:lod_offset_right]
B
barrierye 已提交
540 541
                            midped_data_dict[data_id][lod_offset_name] = \
                                    lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
B
barrierye 已提交
542
                    else:
B
barrierye 已提交
543
                        # normal tensor
B
barrierye 已提交
544
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
545 546 547
                            left = input_offset[idx]
                            right = input_offset[idx + 1]
                            midped_data_dict[data_id][name] = value[left:right]
548
        else:
549
            midped_data_dict = preped_data_dict
B
barriery 已提交
550
        _LOGGER.debug("{} Succ process".format(op_info_prefix))
551 552
        return midped_data_dict, err_channeldata_dict

B
barriery 已提交
553 554 555
    def _run_postprocess(self, parsed_data_dict, midped_data_dict,
                         op_info_prefix):
        _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
556 557
        postped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
B
bug fix  
barriery 已提交
558
        for data_id, midped_data in midped_data_dict.items():
559 560
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
561 562
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
563
            except Exception as e:
B
barriery 已提交
564 565 566
                error_info = "(logid={}) {} Failed to postprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
567 568 569 570 571 572 573 574 575
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
B
barriery 已提交
576 577 578 579 580
                    error_info = "(logid={}) {} Failed to postprocess: " \
                            "output of postprocess funticon must be " \
                            "dict type, but get {}".format(
                                data_id, op_info_prefix,
                                type(postped_data))
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
602
        _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
603
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
604 605

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
B
barriery 已提交
606
                                 timeout, op_info_prefix):
B
barriery 已提交
607 608 609 610 611 612 613 614 615 616 617 618
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
619 620
                                _LOGGER.debug("{} Failed to generate batch: "
                                              "timeout".format(op_info_prefix))
B
barriery 已提交
621
                                break
B
barriery 已提交
622 623
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
624 625 626 627
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
628 629
                        _LOGGER.debug("{} Failed to generate batch: "
                                      "timeout".format(op_info_prefix))
B
barriery 已提交
630
                        break
B
barriery 已提交
631 632
            _LOGGER.debug("{} Got actual batch_size: {}".format(op_info_prefix,
                                                                len(batch)))
B
barriery 已提交
633
            yield batch
634

635
    def _parse_channeldata_batch(self, batch, output_channels):
636
        parsed_data_dict = collections.OrderedDict()
637 638
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
639
        for channeldata_dict in batch:
640 641 642 643 644 645 646 647 648 649
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
650 651
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
652 653

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
654 655

    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
B
barriery 已提交
656
             is_thread_op, trace_buffer):
657
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
658
        tid = threading.current_thread().ident
B
barrierye 已提交
659

B
barrierye 已提交
660
        # init op
B
barriery 已提交
661
        profiler = None
B
barrierye 已提交
662
        try:
B
barriery 已提交
663 664
            profiler = self._initialize(is_thread_op, client_type,
                                        concurrency_idx)
B
barrierye 已提交
665
        except Exception as e:
B
barriery 已提交
666 667 668
            _LOGGER.critical(
                "{} Failed to init op: {}".format(op_info_prefix, e),
                exc_info=True)
B
barrierye 已提交
669
            os._exit(-1)
B
barriery 已提交
670
        _LOGGER.info("{} Succ init".format(op_info_prefix))
671

B
barriery 已提交
672
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
673 674 675 676
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
B
barriery 已提交
677
            op_info_prefix=op_info_prefix)
B
barriery 已提交
678

B
barriery 已提交
679
        start, end = None, None
B
barrierye 已提交
680
        trace_que = collections.deque()
B
barrierye 已提交
681
        while True:
B
barriery 已提交
682
            start = int(round(_time() * 1000000))
B
barrierye 已提交
683
            try:
B
barriery 已提交
684
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
685
            except ChannelStopError:
B
barriery 已提交
686
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
B
barriery 已提交
687
                self._finalize(is_thread_op)
B
barrierye 已提交
688
                break
B
barriery 已提交
689
            end = int(round(_time() * 1000000))
B
barrierye 已提交
690
            in_time = end - start
691

B
barriery 已提交
692 693
            # parse channeldata batch
            try:
694 695 696
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
697
            except ChannelStopError:
B
barriery 已提交
698
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
699
                self._finalize(is_thread_op)
B
barriery 已提交
700
                break
701 702 703
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
704 705

            # preprecess
B
barriery 已提交
706
            start = profiler.record("prep#{}_0".format(op_info_prefix))
707
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
708
                    = self._run_preprocess(parsed_data_dict, op_info_prefix)
B
barriery 已提交
709
            end = profiler.record("prep#{}_1".format(op_info_prefix))
B
barrierye 已提交
710
            prep_time = end - start
711 712
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
713
                    self._push_to_output_channels(
B
barriery 已提交
714 715
                        data=err_channeldata,
                        channels=output_channels,
716 717 718
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
719
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
720 721
                self._finalize(is_thread_op)
                break
B
bug fix  
barrierye 已提交
722
            if len(preped_data_dict) == 0:
723 724
                continue

B
barrierye 已提交
725
            # process
B
barriery 已提交
726
            start = profiler.record("midp#{}_0".format(op_info_prefix))
727
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
728
                    = self._run_process(preped_data_dict, op_info_prefix)
B
barriery 已提交
729
            end = profiler.record("midp#{}_1".format(op_info_prefix))
B
barrierye 已提交
730
            midp_time = end - start
731 732
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
733
                    self._push_to_output_channels(
B
barriery 已提交
734 735
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
736 737
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
738
            except ChannelStopError:
B
barriery 已提交
739
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
740 741 742
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
743
                continue
744 745

            # postprocess
B
barriery 已提交
746
            start = profiler.record("postp#{}_0".format(op_info_prefix))
747 748
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
749
                            parsed_data_dict, midped_data_dict, op_info_prefix)
B
barriery 已提交
750
            end = profiler.record("postp#{}_1".format(op_info_prefix))
B
barrierye 已提交
751
            postp_time = end - start
752 753
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
754
                    self._push_to_output_channels(
B
bug fix  
barrierye 已提交
755
                        data=err_channeldata,
B
barriery 已提交
756
                        channels=output_channels,
B
barriery 已提交
757 758
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
759
            except ChannelStopError:
B
barriery 已提交
760
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
761 762 763
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
764
                continue
765 766

            # push data to channel (if run succ)
B
barriery 已提交
767
            start = int(round(_time() * 1000000))
B
barrierye 已提交
768
            try:
B
barriery 已提交
769
                profile_str = profiler.gen_profile_str()
770
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
771 772
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
773
                    self._push_to_output_channels(
B
barriery 已提交
774 775 776
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
777 778
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
779
            except ChannelStopError:
B
barriery 已提交
780
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
781
                self._finalize(is_thread_op)
B
barrierye 已提交
782
                break
B
barriery 已提交
783
            end = int(round(_time() * 1000000))
B
barrierye 已提交
784
            out_time = end - start
B
barriery 已提交
785
            if trace_buffer is not None:
B
barrierye 已提交
786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
                trace_que.append({
                    "name": self.name,
                    "actions": {
                        "in": in_time,
                        "prep": prep_time,
                        "midp": midp_time,
                        "postp": postp_time,
                        "out": out_time,
                    }
                })
                while trace_que:
                    info = trace_que[0]
                    try:
                        trace_buffer.put_nowait(info)
                        trace_que.popleft()
                    except Queue.Full:
                        break
B
barriery 已提交
803

B
bug fix  
barriery 已提交
804
    def _initialize(self, is_thread_op, client_type, concurrency_idx):
B
barriery 已提交
805 806 807 808 809 810 811
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
                    self.client = self.init_client(
B
barriery 已提交
812 813
                        client_type, self._client_config,
                        self._server_endpoints, self._fetch_names)
B
barriery 已提交
814 815 816 817
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
818 819 820
        else:
            self.concurrency_idx = concurrency_idx
            # init client
B
barriery 已提交
821 822 823
            self.client = self.init_client(client_type, self._client_config,
                                           self._server_endpoints,
                                           self._fetch_names)
B
bug fix  
barriery 已提交
824 825
            # user defined
            self.init_op()
B
barriery 已提交
826

B
barriery 已提交
827 828 829 830 831
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
832 833 834 835 836 837 838 839
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
840 841 842 843 844

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
845 846 847
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
848
    def __init__(self):
B
barriery 已提交
849 850
        # PipelineService.name = "@DAGExecutor"
        super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
B
barrierye 已提交
851
        # init op
852
        try:
853
            self.init_op()
854
        except Exception as e:
B
barriery 已提交
855
            _LOGGER.critical("Op(Request) Failed to init: {}".format(e))
856
            os._exit(-1)
B
barrierye 已提交
857 858 859 860

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
861 862
            data = request.value[idx]
            try:
B
barriery 已提交
863 864 865
                evaled_data = eval(data)
                if isinstance(evaled_data, np.ndarray):
                    data = evaled_data
B
barrierye 已提交
866 867 868
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
869 870 871 872 873 874
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
875
    def __init__(self, input_ops):
B
barriery 已提交
876 877
        super(ResponseOp, self).__init__(
            name="@DAGExecutor", input_ops=input_ops)
B
barrierye 已提交
878
        # init op
879
        try:
880
            self.init_op()
881
        except Exception as e:
B
barriery 已提交
882 883
            _LOGGER.critical("Op(ResponseOp) Failed to init: {}".format(
                e, exc_info=True))
884
            os._exit(-1)
B
barrierye 已提交
885 886 887 888 889 890 891 892 893

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
894
                np.set_printoptions(threshold=sys.maxsize)
B
barrierye 已提交
895 896 897 898 899 900 901 902 903 904 905
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
B
barriery 已提交
906 907 908
                        _LOGGER.error("(logid={}) Failed to pack RPC "
                                      "response package: {}".format(
                                          channeldata.id, resp.error_info))
B
barrierye 已提交
909 910 911 912 913 914
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
B
barriery 已提交
915 916 917 918
                    "error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error("(logid={}) Failed to pack RPC response"
                              " package: {}".format(channeldata.id,
                                                    resp.error_info))
B
barrierye 已提交
919 920 921
        else:
            resp.error_info = channeldata.error_info
        return resp
922 923 924 925 926 927 928


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
929
            name=name, input_ops=None, concurrency=concurrency)
930 931 932 933 934
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
935
    def _actual_pred_op_names(self, op):
B
barriery 已提交
936
        # can use disjoint-set, but it's not necessary
B
barrierye 已提交
937 938 939 940 941 942 943
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

944 945
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
946
            _LOGGER.critical(
B
barriery 已提交
947 948 949
                self._log("Failed to add output_channel: output_channel"
                          " must be Channel type, not {}".format(
                              type(channel))))
950
            os._exit(-1)
951
        for op in self._virtual_pred_ops:
B
barrierye 已提交
952 953
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
954
        self._outputs.append(channel)
D
dongdaxiang 已提交
955

956
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
957
             is_thread_op):
958
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
959 960 961
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

962 963 964 965 966 967 968
        batch_generator = self._auto_batching_generator(
            input_channel=input_channel,
            op_name=self.name,
            batch_size=1,
            timeout=None,
            log_func=log)

B
barrierye 已提交
969 970
        while True:
            try:
971
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
972
            except ChannelStopError:
B
barriery 已提交
973
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
974
                self._finalize(is_thread_op)
B
barrierye 已提交
975
                break
D
dongdaxiang 已提交
976

B
barrierye 已提交
977
            try:
978 979 980 981
                for channeldata_dict in channeldata_dict_batch:
                    for name, data in channeldata_dict.items():
                        self._push_to_output_channels(
                            data, channels=output_channels, name=name)
B
barrierye 已提交
982
            except ChannelStopError:
B
barriery 已提交
983
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
984
                self._finalize(is_thread_op)
B
barrierye 已提交
985
                break