operator.py 21.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
D
dongdaxiang 已提交
15

16 17 18 19 20 21
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout
22
import os
B
barrierye 已提交
23
import sys
B
barrierye 已提交
24
from numpy import *
25

B
barrierye 已提交
26
from .proto import pipeline_service_pb2
B
barrierye 已提交
27 28
from .channel import (ThreadChannel, ProcessChannel, ChannelDataEcode,
                      ChannelData, ChannelDataType, ChannelStopError)
B
barrierye 已提交
29
from .util import NameGenerator
B
barrierye 已提交
30
from .profiler import TimeProfiler
31

W
wangjiawei04 已提交
32
_LOGGER = logging.getLogger()
B
barrierye 已提交
33 34
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
35 36 37

class Op(object):
    def __init__(self,
B
barrierye 已提交
38
                 name=None,
D
dongdaxiang 已提交
39 40
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
41 42
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
43 44 45
                 concurrency=1,
                 timeout=-1,
                 retry=1):
B
barrierye 已提交
46
        if name is None:
B
barrierye 已提交
47
            name = _op_name_gen.next()
48
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
49
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
50
        self.set_input_ops(input_ops)
B
barrierye 已提交
51 52

        self._server_endpoints = server_endpoints
53
        self.with_serving = False
B
barrierye 已提交
54
        if len(self._server_endpoints) != 0:
55
            self.with_serving = True
B
barrierye 已提交
56 57 58
        self._client_config = client_config
        self._fetch_names = fetch_list

59 60 61 62
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
B
barrierye 已提交
63 64

        self._use_profile = False
65

B
barrierye 已提交
66 67
        # only for multithread
        self._for_init_op_lock = threading.Lock()
B
barrierye 已提交
68
        self._for_close_op_lock = threading.Lock()
B
barrierye 已提交
69
        self._succ_init_op = False
B
barrierye 已提交
70
        self._succ_close_op = False
B
barrierye 已提交
71

B
barrierye 已提交
72
    def use_profiler(self, use_profile):
B
barrierye 已提交
73
        self._use_profile = use_profile
74 75 76 77 78 79

    def _profiler_record(self, string):
        if self._profiler is None:
            return
        self._profiler.record(string)

B
barrierye 已提交
80 81
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
82
        if self.with_serving == False:
B
barrierye 已提交
83
            _LOGGER.debug("{} no client".format(self.name))
B
barrierye 已提交
84
            return None
B
barrierye 已提交
85 86
        _LOGGER.debug("{} client_config: {}".format(self.name, client_config))
        _LOGGER.debug("{} fetch_names: {}".format(self.name, fetch_names))
87
        if client_type == 'brpc':
B
barrierye 已提交
88 89
            client = Client()
            client.load_client_config(client_config)
90
        elif client_type == 'grpc':
B
barrierye 已提交
91
            client = MultiLangClient()
92 93
        else:
            raise ValueError("unknow client type: {}".format(client_type))
B
barrierye 已提交
94
        client.connect(server_endpoints)
95
        self._fetch_names = fetch_names
B
barrierye 已提交
96
        return client
97 98 99 100 101 102 103 104 105 106 107 108 109 110

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
111

112 113 114 115 116 117 118
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
119

B
barrierye 已提交
120 121 122 123 124
    def _clean_input_channel(self):
        self._input = None

    def _get_input_channel(self):
        return self._input
D
dongdaxiang 已提交
125

126 127 128 129 130 131 132
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
133

B
barrierye 已提交
134 135 136 137 138 139
    def _clean_output_channels(self):
        self._outputs = []

    def _get_output_channels(self):
        return self._outputs

W
wangjiawei04 已提交
140
    def preprocess(self, input_dicts):
B
barrierye 已提交
141
        # multiple previous Op
B
barrierye 已提交
142
        if len(input_dicts) != 1:
143
            raise NotImplementedError(
B
barrierye 已提交
144
                'this Op has multiple previous inputs. Please override this func.'
145
            )
