diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index d8a772123f722135325bf9637199496fb7e91b36..9670149114d0f7cc953129b83215c0e8b7caa56a 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -56,7 +56,6 @@ add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_k add_kernel(crop_compute_arm ARM extra SRCS crop_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(power_compute_arm ARM extra SRCS power_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(norm_compute_arm ARM extra SRCS norm_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(assign_compute_arm ARM extra SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) ## 3. extra kernels add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/assign_compute.cc b/lite/kernels/arm/assign_compute.cc deleted file mode 100644 index 8398634bb365c628b64e1ddd2b14984d5f2acb59..0000000000000000000000000000000000000000 --- a/lite/kernels/arm/assign_compute.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/arm/assign_compute.h" -#include -#include "lite/backends/arm/math/funcs.h" -#include "lite/core/op_registry.h" -#include "lite/core/type_system.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace arm { - -void AssignCompute::Run() { - auto& param = Param(); - param.Out->CopyDataFrom(*param.X); -} - -} // namespace arm -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL( - assign, kARM, kAny, kNCHW, paddle::lite::kernels::arm::AssignCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) - .Finalize(); diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index b4935496cd63540095fa95500a2099308ed95758..a0085e6d6c5e65667e96393c42a1608c8dd24d0c 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -12,3 +12,4 @@ add_kernel(logical_compute_host Host extra SRCS logical_compute.cc DEPS ${lite_k add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps}) add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps}) add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/host/assign_compute.cc b/lite/kernels/host/assign_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..e496ffbd1d9a6362d730117be949cbdab83ec62a --- /dev/null +++ b/lite/kernels/host/assign_compute.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/assign_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +void AssignCompute::Run() { + auto& param = Param(); + if (param.X != nullptr) { + param.Out->CopyDataFrom(*param.X); + } else if (param.X_array != nullptr) { + auto x_array = param.X_array; + auto out_array = param.Out_array; + out_array->resize(x_array->size()); + for (size_t i = 0; i < x_array->size(); i++) { + out_array->at(i).CopyDataFrom(x_array->at(i)); + } + } else { + LOG(FATAL) << "x or x_array of assign must be set."; + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + assign, kHost, kAny, kAny, paddle::lite::kernels::host::AssignCompute, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/arm/assign_compute.h b/lite/kernels/host/assign_compute.h similarity index 84% rename from lite/kernels/arm/assign_compute.h rename to lite/kernels/host/assign_compute.h index e144486b5970b4e4e82c58148e33ccc5b2d37ff4..01b8e5a4bc2b36699b0687a908c92160bca54c14 100644 --- a/lite/kernels/arm/assign_compute.h +++ b/lite/kernels/host/assign_compute.h @@ -15,14 +15,15 @@ #pragma once #include #include "lite/core/kernel.h" -#include "lite/operators/assign_op.h" +#include "lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { -namespace arm { +namespace host { -class AssignCompute : public KernelLite { +class AssignCompute + : public KernelLite { public: using param_t = operators::AssignParam; @@ -31,7 +32,7 @@ class AssignCompute : public KernelLite { virtual ~AssignCompute() = default; }; -} // namespace arm +} // namespace host } // namespace kernels } // namespace lite } // namespace paddle diff --git a/lite/operators/assign_op.cc b/lite/operators/assign_op.cc index 25e8539d2e55a07a19d707713489d86f84aa64db..fe1e8db1f954af38041621d1d676cf16833357da 100644 --- a/lite/operators/assign_op.cc +++ b/lite/operators/assign_op.cc @@ -27,20 +27,33 @@ bool AssignOpLite::CheckShape() const { } bool AssignOpLite::InferShapeImpl() const { - lite::DDim input_dims; - input_dims = param_.X->dims(); - param_.Out->Resize(lite::DDim(input_dims)); + if (param_.X != nullptr) { + param_.Out->Resize(param_.X->dims()); + } else if (param_.X_array != nullptr) { + param_.Out_array->resize(param_.Out_array->size()); + } else { + LOG(FATAL) << "x or x_array must be set."; + } return true; } // TODO(Superjomn) replace framework::OpDesc with a lite one. bool AssignOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { - auto input = op_desc.Input("X").front(); - auto out = op_desc.Output("Out").front(); + auto x_name = op_desc.Input("X").front(); + auto out_name = op_desc.Output("Out").front(); - param_.X = scope->FindVar(input)->GetMutable(); - CHECK(scope->FindVar(out)); - param_.Out = scope->FindVar(out)->GetMutable(); + auto x_var = scope->FindVar(x_name); + if (x_var->IsType()) { + param_.X = scope->FindTensor(x_name); + param_.Out = scope->FindMutableTensor(out_name); + } else if (x_var->IsType>()) { + param_.X_array = x_var->GetMutable>(); + param_.Out_array = + scope->FindVar(out_name)->GetMutable>(); + } else { + LOG(FATAL) << "X type for assign op is unsupported. Expected type is " + "tensor or tensor_array."; + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index c5e595672eb580a282ee101294a48a053d8d9c02..05bcdd54cdc42b4cc874db2157579cc1cc9a65cb 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1279,8 +1279,13 @@ struct GatherParam : ParamBase { /// ----------------------- assign operators ----------------------- struct AssignParam : ParamBase { - const lite::Tensor* X{}; - lite::Tensor* Out{}; + // for tensor + const lite::Tensor* X{nullptr}; + lite::Tensor* Out{nullptr}; + + // for tensor_array + const std::vector* X_array{nullptr}; + std::vector* Out_array{nullptr}; }; /// ----------------------- roi_align operators ----------------------- diff --git a/lite/tests/kernels/assign_compute_test.cc b/lite/tests/kernels/assign_compute_test.cc index d757b906083f1ae63ea94ea5e092f1eb3e77a732..07bc9cf6ed08d9b62d5d8025defd2d44cd24fc46 100644 --- a/lite/tests/kernels/assign_compute_test.cc +++ b/lite/tests/kernels/assign_compute_test.cc @@ -69,7 +69,7 @@ void TestAssign(const Place& place) { TEST(Assign, precision) { Place place; #ifdef LITE_WITH_ARM - place = TARGET(kARM); + place = TARGET(kHost); #else return; #endif