未验证 提交 9611e046 编写于 作者: H huzhiqiang 提交者: GitHub

【BUG FIX】fix the bug that ocr model can not be loaded properly (#2757)

上级 d2fb7f8f
...@@ -20,28 +20,37 @@ namespace lite { ...@@ -20,28 +20,37 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void WriteToArrayCompute::PrepareForRun() {}
void WriteToArrayCompute::Run() { void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::WriteToArrayParam>(); auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
const auto* x_data = param.X->data<float>(); auto precision_type = param.X->precision();
int id = param.I->data<float>()[0];
int id_test = param.I->data<int64_t>()[0]; #define SOLVE_TYPE(type__, T) \
if (id >= param.Out->size()) { case type__: { \
for (int i = param.Out->size(); i < id + 1; i++) { const auto* x_data = param.X->data<T>(); \
lite::Tensor tmp; int id = param.I->data<int64_t>()[0]; \
param.Out->push_back(tmp); if (id >= param.Out->size()) { \
} for (int i = param.Out->size(); i < id + 1; i++) { \
lite::Tensor tmp; \
param.Out->push_back(tmp); \
} \
} \
(*param.Out)[id].Resize(param.X->dims()); \
auto out_lod = (*param.Out)[id].mutable_lod(); \
*out_lod = param.X->lod(); \
auto* o_data = (*param.Out)[id].mutable_data<T>(TARGET(kHost)); \
int input_size = param.X->numel(); \
memcpy(o_data, x_data, sizeof(T) * input_size); \
} break;
switch (precision_type) {
SOLVE_TYPE(PRECISION(kFloat), float);
SOLVE_TYPE(PRECISION(kInt64), int64_t);
default:
LOG(FATAL) << "Unsupported precision type.";
} }
(*param.Out)[id].Resize(param.X->dims()); #undef SOLVE_TYPE
auto out_lod = (*param.Out)[id].mutable_lod();
*out_lod = param.X->lod();
auto* o_data = (*param.Out)[id].mutable_data<float>(TARGET(kHost));
int input_size = param.X->numel();
memcpy(o_data, x_data, sizeof(float) * input_size);
} }
} // namespace arm } // namespace arm
...@@ -51,11 +60,11 @@ void WriteToArrayCompute::Run() { ...@@ -51,11 +60,11 @@ void WriteToArrayCompute::Run() {
REGISTER_LITE_KERNEL(write_to_array, REGISTER_LITE_KERNEL(write_to_array,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::WriteToArrayCompute, paddle::lite::kernels::arm::WriteToArrayCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -23,12 +23,10 @@ namespace lite { ...@@ -23,12 +23,10 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::WriteToArrayParam; using param_t = operators::WriteToArrayParam;
void PrepareForRun() override;
void Run() override; void Run() override;
~WriteToArrayCompute() {} ~WriteToArrayCompute() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册