D
dongdaxiang 已提交
146

B
barrierye 已提交
147 148
        (_, input_dict), = input_dicts.items()
        return input_dict
B
barrierye 已提交
149

B
barrierye 已提交
150
    def process(self, feed_dict):
B
barrierye 已提交
151 152 153 154
        err, err_info = ChannelData.check_npdata(feed_dict)
        if err != 0:
            raise NotImplementedError(
                "{} Please override preprocess func.".format(err_info))
B
barrierye 已提交
155
        call_result = self.client.predict(
B
barrierye 已提交
156 157
            feed=feed_dict, fetch=self._fetch_names)
        _LOGGER.debug(self._log("get call_result"))
158 159
        return call_result

W
wangjiawei04 已提交
160
    def postprocess(self, input_dict, fetch_dict):
B
barrierye 已提交
161
        return fetch_dict
D
dongdaxiang 已提交
162

B
barrierye 已提交
163
    def _parse_channeldata(self, channeldata_dict):
164
        data_id, error_channeldata = None, None
B
barrierye 已提交
165 166 167 168 169 170 171 172 173 174 175
        parsed_data = {}

        key = list(channeldata_dict.keys())[0]
        data_id = channeldata_dict[key].id

        for name, data in channeldata_dict.items():
            if data.ecode != ChannelDataEcode.OK.value:
                error_channeldata = data
                break
            parsed_data[name] = data.parse()
        return data_id, error_channeldata, parsed_data
176 177 178 179 180 181 182

    def _push_to_output_channels(self, data, channels, name=None):
        if name is None:
            name = self.name
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
183
    def start_with_process(self, client_type):
184
        proces = []
B
barrierye 已提交
185
        for concurrency_idx in range(self.concurrency):
186 187
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
188
                args=(concurrency_idx, self._get_input_channel(),
189
                      self._get_output_channels(), client_type, False))
190 191 192 193
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
194
    def start_with_thread(self, client_type):
195
        threads = []
B
barrierye 已提交
196
        for concurrency_idx in range(self.concurrency):
197 198
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
199
                args=(concurrency_idx, self._get_input_channel(),
200
                      self._get_output_channels(), client_type, True))
201 202 203 204
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
205
    def init_op(self):
B
barrierye 已提交
206 207
        pass

W
wangjiawei04 已提交
208
    def _run_preprocess(self, parsed_data, data_id, log_func):
209 210
        preped_data, error_channeldata = None, None
        try:
W
wangjiawei04 已提交
211
            preped_data = self.preprocess(parsed_data)
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        except NotImplementedError as e:
            # preprocess function not implemented
            error_info = log_func(e)
            _LOGGER.error(error_info)
            error_channeldata = ChannelData(
                ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                error_info=error_info,
                data_id=data_id)
        except TypeError as e:
            # Error type in channeldata.datatype
            error_info = log_func(e)
            _LOGGER.error(error_info)
            error_channeldata = ChannelData(
                ecode=ChannelDataEcode.TYPE_ERROR.value,
                error_info=error_info,
                data_id=data_id)
        except Exception as e:
            error_info = log_func(e)
            _LOGGER.error(error_info)
            error_channeldata = ChannelData(
                ecode=ChannelDataEcode.UNKNOW.value,
                error_info=error_info,
                data_id=data_id)
        return preped_data, error_channeldata

B
barrierye 已提交
237
    def _run_process(self, preped_data, data_id, log_func):
238 239 240 241 242
        midped_data, error_channeldata = None, None
        if self.with_serving:
            ecode = ChannelDataEcode.OK.value
            if self._timeout <= 0:
                try:
B
barrierye 已提交
243
                    midped_data = self.process(preped_data)
