diff --git a/lite/kernels/arm/fill_constant_compute.cc b/lite/kernels/arm/fill_constant_compute.cc index 0b1911abf4fe553b670cf21dbb519c24dc08f184..05d43dddec47a303a89a2d48b3fb91ff45e6e2c0 100644 --- a/lite/kernels/arm/fill_constant_compute.cc +++ b/lite/kernels/arm/fill_constant_compute.cc @@ -25,6 +25,38 @@ class FillConstantCompute : public KernelLite { public: using param_t = operators::FillConstantParam; + inline DDimLite GetShape(const param_t& param) { + // 1. shape is a Tensor + if (param.shape_tensor != nullptr) { + auto* shape_tensor = param.shape_tensor; + auto* shape_data = shape_tensor->data(); + auto vec_shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + return DDimLite(vec_shape); + } + + // 2. shape is a list/tuple containing Tensor + auto shape_tensor_list = param.shape_tensor_list; + if (shape_tensor_list.size() > 0) { + std::vector vec_shape; + for (size_t i = 0; i < shape_tensor_list.size(); ++i) { + auto tensor = shape_tensor_list[i]; + vec_shape.push_back(*tensor->data()); + } + return DDimLite(vec_shape); + } + + // 3. shape is a list/tuple without containing Tensor + auto vec_shape = param.shape; + return DDimLite(vec_shape); + } + + void PrepareForRun() override { + auto& param = *param_.get_mutable(); + auto outdims = GetShape(param); + param.Out->Resize(outdims); + } + void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); @@ -107,6 +139,11 @@ REGISTER_LITE_KERNEL(fill_constant, kNCHW, paddle::lite::kernels::arm::FillConstantCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("ShapeTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); REGISTER_LITE_KERNEL( diff --git a/lite/kernels/x86/fill_constant_compute.cc b/lite/kernels/x86/fill_constant_compute.cc index 1eb76332ccc21b0c5196d71b9246ed8b144a6593..dace1e90258a93aa5c8e89d1d9369adf39416659 100644 --- a/lite/kernels/x86/fill_constant_compute.cc +++ b/lite/kernels/x86/fill_constant_compute.cc @@ -29,6 +29,38 @@ class FillConstantCompute : public KernelLite { public: using param_t = operators::FillConstantParam; + inline DDimLite GetShape(const param_t& param) { + // 1. shape is a Tensor + if (param.shape_tensor != nullptr) { + auto* shape_tensor = param.shape_tensor; + auto* shape_data = shape_tensor->data(); + auto vec_shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + return DDimLite(vec_shape); + } + + // 2. shape is a list/tuple containing Tensor + auto shape_tensor_list = param.shape_tensor_list; + if (shape_tensor_list.size() > 0) { + std::vector vec_shape; + for (size_t i = 0; i < shape_tensor_list.size(); ++i) { + auto tensor = shape_tensor_list[i]; + vec_shape.push_back(*tensor->data()); + } + return DDimLite(vec_shape); + } + + // 3. shape is a list/tuple without containing Tensor + auto vec_shape = param.shape; + return DDimLite(vec_shape); + } + + void PrepareForRun() override { + auto& param = *param_.get_mutable(); + auto outdims = GetShape(param); + param.Out->Resize(outdims); + } + void Run() override { auto& param = *param_.get_mutable(); auto& context = ctx_->As(); @@ -55,5 +87,9 @@ REGISTER_LITE_KERNEL(fill_constant, kNCHW, paddle::lite::kernels::x86::FillConstantCompute, def) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("ShapeTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/lite/operators/fill_constant_op.cc b/lite/operators/fill_constant_op.cc index 6e4bee4da87095245d90c6af5db98d2e95d7d3d8..acf9701cbd750e83ba51f25c66064c2dd7781db6 100644 --- a/lite/operators/fill_constant_op.cc +++ b/lite/operators/fill_constant_op.cc @@ -29,6 +29,12 @@ class FillConstantOp : public OpLite { } bool InferShape() const override { + lite::Tensor* shape_tensor_ = param_.shape_tensor; + if (param_.shape.empty() && shape_tensor_ != nullptr) { + param_.Out->Resize(shape_tensor_->dims()); + return true; + } + param_.Out->Resize(param_.shape); return true; } @@ -41,6 +47,23 @@ class FillConstantOp : public OpLite { param_.shape = opdesc.GetAttr>("shape"); param_.value = opdesc.GetAttr("value"); param_.force_cpu = opdesc.GetAttr("force_cpu"); + param_.shape_tensor = nullptr; + param_.shape_tensor_list = {}; + + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "ShapeTensor") != input_arg_names.end()) { + auto args = opdesc.Input("ShapeTensor"); + auto* var = scope->FindVar(args.front()); + param_.shape_tensor = var->GetMutable(); + } + if (opdesc.HasAttr("ShapeTensorList")) { + auto args = opdesc.Input("ShapeTensorList"); + auto* var = scope->FindVar(args.front()); + param_.shape_tensor_list = + *(var->GetMutable>()); + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index d47543961529b0147768ca11f2df70f8b3b66526..4f0c707484f6a66148dabc80968665c1d38de445 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -408,6 +408,9 @@ struct MeanGradParam { struct FillConstantParam { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; + lite::Tensor* shape_tensor; + std::vector shape_tensor_list{}; + float value{0.0f}; // useless for x86, keep it for compatibility bool force_cpu{false}; diff --git a/lite/tests/kernels/fill_constant_compute_test.cc b/lite/tests/kernels/fill_constant_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e211582b04d279b535f0d3873a9b0c537e375a60 --- /dev/null +++ b/lite/tests/kernels/fill_constant_compute_test.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class FillConstantComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string out_ = "out"; + int dtype_{static_cast(VarDescAPI::VarDataType::FP32)}; + std::vector shape_{}; + std::string shape_tensor_ = "ShapeTensor"; + std::vector shape_tensor_list_; + bool is_use_shape_tensor_{false}; + bool is_use_shape_tensor_list_{false}; + + float value_{0.0f}; + // useless for x86, keep it for compatibility + bool force_cpu_{false}; + // DDim shape_tensor_data{{5, 3}}; + std::vector shape_tensor_data; + DDim shape_test{{1, 2}}; + + public: + FillConstantComputeTester(const Place& place, + const std::string& alias, + std::vector shape, + const bool is_use_shape_tensor, + const bool is_use_shape_tensor_list, + float value, + bool force_cpu) + : TestCase(place, alias) { + shape_ = shape; + value_ = value; + force_cpu_ = force_cpu; + is_use_shape_tensor_ = is_use_shape_tensor; + is_use_shape_tensor_list_ = is_use_shape_tensor_list; + + for (int i = 0; i < shape_test.size(); i++) { + shape_tensor_data.push_back(i + 1); + } + } + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(out_); + DDim output_dims{shape_}; + if (is_use_shape_tensor_) { + auto* temp_shape = scope->FindTensor(shape_tensor_); + auto* shape_data = temp_shape->data(); + auto vec_shape = + std::vector(shape_data, shape_data + temp_shape->numel()); + output_dims.ConstructFrom(vec_shape); + } + if (is_use_shape_tensor_list_) { + std::vector vec_shape; + for (int i = 0; i < shape_tensor_list_.size(); i++) { + auto* temp_shape = scope->FindTensor(shape_tensor_list_[i]); + vec_shape.push_back(*temp_shape->data()); + } + + output_dims.ConstructFrom(vec_shape); + } + out->Resize(output_dims); + + auto* output_data = out->mutable_data(); + for (int i = 0; i < out->numel(); i++) { + output_data[i] = value_; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + LOG(INFO) << "PrepareOpDesc"; + + op_desc->SetType("fill_constant"); + op_desc->SetAttr("dtype", dtype_); + op_desc->SetAttr("shape", shape_); + op_desc->SetAttr("value", value_); + op_desc->SetAttr("force_cpu", force_cpu_); + if (is_use_shape_tensor_) { + op_desc->SetInput("ShapeTensor", {shape_tensor_}); + } + if (is_use_shape_tensor_list_) { + // std::vector shape_tensor_list_; + for (int i = 0; i < shape_test.size(); ++i) { + shape_tensor_list_.push_back("shape_tensor_list_" + std::to_string(i)); + } + op_desc->SetInput("ShapeTensorList", {shape_tensor_list_}); + } + op_desc->SetOutput("Out", {out_}); + } + + void PrepareData() override { + if (is_use_shape_tensor_) { + // std::vector temp = x_dims_.data(); + // int64_t* data = temp.data(); + SetCommonTensor(shape_tensor_, shape_test, shape_tensor_data.data()); + } + if (is_use_shape_tensor_list_) { + Scope& scope_ = this->scope(); + for (int i = 0; i < shape_test.size(); ++i) { + auto* tensor = + scope_.NewTensor("shape_tensor_list_" + std::to_string(i)); + tensor->Resize(DDim({1})); + auto* d = tensor->mutable_data(); + d[0] = shape_tensor_data[i]; + } + } + } +}; + +TEST(fill_constant, precision) { + LOG(INFO) << "test fill_constant op, kARM"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + std::vector shape{1, 2}; + + for (int dtype : {static_cast(VarDescAPI::VarDataType::INT32)}) { + for (float value : {1, 2}) { + for (bool is_use_shape_tensor_list : {false, true}) { + for (bool is_use_shape_tensor : {false, true}) { + if (is_use_shape_tensor && is_use_shape_tensor_list) break; + LOG(INFO) << "value:" << value + << ", is_use_shape_tensor:" << is_use_shape_tensor + << ", is_use_shape_tensor_list:" + << is_use_shape_tensor_list; + + std::unique_ptr tester( + new FillConstantComputeTester(place, + "def", + shape, + is_use_shape_tensor, + is_use_shape_tensor_list, + value, + false)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + } + } +#endif + +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); + LOG(INFO) << "test concate op, x86"; + for (int axis : {1, 2}) { + for (bool is_use_axis_tensor : {false, true}) { + LOG(INFO) << "axis:" << axis + << ", is_use_axis_tensor:" << is_use_axis_tensor; + std::unique_ptr tester( + new ConcateComputeTester(place, "def", axis, is_use_axis_tensor)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + +#endif +} + +} // namespace lite +} // namespace paddle