未验证 提交 68398abc 编写于 作者: C cc 提交者: GitHub

[Inference] zero_copy_tensor supports int8_t (#30053)

* zero_copy_tensor supports int8_t
上级 1b999d2b
......@@ -165,10 +165,14 @@ template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int32_t>(
const int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<uint8_t>(
const uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_from_cpu<int8_t>(
const int8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<float>(float *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int64_t>(int64_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int32_t>(int32_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<uint8_t>(uint8_t *data);
template PD_INFER_DECL void ZeroCopyTensor::copy_to_cpu<int8_t>(int8_t *data);
template PD_INFER_DECL float *ZeroCopyTensor::data<float>(PaddlePlace *place,
int *size) const;
......@@ -178,6 +182,9 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::data<int32_t>(
PaddlePlace *place, int *size) const;
template PD_INFER_DECL uint8_t *ZeroCopyTensor::data<uint8_t>(
PaddlePlace *place, int *size) const;
template PD_INFER_DECL int8_t *ZeroCopyTensor::data<int8_t>(PaddlePlace *place,
int *size) const;
template PD_INFER_DECL float *ZeroCopyTensor::mutable_data<float>(
PaddlePlace place);
template PD_INFER_DECL int64_t *ZeroCopyTensor::mutable_data<int64_t>(
......@@ -186,6 +193,8 @@ template PD_INFER_DECL int32_t *ZeroCopyTensor::mutable_data<int32_t>(
PaddlePlace place);
template PD_INFER_DECL uint8_t *ZeroCopyTensor::mutable_data<uint8_t>(
PaddlePlace place);
template PD_INFER_DECL int8_t *ZeroCopyTensor::mutable_data<int8_t>(
PaddlePlace place);
void *ZeroCopyTensor::FindTensor() const {
PADDLE_ENFORCE_EQ(
......
......@@ -39,6 +39,7 @@ enum PaddleDType {
INT64,
INT32,
UINT8,
INT8,
// TODO(Superjomn) support more data types if needed.
};
......
......@@ -257,40 +257,29 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
EXPECT_GT(size, 0UL);
EXPECT_EQ(size, ref_size);
EXPECT_EQ(out.dtype, ref_out.dtype);
switch (out.dtype) {
case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
CheckError(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = static_cast<uint8_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
#define COMPARE(paddle_type, type, func) \
case paddle_type: { \
type *pdata = static_cast<type *>(out.data.data()); \
type *pdata_ref = static_cast<type *>(ref_out.data.data()); \
for (size_t j = 0; j < size; ++j) { \
func(pdata_ref[j], pdata[j]); \
} \
break; \
}
switch (out.dtype) {
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
COMPARE(PaddleDType::FLOAT32, float, CheckError);
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType: Unsupported dtype %d",
static_cast<int>(out.dtype)));
}
#undef COMPARE
}
}
......@@ -306,44 +295,30 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
EXPECT_GT(size, 0UL);
int ref_size = 0; // this is the number of elements not memory size
PaddlePlace place;
switch (out.dtype) {
case PaddleDType::INT64: {
int64_t *pdata = static_cast<int64_t *>(out.data.data());
int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::FLOAT32: {
float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = ref_out.data<float>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
CheckError(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
case PaddleDType::UINT8: {
uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
uint8_t *pdata_ref = ref_out.data<uint8_t>(&place, &ref_size);
EXPECT_EQ(size, static_cast<size_t>(ref_size));
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
#define COMPARE(paddle_type, type, func) \
case paddle_type: { \
type *pdata = static_cast<type *>(out.data.data()); \
type *pdata_ref = ref_out.data<type>(&place, &ref_size); \
EXPECT_EQ(size, static_cast<size_t>(ref_size)); \
for (size_t j = 0; j < size; ++j) { \
func(pdata_ref[j], pdata[j]); \
} \
break; \
}
switch (out.dtype) {
COMPARE(PaddleDType::INT64, int64_t, EXPECT_EQ);
COMPARE(PaddleDType::FLOAT32, float, CheckError);
COMPARE(PaddleDType::INT32, int32_t, EXPECT_EQ);
COMPARE(PaddleDType::UINT8, uint8_t, EXPECT_EQ);
COMPARE(PaddleDType::INT8, int8_t, EXPECT_EQ);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType: Unsupported dtype %d",
static_cast<int>(out.dtype)));
}
#undef COMPARE
}
}
......
......@@ -199,6 +199,9 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
case PaddleDType::UINT8:
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
break;
case PaddleDType::INT8:
tensor.copy_to_cpu<int8_t>(static_cast<int8_t *>(array.mutable_data()));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64, UINT8 and "
......@@ -223,6 +226,12 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
case PaddleDType::FLOAT32:
tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
break;
case PaddleDType::UINT8:
tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
break;
case PaddleDType::INT8:
tensor.CopyToCpu(static_cast<int8_t *>(array.mutable_data()));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册