244 245 246 247 248 249 250 251
                except Exception as e:
                    ecode = ChannelDataEcode.UNKNOW.value
                    error_info = log_func(e)
                    _LOGGER.error(error_info)
            else:
                for i in range(self._retry):
                    try:
                        midped_data = func_timeout.func_timeout(
B
barrierye 已提交
252
                            self._timeout, self.process, args=(preped_data, ))
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
                    except func_timeout.FunctionTimedOut as e:
                        if i + 1 >= self._retry:
                            ecode = ChannelDataEcode.TIMEOUT.value
                            error_info = log_func(e)
                            _LOGGER.error(error_info)
                        else:
                            _LOGGER.warn(
                                log_func("timeout, retry({})".format(i + 1)))
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log_func(e)
                        _LOGGER.error(error_info)
                        break
                    else:
                        break
            if ecode != ChannelDataEcode.OK.value:
                error_channeldata = ChannelData(
                    ecode=ecode, error_info=error_info, data_id=data_id)
            elif midped_data is None:
                # op client return None
                error_channeldata = ChannelData(
                    ecode=ChannelDataEcode.CLIENT_ERROR.value,
                    error_info=log_func(
                        "predict failed. pls check the server side."),
                    data_id=data_id)
        else:
            midped_data = preped_data
        return midped_data, error_channeldata

W
wangjiawei04 已提交
282
    def _run_postprocess(self, input_dict, midped_data, data_id, log_func):
283 284
        output_data, error_channeldata = None, None
        try:
W
wangjiawei04 已提交
285
            postped_data = self.postprocess(input_dict, midped_data)
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
        except Exception as e:
            error_info = log_func(e)
            _LOGGER.error(error_info)
            error_channeldata = ChannelData(
                ecode=ChannelDataEcode.UNKNOW.value,
                error_info=error_info,
                data_id=data_id)
            return output_data, error_channeldata

        if not isinstance(postped_data, dict):
            error_info = log_func("output of postprocess funticon must be " \
                    "dict type, but get {}".format(type(postped_data)))
            _LOGGER.error(error_info)
            error_channeldata = ChannelData(
                ecode=ChannelDataEcode.UNKNOW.value,
                error_info=error_info,
                data_id=data_id)
            return output_data, error_channeldata

        err, _ = ChannelData.check_npdata(postped_data)
        if err == 0:
            output_data = ChannelData(
                ChannelDataType.CHANNEL_NPDATA.value,
                npdata=postped_data,
                data_id=data_id)
        else:
            output_data = ChannelData(
                ChannelDataType.DICT.value,
                dictdata=postped_data,
                data_id=data_id)
        return output_data, error_channeldata

318 319
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
             use_multithread):
B
barrierye 已提交
320 321 322 323 324 325
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

326
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
327
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
328
        tid = threading.current_thread().ident
B
barrierye 已提交
329

B
barrierye 已提交
330
        # init op
331
        self.concurrency_idx = concurrency_idx
B
barrierye 已提交
332 333 334 335
        try:
            if use_multithread:
                with self._for_init_op_lock:
                    if not self._succ_init_op:
B
barrierye 已提交
336 337 338 339
                        # init profiler
                        self._profiler = TimeProfiler()
                        self._profiler.enable(self._use_profile)
                        # init client
B
barrierye 已提交
340
                        self.client = self.init_client(
B
barrierye 已提交
341 342 343
                            client_type, self._client_config,
                            self._server_endpoints, self._fetch_names)
                        # user defined
344
                        self.init_op()
B
barrierye 已提交
345
                        self._succ_init_op = True
B
barrierye 已提交
346
                        self._succ_close_op = False
B
barrierye 已提交
347
            else:
B
barrierye 已提交
348 349 350 351
                # init profiler
                self._profiler = TimeProfiler()
                self._profiler.enable(self._use_profile)
                # init client
B
barrierye 已提交
352 353 354
                self.client = self.init_client(client_type, self._client_config,
                                               self._server_endpoints,
                                               self._fetch_names)
B
barrierye 已提交
355
                # user defined
356
                self.init_op()
B
barrierye 已提交
357 358 359
        except Exception as e:
            _LOGGER.error(log(e))
            os._exit(-1)
