未验证 提交 89df8f01 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

fill_constant op support param shape can be tensor or tensorlist, test=develop (#2459)

* fill_constant can support shape is tensor or tensorlist
上级 826f6605
...@@ -25,6 +25,38 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -25,6 +25,38 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::FillConstantParam; using param_t = operators::FillConstantParam;
inline DDimLite GetShape(const param_t& param) {
// 1. shape is a Tensor
if (param.shape_tensor != nullptr) {
auto* shape_tensor = param.shape_tensor;
auto* shape_data = shape_tensor->data<int>();
auto vec_shape =
std::vector<int64_t>(shape_data, shape_data + shape_tensor->numel());
return DDimLite(vec_shape);
}
// 2. shape is a list/tuple containing Tensor
auto shape_tensor_list = param.shape_tensor_list;
if (shape_tensor_list.size() > 0) {
std::vector<int64_t> vec_shape;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
vec_shape.push_back(*tensor->data<int>());
}
return DDimLite(vec_shape);
}
// 3. shape is a list/tuple without containing Tensor
auto vec_shape = param.shape;
return DDimLite(vec_shape);
}
void PrepareForRun() override {
auto& param = *param_.get_mutable<param_t>();
auto outdims = GetShape(param);
param.Out->Resize(outdims);
}
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>(); auto& context = ctx_->As<ARMContext>();
...@@ -107,6 +139,11 @@ REGISTER_LITE_KERNEL(fill_constant, ...@@ -107,6 +139,11 @@ REGISTER_LITE_KERNEL(fill_constant,
kNCHW, kNCHW,
paddle::lite::kernels::arm::FillConstantCompute<float>, paddle::lite::kernels::arm::FillConstantCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
......
...@@ -29,6 +29,38 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -29,6 +29,38 @@ class FillConstantCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public: public:
using param_t = operators::FillConstantParam; using param_t = operators::FillConstantParam;
inline DDimLite GetShape(const param_t& param) {
// 1. shape is a Tensor
if (param.shape_tensor != nullptr) {
auto* shape_tensor = param.shape_tensor;
auto* shape_data = shape_tensor->data<int>();
auto vec_shape =
std::vector<int64_t>(shape_data, shape_data + shape_tensor->numel());
return DDimLite(vec_shape);
}
// 2. shape is a list/tuple containing Tensor
auto shape_tensor_list = param.shape_tensor_list;
if (shape_tensor_list.size() > 0) {
std::vector<int64_t> vec_shape;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
vec_shape.push_back(*tensor->data<int>());
}
return DDimLite(vec_shape);
}
// 3. shape is a list/tuple without containing Tensor
auto vec_shape = param.shape;
return DDimLite(vec_shape);
}
void PrepareForRun() override {
auto& param = *param_.get_mutable<param_t>();
auto outdims = GetShape(param);
param.Out->Resize(outdims);
}
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>(); auto& context = ctx_->As<X86Context>();
...@@ -55,5 +87,9 @@ REGISTER_LITE_KERNEL(fill_constant, ...@@ -55,5 +87,9 @@ REGISTER_LITE_KERNEL(fill_constant,
kNCHW, kNCHW,
paddle::lite::kernels::x86::FillConstantCompute<float>, paddle::lite::kernels::x86::FillConstantCompute<float>,
def) def)
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -29,6 +29,12 @@ class FillConstantOp : public OpLite { ...@@ -29,6 +29,12 @@ class FillConstantOp : public OpLite {
} }
bool InferShape() const override { bool InferShape() const override {
lite::Tensor* shape_tensor_ = param_.shape_tensor;
if (param_.shape.empty() && shape_tensor_ != nullptr) {
param_.Out->Resize(shape_tensor_->dims());
return true;
}
param_.Out->Resize(param_.shape); param_.Out->Resize(param_.shape);
return true; return true;
} }
...@@ -41,6 +47,23 @@ class FillConstantOp : public OpLite { ...@@ -41,6 +47,23 @@ class FillConstantOp : public OpLite {
param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape"); param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
param_.value = opdesc.GetAttr<float>("value"); param_.value = opdesc.GetAttr<float>("value");
param_.force_cpu = opdesc.GetAttr<bool>("force_cpu"); param_.force_cpu = opdesc.GetAttr<bool>("force_cpu");
param_.shape_tensor = nullptr;
param_.shape_tensor_list = {};
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"ShapeTensor") != input_arg_names.end()) {
auto args = opdesc.Input("ShapeTensor");
auto* var = scope->FindVar(args.front());
param_.shape_tensor = var->GetMutable<lite::Tensor>();
}
if (opdesc.HasAttr("ShapeTensorList")) {
auto args = opdesc.Input("ShapeTensorList");
auto* var = scope->FindVar(args.front());
param_.shape_tensor_list =
*(var->GetMutable<std::vector<lite::Tensor*>>());
}
return true; return true;
} }
......
...@@ -408,6 +408,9 @@ struct MeanGradParam { ...@@ -408,6 +408,9 @@ struct MeanGradParam {
struct FillConstantParam { struct FillConstantParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{}; std::vector<int64_t> shape{};
lite::Tensor* shape_tensor;
std::vector<lite::Tensor*> shape_tensor_list{};
float value{0.0f}; float value{0.0f};
// useless for x86, keep it for compatibility // useless for x86, keep it for compatibility
bool force_cpu{false}; bool force_cpu{false};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class FillConstantComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string out_ = "out";
int dtype_{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape_{};
std::string shape_tensor_ = "ShapeTensor";
std::vector<std::string> shape_tensor_list_;
bool is_use_shape_tensor_{false};
bool is_use_shape_tensor_list_{false};
float value_{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu_{false};
// DDim shape_tensor_data{{5, 3}};
std::vector<int32_t> shape_tensor_data;
DDim shape_test{{1, 2}};
public:
FillConstantComputeTester(const Place& place,
const std::string& alias,
std::vector<int64_t> shape,
const bool is_use_shape_tensor,
const bool is_use_shape_tensor_list,
float value,
bool force_cpu)
: TestCase(place, alias) {
shape_ = shape;
value_ = value;
force_cpu_ = force_cpu;
is_use_shape_tensor_ = is_use_shape_tensor;
is_use_shape_tensor_list_ = is_use_shape_tensor_list;
for (int i = 0; i < shape_test.size(); i++) {
shape_tensor_data.push_back(i + 1);
}
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_);
DDim output_dims{shape_};
if (is_use_shape_tensor_) {
auto* temp_shape = scope->FindTensor(shape_tensor_);
auto* shape_data = temp_shape->data<int>();
auto vec_shape =
std::vector<int64_t>(shape_data, shape_data + temp_shape->numel());
output_dims.ConstructFrom(vec_shape);
}
if (is_use_shape_tensor_list_) {
std::vector<int64_t> vec_shape;
for (int i = 0; i < shape_tensor_list_.size(); i++) {
auto* temp_shape = scope->FindTensor(shape_tensor_list_[i]);
vec_shape.push_back(*temp_shape->data<int>());
}
output_dims.ConstructFrom(vec_shape);
}
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();
for (int i = 0; i < out->numel(); i++) {
output_data[i] = value_;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
LOG(INFO) << "PrepareOpDesc";
op_desc->SetType("fill_constant");
op_desc->SetAttr("dtype", dtype_);
op_desc->SetAttr("shape", shape_);
op_desc->SetAttr("value", value_);
op_desc->SetAttr("force_cpu", force_cpu_);
if (is_use_shape_tensor_) {
op_desc->SetInput("ShapeTensor", {shape_tensor_});
}
if (is_use_shape_tensor_list_) {
// std::vector<std::string> shape_tensor_list_;
for (int i = 0; i < shape_test.size(); ++i) {
shape_tensor_list_.push_back("shape_tensor_list_" + std::to_string(i));
}
op_desc->SetInput("ShapeTensorList", {shape_tensor_list_});
}
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
if (is_use_shape_tensor_) {
// std::vector<int64_t> temp = x_dims_.data();
// int64_t* data = temp.data();
SetCommonTensor(shape_tensor_, shape_test, shape_tensor_data.data());
}
if (is_use_shape_tensor_list_) {
Scope& scope_ = this->scope();
for (int i = 0; i < shape_test.size(); ++i) {
auto* tensor =
scope_.NewTensor("shape_tensor_list_" + std::to_string(i));
tensor->Resize(DDim({1}));
auto* d = tensor->mutable_data<int>();
d[0] = shape_tensor_data[i];
}
}
}
};
TEST(fill_constant, precision) {
LOG(INFO) << "test fill_constant op, kARM";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::vector<int64_t> shape{1, 2};
for (int dtype : {static_cast<int>(VarDescAPI::VarDataType::INT32)}) {
for (float value : {1, 2}) {
for (bool is_use_shape_tensor_list : {false, true}) {
for (bool is_use_shape_tensor : {false, true}) {
if (is_use_shape_tensor && is_use_shape_tensor_list) break;
LOG(INFO) << "value:" << value
<< ", is_use_shape_tensor:" << is_use_shape_tensor
<< ", is_use_shape_tensor_list:"
<< is_use_shape_tensor_list;
std::unique_ptr<arena::TestCase> tester(
new FillConstantComputeTester(place,
"def",
shape,
is_use_shape_tensor,
is_use_shape_tensor_list,
value,
false));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
#endif
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
LOG(INFO) << "test concate op, x86";
for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor;
std::unique_ptr<arena::TestCase> tester(
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册