From 1283833395645c8d52d7b603c2e8bc3092d4ef12 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Mon, 11 Mar 2019 15:18:32 +0800 Subject: [PATCH] zero_copy tensor support INT32 test=develop --- .../fluid/inference/api/details/zero_copy_tensor.cc | 5 +++++ paddle/fluid/inference/tests/api/tester_helper.h | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index cf02901d963..9a40cf4b60a 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -126,15 +126,20 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { } template void ZeroCopyTensor::copy_from_cpu(const float *data); template void ZeroCopyTensor::copy_from_cpu(const int64_t *data); +template void ZeroCopyTensor::copy_from_cpu(const int32_t *data); template void ZeroCopyTensor::copy_to_cpu(float *data); template void ZeroCopyTensor::copy_to_cpu(int64_t *data); +template void ZeroCopyTensor::copy_to_cpu(int32_t *data); template float *ZeroCopyTensor::data(PaddlePlace *place, int *size) const; template int64_t *ZeroCopyTensor::data(PaddlePlace *place, int *size) const; +template int32_t *ZeroCopyTensor::data(PaddlePlace *place, + int *size) const; template float *ZeroCopyTensor::mutable_data(PaddlePlace place); template int64_t *ZeroCopyTensor::mutable_data(PaddlePlace place); +template int32_t *ZeroCopyTensor::mutable_data(PaddlePlace place); void *ZeroCopyTensor::FindTensor() const { PADDLE_ENFORCE(!name_.empty(), diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 915ea772ed0..a4881afe58a 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -141,6 +141,15 @@ void CompareResult(const std::vector &outputs, } break; } + case PaddleDType::INT32: { + int32_t *pdata = static_cast(out.data.data()); + int32_t *pdata_ref = ref_out.data(&place, &ref_size); + EXPECT_EQ(size, ref_size); + for (size_t j = 0; j < size; ++j) { + EXPECT_EQ(pdata_ref[j], pdata[j]); + } + break; + } } } } @@ -253,6 +262,8 @@ void ConvertPaddleTensorToZeroCopyTensor( ZeroCopyTensorAssignData(tensor.get(), input.data); } else if (input.dtype == PaddleDType::FLOAT32) { ZeroCopyTensorAssignData(tensor.get(), input.data); + } else if (input.dtype == PaddleDType::INT32) { + ZeroCopyTensorAssignData(tensor.get(), input.data); } else { LOG(ERROR) << "unsupported feed type " << input.dtype; } -- GitLab