diff --git a/lite/kernels/arm/write_to_array_compute.cc b/lite/kernels/arm/write_to_array_compute.cc index ee68442ffcd0a5c12f3659e0739715c2128ece28..a394c28a698c278dea7ded51ae016b777d2a0971 100644 --- a/lite/kernels/arm/write_to_array_compute.cc +++ b/lite/kernels/arm/write_to_array_compute.cc @@ -20,28 +20,37 @@ namespace lite { namespace kernels { namespace arm { -void WriteToArrayCompute::PrepareForRun() {} - void WriteToArrayCompute::Run() { auto& ctx = this->ctx_->template As(); - auto& param = this->Param(); - + auto& param = this->template Param(); CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; - const auto* x_data = param.X->data(); - int id = param.I->data()[0]; - int id_test = param.I->data()[0]; - if (id >= param.Out->size()) { - for (int i = param.Out->size(); i < id + 1; i++) { - lite::Tensor tmp; - param.Out->push_back(tmp); - } + auto precision_type = param.X->precision(); + +#define SOLVE_TYPE(type__, T) \ + case type__: { \ + const auto* x_data = param.X->data(); \ + int id = param.I->data()[0]; \ + if (id >= param.Out->size()) { \ + for (int i = param.Out->size(); i < id + 1; i++) { \ + lite::Tensor tmp; \ + param.Out->push_back(tmp); \ + } \ + } \ + (*param.Out)[id].Resize(param.X->dims()); \ + auto out_lod = (*param.Out)[id].mutable_lod(); \ + *out_lod = param.X->lod(); \ + auto* o_data = (*param.Out)[id].mutable_data(TARGET(kHost)); \ + int input_size = param.X->numel(); \ + memcpy(o_data, x_data, sizeof(T) * input_size); \ + } break; + + switch (precision_type) { + SOLVE_TYPE(PRECISION(kFloat), float); + SOLVE_TYPE(PRECISION(kInt64), int64_t); + default: + LOG(FATAL) << "Unsupported precision type."; } - (*param.Out)[id].Resize(param.X->dims()); - auto out_lod = (*param.Out)[id].mutable_lod(); - *out_lod = param.X->lod(); - auto* o_data = (*param.Out)[id].mutable_data(TARGET(kHost)); - int input_size = param.X->numel(); - memcpy(o_data, x_data, sizeof(float) * input_size); +#undef SOLVE_TYPE } } // namespace arm @@ -51,11 +60,11 @@ void WriteToArrayCompute::Run() { REGISTER_LITE_KERNEL(write_to_array, kARM, - kFloat, + kAny, kNCHW, paddle::lite::kernels::arm::WriteToArrayCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/write_to_array_compute.h b/lite/kernels/arm/write_to_array_compute.h index c7b7c64c341fb16188aea1a166a93ffe7a78ecb7..8235f9dae3fec639312f12faf08e764e79ab0bd5 100644 --- a/lite/kernels/arm/write_to_array_compute.h +++ b/lite/kernels/arm/write_to_array_compute.h @@ -23,12 +23,10 @@ namespace lite { namespace kernels { namespace arm { -class WriteToArrayCompute : public KernelLite { +class WriteToArrayCompute : public KernelLite { public: using param_t = operators::WriteToArrayParam; - void PrepareForRun() override; - void Run() override; ~WriteToArrayCompute() {}