未验证 提交 68398abc 编写于 作者: C cc 提交者: GitHub

[Inference] zero_copy_tensor supports int8_t (#30053)

* zero_copy_tensor supports int8_t
上级 1b999d2b
...@@ -165,10 +165,14 @@ template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int32_t>( ...@@ -165,10 +165,14 @@ template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int32_t>(
const int32_t *data); const int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<uint8_t>( template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<uint8_t>(
const uint8_t *data); const uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int8_t>(
const int8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<float>(float *data); template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data); template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data); template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data); template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int8_t>(int8_t *data);
template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place, template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const; int *size) const;
...@@ -178,6 +182,9 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::data<int32_t>( ...@@ -178,6 +182,9 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::data<int32_t>(
PaddlePlace *place, int *size) const; PaddlePlace *place, int *size) const;
template PD_INFER_DECL uint8_t *ZeroCopyTensor::data<uint8_t>( template PD_INFER_DECL uint8_t *ZeroCopyTensor::data<uint8_t>(
PaddlePlace *place, int *size) const; PaddlePlace *place, int *size) const;
template PD_INFER_DECL int8_t *ZeroCopyTensor::data<int8_t>(PaddlePlace *place,
int *size) const;
template PD_INFER_DECL float *ZeroCopyTensor::mutable_data<float>( template PD_INFER_DECL float *ZeroCopyTensor::mutable_data<float>(
PaddlePlace place); PaddlePlace place);
template PD_INFER_DECL int64_t *ZeroCopyTensor::mutable_data<int64_t>( template PD_INFER_DECL int64_t *ZeroCopyTensor::mutable_data<int64_t>(
...@@ -186,6 +193,8 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::mutable_data<int32_t>( ...@@ -186,6 +193,8 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::mutable_data<int32_t>(
PaddlePlace place); PaddlePlace place);
template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>( template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
PaddlePlace place); PaddlePlace place);
template PD_INFER_DECL int8_t *ZeroCopyTensor::mutable_data<int8_t>(
PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const { void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -39,6 +39,7 @@ enum PaddleDType { ...@@ -39,6 +39,7 @@ enum PaddleDType {
INT64, INT64,
INT32, INT32,
UINT8, UINT8,
INT8,
// TODO(Superjomn) support more data types if needed. // TODO(Superjomn) support more data types if needed.
}; };
......
...@@ -257,40 +257,29 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -257,40 +257,29 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
EXPECT_GT(size, 0UL); EXPECT_GT(size, 0UL);
EXPECT_EQ(size, ref_size); EXPECT_EQ(size, ref_size);
EXPECT_EQ(out.dtype, ref_out.dtype); EXPECT_EQ(out.dtype, ref_out.dtype);
switch (out.dtype) {
case PaddleDType::INT64: { #define COMPARE(paddle_type, type, func) \
int64_t *pdata = static_cast<int64_t *>(out.data.data()); case paddle_type: { \
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data()); type *pdata = static_cast<type *>(out.data.data()); \
for (size_t j = 0; j < size; ++j) { type *pdata_ref = static_cast<type *>(ref_out.data.data()); \
EXPECT_EQ(pdata_ref[j], pdata[j]); for (size_t j = 0; j < size; ++j) { \
} func(pdata_ref[j], pdata[j]); \
break; } \
} break; \
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
CheckError(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = static_cast<uint8_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
} }
switch (out.dtype) {
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
COMPARE(PaddleDType::FLOAT32, float, CheckError);
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType: Unsupported dtype %d",
static_cast<int>(out.dtype)));
} }
#undef COMPARE
} }
} }
...@@ -306,44 +295,30 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -306,44 +295,30 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
EXPECT_GT(size, 0UL); EXPECT_GT(size, 0UL);
int ref_size = 0; // this is the number of elements not memory size int ref_size = 0; // this is the number of elements not memory size
PaddlePlace place; PaddlePlace place;
switch (out.dtype) {
case PaddleDType::INT64: { #define COMPARE(paddle_type, type, func) \
int64_t *pdata = static_cast<int64_t *>(out.data.data()); case paddle_type: { \
int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size); type *pdata = static_cast<type *>(out.data.data()); \
EXPECT_EQ(size, static_cast<size_t>(ref_size)); type *pdata_ref = ref_out.data<type>(&place, &ref_size); \
for (size_t j = 0; j < size; ++j) { EXPECT_EQ(size, static_cast<size_t>(ref_size)); \
EXPECT_EQ(pdata_ref[j], pdata[j]); for (size_t j = 0; j < size; ++j) { \
} func(pdata_ref[j], pdata[j]); \
break; } \
} break; \
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = ref_out.data<float>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
CheckError(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = ref_out.data<uint8_t>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
} }
switch (out.dtype) {
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
COMPARE(PaddleDType::FLOAT32, float, CheckError);
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType: Unsupported dtype %d",
static_cast<int>(out.dtype)));
} }
#undef COMPARE
} }
} }
......
...@@ -199,6 +199,9 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT ...@@ -199,6 +199,9 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
case PaddleDType::UINT8: case PaddleDType::UINT8:
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data())); tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
break; break;
case PaddleDType::INT8:
tensor.copy_to_cpu<int8_t>(static_cast<int8_t *>(array.mutable_data()));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64, UINT8 and " "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
...@@ -223,6 +226,12 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT ...@@ -223,6 +226,12 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
case PaddleDType::FLOAT32: case PaddleDType::FLOAT32:
tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data())); tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
break; break;
case PaddleDType::UINT8:
tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
break;
case PaddleDType::INT8:
tensor.CopyToCpu(static_cast<int8_t *>(array.mutable_data()));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and " "Unsupported data type. Now only supports INT32, INT64 and "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册