From 6f648cfca0b52a3d913d622d70c94c75f36b4b55 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Thu, 27 Aug 2020 14:46:17 +0800 Subject: [PATCH] [BUG FIX][ARM] Fix the issue that OCR model can not operate (#4205) --- lite/kernels/arm/cast_compute.cc | 7 ++++--- lite/kernels/host/one_hot_compute.cc | 4 ++-- lite/operators/CMakeLists.txt | 2 +- lite/operators/fusion_elementwise_activation_ops.cc | 1 + lite/operators/one_hot_op.cc | 13 +++++++------ lite/operators/one_hot_op_test.cc | 3 +-- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lite/kernels/arm/cast_compute.cc b/lite/kernels/arm/cast_compute.cc index 4be34df0fb..2c01800a00 100644 --- a/lite/kernels/arm/cast_compute.cc +++ b/lite/kernels/arm/cast_compute.cc @@ -31,12 +31,13 @@ void CastCompute::PrepareForRun() {} void CastCompute::Run() { auto& ctx = this->ctx_->template As(); auto& param = this->Param(); - auto input_dims = param.X->dims(); - + if (param.X->precision() == PrecisionType::kFloat) { + param.in_dtype = 5; + } // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; // SIZE_T = 19;UINT8 = 20;INT8 = 21; - if (param.in_dtype == param.out_dtype && param.in_dtype == 2) { + if (param.in_dtype == param.out_dtype && param.in_dtype == 5) { const auto* x_data = param.X->data(); auto* o_data = param.Out->mutable_data(); memcpy(o_data, x_data, sizeof(float) * param.X->numel()); diff --git a/lite/kernels/host/one_hot_compute.cc b/lite/kernels/host/one_hot_compute.cc index 6880de39ae..8f4a83ffa1 100644 --- a/lite/kernels/host/one_hot_compute.cc +++ b/lite/kernels/host/one_hot_compute.cc @@ -24,7 +24,7 @@ void OneHotKernelFunctor(const Tensor* in, Tensor* out, int depth, bool allow_out_of_range = false) { - auto* p_in_data = in->data(); + auto* p_in_data = in->data(); auto numel = in->numel(); auto* p_out_data = out->mutable_data(); memset(p_out_data, 0, out->numel() * sizeof(T)); @@ -77,7 +77,7 @@ REGISTER_LITE_KERNEL( one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), - PRECISION(kAny), + PRECISION(kInt64), DATALAYOUT(kAny))}) .BindInput("depth_tensor", {LiteType::GetTensorTy(TARGET(kHost), diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index db0b8c6971..2099958960 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -177,8 +177,8 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS}) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS}) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS}) -lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host) if (NOT LITE_WITH_X86) + lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host) lite_cc_test(test_fc_op SRCS fc_op_test.cc DEPS fc_op memory X86_DEPS fc_compute_x86 diff --git a/lite/operators/fusion_elementwise_activation_ops.cc b/lite/operators/fusion_elementwise_activation_ops.cc index 3fa79acda3..825d3441e0 100644 --- a/lite/operators/fusion_elementwise_activation_ops.cc +++ b/lite/operators/fusion_elementwise_activation_ops.cc @@ -30,6 +30,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { bool FusionElementwiseActivationOp::InferShapeImpl() const { size_t x_size = param_.X->dims().size(); size_t y_size = param_.Y->dims().size(); + param_.Out->set_lod(param_.X->lod()); if (x_size >= y_size) { param_.Out->Resize(param_.X->dims()); } else { diff --git a/lite/operators/one_hot_op.cc b/lite/operators/one_hot_op.cc index 88b939a0de..fe8db5e86b 100644 --- a/lite/operators/one_hot_op.cc +++ b/lite/operators/one_hot_op.cc @@ -25,11 +25,10 @@ bool OneHotOp::CheckShape() const { } bool OneHotOp::InferShapeImpl() const { + // Set output dims auto out_dims = param_.X->dims(); CHECK_GE(out_dims.size(), 2); - int depth = param_.depth_tensor ? param_.depth - : param_.depth_tensor->data()[0]; - out_dims[out_dims.size() - 1] = depth; + out_dims[out_dims.size() - 1] = param_.depth; param_.Out->Resize(out_dims); param_.Out->set_lod(param_.X->lod()); return true; @@ -41,15 +40,17 @@ bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.X = scope->FindVar(x)->GetMutable(); param_.Out = scope->FindMutableTensor(out); + if (op_desc.HasAttr("depth")) { + param_.depth = op_desc.GetAttr("depth"); + } + if (op_desc.HasInput("depth_tensor") && !op_desc.Input("depth_tensor").empty()) { auto depth_tensor = op_desc.Input("depth_tensor").front(); param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable(); + param_.depth = param_.depth_tensor->data()[0]; } - if (op_desc.HasAttr("depth")) { - param_.depth = op_desc.GetAttr("depth"); - } if (op_desc.HasAttr("allow_out_of_range")) { param_.allow_out_of_range = op_desc.GetAttr("allow_out_of_range"); } diff --git a/lite/operators/one_hot_op_test.cc b/lite/operators/one_hot_op_test.cc index 5daa837886..af5d0ba0fe 100644 --- a/lite/operators/one_hot_op_test.cc +++ b/lite/operators/one_hot_op_test.cc @@ -31,7 +31,7 @@ TEST(one_hot_op_lite, TestHost) { // set data x->Resize(DDim(std::vector({4, 1}))); - auto* x_data = x->mutable_data(); + auto* x_data = x->mutable_data(); x_data[0] = 1; x_data[1] = 1; x_data[2] = 3; @@ -41,7 +41,6 @@ TEST(one_hot_op_lite, TestHost) { cpp::OpDesc desc; desc.SetType("one_hot"); desc.SetInput("X", {"X"}); - desc.SetInput("depth_tensor", {"depth_tensor"}); desc.SetOutput("Out", {"Out"}); desc.SetAttr("depth", static_cast(4)); desc.SetAttr("dtype", static_cast(1)); -- GitLab