operator.py 38.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
B
barriery 已提交
15
from time import time as _time
B
barriery 已提交
16
import time
17 18 19 20 21 22
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
23
import os
B
barrierye 已提交
24
import sys
25
import collections
B
barrierye 已提交
26
import numpy as np
B
barrierye 已提交
27
from numpy import *
B
barrierye 已提交
28 29 30 31 32 33
if sys.version_info.major == 2:
    import Queue
elif sys.version_info.major == 3:
    import queue as Queue
else:
    raise Exception("Error Python version")
34

B
barrierye 已提交
35
from .proto import pipeline_service_pb2
B
barrierye 已提交
36
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
B
bug fix  
barriery 已提交
37 38
                      ChannelData, ChannelDataType, ChannelStopError,
                      ChannelTimeoutError)
B
barrierye 已提交
39
from .util import NameGenerator
B
barriery 已提交
40
from .profiler import UnsafeTimeProfiler as TimeProfiler
41

42
_LOGGER = logging.getLogger(__name__)
B
barrierye 已提交
43 44
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
45 46 47

class Op(object):
    def __init__(self,
B
barrierye 已提交
48
                 name=None,
D
dongdaxiang 已提交
49 50
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
51 52
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
53 54
                 concurrency=1,
                 timeout=-1,
B
barriery 已提交
55 56
                 retry=1,
                 batch_size=1,
57
                 auto_batching_timeout=None,
B
barriery 已提交
58
                 local_rpc_service_handler=None):
B
barrierye 已提交
59
        if name is None:
B
barrierye 已提交
60
            name = _op_name_gen.next()
61
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
62
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
63
        self.set_input_ops(input_ops)
B
barrierye 已提交
64

65 66
        if len(server_endpoints) != 0:
            # remote service
67
            self.with_serving = True
68
        else:
B
barriery 已提交
69
            if local_rpc_service_handler is not None:
70 71
                # local rpc service
                self.with_serving = True
B
barriery 已提交
72 73 74
                local_rpc_service_handler.prepare_server()  # get fetch_list
                serivce_ports = local_rpc_service_handler.get_port_list()
                server_endpoints = [
75 76
                    "127.0.0.1:{}".format(p) for p in serivce_ports
                ]
B
barriery 已提交
77 78 79 80 81
                if client_config is None:
                    client_config = local_rpc_service_handler.get_client_config(
                    )
                if len(fetch_list) == 0:
                    fetch_list = local_rpc_service_handler.get_fetch_list()
82 83
            else:
                self.with_serving = False
B
barriery 已提交
84 85
        self._local_rpc_service_handler = local_rpc_service_handler
        self._server_endpoints = server_endpoints
B
barrierye 已提交
86
        self._fetch_names = fetch_list
B
barriery 已提交
87
        self._client_config = client_config
B
barrierye 已提交
88

89 90 91 92
        if timeout > 0:
            self._timeout = timeout / 1000.0
        else:
            self._timeout = -1
93 94 95
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
96

B
barriery 已提交
97
        self._batch_size = batch_size
B
bug fix  
barriery 已提交
98
        self._auto_batching_timeout = auto_batching_timeout
B
barriery 已提交
99 100
        if self._auto_batching_timeout is not None:
            if self._auto_batching_timeout <= 0 or self._batch_size == 1:
101
                _LOGGER.warning(
B
barriery 已提交
102 103 104
                    self._log(
                        "Because auto_batching_timeout <= 0 or batch_size == 1,"
                        " set auto_batching_timeout to None."))
B
barriery 已提交
105
                self._auto_batching_timeout = None
106 107
            else:
                self._auto_batching_timeout = self._auto_batching_timeout / 1000.0
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        if not isinstance(self, RequestOp) and not isinstance(self, ResponseOp):
            _LOGGER.info(
                self._log("\n\tinput_ops: {},"
                          "\n\tserver_endpoints: {}"
                          "\n\tfetch_list: {}"
                          "\n\tclient_config: {}"
                          "\n\tconcurrency: {},"
                          "\n\ttimeout(s): {},"
                          "\n\tretry: {},"
                          "\n\tbatch_size: {},"
                          "\n\tauto_batching_timeout(s): {}".format(
                              ", ".join([op.name for op in input_ops
                                         ]), self._server_endpoints,
                              self._fetch_names, self._client_config,
                              self.concurrency, self._timeout, self._retry,
                              self._batch_size, self._auto_batching_timeout)))
B
barriery 已提交
124

