未验证 提交 9d97d56e 编写于 作者: W Wilber 提交者: GitHub

fix fill_constant kernel bug test=develop (#2376)

fill_constant kernel only registered float type, only the float data type is produced, which is obviously a bug.

Now, produce data based on the data type attr.

By the way, fix the cast kernel bug.
上级 9d481778
......@@ -49,10 +49,7 @@ void CastCompute::Run() {
const int32_t* x_data_begin = param.X->data<int32_t>();
const int32_t* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
// std::transform(x_data_begin, x_data_end, out_data, TransOp<int32_t,
// float>);
// todo: the input type actually is float.
memcpy(out_data, x_data_begin, sizeof(float) * param.X->numel());
std::transform(x_data_begin, x_data_end, out_data, TransOp<int32_t, float>);
} else if (param.in_dtype == 20 && param.out_dtype == 5) { // uint8->float32
const unsigned char* x_data_begin = param.X->data<unsigned char>();
const unsigned char* x_data_end = x_data_begin + param.X->numel();
......
......@@ -29,9 +29,25 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
auto data = param.Out->template mutable_data<T>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.Out->template mutable_data<float>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.Out->template mutable_data<int32_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.Out->template mutable_data<int8_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
}
......@@ -54,9 +70,25 @@ class FillConstantBatchLikeCompute
param.out->Resize(odims);
}
auto data = param.out->template mutable_data<T>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
}
......
......@@ -116,10 +116,10 @@ TEST(Cast, precision) {
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
// std::unique_ptr<arena::TestCase> tester1(
// new CastComputeTester(place, "def", 2, 5));
// arena::Arena arena1(std::move(tester1), place, 2e-5);
// arena1.TestPrecision();
std::unique_ptr<arena::TestCase> tester1(
new CastComputeTester(place, "def", 2, 5));
arena::Arena arena1(std::move(tester1), place, 2e-5);
arena1.TestPrecision();
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册