general_copy_op.cpp 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
guru4elephant 已提交
15
#include "core/general-server/op/general_copy_op.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"

namespace baidu {
namespace paddle_serving {
namespace serving {

using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;

G
guru4elephant 已提交
36
int GeneralCopyOp::inference() {
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  // reade request from client
  const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
  VLOG(2) << "precedent name: " << pre_name();
  const TensorVector *in = &input_blob->tensor_vector;
  VLOG(2) << "input size: " << in->size();
  int batch_size = input_blob->GetBatchSize();
  int input_var_num = 0;

  GeneralBlob *res = mutable_data<GeneralBlob>();
  TensorVector *out = &res->tensor_vector;

  VLOG(2) << "input batch size: " << batch_size;
  res->SetBatchSize(batch_size);

  if (!res) {
    LOG(ERROR) << "Failed get op tls reader object output";
  }

  Timer timeline;
  int64_t start = timeline.TimeStampUS();

  VLOG(2) << "Going to init lod tensor";
  for (int i = 0; i < in->size(); ++i) {
    paddle::PaddleTensor lod_tensor;
    CopyLod(&in->at(i), &lod_tensor);
    lod_tensor.dtype = in->at(i).dtype;
    lod_tensor.name = in->at(i).name;
    VLOG(2) << "lod tensor [" << i << "].name = " << lod_tensor.name;
    out->push_back(lod_tensor);
  }

  VLOG(2) << "pack done.";

  for (int i = 0; i < out->size(); ++i) {
    int64_t *src_ptr = static_cast<int64_t *>(in->at(i).data.data());
B
barrierye 已提交
72
    out->at(i).data.Resize(out->at(i).lod[0].back() * sizeof(int64_t));
73 74 75 76 77 78 79 80 81 82 83
    out->at(i).shape = {out->at(i).lod[0].back(), 1};
    int64_t *tgt_ptr = static_cast<int64_t *>(out->at(i).data.data());
    for (int j = 0; j < out->at(i).lod[0].back(); ++j) {
      tgt_ptr[j] = src_ptr[j];
    }
  }

  VLOG(2) << "output done.";

  timeline.Pause();
  int64_t end = timeline.TimeStampUS();
G
guru4elephant 已提交
84
  CopyBlobInfo(input_blob, res);
85 86 87 88 89 90
  AddBlobInfo(res, start);
  AddBlobInfo(res, end);

  VLOG(2) << "read data from client success";
  return 0;
}
G
guru4elephant 已提交
91 92

DEFINE_OP(GeneralCopyOp);
93 94 95
}  // namespace serving
}  // namespace paddle_serving
}  // namespace baidu