提交 717bbc08 编写于 作者: L lidanqing 提交者: ceci3

Add INT32 support. INT32 in last switch case

test=develop
上级 bcd7b993
......@@ -243,6 +243,8 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input_ptr = input.mutable_data<int64_t>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::INT32) {
input_ptr = input.mutable_data<int32_t>(ddim, place_);
} else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false;
......@@ -326,8 +328,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else if (type == framework::proto::VarType::INT32) {
GetFetchOne<int32_t>(fetch, output);
output->dtype = PaddleDType::INT32;
} else {
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
LOG(ERROR) << "unknown type, only support float32, int64 and int32 now.";
}
}
return true;
......
......@@ -28,6 +28,8 @@ int PaddleDtypeSize(PaddleDType dtype) {
return sizeof(float);
case PaddleDType::INT64:
return sizeof(int64_t);
case PaddleDType::INT32:
return sizeof(int32_t);
default:
assert(false);
return -1;
......
......@@ -203,6 +203,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input_ptr = input.mutable_data<int64_t>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::FLOAT32) {
input_ptr = input.mutable_data<float>(ddim, place_);
} else if (inputs[i].dtype == PaddleDType::INT32) {
input_ptr = input.mutable_data<int32_t>(ddim, place_);
} else {
LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
return false;
......@@ -281,8 +283,11 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
} else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64;
} else if (type == framework::DataTypeTrait<int32_t>::DataType) {
GetFetchOne<int32_t>(fetch, output);
output->dtype = PaddleDType::INT32;
} else {
LOG(ERROR) << "unknown type, only support float32 and int64 now.";
LOG(ERROR) << "unknown type, only support float32, int64 and int32 now.";
}
}
return true;
......
......@@ -42,6 +42,9 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
} else if (t->type() == framework::proto::VarType::FP32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else if (t->type() == framework::proto::VarType::INT32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else {
LOG(FATAL) << "unsupported type.";
}
......
......@@ -88,13 +88,20 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
}
break;
}
case PaddleDType::FLOAT32:
case PaddleDType::FLOAT32: {
for (size_t i = 0; i < numel; ++i) {
CHECK_LT(
fabs(static_cast<float*>(output.data.data())[i] - refer.data[i]),
1e-5);
}
break;
}
case PaddleDType::INT32: {
for (size_t i = 0; i < numel; ++i) {
CHECK_EQ(static_cast<int32_t*>(output.data.data())[i], refer.data[i]);
}
break;
}
}
}
......@@ -113,11 +120,18 @@ static std::string SummaryTensor(const PaddleTensor& tensor) {
}
break;
}
case PaddleDType::FLOAT32:
case PaddleDType::FLOAT32: {
for (int i = 0; i < std::min(num_elems, 10); i++) {
ss << static_cast<float*>(tensor.data.data())[i] << " ";
}
break;
}
case PaddleDType::INT32: {
for (int i = 0; i < std::min(num_elems, 10); i++) {
ss << static_cast<int32_t*>(tensor.data.data())[i] << " ";
}
break;
}
}
return ss.str();
}
......
......@@ -202,6 +202,9 @@ static std::string DescribeTensor(const PaddleTensor &tensor,
case PaddleDType::INT64:
os << "int64";
break;
case PaddleDType::INT32:
os << "int32";
break;
default:
os << "unset";
}
......
......@@ -36,6 +36,7 @@ namespace paddle {
enum PaddleDType {
FLOAT32,
INT64,
INT32,
// TODO(Superjomn) support more data types if needed.
};
......
......@@ -25,7 +25,6 @@
#ifdef WITH_GPERFTOOLS
#include <gperftools/profiler.h>
#endif
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
......@@ -97,6 +96,14 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
break;
}
case PaddleDType::INT32: {
int32_t *pdata = static_cast<int32_t *>(out.data.data());
int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) {
EXPECT_EQ(pdata_ref[j], pdata[j]);
}
break;
}
}
}
}
......
......@@ -65,7 +65,8 @@ void BindInferenceApi(py::module *m) {
void BindPaddleDType(py::module *m) {
py::enum_<PaddleDType>(*m, "PaddleDType")
.value("FLOAT32", PaddleDType::FLOAT32)
.value("INT64", PaddleDType::INT64);
.value("INT64", PaddleDType::INT64)
.value("INT32", PaddleDType::INT32);
}
void BindPaddleBuf(py::module *m) {
......@@ -103,6 +104,11 @@ void BindPaddleBuf(py::module *m) {
int64_t *data = static_cast<int64_t *>(self.data());
return {data, data + self.length() / sizeof(*data)};
})
.def("int32_data",
[](PaddleBuf &self) -> std::vector<int32_t> {
int32_t *data = static_cast<int32_t *>(self.data());
return {data, data + self.length() / sizeof(*data)};
})
.def("length", &PaddleBuf::length);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册