B
barrierye 已提交
125
        self._server_use_profile = False
B
barriery 已提交
126
        self._tracer = None
127

128
        # only for thread op
B
barrierye 已提交
129
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
130
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
131
        self._succ_init_op = False
B
barrierye 已提交
132
        self._succ_close_op = False
B
barrierye 已提交
133

134
    def launch_local_rpc_service(self):
B
barriery 已提交
135 136 137 138 139 140 141 142
        if self._local_rpc_service_handler is None:
            _LOGGER.warning(
                self._log("Failed to launch local rpc"
                          " service: local_rpc_service_handler is None."))
            return
        port = self._local_rpc_service_handler.get_port_list()
        self._local_rpc_service_handler.start_server()
        _LOGGER.info("Op({}) use local rpc service at port: {}"
143 144
                     .format(self.name, port))

B
barriery 已提交
145
    def use_default_auto_batching_config(self):
B
bug fix  
barriery 已提交
146
        if self._batch_size != 1:
147 148
            _LOGGER.warning("Op({}) reset batch_size=1 (original: {})"
                            .format(self.name, self._batch_size))
B
bug fix  
barriery 已提交
149 150
            self._batch_size = 1
        if self._auto_batching_timeout != None:
151
            _LOGGER.warning(
B
barriery 已提交
152 153
                "Op({}) reset auto_batching_timeout=None (original: {})"
                .format(self.name, self._auto_batching_timeout))
B
bug fix  
barriery 已提交
154
            self._auto_batching_timeout = None
B
barriery 已提交
155

B
barrierye 已提交
156
    def use_profiler(self, use_profile):
B
barrierye 已提交
157
        self._server_use_profile = use_profile
158

B
barriery 已提交
159 160 161
    def set_tracer(self, tracer):
        self._tracer = tracer

B
barrierye 已提交
162 163
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
164
        if self.with_serving == False:
B
barriery 已提交
165
            _LOGGER.info("Op({}) has no client (and it also do not "
166
                         "run the process function)".format(self.name))
B
barrierye 已提交
167
            return None
168
        if client_type == 'brpc':
B
barrierye 已提交
169 170
            client = Client()
            client.load_client_config(client_config)
171
        elif client_type == 'grpc':
B
barrierye 已提交
172
            client = MultiLangClient()
173
        else:
B
barriery 已提交
174 175
            raise ValueError("Failed to init client: unknow client "
                             "type {}".format(client_type))
B
barrierye 已提交
176
        client.connect(server_endpoints)
177
        self._fetch_names = fetch_names
B
barrierye 已提交
178
        return client
179 180 181 182 183 184 185 186 187 188

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
189
                _LOGGER.critical(
B
barriery 已提交
190 191
                    self._log("Failed to set input_ops: input op "
                              "must be Op type, not {}".format(type(op))))
192
                os._exit(-1)
193
            self._input_ops.append(op)
D
dongdaxiang 已提交
194

195 196
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
197
            _LOGGER.critical(
B
barriery 已提交
198 199 200
                self._log("Failed to set input_channel: input "
                          "channel must be Channel type, not {}".format(
                              type(channel))))
201
            os._exit(-1)
202 203
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
204

205
    def clean_input_channel(self):
B
barrierye 已提交
206 207 208 209
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
210

211 212
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
213
            _LOGGER.critical(
B
barriery 已提交
214 215
                self._log("Failed to add output_channel: output channel "
                          "must be Channel type, not {}".format(type(channel))))
216
            os._exit(-1)
217 218
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
219

220
    def clean_output_channels(self):
B
barrierye 已提交
221 222 223 224 225
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
226
    def preprocess(self, input_dicts):
B
barrierye 已提交
227
        # multiple previous Op
B
barrierye 已提交
228
        if len(input_dicts) != 1:
229 230
            _LOGGER.critical(
                self._log(
B
barriery 已提交
231 232
                    "Failed to run preprocess: this Op has multiple previous "
                    "inputs. Please override this func."))
233
            os._exit(-1)
D
dongdaxiang 已提交
234

B
barrierye 已提交
235 236
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
237

B
barriery 已提交
238
    def process(self, feed_batch, typical_logid):
B
bug fix  
barriery 已提交
239
        err, err_info = ChannelData.check_batch_npdata(feed_batch)
B
barrierye 已提交
240
        if err != 0:
241
            _LOGGER.critical(
B
barriery 已提交
242 243
                self._log("Failed to run process: {}. Please override "
                          "preprocess func.".format(err_info)))
