提交 383c5087 编写于 作者: H huzhiqiang 提交者: GitHub

【BUG FIX】fix the bug that ocr model can not be loaded properly (#2757)

上级 b14e21c3
......@@ -20,28 +20,37 @@ namespace lite {
namespace kernels {
namespace arm {
void WriteToArrayCompute::PrepareForRun() {}
void WriteToArrayCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::WriteToArrayParam>();
auto& param = this->template Param<operators::WriteToArrayParam>();
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
const auto* x_data = param.X->data<float>();
int id = param.I->data<float>()[0];
int id_test = param.I->data<int64_t>()[0];
if (id >= param.Out->size()) {
for (int i = param.Out->size(); i < id + 1; i++) {
lite::Tensor tmp;
param.Out->push_back(tmp);
}
auto precision_type = param.X->precision();
#define SOLVE_TYPE(type__, T) \
case type__: { \
const auto* x_data = param.X->data<T>(); \
int id = param.I->data<int64_t>()[0]; \
if (id >= param.Out->size()) { \
for (int i = param.Out->size(); i < id + 1; i++) { \
lite::Tensor tmp; \
param.Out->push_back(tmp); \
} \
} \
(*param.Out)[id].Resize(param.X->dims()); \
auto out_lod = (*param.Out)[id].mutable_lod(); \
*out_lod = param.X->lod(); \
auto* o_data = (*param.Out)[id].mutable_data<T>(TARGET(kHost)); \
int input_size = param.X->numel(); \
memcpy(o_data, x_data, sizeof(T) * input_size); \
} break;
switch (precision_type) {
SOLVE_TYPE(PRECISION(kFloat), float);
SOLVE_TYPE(PRECISION(kInt64), int64_t);
default:
LOG(FATAL) << "Unsupported precision type.";
}
(*param.Out)[id].Resize(param.X->dims());
auto out_lod = (*param.Out)[id].mutable_lod();
*out_lod = param.X->lod();
auto* o_data = (*param.Out)[id].mutable_data<float>(TARGET(kHost));
int input_size = param.X->numel();
memcpy(o_data, x_data, sizeof(float) * input_size);
#undef SOLVE_TYPE
}
} // namespace arm
......@@ -51,11 +60,11 @@ void WriteToArrayCompute::Run() {
REGISTER_LITE_KERNEL(write_to_array,
kARM,
kFloat,
kAny,
kNCHW,
paddle::lite::kernels::arm::WriteToArrayCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.Finalize();
......@@ -23,12 +23,10 @@ namespace lite {
namespace kernels {
namespace arm {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class WriteToArrayCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::WriteToArrayParam;
void PrepareForRun() override;
void Run() override;
~WriteToArrayCompute() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册