general_response_op.cpp 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

M
MRXLT 已提交
15
#include "core/general-server/op/general_response_op.h"
16 17
#include <algorithm>
#include <iostream>
M
MRXLT 已提交
18
#include <map>
19 20
#include <memory>
#include <sstream>
M
MRXLT 已提交
21
#include <utility>
22
#include "core/general-server/op/general_infer_helper.h"
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"

namespace baidu {
namespace paddle_serving {
namespace serving {

using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FetchInst;
B
barrierye 已提交
38
using baidu::paddle_serving::predictor::general_model::ModelOutput;
39 40 41
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;

42
int GeneralResponseOp::inference() {
H
HexToString 已提交
43 44 45 46 47 48
  const std::vector<std::string> pre_node_names = pre_names();
  VLOG(2) << "pre node names size: " << pre_node_names.size();
  const GeneralBlob *input_blob;
  uint64_t log_id =
      get_depend_argument<GeneralBlob>(pre_node_names[0])->GetLogId();

W
wangjiawei04 已提交
49 50 51
  const Request *req = dynamic_cast<const Request *>(get_request_message());
  // response inst with only fetch_var_names
  Response *res = mutable_data<Response>();
H
HexToString 已提交
52 53 54 55 56 57 58 59

  Timer timeline;
  // double response_time = 0.0;
  // timeline.Start();
  int64_t start = timeline.TimeStampUS();

  VLOG(2) << "(logid=" << log_id
          << ") start to call load general model_conf op";
W
wangjiawei04 已提交
60 61
  baidu::paddle_serving::predictor::Resource &resource =
      baidu::paddle_serving::predictor::Resource::instance();
H
HexToString 已提交
62 63

  VLOG(2) << "(logid=" << log_id << ") get resource pointer done.";
W
wangjiawei04 已提交
64 65
  std::shared_ptr<PaddleGeneralModelConfig> model_config =
      resource.get_general_model_config();
H
HexToString 已提交
66 67 68 69 70 71

  VLOG(2) << "(logid=" << log_id
          << ") max body size : " << brpc::fLU64::FLAGS_max_body_size;

  std::vector<int> fetch_index;
  fetch_index.resize(req->fetch_var_names_size());
W
wangjiawei04 已提交
72
  for (int i = 0; i < req->fetch_var_names_size(); ++i) {
H
HexToString 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    fetch_index[i] =
        model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
  }

  for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
    const std::string &pre_name = pre_node_names[pi];
    VLOG(2) << "(logid=" << log_id << ") pre names[" << pi << "]: " << pre_name
            << " (" << pre_node_names.size() << ")";
    input_blob = get_depend_argument<GeneralBlob>(pre_name);
    // fprintf(stderr, "input(%s) blob address %x\n", pre_names.c_str(),
    // input_blob);
    if (!input_blob) {
      LOG(ERROR) << "(logid=" << log_id
                 << ") Failed mutable depended argument, op: " << pre_name;
      return -1;
W
wangjiawei04 已提交
88
    }
H
HexToString 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

    const TensorVector *in = &input_blob->tensor_vector;

    ModelOutput *output = res->add_outputs();
    // To get the order of model return values
    output->set_engine_name(pre_name);
    FetchInst *fetch_inst = output->add_insts();

    for (auto &idx : fetch_index) {
      Tensor *tensor = fetch_inst->add_tensor_array();
      if (model_config->_is_lod_fetch[idx]) {
        VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
                << model_config->_fetch_name[idx] << " is lod_tensor";
        for (int k = 0; k < in->at(idx).shape.size(); ++k) {
          VLOG(2) << "(logid=" << log_id << ") shape[" << k
                  << "]: " << in->at(idx).shape[k];
          tensor->add_shape(in->at(idx).shape[k]);
        }
      } else {
        VLOG(2) << "(logid=" << log_id << ") out[" << idx << "] "
                << model_config->_fetch_name[idx] << " is tensor";
        for (int k = 0; k < in->at(idx).shape.size(); ++k) {
          VLOG(2) << "(logid=" << log_id << ") shape[" << k
                  << "]: " << in->at(idx).shape[k];
          tensor->add_shape(in->at(idx).shape[k]);
        }
      }
W
wangjiawei04 已提交
116
    }
H
HexToString 已提交
117 118 119 120 121 122

    int var_idx = 0;
    for (auto &idx : fetch_index) {
      int cap = 1;
      for (int j = 0; j < in->at(idx).shape.size(); ++j) {
        cap *= in->at(idx).shape[j];
W
wangjiawei04 已提交
123
      }
H
HexToString 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

      FetchInst *fetch_p = output->mutable_insts(0);
      auto dtype = in->at(idx).dtype;

      if (dtype == paddle::PaddleDType::INT64) {
        VLOG(2) << "(logid=" << log_id << ") Prepare int64 var ["
                << model_config->_fetch_name[idx] << "].";
        int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
        // from
        // https://stackoverflow.com/questions/15499641/copy-a-stdvector-to-a-repeated-field-from-protobuf-with-memcpy
        // `Swap` method is faster than `{}` method.
        google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
                                                          data_ptr + cap);
        fetch_p->mutable_tensor_array(var_idx)->mutable_int64_data()->Swap(
            &tmp_data);
      } else if (dtype == paddle::PaddleDType::FLOAT32) {
        VLOG(2) << "(logid=" << log_id << ") Prepare float var ["
                << model_config->_fetch_name[idx] << "].";
H
HexToString 已提交
142
        
H
HexToString 已提交
143
        float *data_ptr = static_cast<float *>(in->at(idx).data.data());
H
HexToString 已提交
144 145 146 147
        std::cout<<" response op ---- for"<<std::endl;
        for(int k =0; k<cap; ++k){
          std::cout<< "i am ysl -response op-copy idx = "<< k<< "num = "<< *(data_ptr+k)<<std::endl;
        }
H
HexToString 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        google::protobuf::RepeatedField<float> tmp_data(data_ptr,
                                                        data_ptr + cap);
        fetch_p->mutable_tensor_array(var_idx)->mutable_float_data()->Swap(
            &tmp_data);
      } else if (dtype == paddle::PaddleDType::INT32) {
        VLOG(2) << "(logid=" << log_id << ")Prepare int32 var ["
                << model_config->_fetch_name[idx] << "].";
        int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data());
        google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
                                                          data_ptr + cap);
        fetch_p->mutable_tensor_array(var_idx)->mutable_int_data()->Swap(
            &tmp_data);
      }

