未验证 提交 6f648cfc 编写于 作者: H huzhiqiang 提交者: GitHub

[BUG FIX][ARM] Fix the issue that OCR model can not operate (#4205)

上级 e4128757
...@@ -31,12 +31,13 @@ void CastCompute::PrepareForRun() {} ...@@ -31,12 +31,13 @@ void CastCompute::PrepareForRun() {}
void CastCompute::Run() { void CastCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::CastParam>(); auto& param = this->Param<operators::CastParam>();
auto input_dims = param.X->dims(); auto input_dims = param.X->dims();
if (param.X->precision() == PrecisionType::kFloat) {
param.in_dtype = 5;
}
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21; // SIZE_T = 19;UINT8 = 20;INT8 = 21;
if (param.in_dtype == param.out_dtype && param.in_dtype == 2) { if (param.in_dtype == param.out_dtype && param.in_dtype == 5) {
const auto* x_data = param.X->data<float>(); const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>(); auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel()); memcpy(o_data, x_data, sizeof(float) * param.X->numel());
......
...@@ -24,7 +24,7 @@ void OneHotKernelFunctor(const Tensor* in, ...@@ -24,7 +24,7 @@ void OneHotKernelFunctor(const Tensor* in,
Tensor* out, Tensor* out,
int depth, int depth,
bool allow_out_of_range = false) { bool allow_out_of_range = false) {
auto* p_in_data = in->data<T>(); auto* p_in_data = in->data<int64_t>();
auto numel = in->numel(); auto numel = in->numel();
auto* p_out_data = out->mutable_data<T>(); auto* p_out_data = out->mutable_data<T>();
memset(p_out_data, 0, out->numel() * sizeof(T)); memset(p_out_data, 0, out->numel() * sizeof(T));
...@@ -77,7 +77,7 @@ REGISTER_LITE_KERNEL( ...@@ -77,7 +77,7 @@ REGISTER_LITE_KERNEL(
one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def) one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def)
.BindInput("X", .BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny), PRECISION(kInt64),
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.BindInput("depth_tensor", .BindInput("depth_tensor",
{LiteType::GetTensorTy(TARGET(kHost), {LiteType::GetTensorTy(TARGET(kHost),
......
...@@ -177,8 +177,8 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) ...@@ -177,8 +177,8 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS}) add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS}) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS}) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
if (NOT LITE_WITH_X86) if (NOT LITE_WITH_X86)
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
lite_cc_test(test_fc_op SRCS fc_op_test.cc lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory DEPS fc_op memory
X86_DEPS fc_compute_x86 X86_DEPS fc_compute_x86
......
...@@ -30,6 +30,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { ...@@ -30,6 +30,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
bool FusionElementwiseActivationOp::InferShapeImpl() const { bool FusionElementwiseActivationOp::InferShapeImpl() const {
size_t x_size = param_.X->dims().size(); size_t x_size = param_.X->dims().size();
size_t y_size = param_.Y->dims().size(); size_t y_size = param_.Y->dims().size();
param_.Out->set_lod(param_.X->lod());
if (x_size >= y_size) { if (x_size >= y_size) {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
} else { } else {
......
...@@ -25,11 +25,10 @@ bool OneHotOp::CheckShape() const { ...@@ -25,11 +25,10 @@ bool OneHotOp::CheckShape() const {
} }
bool OneHotOp::InferShapeImpl() const { bool OneHotOp::InferShapeImpl() const {
// Set output dims
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
CHECK_GE(out_dims.size(), 2); CHECK_GE(out_dims.size(), 2);
int depth = param_.depth_tensor ? param_.depth out_dims[out_dims.size() - 1] = param_.depth;
: param_.depth_tensor->data<int32_t>()[0];
out_dims[out_dims.size() - 1] = depth;
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod()); param_.Out->set_lod(param_.X->lod());
return true; return true;
...@@ -41,15 +40,17 @@ bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -41,15 +40,17 @@ bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.X = scope->FindVar(x)->GetMutable<Tensor>(); param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.Out = scope->FindMutableTensor(out); param_.Out = scope->FindMutableTensor(out);
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasInput("depth_tensor") && if (op_desc.HasInput("depth_tensor") &&
!op_desc.Input("depth_tensor").empty()) { !op_desc.Input("depth_tensor").empty()) {
auto depth_tensor = op_desc.Input("depth_tensor").front(); auto depth_tensor = op_desc.Input("depth_tensor").front();
param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>(); param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>();
param_.depth = param_.depth_tensor->data<int32_t>()[0];
} }
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasAttr("allow_out_of_range")) { if (op_desc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range"); param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range");
} }
......
...@@ -31,7 +31,7 @@ TEST(one_hot_op_lite, TestHost) { ...@@ -31,7 +31,7 @@ TEST(one_hot_op_lite, TestHost) {
// set data // set data
x->Resize(DDim(std::vector<int64_t>({4, 1}))); x->Resize(DDim(std::vector<int64_t>({4, 1})));
auto* x_data = x->mutable_data<int32_t>(); auto* x_data = x->mutable_data<int64_t>();
x_data[0] = 1; x_data[0] = 1;
x_data[1] = 1; x_data[1] = 1;
x_data[2] = 3; x_data[2] = 3;
...@@ -41,7 +41,6 @@ TEST(one_hot_op_lite, TestHost) { ...@@ -41,7 +41,6 @@ TEST(one_hot_op_lite, TestHost) {
cpp::OpDesc desc; cpp::OpDesc desc;
desc.SetType("one_hot"); desc.SetType("one_hot");
desc.SetInput("X", {"X"}); desc.SetInput("X", {"X"});
desc.SetInput("depth_tensor", {"depth_tensor"});
desc.SetOutput("Out", {"Out"}); desc.SetOutput("Out", {"Out"});
desc.SetAttr("depth", static_cast<int>(4)); desc.SetAttr("depth", static_cast<int>(4));
desc.SetAttr("dtype", static_cast<int>(1)); desc.SetAttr("dtype", static_cast<int>(1));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册