244
            os._exit(-1)
B
barrierye 已提交
245
        call_result = self.client.predict(
B
barriery 已提交
246
            feed=feed_batch, fetch=self._fetch_names, log_id=typical_logid)
B
barriery 已提交
247 248 249 250
        if isinstance(self.client, MultiLangClient):
            if call_result is None or call_result["serving_status_code"] != 0:
                return None
            call_result.pop("serving_status_code")
251 252
        return call_result

W
wangjiawei04 已提交
253
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
254
        return fetch_dict
D
dongdaxiang 已提交
255

B
barrierye 已提交
256
    def _parse_channeldata(self, channeldata_dict):
257
        data_id, error_channeldata = None, None
B
barrierye 已提交
258
        client_need_profile, profile_set = False, set()
B
barrierye 已提交
259 260 261 262
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id
B
barrierye 已提交
263
        client_need_profile = channeldata_dict[key].client_need_profile
B
barrierye 已提交
264 265 266 267 268 269

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
B
barrierye 已提交
270
            if client_need_profile:
B
barrierye 已提交
271
                profile_set |= data.profile_data_set
B
barrierye 已提交
272
        return (data_id, error_channeldata, parsed_data, client_need_profile,
B
barrierye 已提交
273
                profile_set)
B
barrierye 已提交
274 275 276 277 278

    def _push_to_output_channels(self,
                                 data,
                                 channels,
                                 name=None,
B
barriery 已提交
279
                                 profile_str=None,
B
barrierye 已提交
280
                                 client_need_profile=False,
B
barrierye 已提交
281
                                 profile_set=None):
282 283
        if name is None:
            name = self.name
B
barrierye 已提交
284

B
barriery 已提交
285
        # add profile into channeldata
B
barrierye 已提交
286
        if client_need_profile and profile_set is not None:
B
barriery 已提交
287 288
            if profile_str is not None:
                profile_set.add(profile_str)
B
barrierye 已提交
289
            data.add_profile(profile_set)
B
barrierye 已提交
290

B
barriery 已提交
291 292 293
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
294
    def start_with_process(self, client_type):
B
barriery 已提交
295 296 297
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
298
        proces = []
B
barrierye 已提交
299
        for concurrency_idx in range(self.concurrency):
300 301
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
302
                args=(concurrency_idx, self._get_input_channel(),
B
barriery 已提交
303
                      self._get_output_channels(), client_type, False,
B
barriery 已提交
304
                      trace_buffer))
B
barriery 已提交
305
            p.daemon = True
306 307 308 309
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
310
    def start_with_thread(self, client_type):
B
barriery 已提交
311 312 313
        trace_buffer = None
        if self._tracer is not None:
            trace_buffer = self._tracer.data_buffer()
314
        threads = []
B
barrierye 已提交
315
        for concurrency_idx in range(self.concurrency):
316 317
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
318
                args=(concurrency_idx, self._get_input_channel(),
B
barriery 已提交
319
                      self._get_output_channels(), client_type, True,
B
barriery 已提交
320
                      trace_buffer))
B
barriery 已提交
321 322 323
            # When a process exits, it attempts to terminate
            # all of its daemonic child processes.
            t.daemon = True
324 325 326 327
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
328
    def init_op(self):
B
barrierye 已提交
329 330
        pass

B
barriery 已提交
331 332
    def _run_preprocess(self, parsed_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running preprocess".format(op_info_prefix))
333 334
        preped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
335 336 337 338 339 340
        for data_id, parsed_data in parsed_data_dict.items():
            preped_data, error_channeldata = None, None
            try:
                preped_data = self.preprocess(parsed_data)
            except TypeError as e:
                # Error type in channeldata.datatype
B
barriery 已提交
341 342 343
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
344 345 346 347 348
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.TYPE_ERROR.value,
                    error_info=error_info,
                    data_id=data_id)
            except Exception as e:
B
barriery 已提交
349 350 351
                error_info = "(logid={}) {} Failed to preprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
352 353 354 355 356 357 358 359
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if error_channeldata is not None:
                err_channeldata_dict[data_id] = error_channeldata
            else:
                preped_data_dict[data_id] = preped_data
B
barriery 已提交
360
        _LOGGER.debug("{} Succ preprocess".format(op_info_prefix))
361 362
        return preped_data_dict, err_channeldata_dict

B
barriery 已提交
363 364
    def _run_process(self, preped_data_dict, op_info_prefix):
        _LOGGER.debug("{} Running process".format(op_info_prefix))
