classify_op.cpp 3.4 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
guru4elephant 已提交
15 16 17 18
#include "examples/demo-serving/op/classify_op.h"
#include "examples/demo-serving/op/reader_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
W
serving  
wangguibao 已提交
19 20 21 22 23 24 25 26 27 28

namespace baidu {
namespace paddle_serving {
namespace serving {

using baidu::paddle_serving::predictor::format::DensePrediction;
using baidu::paddle_serving::predictor::image_classification::ClassifyResponse;
using baidu::paddle_serving::predictor::InferManager;

int ClassifyOp::inference() {
W
wangguibao 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
  const ReaderOutput* reader_out =
      get_depend_argument<ReaderOutput>("image_reader_op");
  if (!reader_out) {
    LOG(ERROR) << "Failed mutable depended argument, op:"
               << "reader_op";
    return -1;
  }

  const TensorVector* in = &reader_out->tensors;

  TensorVector* out = butil::get_object<TensorVector>();
  if (!out) {
    LOG(ERROR) << "Failed get tls output object failed";
    return -1;
  }

W
Wang Guibao 已提交
45 46 47
  if (in->size() != 1) {
    LOG(ERROR) << "Samples should have been packed into a single tensor";
    return -1;
W
wangguibao 已提交
48 49
  }

W
Wang Guibao 已提交
50
  int batch_size = in->at(0).shape[0];
W
wangguibao 已提交
51 52
  // call paddle fluid model for inferencing
  if (InferManager::instance().infer(
W
Wang Guibao 已提交
53
          IMAGE_CLASSIFICATION_MODEL_NAME, in, out, batch_size)) {
W
wangguibao 已提交
54 55 56 57 58
    LOG(ERROR) << "Failed do infer in fluid model: "
               << IMAGE_CLASSIFICATION_MODEL_NAME;
    return -1;
  }

W
Wang Guibao 已提交
59
  if (out->size() != in->size()) {
W
wangguibao 已提交
60 61 62 63 64 65 66
    LOG(ERROR) << "Output size not eq input size: " << in->size()
               << out->size();
    return -1;
  }

  // copy output tensor into response
  ClassifyResponse* res = mutable_data<ClassifyResponse>();
W
Wang Guibao 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  const paddle::PaddleTensor& out_tensor = (*out)[0];

#if 0
  int out_shape_size = out_tensor.shape.size();
  LOG(ERROR) << "out_tensor.shpae";
  for (int i = 0; i < out_shape_size; ++i) {
    LOG(ERROR) << out_tensor.shape[i] << ":";
  }

  if (out_shape_size != 2) {
    return -1;
  }
#endif

  int sample_size = out_tensor.shape[0];
#if 0
  LOG(ERROR) << "Output sample size " << sample_size;
#endif
W
wangguibao 已提交
85 86 87 88 89
  for (uint32_t si = 0; si < sample_size; si++) {
    DensePrediction* ins = res->add_predictions();
    if (!ins) {
      LOG(ERROR) << "Failed append new out tensor";
      return -1;
W
serving  
wangguibao 已提交
90 91
    }

W
wangguibao 已提交
92
    // assign output data
W
Wang Guibao 已提交
93 94 95
    uint32_t data_size = out_tensor.shape[1];
    float* data = reinterpret_cast<float*>(out_tensor.data.data() +
                                           si * sizeof(float) * data_size);
W
wangguibao 已提交
96 97
    for (uint32_t di = 0; di < data_size; ++di) {
      ins->add_categories(data[di]);
W
serving  
wangguibao 已提交
98
    }
W
wangguibao 已提交
99
  }
W
serving  
wangguibao 已提交
100

W
wangguibao 已提交
101 102 103 104 105 106 107
  // release out tensor object resource
  size_t out_size = out->size();
  for (size_t oi = 0; oi < out_size; ++oi) {
    (*out)[oi].shape.clear();
  }
  out->clear();
  butil::return_object<TensorVector>(out);
W
serving  
wangguibao 已提交
108

W
wangguibao 已提交
109
  return 0;
W
serving  
wangguibao 已提交
110 111 112 113
}

DEFINE_OP(ClassifyOp);

W
wangguibao 已提交
114 115 116
}  // namespace serving
}  // namespace paddle_serving
}  // namespace baidu