360

B
barrierye 已提交
361
        while True:
B
barrierye 已提交
362
            #self._profiler_record("get#{}_0".format(op_info_prefix))
B
barrierye 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
                _LOGGER.info(log("stop."))
                with self._for_close_op_lock:
                    if not self._succ_close_op:
                        self._clean_input_channel()
                        self._clean_output_channels()
                        self._profiler = None
                        self.client = None
                        self._succ_init_op = False
                        self._succ_close_op = True
                break
B
barrierye 已提交
376
            #self._profiler_record("get#{}_1".format(op_info_prefix))
B
barrierye 已提交
377
            _LOGGER.debug(log("input_data: {}".format(channeldata_dict)))
378

B
barrierye 已提交
379 380
            data_id, error_channeldata, parsed_data = self._parse_channeldata(
                channeldata_dict)
381 382
            # error data in predecessor Op
            if error_channeldata is not None:
B
barrierye 已提交
383 384 385 386 387
                try:
                    self._push_to_output_channels(error_channeldata,
                                                  output_channels)
                except ChannelStopError:
                    _LOGGER.info(log("stop."))
388 389 390
                continue

            # preprecess
B
barrierye 已提交
391
            self._profiler_record("prep#{}_0".format(op_info_prefix))
W
wangjiawei04 已提交
392 393
            preped_data, error_channeldata = self._run_preprocess(parsed_data,
                                                                  data_id, log)
B
barrierye 已提交
394
            self._profiler_record("prep#{}_1".format(op_info_prefix))
395
            if error_channeldata is not None:
B
barrierye 已提交
396 397 398 399 400
                try:
                    self._push_to_output_channels(error_channeldata,
                                                  output_channels)
                except ChannelStopError:
                    _LOGGER.info(log("stop."))
401 402
                continue

B
barrierye 已提交
403
            # process
B
barrierye 已提交
404
            self._profiler_record("midp#{}_0".format(op_info_prefix))
B
barrierye 已提交
405 406
            midped_data, error_channeldata = self._run_process(preped_data,
                                                               data_id, log)
B
barrierye 已提交
407
            self._profiler_record("midp#{}_1".format(op_info_prefix))
408
            if error_channeldata is not None:
B
barrierye 已提交
409 410 411 412 413
                try:
                    self._push_to_output_channels(error_channeldata,
                                                  output_channels)
                except ChannelStopError:
                    _LOGGER.info(log("stop."))
414
                continue
415 416

            # postprocess
B
barrierye 已提交
417
            self._profiler_record("postp#{}_0".format(op_info_prefix))
W
wangjiawei04 已提交
418
            output_data, error_channeldata = self._run_postprocess(
W
wangjiawei04 已提交
419
                parsed_data, midped_data, data_id, log)
B
barrierye 已提交
420
            self._profiler_record("postp#{}_1".format(op_info_prefix))
421
            if error_channeldata is not None:
B
barrierye 已提交
422 423 424 425 426
                try:
                    self._push_to_output_channels(error_channeldata,
                                                  output_channels)
                except ChannelStopError:
                    _LOGGER.info(log("stop."))
427
                continue
428

B
barrierye 已提交
429 430 431 432 433 434
            if self._use_profile:
                profile_str = self._profiler.gen_profile_str()
                sys.stderr.write(profile_str)
                #TODO
                #output_data.add_profile(profile_str)

435
            # push data to channel (if run succ)
B
barrierye 已提交
436
            #self._profiler_record("push#{}_0".format(op_info_prefix))
B
barrierye 已提交
437 438 439 440 441
            try:
                self._push_to_output_channels(output_data, output_channels)
            except ChannelStopError:
                _LOGGER.info(log("stop."))
                break
B
barrierye 已提交
442
            #self._profiler_record("push#{}_1".format(op_info_prefix))
B
barrierye 已提交
443
            #self._profiler.print_profile()
444 445 446 447 448

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
449 450 451
class RequestOp(Op):
    """ RequestOp do not run preprocess, process, postprocess. """