365 366
        midped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
367
        if self.with_serving:
368
            data_ids = preped_data_dict.keys()
B
barriery 已提交
369 370 371 372
            typical_logid = data_ids[0]
            if len(data_ids) != 1:
                for data_id in data_ids:
                    _LOGGER.info(
373 374 375 376
                        "(logid={}) {} During access to PaddleServingService,"
                        " we selected logid={} (from batch: {}) as a "
                        "representative for logging.".format(
                            data_id, op_info_prefix, typical_logid, data_ids))
B
barrierye 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395

            # combine samples to batch
            one_input = preped_data_dict[data_ids[0]]
            feed_batch = []
            input_offset = None
            if isinstance(one_input, dict):
                # sample input
                feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
                input_offset = list(range(len(data_ids) + 1))
            elif isinstance(one_input, list):
                # batch input
                input_offset = [0]
                for data_id in data_ids:
                    batch_input = preped_data_dict[data_id]
                    offset = input_offset[-1] + len(batch_input)
                    feed_batch += batch_input
                    input_offset.append(offset)
            else:
                _LOGGER.critical(
B
barriery 已提交
396 397 398
                    "{} Failed to process: expect input type is dict(sample"
                    " input) or list(batch input), but get {}".format(
                        op_info_prefix, type(one_input)))
B
barrierye 已提交
399 400
                os._exit(-1)

B
bug fix  
barriery 已提交
401
            midped_batch = None
402 403 404
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barriery 已提交
405
                    midped_batch = self.process(feed_batch, typical_logid)
406 407
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
408 409
                    error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                        typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
410
                    _LOGGER.error(error_info, exc_info=True)
411 412 413
            else:
                for i in range(self._retry):
                    try:
414
                        midped_batch = func_timeout.func_timeout(
B
barriery 已提交
415 416 417
                            self._timeout,
                            self.process,
                            args=(feed_batch, typical_logid))
418 419 420
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
B
barriery 已提交
421
                            error_info = "(logid={}) {} Failed to process(batch: {}): " \
B
barriery 已提交
422
                                    "exceeded retry count.".format(
B
barriery 已提交
423
                                            typical_logid, op_info_prefix, data_ids)
424 425
                            _LOGGER.error(error_info)
                        else:
426
                            _LOGGER.warning(
B
barriery 已提交
427 428 429 430
                                "(logid={}) {} Failed to process(batch: {}): timeout,"
                                " and retrying({}/{})...".format(
                                    typical_logid, op_info_prefix, data_ids, i +
                                    1, self._retry))
431 432
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
B
barriery 已提交
433 434
                        error_info = "(logid={}) {} Failed to process(batch: {}): {}".format(
                            typical_logid, op_info_prefix, data_ids, e)
B
barriery 已提交
435
                        _LOGGER.error(error_info, exc_info=True)
436 437 438 439
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
440 441
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
442
                        ecode=ecode, error_info=error_info, data_id=data_id)
443
            elif midped_batch is None:
444
                # op client return None
B
barriery 已提交
445 446 447 448
                error_info = "(logid={}) {} Failed to predict, please check if " \
                        "PaddleServingService is working properly.".format(
                                typical_logid, op_info_prefix)
                _LOGGER.error(error_info)
449 450
                for data_id in data_ids:
                    err_channeldata_dict[data_id] = ChannelData(
B
barriery 已提交
451 452 453
                        ecode=ChannelDataEcode.CLIENT_ERROR.value,
                        error_info=error_info,
                        data_id=data_id)
454 455
            else:
                # transform np format to dict format
B
barrierye 已提交
456 457 458 459 460 461 462 463 464 465
                var_names = midped_batch.keys()
                lod_var_names = set()
                lod_offset_names = set()
                for name in var_names:
                    lod_offset_name = "{}.lod".format(name)
                    if lod_offset_name in var_names:
                        _LOGGER.debug("(logid={}) {} {} is LodTensor".format(
                            typical_logid, op_info_prefix, name))
                        lod_var_names.add(name)
                        lod_offset_names.add(lod_offset_name)
B
barriery 已提交
466

467
                for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
468
                    midped_data_dict[data_id] = {}
B
barriery 已提交
469

B
barrierye 已提交
470 471 472 473 474 475
                for name, value in midped_batch.items():
                    if name in lod_offset_names:
                        continue
                    if name in lod_var_names:
                        # lodtensor
                        lod_offset_name = "{}.lod".format(name)
