general_copy_op.cpp 3.7 KB
Newer Older
B
barrierye 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "core/general-server/op/general_copy_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"

namespace baidu {
namespace paddle_serving {
namespace serving {

using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;

int GeneralCopyOp::inference() {
  // reade request from client
  const std::vector<std::string> pre_node_names = pre_names();
  if (pre_node_names.size() != 1) {
B
barrierye 已提交
40 41 42
    LOG(ERROR) << "This op(" << op_name()
               << ") can only have one predecessor op, but received "
               << pre_node_names.size();
B
barrierye 已提交
43 44 45 46 47
    return -1;
  }
  const std::string pre_name = pre_node_names[0];

  const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
B
barriery 已提交
48 49 50
  uint64_t log_id = input_blob->GetLogId();

  VLOG(2) << "(logid=" << log_id << ") precedent name: " << pre_name;
B
barrierye 已提交
51
  const TensorVector *in = &input_blob->tensor_vector;
B
barriery 已提交
52
  VLOG(2) << "(logid=" << log_id << ") input size: " << in->size();
B
barrierye 已提交
53 54 55 56
  int batch_size = input_blob->GetBatchSize();
  int input_var_num = 0;

  GeneralBlob *res = mutable_data<GeneralBlob>();
B
barriery 已提交
57
  res->SetLogId(log_id);
B
barrierye 已提交
58 59
  TensorVector *out = &res->tensor_vector;

B
barriery 已提交
60
  VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
B
barrierye 已提交
61 62 63
  res->SetBatchSize(batch_size);

  if (!res) {
B
barriery 已提交
64 65
    LOG(ERROR) << "(logid=" << log_id
               << ") Failed get op tls reader object output";
B
barrierye 已提交
66 67 68 69 70
  }

  Timer timeline;
  int64_t start = timeline.TimeStampUS();

B
barriery 已提交
71
  VLOG(2) << "(logid=" << log_id << ") Going to init lod tensor";
B
barrierye 已提交
72 73 74 75 76
  for (int i = 0; i < in->size(); ++i) {
    paddle::PaddleTensor lod_tensor;
    CopyLod(&in->at(i), &lod_tensor);
    lod_tensor.dtype = in->at(i).dtype;
    lod_tensor.name = in->at(i).name;
B
barriery 已提交
77 78
    VLOG(2) << "(logid=" << log_id << ") lod tensor [" << i
            << "].name = " << lod_tensor.name;
B
barrierye 已提交
79 80 81
    out->push_back(lod_tensor);
  }

B
barriery 已提交
82
  VLOG(2) << "(logid=" << log_id << ") pack done.";
B
barrierye 已提交
83 84 85 86 87 88 89 90 91 92 93

  for (int i = 0; i < out->size(); ++i) {
    int64_t *src_ptr = static_cast<int64_t *>(in->at(i).data.data());
    out->at(i).data.Resize(out->at(i).lod[0].back() * sizeof(int64_t));
    out->at(i).shape = {out->at(i).lod[0].back(), 1};
    int64_t *tgt_ptr = static_cast<int64_t *>(out->at(i).data.data());
    for (int j = 0; j < out->at(i).lod[0].back(); ++j) {
      tgt_ptr[j] = src_ptr[j];
    }
  }

B
barriery 已提交
94
  VLOG(2) << "(logid=" << log_id << ") output done.";
B
barrierye 已提交
95 96 97 98 99 100 101

  timeline.Pause();
  int64_t end = timeline.TimeStampUS();
  CopyBlobInfo(input_blob, res);
  AddBlobInfo(res, start);
  AddBlobInfo(res, end);

B
barriery 已提交
102
  VLOG(2) << "(logid=" << log_id << ") read data from client success";
B
barrierye 已提交
103 104 105 106 107 108 109
  return 0;
}

DEFINE_OP(GeneralCopyOp);
}  // namespace serving
}  // namespace paddle_serving
}  // namespace baidu