未验证 提交 6f648cfc 编写于 作者: H huzhiqiang 提交者: GitHub

[BUG FIX][ARM] Fix the issue that OCR model can not operate (#4205)

上级 e4128757
......@@ -31,12 +31,13 @@ void CastCompute::PrepareForRun() {}
void CastCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::CastParam>();
auto input_dims = param.X->dims();
if (param.X->precision() == PrecisionType::kFloat) {
param.in_dtype = 5;
}
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
if (param.in_dtype == param.out_dtype && param.in_dtype == 2) {
if (param.in_dtype == param.out_dtype && param.in_dtype == 5) {
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
......
......@@ -24,7 +24,7 @@ void OneHotKernelFunctor(const Tensor* in,
Tensor* out,
int depth,
bool allow_out_of_range = false) {
auto* p_in_data = in->data<T>();
auto* p_in_data = in->data<int64_t>();
auto numel = in->numel();
auto* p_out_data = out->mutable_data<T>();
memset(p_out_data, 0, out->numel() * sizeof(T));
......@@ -77,7 +77,7 @@ REGISTER_LITE_KERNEL(
one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
PRECISION(kInt64),
DATALAYOUT(kAny))})
.BindInput("depth_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
......
......@@ -177,8 +177,8 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
if (NOT LITE_WITH_X86)
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory
X86_DEPS fc_compute_x86
......
......@@ -30,6 +30,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
bool FusionElementwiseActivationOp::InferShapeImpl() const {
size_t x_size = param_.X->dims().size();
size_t y_size = param_.Y->dims().size();
param_.Out->set_lod(param_.X->lod());
if (x_size >= y_size) {
param_.Out->Resize(param_.X->dims());
} else {
......
......@@ -25,11 +25,10 @@ bool OneHotOp::CheckShape() const {
}
bool OneHotOp::InferShapeImpl() const {
// Set output dims
auto out_dims = param_.X->dims();
CHECK_GE(out_dims.size(), 2);
int depth = param_.depth_tensor ? param_.depth
: param_.depth_tensor->data<int32_t>()[0];
out_dims[out_dims.size() - 1] = depth;
out_dims[out_dims.size() - 1] = param_.depth;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true;
......@@ -41,15 +40,17 @@ bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.Out = scope->FindMutableTensor(out);
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasInput("depth_tensor") &&
!op_desc.Input("depth_tensor").empty()) {
auto depth_tensor = op_desc.Input("depth_tensor").front();
param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>();
param_.depth = param_.depth_tensor->data<int32_t>()[0];
}
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range");
}
......
......@@ -31,7 +31,7 @@ TEST(one_hot_op_lite, TestHost) {
// set data
x->Resize(DDim(std::vector<int64_t>({4, 1})));
auto* x_data = x->mutable_data<int32_t>();
auto* x_data = x->mutable_data<int64_t>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
......@@ -41,7 +41,6 @@ TEST(one_hot_op_lite, TestHost) {
cpp::OpDesc desc;
desc.SetType("one_hot");
desc.SetInput("X", {"X"});
desc.SetInput("depth_tensor", {"depth_tensor"});
desc.SetOutput("Out", {"Out"});
desc.SetAttr("depth", static_cast<int>(4));
desc.SetAttr("dtype", static_cast<int>(1));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册