B
barrierye 已提交
476
                        lod_offset = midped_batch[lod_offset_name]
B
barrierye 已提交
477
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
478 479 480 481
                            data_offset_left = input_offset[idx]
                            data_offset_right = input_offset[idx + 1]
                            lod_offset_left = lod_offset[data_offset_left]
                            lod_offset_right = lod_offset[data_offset_right]
B
barriery 已提交
482 483
                            midped_data_dict[data_id][name] = value[
                                lod_offset_left:lod_offset_right]
B
barrierye 已提交
484 485
                            midped_data_dict[data_id][lod_offset_name] = \
                                    lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
B
barrierye 已提交
486
                    else:
B
barrierye 已提交
487
                        # normal tensor
B
barrierye 已提交
488
                        for idx, data_id in enumerate(data_ids):
B
barrierye 已提交
489 490 491
                            left = input_offset[idx]
                            right = input_offset[idx + 1]
                            midped_data_dict[data_id][name] = value[left:right]
492
        else:
493
            midped_data_dict = preped_data_dict
B
barriery 已提交
494
        _LOGGER.debug("{} Succ process".format(op_info_prefix))
495 496
        return midped_data_dict, err_channeldata_dict

B
barriery 已提交
497 498 499
    def _run_postprocess(self, parsed_data_dict, midped_data_dict,
                         op_info_prefix):
        _LOGGER.debug("{} Running postprocess".format(op_info_prefix))
500 501
        postped_data_dict = collections.OrderedDict()
        err_channeldata_dict = collections.OrderedDict()
B
bug fix  
barriery 已提交
502
        for data_id, midped_data in midped_data_dict.items():
503 504
            postped_data, err_channeldata = None, None
            try:
B
barriery 已提交
505 506
                postped_data = self.postprocess(parsed_data_dict[data_id],
                                                midped_data)
507
            except Exception as e:
B
barriery 已提交
508 509 510
                error_info = "(logid={}) {} Failed to postprocess: {}".format(
                    data_id, op_info_prefix, e)
                _LOGGER.error(error_info, exc_info=True)
511 512 513 514 515 516 517 518 519
                err_channeldata = ChannelData(
                    ecode=ChannelDataEcode.UNKNOW.value,
                    error_info=error_info,
                    data_id=data_id)
            if err_channeldata is not None:
                err_channeldata_dict[data_id] = err_channeldata
                continue
            else:
                if not isinstance(postped_data, dict):
B
barriery 已提交
520 521 522 523 524
                    error_info = "(logid={}) {} Failed to postprocess: " \
                            "output of postprocess funticon must be " \
                            "dict type, but get {}".format(
                                data_id, op_info_prefix,
                                type(postped_data))
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
                    _LOGGER.error(error_info)
                    err_channeldata = ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id)
                    err_channeldata_dict[data_id] = err_channeldata
                    continue

                output_data = None
                err, _ = ChannelData.check_npdata(postped_data)
                if err == 0:
                    output_data = ChannelData(
                        ChannelDataType.CHANNEL_NPDATA.value,
                        npdata=postped_data,
                        data_id=data_id)
                else:
                    output_data = ChannelData(
                        ChannelDataType.DICT.value,
                        dictdata=postped_data,
                        data_id=data_id)
                postped_data_dict[data_id] = output_data
B
barriery 已提交
546
        _LOGGER.debug("{} Succ postprocess".format(op_info_prefix))
547
        return postped_data_dict, err_channeldata_dict
B
barriery 已提交
548 549

    def _auto_batching_generator(self, input_channel, op_name, batch_size,
B
barriery 已提交
550
                                 timeout, op_info_prefix):
B
barriery 已提交
551 552 553 554 555 556 557 558 559 560 561 562
        while True:
            batch = []
            while len(batch) == 0:
                endtime = None
                if timeout is not None:
                    endtime = _time() + timeout
                for idx in range(batch_size):
                    try:
                        channeldata_dict = None
                        if timeout is not None:
                            remaining = endtime - _time()
                            if remaining <= 0.0:
B
barriery 已提交
563 564
                                _LOGGER.debug("{} Failed to generate batch: "
                                              "timeout".format(op_info_prefix))
B
barriery 已提交
565
                                break
B
barriery 已提交
566 567
                            channeldata_dict = input_channel.front(op_name,
                                                                   timeout)
B
barriery 已提交
568 569 570 571
                        else:
                            channeldata_dict = input_channel.front(op_name)
                        batch.append(channeldata_dict)
                    except ChannelTimeoutError:
B
barriery 已提交
572 573
                        _LOGGER.debug("{} Failed to generate batch: "
                                      "timeout".format(op_info_prefix))
B
barriery 已提交
574
                        break
B
barriery 已提交
575 576
            _LOGGER.debug("{} Got actual batch_size: {}".format(op_info_prefix,
                                                                len(batch)))
B
barriery 已提交
577
            yield batch
578

579
    def _parse_channeldata_batch(self, batch, output_channels):
580
        parsed_data_dict = collections.OrderedDict()
581 582
        need_profile_dict = {}
        profile_dict = {}
B
bug fix  
barriery 已提交
583
        for channeldata_dict in batch:
584 585 586 587 588 589 590 591 592 593
            (data_id, error_channeldata, parsed_data,
                    client_need_profile, profile_set) = \
                            self._parse_channeldata(channeldata_dict)
            if error_channeldata is None:
                parsed_data_dict[data_id] = parsed_data
                need_profile_dict[data_id] = client_need_profile
                profile_dict[data_id] = profile_set
            else:
                # error data in predecessor Op
                # (error_channeldata with profile info)
B
barriery 已提交
594 595
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
596 597

        return parsed_data_dict, need_profile_dict, profile_dict
B
barriery 已提交
598 599

    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
B
barriery 已提交
600
             is_thread_op, trace_buffer):
601
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
602
        tid = threading.current_thread().ident
B
barrierye 已提交
603

B
barrierye 已提交
604
        # init op
B
barriery 已提交
605
        profiler = None
B
barrierye 已提交
606
        try:
B
barriery 已提交
607 608
            profiler = self._initialize(is_thread_op, client_type,
                                        concurrency_idx)
B
barrierye 已提交
609
        except Exception as e:
B
barriery 已提交
610 611 612
            _LOGGER.critical(
                "{} Failed to init op: {}".format(op_info_prefix, e),
                exc_info=True)
B
barrierye 已提交
613
            os._exit(-1)
B
barriery 已提交
614
        _LOGGER.info("{} Succ init".format(op_info_prefix))
615

B
barriery 已提交
616
        batch_generator = self._auto_batching_generator(
B
barriery 已提交
617 618 619 620
            input_channel=input_channel,
            op_name=self.name,
            batch_size=self._batch_size,
            timeout=self._auto_batching_timeout,
B
barriery 已提交
621
            op_info_prefix=op_info_prefix)
B
barriery 已提交
622

B
barriery 已提交
623
        start, end = None, None
B
barrierye 已提交
624
        trace_que = collections.deque()
B
barrierye 已提交
625
        while True:
B
barriery 已提交
626
            start = int(round(_time() * 1000000))
B
barrierye 已提交
627
            try:
B
barriery 已提交
628
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
629
            except ChannelStopError:
B
barriery 已提交
630
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
B
barriery 已提交
631
                self._finalize(is_thread_op)
B
barrierye 已提交
632
                break
B
barriery 已提交
633
            end = int(round(_time() * 1000000))
B
barrierye 已提交
634
            in_time = end - start
635

B
barriery 已提交
636 637
            # parse channeldata batch
            try:
638 639 640
                parsed_data_dict, need_profile_dict, profile_dict \
                        = self._parse_channeldata_batch(
                                channeldata_dict_batch, output_channels)
B
barriery 已提交
641
            except ChannelStopError:
B
barriery 已提交
642
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
643
                self._finalize(is_thread_op)
B
barriery 已提交
644
                break
645 646 647
            if len(parsed_data_dict) == 0:
                # data in the whole batch is all error data
                continue
648 649

            # preprecess
B
barriery 已提交
650
            start = profiler.record("prep#{}_0".format(op_info_prefix))
651
            preped_data_dict, err_channeldata_dict \
B
barriery 已提交
652
                    = self._run_preprocess(parsed_data_dict, op_info_prefix)
B
barriery 已提交
653
            end = profiler.record("prep#{}_1".format(op_info_prefix))
B
barrierye 已提交
654
            prep_time = end - start
