operator.py 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
D
dongdaxiang 已提交
15

16 17 18 19 20 21 22 23
import threading
import multiprocessing
from paddle_serving_client import MultiLangClient, Client
from concurrent import futures
import logging
import func_timeout

from .channel import ThreadChannel, ProcessChannel, ChannelDataEcode, ChannelData, ChannelDataType
B
barrierye 已提交
24
from .util import NameGenerator
25

B
barrierye 已提交
26 27
_op_name_gen = NameGenerator("Op")

D
dongdaxiang 已提交
28 29 30

class Op(object):
    def __init__(self,
B
barrierye 已提交
31
                 name=None,
D
dongdaxiang 已提交
32 33
                 input_ops=[],
                 server_endpoints=[],
B
barrierye 已提交
34 35
                 fetch_list=[],
                 client_config=None,
D
dongdaxiang 已提交
36 37 38
                 concurrency=1,
                 timeout=-1,
                 retry=1):
B
barrierye 已提交
39
        if name is None:
B
barrierye 已提交
40
            name = _op_name_gen.next()
41 42
        self._is_run = False
        self.name = name  # to identify the type of OP, it must be globally unique
B
barrierye 已提交
43
        self.concurrency = concurrency  # amount of concurrency
B
barrierye 已提交
44
        self.set_input_ops(input_ops)
B
barrierye 已提交
45 46

        self._server_endpoints = server_endpoints
47
        self.with_serving = False
B
barrierye 已提交
48
        if len(self._server_endpoints) != 0:
49
            self.with_serving = True
B
barrierye 已提交
50 51 52
        self._client_config = client_config
        self._fetch_names = fetch_list

53 54 55 56 57 58 59 60 61 62 63 64 65 66
        self._timeout = timeout
        self._retry = max(1, retry)
        self._input = None
        self._outputs = []
        self._profiler = None

    def init_profiler(self, profiler):
        self._profiler = profiler

    def _profiler_record(self, string):
        if self._profiler is None:
            return
        self._profiler.record(string)

B
barrierye 已提交
67 68
    def init_client(self, client_type, client_config, server_endpoints,
                    fetch_names):
69 70 71 72 73 74 75 76 77 78 79 80
        if self.with_serving == False:
            logging.debug("{} no client".format(self.name))
            return
        logging.debug("{} client_config: {}".format(self.name, client_config))
        logging.debug("{} fetch_names: {}".format(self.name, fetch_names))
        if client_type == 'brpc':
            self._client = Client()
        elif client_type == 'grpc':
            self._client = MultiLangClient()
        else:
            raise ValueError("unknow client type: {}".format(client_type))
        self._client.load_client_config(client_config)
B
barrierye 已提交
81
        self._client.connect(server_endpoints)
82 83
        self._fetch_names = fetch_names

B
barrierye 已提交
84
    def _get_input_channel(self):
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        return self._input

    def get_input_ops(self):
        return self._input_ops

    def set_input_ops(self, ops):
        if not isinstance(ops, list):
            ops = [] if ops is None else [ops]
        self._input_ops = []
        for op in ops:
            if not isinstance(op, Op):
                raise TypeError(
                    self._log('input op must be Op type, not {}'.format(
                        type(op))))
            self._input_ops.append(op)
D
dongdaxiang 已提交
100

