NEW_OPERATOR.md 5.5 KB
Newer Older
D
Dong Daxiang 已提交
1 2
# How to write an general operator?

3 4
([简体中文](./NEW_OPERATOR_CN.md)|English)

D
Dong Daxiang 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
In this document, we mainly focus on how to develop a new server side operator for PaddleServing. Before we start to write a new operator, let's look at some sample code to get the basic idea of writing a new operator for server. We assume you have known the basic computation logic on server side of PaddleServing, please reference to []() if you do not know much about it. The following code can be visited at `core/general-server/op` of Serving repo.

``` c++
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle_inference_api.h"  // NOLINT
#endif
#include "core/general-server/general_model_service.pb.h"
#include "core/general-server/op/general_infer_helper.h"

namespace baidu {
namespace paddle_serving {
namespace serving {

class GeneralInferOp
    : public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
 public:
  typedef std::vector<paddle::PaddleTensor> TensorVector;

  DECLARE_OP(GeneralInferOp);

  int inference();

};

}  // namespace serving
}  // namespace paddle_serving
}  // namespace baidu
```

## Define an operator

The header file above declares a PaddleServing operator called `GeneralInferOp`. At runtime, the function `int inference()` will be called. Usually we define a server side operator to be a subclass of`baidu::paddle_serving::predictor::OpWithChannel`, and `GeneralBlob` data structure is used. 

## Use `GeneralBlob`  between operators

`GeneralBlob` is a data structure that can be used between server side operators. The `tensor_vector` is the most important data structure in `GeneralBlob`. An operator on server side can have multiple `paddle::PaddleTensor` as inputs, and have multiple `paddle::PaddleTensor` as outputs. In particular, `tensor_vector` can be feed into Paddle inference engine directly with zero copy.

``` c++
struct GeneralBlob {
  std::vector<paddle::PaddleTensor> tensor_vector;
  int64_t time_stamp[20];
  int p_size = 0;

  int _batch_size;

  void Clear() {
    size_t tensor_count = tensor_vector.size();
    for (size_t ti = 0; ti < tensor_count; ++ti) {
      tensor_vector[ti].shape.clear();
    }
    tensor_vector.clear();
  }

  int SetBatchSize(int batch_size) { _batch_size = batch_size; }

  int GetBatchSize() const { return _batch_size; }
  std::string ShortDebugString() const { return "Not implemented!"; }
};
```

### Implement `int Inference()`

``` c++
int GeneralInferOp::inference() {
  VLOG(2) << "Going to run inference";
  const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
  VLOG(2) << "Get precedent op name: " << pre_name();
  GeneralBlob *output_blob = mutable_data<GeneralBlob>();

  if (!input_blob) {
    LOG(ERROR) << "Failed mutable depended argument, op:" << pre_name();
    return -1;
  }

  const TensorVector *in = &input_blob->tensor_vector;
  TensorVector *out = &output_blob->tensor_vector;
  int batch_size = input_blob->GetBatchSize();
  VLOG(2) << "input batch size: " << batch_size;

  output_blob->SetBatchSize(batch_size);

  VLOG(2) << "infer batch size: " << batch_size;

  Timer timeline;
  int64_t start = timeline.TimeStampUS();
  timeline.Start();

  if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
    LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
    return -1;
  }

  int64_t end = timeline.TimeStampUS();
  CopyBlobInfo(input_blob, output_blob);
  AddBlobInfo(output_blob, start);
  AddBlobInfo(output_blob, end);
  return 0;
}
DEFINE_OP(GeneralInferOp);
```

`input_blob` and `output_blob` both have multiple `paddle::PaddleTensor`, and the Paddle Inference library can be called through `InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)`. Most of the other code in this function is about profiling, we may remove redudant code in the future as well.

Basically, the above code can implement a new operator. If you want to visit dictionary resource, you can reference `core/predictor/framework/resource.cpp` to add global visible resources. The initialization of resources is executed at the runtime of starting server.

## Define Python API

After you have defined a C++ operator on server side for Paddle Serving, the last step is to add a registration in Python API for PaddleServing server API, `python/paddle_serving_server/__init__.py` in the repo has the code piece.

``` c++
self.op_dict = {
            "general_infer": "GeneralInferOp",
            "general_reader": "GeneralReaderOp",
            "general_response": "GeneralResponseOp",
            "general_text_reader": "GeneralTextReaderOp",
            "general_text_response": "GeneralTextResponseOp",
            "general_single_kv": "GeneralSingleKVOp",
            "general_dist_kv": "GeneralDistKVOp"
        }
```