655 656
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
657
                    self._push_to_output_channels(
B
barriery 已提交
658 659
                        data=err_channeldata,
                        channels=output_channels,
660 661 662
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
            except ChannelStopError:
B
barriery 已提交
663
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
664 665
                self._finalize(is_thread_op)
                break
B
bug fix  
barrierye 已提交
666
            if len(preped_data_dict) == 0:
667 668
                continue

B
barrierye 已提交
669
            # process
B
barriery 已提交
670
            start = profiler.record("midp#{}_0".format(op_info_prefix))
671
            midped_data_dict, err_channeldata_dict \
B
barriery 已提交
672
                    = self._run_process(preped_data_dict, op_info_prefix)
B
barriery 已提交
673
            end = profiler.record("midp#{}_1".format(op_info_prefix))
B
barrierye 已提交
674
            midp_time = end - start
675 676
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
677
                    self._push_to_output_channels(
B
barriery 已提交
678 679
                        data=err_channeldata,
                        channels=output_channels,
B
barriery 已提交
680 681
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
682
            except ChannelStopError:
B
barriery 已提交
683
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
684 685 686
                self._finalize(is_thread_op)
                break
            if len(midped_data_dict) == 0:
687
                continue
688 689

            # postprocess
B
barriery 已提交
690
            start = profiler.record("postp#{}_0".format(op_info_prefix))
691 692
            postped_data_dict, err_channeldata_dict \
                    = self._run_postprocess(
B
barriery 已提交
693
                            parsed_data_dict, midped_data_dict, op_info_prefix)
B
barriery 已提交
694
            end = profiler.record("postp#{}_1".format(op_info_prefix))
B
barrierye 已提交
695
            postp_time = end - start
696 697
            try:
                for data_id, err_channeldata in err_channeldata_dict.items():
B
barrierye 已提交
698
                    self._push_to_output_channels(
B
bug fix  
barrierye 已提交
699
                        data=err_channeldata,
B
barriery 已提交
700
                        channels=output_channels,
B
barriery 已提交
701 702
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
703
            except ChannelStopError:
B
barriery 已提交
704
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
705 706 707
                self._finalize(is_thread_op)
                break
            if len(postped_data_dict) == 0:
708
                continue
709 710

            # push data to channel (if run succ)
B
barriery 已提交
711
            start = int(round(_time() * 1000000))
B
barrierye 已提交
712
            try:
B
barriery 已提交
713
                profile_str = profiler.gen_profile_str()
714
                for data_id, postped_data in postped_data_dict.items():
B
barriery 已提交
715 716
                    if self._server_use_profile:
                        sys.stderr.write(profile_str)
717
                    self._push_to_output_channels(
B
barriery 已提交
718 719 720
                        data=postped_data,
                        channels=output_channels,
                        profile_str=profile_str,
B
barriery 已提交
721 722
                        client_need_profile=need_profile_dict[data_id],
                        profile_set=profile_dict[data_id])
B
barrierye 已提交
723
            except ChannelStopError:
B
barriery 已提交
724
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
725
                self._finalize(is_thread_op)
B
barrierye 已提交
726
                break
B
barriery 已提交
727
            end = int(round(_time() * 1000000))
B
barrierye 已提交
728
            out_time = end - start
B
barriery 已提交
729
            if trace_buffer is not None:
B
barrierye 已提交
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
                trace_que.append({
                    "name": self.name,
                    "actions": {
                        "in": in_time,
                        "prep": prep_time,
                        "midp": midp_time,
                        "postp": postp_time,
                        "out": out_time,
                    }
                })
                while trace_que:
                    info = trace_que[0]
                    try:
                        trace_buffer.put_nowait(info)
                        trace_que.popleft()
                    except Queue.Full:
                        break
B
barriery 已提交
747

B
bug fix  
barriery 已提交
748
    def _initialize(self, is_thread_op, client_type, concurrency_idx):
B
barriery 已提交
749 750 751 752 753 754 755
        if is_thread_op:
            with self._for_init_op_lock:
                if not self._succ_init_op:
                    # for the threaded version of Op, each thread cannot get its concurrency_idx
                    self.concurrency_idx = None
                    # init client
                    self.client = self.init_client(
B
barriery 已提交
756 757
                        client_type, self._client_config,
                        self._server_endpoints, self._fetch_names)
B
barriery 已提交
758 759 760 761
                    # user defined
                    self.init_op()
                    self._succ_init_op = True
                    self._succ_close_op = False
B
bug fix  
barriery 已提交
762 763 764
        else:
            self.concurrency_idx = concurrency_idx
            # init client
B
barriery 已提交
765 766 767
            self.client = self.init_client(client_type, self._client_config,
                                           self._server_endpoints,
                                           self._fetch_names)
B
bug fix  
barriery 已提交
768 769
            # user defined
            self.init_op()
B
barriery 已提交
770

B
barriery 已提交
771 772 773 774 775
        # use a separate TimeProfiler per thread or process
        profiler = TimeProfiler()
        profiler.enable(True)
        return profiler

B
barriery 已提交
776 777 778 779 780 781 782 783
    def _finalize(self, is_thread_op):
        if is_thread_op:
            with self._for_close_op_lock:
                if not self._succ_close_op:
                    self._profiler = None
                    self.client = None
                    self._succ_init_op = False
                    self._succ_close_op = True
784 785 786 787 788

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
789 790 791
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
792
    def __init__(self):
B
barriery 已提交
793 794
        # PipelineService.name = "@DAGExecutor"
        super(RequestOp, self).__init__(name="@DAGExecutor", input_ops=[])
B
barrierye 已提交
795
        # init op
796
        try:
797
            self.init_op()
798
        except Exception as e:
B
barriery 已提交
799
            _LOGGER.critical("Op(Request) Failed to init: {}".format(e))
800
            os._exit(-1)
B
barrierye 已提交
801 802 803 804

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
805 806
            data = request.value[idx]
            try:
B
barriery 已提交
807 808 809
                evaled_data = eval(data)
                if isinstance(evaled_data, np.ndarray):
                    data = evaled_data
B
barrierye 已提交
810 811 812
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
813 814 815 816 817 818
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
819
    def __init__(self, input_ops):
B
barriery 已提交
820 821
        super(ResponseOp, self).__init__(
            name="@DAGExecutor", input_ops=input_ops)
B
barrierye 已提交
822
        # init op
823
        try:
824
            self.init_op()
825
        except Exception as e:
B
barriery 已提交
826 827
            _LOGGER.critical("Op(ResponseOp) Failed to init: {}".format(
                e, exc_info=True))
828
            os._exit(-1)
B
barrierye 已提交
829 830 831 832 833 834 835 836 837

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
B
barrierye 已提交
838
                np.set_printoptions(threshold=sys.maxsize)
B
barrierye 已提交
839 840 841 842 843 844 845 846 847 848 849
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
B
barriery 已提交
850 851 852
                        _LOGGER.error("(logid={}) Failed to pack RPC "
                                      "response package: {}".format(
                                          channeldata.id, resp.error_info))
B
barrierye 已提交
853 854 855 856 857 858
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
B
barriery 已提交
859 860 861 862
                    "error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error("(logid={}) Failed to pack RPC response"
                              " package: {}".format(channeldata.id,
                                                    resp.error_info))
B
barrierye 已提交
863 864 865
        else:
            resp.error_info = channeldata.error_info
        return resp
866 867 868 869 870 871 872


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
873
            name=name, input_ops=None, concurrency=concurrency)
874 875 876 877 878
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
879
    def _actual_pred_op_names(self, op):
B
barriery 已提交
880
        # can use disjoint-set, but it's not necessary
B
barrierye 已提交
881 882 883 884 885 886 887
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

888 889
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
890
            _LOGGER.critical(
B
barriery 已提交
891 892 893
                self._log("Failed to add output_channel: output_channel"
                          " must be Channel type, not {}".format(
                              type(channel))))
894
            os._exit(-1)
895
        for op in self._virtual_pred_ops:
B
barrierye 已提交
896 897
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
898
        self._outputs.append(channel)
D
dongdaxiang 已提交
899

900
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
901
             is_thread_op):
902
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
903 904 905
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

906 907 908 909 910 911 912
        batch_generator = self._auto_batching_generator(
            input_channel=input_channel,
            op_name=self.name,
            batch_size=1,
            timeout=None,
            log_func=log)

B
barrierye 已提交
913 914
        while True:
            try:
915
                channeldata_dict_batch = next(batch_generator)
B
barrierye 已提交
916
            except ChannelStopError:
B
barriery 已提交
917
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
918
                self._finalize(is_thread_op)
B
barrierye 已提交
919
                break
D
dongdaxiang 已提交
920

B
barrierye 已提交
921
            try:
922 923 924 925
                for channeldata_dict in channeldata_dict_batch:
                    for name, data in channeldata_dict.items():
                        self._push_to_output_channels(
                            data, channels=output_channels, name=name)
B
barrierye 已提交
926
            except ChannelStopError:
B
barriery 已提交
927
                _LOGGER.debug("{} Stop.".format(op_info_prefix))
928
                self._finalize(is_thread_op)
B
barrierye 已提交
929
                break