101 102 103 104 105 106 107
    def add_input_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('input channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_consumer(self.name)
        self._input = channel
D
dongdaxiang 已提交
108

B
barrierye 已提交
109
    def _get_output_channels(self):
110
        return self._outputs
D
dongdaxiang 已提交
111

112 113 114 115 116 117 118
    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        channel.add_producer(self.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
119

120
    def preprocess(self, channeldata):
B
barrierye 已提交
121
        # multiple previous Op
122 123 124 125
        if isinstance(channeldata, dict):
            raise NotImplementedError(
                'this Op has multiple previous inputs. Please override this method'
            )
D
dongdaxiang 已提交
126

B
barrierye 已提交
127 128 129 130 131 132 133 134 135 136 137
        if channeldata.datatype is not ChannelDataType.CHANNEL_NPDATA.value:
            raise NotImplementedError(
                'datatype in channeldata is not CHANNEL_NPDATA({}). '
                'Please override this method'.format(channeldata.datatype))

        # get numpy dict
        feed_data = channeldata.parse()
        return feed_data

    def process(self, feed_dict):
        if not isinstance(feed_dict, dict):
138 139
            raise Exception(
                self._log(
B
barrierye 已提交
140 141 142
                    'feed_dict must be dict type(the output of preprocess()), but get {}'.
                    format(type(feed_dict))))
        logging.debug(self._log('feed_dict: {}'.format(feed_dict)))
143 144 145
        logging.debug(self._log('fetch: {}'.format(self._fetch_names)))
        if isinstance(self._client, MultiLangClient):
            call_result = self._client.predict(
B
barrierye 已提交
146 147
                feed=feed_dict, fetch=self._fetch_names)
            logging.debug(self._log("get call_result"))
148 149
        else:
            call_result = self._client.predict(
B
barrierye 已提交
150 151
                feed=feed_dict, fetch=self._fetch_names)
            logging.debug(self._log("get fetch_dict"))
152 153
        return call_result

B
barrierye 已提交
154 155
    def postprocess(self, fetch_dict):
        return fetch_dict
D
dongdaxiang 已提交
156 157

    def stop(self):
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
        self._input.stop()
        for channel in self._outputs:
            channel.stop()
        self._is_run = False

    def _parse_channeldata(self, channeldata):
        data_id, error_channeldata = None, None
        if isinstance(channeldata, dict):
            parsed_data = {}
            key = list(channeldata.keys())[0]
            data_id = channeldata[key].id
            for _, data in channeldata.items():
                if data.ecode != ChannelDataEcode.OK.value:
                    error_channeldata = data
                    break
        else:
            data_id = channeldata.id
            if channeldata.ecode != ChannelDataEcode.OK.value:
                error_channeldata = channeldata
        return data_id, error_channeldata

    def _push_to_output_channels(self, data, channels, name=None):
        if name is None:
            name = self.name
        for channel in channels:
            channel.push(data, name)

B
barrierye 已提交
185
    def start_with_process(self, client_type):
186
        proces = []
B
barrierye 已提交
187
        for concurrency_idx in range(self.concurrency):
188 189
            p = multiprocessing.Process(
                target=self._run,
B
barrierye 已提交
190 191
                args=(concurrency_idx, self._get_input_channel(),
                      self._get_output_channels(), client_type))
192 193 194 195
            p.start()
            proces.append(p)
        return proces

B
barrierye 已提交
196
    def start_with_thread(self, client_type):
197
        threads = []
B
barrierye 已提交
198
        for concurrency_idx in range(self.concurrency):
199 200
            t = threading.Thread(
                target=self._run,
B
barrierye 已提交
201 202
                args=(concurrency_idx, self._get_input_channel(),
                      self._get_output_channels(), client_type))
203 204 205 206
            t.start()
            threads.append(t)
        return threads

B
barrierye 已提交
207 208 209 210 211 212 213 214
    def _run(self, concurrency_idx, input_channel, output_channels,
             client_type):
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

215
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
216
        log = get_log_func(op_info_prefix)
B
barrierye 已提交
217
        tid = threading.current_thread().ident
B
barrierye 已提交
218 219 220 221 222 223

        # create client based on client_type
        self.init_client(client_type, self._client_config,
                         self._server_endpoints, self._fetch_names)

        self._is_run = True
224
        while self._is_run:
B
barrierye 已提交
225
            self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
226
            channeldata = input_channel.front(self.name)
B
barrierye 已提交
227
            self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
228 229 230 231 232 233 234 235 236 237 238 239
            logging.debug(log("input_data: {}".format(channeldata)))

            data_id, error_channeldata = self._parse_channeldata(channeldata)

            # error data in predecessor Op
            if error_channeldata is not None:
                self._push_to_output_channels(error_channeldata,
                                              output_channels)
                continue

            # preprecess
            try:
B
barrierye 已提交
240 241
                self._profiler_record("{}-prep#{}_0".format(op_info_prefix,
                                                            tid))
242
                preped_data = self.preprocess(channeldata)
B
barrierye 已提交
243 244
                self._profiler_record("{}-prep#{}_1".format(op_info_prefix,
                                                            tid))
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
            except NotImplementedError as e:
                # preprocess function not implemented
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.NOT_IMPLEMENTED.value,
                        error_info=error_info,
                        data_id=data_id),
                    output_channels)
                continue
            except TypeError as e:
                # Error type in channeldata.datatype
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.TYPE_ERROR.value,
                        error_info=error_info,
                        data_id=data_id),
                    output_channels)
                continue
            except Exception as e:
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ChannelDataEcode.UNKNOW.value,
                        error_info=error_info,
                        data_id=data_id),
                    output_channels)
                continue

B
barrierye 已提交
278
            # process
279 280 281
            midped_data = None
            if self.with_serving:
                ecode = ChannelDataEcode.OK.value
B
barrierye 已提交
282 283
                self._profiler_record("{}-midp#{}_0".format(op_info_prefix,
                                                            tid))
284 285
                if self._timeout <= 0:
                    try:
B
barrierye 已提交
286
                        midped_data = self.process(preped_data)