      if (model_config->_is_lod_fetch[idx]) {
        if (in->at(idx).lod.size() > 0) {
          for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
            fetch_p->mutable_tensor_array(var_idx)->add_lod(
                in->at(idx).lod[0][j]);
          }
        }
      }

      VLOG(2) << "(logid=" << log_id << ") fetch var ["
              << model_config->_fetch_name[idx] << "] ready";
      var_idx++;
W
wangjiawei04 已提交
174 175
    }
  }
H
HexToString 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

  if (req->profile_server()) {
    int64_t end = timeline.TimeStampUS();
    // TODO(barriery): multi-model profile_time.
    // At present, only the response_op is multi-input, so here we get
    // the profile_time by hard coding. It needs to be replaced with
    // a more elegant way.
    for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
      input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]);
      VLOG(2) << "(logid=" << log_id
              << ") p size for input blob: " << input_blob->p_size;
      int profile_time_idx = -1;
      if (pi == 0) {
        profile_time_idx = 0;
      } else {
        profile_time_idx = input_blob->p_size - 2;
      }
      for (; profile_time_idx < input_blob->p_size; ++profile_time_idx) {
        res->add_profile_time(input_blob->time_stamp[profile_time_idx]);
      }
    }
    // TODO(guru4elephant): find more elegant way to do this
    res->add_profile_time(start);
    res->add_profile_time(end);
  }
H
HexToString 已提交
201 202 203
  std::cout << "GeneralResponseOp    ---ysl" << std::endl;
  LOG(ERROR) << "GeneralResponseOp    ---ysl";
  
204 205
  return 0;
}
206 207

DEFINE_OP(GeneralResponseOp);
208 209 210

}  // namespace serving
}  // namespace paddle_serving
H
HexToString 已提交
211
}  // namespace baidu