From 3a04e11d831e7caab5358e51f3627b7f8401a7bf Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Fri, 17 Apr 2020 16:30:09 +0800 Subject: [PATCH] move read_from_array and write_to_array to host (#3428) --- lite/kernels/arm/CMakeLists.txt | 2 -- .../kernels/arm/beam_search_decode_compute.cc | 4 +-- lite/kernels/host/CMakeLists.txt | 2 ++ .../{arm => host}/read_from_array_compute.cc | 31 ++++++++++++------- .../{arm => host}/read_from_array_compute.h | 11 +++---- .../{arm => host}/write_to_array_compute.cc | 28 ++++++++++------- .../{arm => host}/write_to_array_compute.h | 9 +++--- lite/operators/read_from_array_op.cc | 7 +---- lite/operators/write_to_array_op.cc | 8 +---- .../kernels/read_from_array_compute_test.cc | 2 +- .../kernels/write_to_array_compute_test.cc | 2 +- 11 files changed, 52 insertions(+), 54 deletions(-) rename lite/kernels/{arm => host}/read_from_array_compute.cc (57%) rename lite/kernels/{arm => host}/read_from_array_compute.h (79%) rename lite/kernels/{arm => host}/write_to_array_compute.cc (61%) rename lite/kernels/{arm => host}/write_to_array_compute.h (83%) diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 93ceb976d2..d8a772123f 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -92,8 +92,6 @@ add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute. add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/beam_search_decode_compute.cc b/lite/kernels/arm/beam_search_decode_compute.cc index e0d4ae3f13..bbd17d98c6 100644 --- a/lite/kernels/arm/beam_search_decode_compute.cc +++ b/lite/kernels/arm/beam_search_decode_compute.cc @@ -114,14 +114,14 @@ struct BeamSearchDecoder { lod.push_back(source_level_lod); lod.push_back(sentence_level_lod); - *(id_tensor->mutable_lod()) = lod; + id_tensor->set_lod(lod); id_tensor->Resize({static_cast(id_data.size())}); auto id_ptr = id_tensor->mutable_data(); TargetCopy( TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t)); - *(score_tensor->mutable_lod()) = lod; + score_tensor->set_lod(lod); score_tensor->Resize({static_cast(score_data.size())}); auto score_ptr = score_tensor->mutable_data(); TargetCopy(TARGET(kARM), diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index c906c49786..b4935496cd 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -10,3 +10,5 @@ add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEP add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps}) add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_kernel_deps}) add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/arm/read_from_array_compute.cc b/lite/kernels/host/read_from_array_compute.cc similarity index 57% rename from lite/kernels/arm/read_from_array_compute.cc rename to lite/kernels/host/read_from_array_compute.cc index f2aff42f1c..7520fcb8b3 100644 --- a/lite/kernels/arm/read_from_array_compute.cc +++ b/lite/kernels/host/read_from_array_compute.cc @@ -12,17 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/read_from_array_compute.h" -#include "lite/backends/arm/math/funcs.h" +#include "lite/kernels/host/read_from_array_compute.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { void ReadFromArrayCompute::Run() { - auto& ctx = this->ctx_->template As(); - auto& param = this->Param(); + auto& param = this->Param(); CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; int id = param.I->data()[0]; @@ -33,18 +31,27 @@ void ReadFromArrayCompute::Run() { param.Out->CopyDataFrom((*param.X)[id]); } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL(read_from_array, - kARM, + kHost, kAny, - kNCHW, - paddle::lite::kernels::arm::ReadFromArrayCompute, + kAny, + paddle::lite::kernels::host::ReadFromArrayCompute, def) - .BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) - .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindInput("X", + {LiteType::GetTensorListTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindInput("I", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt64), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/read_from_array_compute.h b/lite/kernels/host/read_from_array_compute.h similarity index 79% rename from lite/kernels/arm/read_from_array_compute.h rename to lite/kernels/host/read_from_array_compute.h index b44f46792a..66ba548ff4 100644 --- a/lite/kernels/arm/read_from_array_compute.h +++ b/lite/kernels/host/read_from_array_compute.h @@ -13,20 +13,17 @@ // limitations under the License. #pragma once -#include -#include "lite/backends/arm/math/type_trans.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { -class ReadFromArrayCompute : public KernelLite { +class ReadFromArrayCompute + : public KernelLite { public: - using param_t = operators::ReadFromArrayParam; - void Run() override; ~ReadFromArrayCompute() {} @@ -34,7 +31,7 @@ class ReadFromArrayCompute : public KernelLite { private: }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/kernels/arm/write_to_array_compute.cc b/lite/kernels/host/write_to_array_compute.cc similarity index 61% rename from lite/kernels/arm/write_to_array_compute.cc rename to lite/kernels/host/write_to_array_compute.cc index 6b82f99126..682805e602 100644 --- a/lite/kernels/arm/write_to_array_compute.cc +++ b/lite/kernels/host/write_to_array_compute.cc @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/arm/write_to_array_compute.h" -#include "lite/backends/arm/math/funcs.h" +#include "lite/kernels/host/write_to_array_compute.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { void WriteToArrayCompute::Run() { - auto& ctx = this->ctx_->template As(); auto& param = this->template Param(); CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; @@ -32,19 +30,27 @@ void WriteToArrayCompute::Run() { param.Out->at(id).CopyDataFrom(*param.X); } -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle REGISTER_LITE_KERNEL(write_to_array, - kARM, + kHost, kAny, - kNCHW, - paddle::lite::kernels::arm::WriteToArrayCompute, + kAny, + paddle::lite::kernels::host::WriteToArrayCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindInput("I", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt64), + DATALAYOUT(kAny))}) .BindOutput("Out", - {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kAny))}) + {LiteType::GetTensorListTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/write_to_array_compute.h b/lite/kernels/host/write_to_array_compute.h similarity index 83% rename from lite/kernels/arm/write_to_array_compute.h rename to lite/kernels/host/write_to_array_compute.h index 960c53d4ef..dcb1433d9b 100644 --- a/lite/kernels/arm/write_to_array_compute.h +++ b/lite/kernels/host/write_to_array_compute.h @@ -13,17 +13,16 @@ // limitations under the License. #pragma once -#include -#include "lite/backends/arm/math/type_trans.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { -class WriteToArrayCompute : public KernelLite { +class WriteToArrayCompute + : public KernelLite { public: void Run() override; @@ -32,7 +31,7 @@ class WriteToArrayCompute : public KernelLite { private: }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/operators/read_from_array_op.cc b/lite/operators/read_from_array_op.cc index 495fd752c9..6a264f9a75 100644 --- a/lite/operators/read_from_array_op.cc +++ b/lite/operators/read_from_array_op.cc @@ -26,12 +26,7 @@ bool ReadFromArrayOp::CheckShape() const { return true; } -bool ReadFromArrayOp::InferShapeImpl() const { - int id = param_.I->data()[0]; - auto out_dims = (*param_.X)[id].dims(); - param_.Out->Resize(out_dims); - return true; -} +bool ReadFromArrayOp::InferShapeImpl() const { return true; } bool ReadFromArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/write_to_array_op.cc b/lite/operators/write_to_array_op.cc index d2cf7b4f94..8d2c4d6b5c 100644 --- a/lite/operators/write_to_array_op.cc +++ b/lite/operators/write_to_array_op.cc @@ -26,13 +26,7 @@ bool WriteToArrayOp::CheckShape() const { return true; } -bool WriteToArrayOp::InferShapeImpl() const { - int id = param_.I->data()[0]; - if (param_.Out->size() < id + 1) { - param_.Out->resize(id + 1); - } - return true; -} +bool WriteToArrayOp::InferShapeImpl() const { return true; } bool WriteToArrayOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { auto inputs = opdesc.Input("X").front(); diff --git a/lite/tests/kernels/read_from_array_compute_test.cc b/lite/tests/kernels/read_from_array_compute_test.cc index cd3596ff56..0a4b095b53 100644 --- a/lite/tests/kernels/read_from_array_compute_test.cc +++ b/lite/tests/kernels/read_from_array_compute_test.cc @@ -88,7 +88,7 @@ TEST(ReadFromArray, precision) { Place place; float abs_error = 1e-5; #ifdef LITE_WITH_ARM - place = TARGET(kARM); + place = TARGET(kHost); #else return; #endif diff --git a/lite/tests/kernels/write_to_array_compute_test.cc b/lite/tests/kernels/write_to_array_compute_test.cc index 233403171a..b8110a2e2c 100644 --- a/lite/tests/kernels/write_to_array_compute_test.cc +++ b/lite/tests/kernels/write_to_array_compute_test.cc @@ -85,7 +85,7 @@ TEST(WriteToArray, precision) { Place place; float abs_error = 1e-5; #ifdef LITE_WITH_ARM - place = TARGET(kARM); + place = TARGET(kHost); #else return; #endif -- GitLab