287 288 289 290 291 292 293 294 295
                    except Exception as e:
                        ecode = ChannelDataEcode.UNKNOW.value
                        error_info = log(e)
                        logging.error(error_info)
                else:
                    for i in range(self._retry):
                        try:
                            midped_data = func_timeout.func_timeout(
                                self._timeout,
B
barrierye 已提交
296 297
                                self.process,
                                args=(preped_data, ))
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
                        except func_timeout.FunctionTimedOut as e:
                            if i + 1 >= self._retry:
                                ecode = ChannelDataEcode.TIMEOUT.value
                                error_info = log(e)
                                logging.error(error_info)
                            else:
                                logging.warn(
                                    log("timeout, retry({})".format(i + 1)))
                        except Exception as e:
                            ecode = ChannelDataEcode.UNKNOW.value
                            error_info = log(e)
                            logging.error(error_info)
                            break
                        else:
                            break
                if ecode != ChannelDataEcode.OK.value:
                    self._push_to_output_channels(
                        ChannelData(
                            ecode=ecode, error_info=error_info,
                            data_id=data_id),
                        output_channels)
                    continue
B
barrierye 已提交
320 321
                self._profiler_record("{}-midp#{}_1".format(op_info_prefix,
                                                            tid))
322 323 324 325 326
            else:
                midped_data = preped_data

            # postprocess
            output_data = None
B
barrierye 已提交
327
            self._profiler_record("{}-postp#{}_0".format(op_info_prefix, tid))
B
barrierye 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
            try:
                postped_data = self.postprocess(midped_data)
            except Exception as e:
                ecode = ChannelDataEcode.UNKNOW.value
                error_info = log(e)
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ecode, error_info=error_info, data_id=data_id),
                    output_channels)
                continue
            if not isinstance(postped_data, dict):
                ecode = ChannelDataEcode.TYPE_ERROR.value
                error_info = log("output of postprocess funticon must be " \
                        "dict type, but get {}".format(type(postped_data)))
                logging.error(error_info)
                self._push_to_output_channels(
                    ChannelData(
                        ecode=ecode, error_info=error_info, data_id=data_id),
                    output_channels)
                continue
349

B
barrierye 已提交
350 351
            err, _ = ChannelData.check_npdata(postped_data)
            if err == 0:
352 353 354 355
                output_data = ChannelData(
                    ChannelDataType.CHANNEL_NPDATA.value,
                    npdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
356 357 358 359 360
            else:
                output_data = ChannelData(
                    ChannelDataType.DICT.value,
                    dictdata=postped_data,
                    data_id=data_id)
B
barrierye 已提交
361
            self._profiler_record("{}-postp#{}_1".format(op_info_prefix, tid))
362 363

            # push data to channel (if run succ)
B
barrierye 已提交
364
            self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
365
            self._push_to_output_channels(output_data, output_channels)
B
barrierye 已提交
366
            self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))
367 368 369 370 371

    def _log(self, info):
        return "{} {}".format(self.name, info)


B
barrierye 已提交
372 373 374 375 376
class ReadOp(Op):
    def __init__(self, concurrency=1):
        # PipelineService.name = "#G"
        super(ReadOp, self).__init__(
            name="#G", input_ops=[], concurrency=concurrency)
377 378 379 380 381 382 383


class VirtualOp(Op):
    ''' For connecting two channels. '''

    def __init__(self, name, concurrency=1):
        super(VirtualOp, self).__init__(
B
barrierye 已提交
384
            name=name, input_ops=None, concurrency=concurrency)
385 386 387 388 389 390 391 392 393 394 395 396 397
        self._virtual_pred_ops = []

    def add_virtual_pred_op(self, op):
        self._virtual_pred_ops.append(op)

    def add_output_channel(self, channel):
        if not isinstance(channel, (ThreadChannel, ProcessChannel)):
            raise TypeError(
                self._log('output channel must be Channel type, not {}'.format(
                    type(channel))))
        for op in self._virtual_pred_ops:
            channel.add_producer(op.name)
        self._outputs.append(channel)
D
dongdaxiang 已提交
398

B
barrierye 已提交
399 400 401 402 403 404 405 406
    def _run(self, concurrency_idx, input_channel, output_channels,
             client_type):
        def get_log_func(op_info_prefix):
            def log_func(info_str):
                return "{} {}".format(op_info_prefix, info_str)

            return log_func

407
        op_info_prefix = "[{}|{}]".format(self.name, concurrency_idx)
B
barrierye 已提交
408 409 410
        log = get_log_func(op_info_prefix)
        tid = threading.current_thread().ident

411 412
        self._is_run = True
        while self._is_run:
B
barrierye 已提交
413
            self._profiler_record("{}-get#{}_0".format(op_info_prefix, tid))
414
            channeldata = input_channel.front(self.name)
B
barrierye 已提交
415
            self._profiler_record("{}-get#{}_1".format(op_info_prefix, tid))
D
dongdaxiang 已提交
416

B
barrierye 已提交
417
            self._profiler_record("{}-push#{}_0".format(op_info_prefix, tid))
418 419 420 421 422 423 424 425 426
            if isinstance(channeldata, dict):
                for name, data in channeldata.items():
                    self._push_to_output_channels(
                        data, channels=output_channels, name=name)
            else:
                self._push_to_output_channels(
                    channeldata,
                    channels=output_channels,
                    name=self._virtual_pred_ops[0].name)
B
barrierye 已提交
427
            self._profiler_record("{}-push#{}_1".format(op_info_prefix, tid))