B
barrierye 已提交
452
    def __init__(self, concurrency=1):
B
barrierye 已提交
453
        # PipelineService.name = "@G"
B
barrierye 已提交
454
        super(RequestOp, self).__init__(
B
barrierye 已提交
455
            name="@G", input_ops=[], concurrency=concurrency)
B
barrierye 已提交
456
        # init op
457
        try:
458
            self.init_op()
459
        except Exception as e:
B
bug fix  
barrierye 已提交
460
            _LOGGER.error(e)
461
            os._exit(-1)
B
barrierye 已提交
462 463 464 465

    def unpack_request_package(self, request):
        dictdata = {}
        for idx, key in enumerate(request.key):
B
barrierye 已提交
466 467 468 469 470 471
            data = request.value[idx]
            try:
                data = eval(data)
            except Exception as e:
                pass
            dictdata[key] = data
B
barrierye 已提交
472 473 474 475 476 477 478 479
        return dictdata


class ResponseOp(Op):
    """ ResponseOp do not run preprocess, process, postprocess. """

    def __init__(self, input_ops, concurrency=1):
        super(ResponseOp, self).__init__(
B
barrierye 已提交
480
            name="@R", input_ops=input_ops, concurrency=concurrency)
B
barrierye 已提交
481
        # init op
482
        try:
483
            self.init_op()
484
        except Exception as e:
B
bug fix  
barrierye 已提交
485
            _LOGGER.error(e)
486
            os._exit(-1)
B
barrierye 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517

    def pack_response_package(self, channeldata):
        resp = pipeline_service_pb2.Response()
        resp.ecode = channeldata.ecode
        if resp.ecode == ChannelDataEcode.OK.value:
            if channeldata.datatype == ChannelDataType.CHANNEL_NPDATA.value:
                feed = channeldata.parse()
                # ndarray to string:
                # https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
                for name, var in feed.items():
                    resp.value.append(var.__repr__())
                    resp.key.append(name)
            elif channeldata.datatype == ChannelDataType.DICT.value:
                feed = channeldata.parse()
                for name, var in feed.items():
                    if not isinstance(var, str):
                        resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                        resp.error_info = self._log(
                            "fetch var type must be str({}).".format(
                                type(var)))
                        break
                    resp.value.append(var)
                    resp.key.append(name)
            else:
                resp.ecode = ChannelDataEcode.TYPE_ERROR.value
                resp.error_info = self._log(
                    "Error type({}) in datatype.".format(channeldata.datatype))
                _LOGGER.error(resp.error_info)
        else:
            resp.error_info = channeldata.error_info
        return resp
518 519 520 521 522 523 524


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
525
            name=name, input_ops=None, concurrency=concurrency)
526 527 528 529 530
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

B
barrierye 已提交
531 532 533 534 535 536 537 538
    def _actual_pred_op_names(self, op):
        if not isinstance(op, VirtualOp):
            return [op.name]
        names = []
        for x in op._virtual_pred_ops:
            names.extend(self._actual_pred_op_names(x))
        return names

539 540 541 542 543 544
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
B
barrierye 已提交
545 546
            for op_name in self._actual_pred_op_names(op):
                channel.add_producer(op_name)
547
        self._outputs.append(channel)
D
dongdaxiang 已提交
548

549 550
    def _run(self, concurrency_idx, input_channel, output_channels, client_type,
             use_multithread):
B
barrierye 已提交
551 552 553 554 555 556
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

557
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
558 559 560
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

B
barrierye 已提交
561 562 563 564 565 566
        while True:
            try:
                channeldata_dict = input_channel.front(self.name)
            except ChannelStopError:
                _LOGGER.info(log("stop."))
                break
D
dongdaxiang 已提交
567

B
barrierye 已提交
568 569 570 571 572 573 574
            try:
                for name, data in channeldata_dict.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            except ChannelStopError:
                _LOGGER.info(log("stop."))
                break