未验证 提交 2882edb1 编写于 作者: Z zhupengyang 提交者: GitHub

split fill_constant & fill_constant_batch_size_like, enable and add uts (#3023)

上级 9d88feea
...@@ -76,7 +76,7 @@ void TestCase::PrepareInputsForInstruction() { ...@@ -76,7 +76,7 @@ void TestCase::PrepareInputsForInstruction() {
const auto* inst_type = Type::GetTensorTy(TARGET(kHost)); const auto* inst_type = Type::GetTensorTy(TARGET(kHost));
CHECK(scope_->FindVar(var)); CHECK(scope_->FindVar(var));
const auto* shared_tensor = scope_->FindTensor((var)); const auto* shared_tensor = scope_->FindTensor(var);
if (!TargetCompatibleTo(*inst_type, *param_type->type)) { if (!TargetCompatibleTo(*inst_type, *param_type->type)) {
/// Create a tensor in the instruction's scope, alloc memory and then /// Create a tensor in the instruction's scope, alloc memory and then
/// copy data there. /// copy data there.
......
...@@ -83,7 +83,6 @@ add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_comput ...@@ -83,7 +83,6 @@ add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_comput
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(distribute_fpn_proposals_compute_arm ARM extra SRCS distribute_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific # for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -100,6 +99,7 @@ add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc D ...@@ -100,6 +99,7 @@ add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc D
add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(read_from_array_compute_arm ARM extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_compute_arm ARM extra SRCS beam_search_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(fill_constant_compute_arm ARM basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(fill_constant_batch_size_like_compute_arm ARM basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lod_reset_compute_arm ARM extra SRCS lod_reset_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(is_empty_compute_arm ARM extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lstm_arm ARM extra SRCS lstm_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/fill_constant_batch_size_like_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void FillConstantBatchSizeLikeCompute::Run() {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
fill_constant_batch_size_like,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchSizeLikeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class FillConstantBatchSizeLikeCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantBatchSizeLikeParam;
void Run() override;
~FillConstantBatchSizeLikeCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -12,118 +12,37 @@ ...@@ -12,118 +12,37 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/core/kernel.h" #include "lite/kernels/arm/fill_constant_compute.h"
#include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { void FillConstantCompute::Run() {
public: auto& param = *param_.get_mutable<param_t>();
using param_t = operators::FillConstantParam; auto& context = ctx_->As<ARMContext>();
inline DDimLite GetShape(const param_t& param) { if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
// 1. shape is a Tensor auto data = param.out->template mutable_data<float>();
if (param.shape_tensor != nullptr) { for (int i = 0; i < param.out->numel(); i++) {
auto* shape_tensor = param.shape_tensor; data[i] = param.value;
auto* shape_data = shape_tensor->data<int>();
auto vec_shape =
std::vector<int64_t>(shape_data, shape_data + shape_tensor->numel());
return DDimLite(vec_shape);
} }
} else if (param.dtype ==
// 2. shape is a list/tuple containing Tensor static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto shape_tensor_list = param.shape_tensor_list; auto data = param.out->template mutable_data<int32_t>();
if (shape_tensor_list.size() > 0) { for (int i = 0; i < param.out->numel(); i++) {
std::vector<int64_t> vec_shape; data[i] = param.value;
for (size_t i = 0; i < shape_tensor_list.size(); ++i) {
auto tensor = shape_tensor_list[i];
vec_shape.push_back(*tensor->data<int>());
}
return DDimLite(vec_shape);
} }
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::INT8)) {
// 3. shape is a list/tuple without containing Tensor auto data = param.out->template mutable_data<int8_t>();
auto vec_shape = param.shape; for (int i = 0; i < param.out->numel(); i++) {
return DDimLite(vec_shape); data[i] = param.value;
}
void PrepareForRun() override {
auto& param = *param_.get_mutable<param_t>();
auto outdims = GetShape(param);
param.Out->Resize(outdims);
}
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.Out->template mutable_data<float>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.Out->template mutable_data<int32_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.Out->template mutable_data<int8_t>();
for (int i = 0; i < param.Out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
} }
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
} }
}
virtual ~FillConstantCompute() = default;
};
class FillConstantBatchLikeCompute
: public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::FillConstantBatchLikeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) {
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
}
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
}
virtual ~FillConstantBatchLikeCompute() = default;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
...@@ -133,23 +52,13 @@ class FillConstantBatchLikeCompute ...@@ -133,23 +52,13 @@ class FillConstantBatchLikeCompute
// float // float
REGISTER_LITE_KERNEL(fill_constant, REGISTER_LITE_KERNEL(fill_constant,
kARM, kARM,
kAny, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::FillConstantCompute, paddle::lite::kernels::arm::FillConstantCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ShapeTensor", .BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList", .BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(fill_constant_batch_size_like,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchLikeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantParam;
void Run() override;
~FillConstantCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -34,8 +34,8 @@ class FillConstantBatchSizeLikeCompute ...@@ -34,8 +34,8 @@ class FillConstantBatchSizeLikeCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& ctx = ctx_->As<X86Context>(); auto& ctx = ctx_->As<X86Context>();
auto* out = param.Out; auto* out = param.out;
auto* in = param.Input; auto* in = param.input;
if (in->lod().size() && param.input_dim_idx == 0) { if (in->lod().size() && param.input_dim_idx == 0) {
// set the correct batch size for the LoDTensor. // set the correct batch size for the LoDTensor.
auto odims = out->dims(); auto odims = out->dims();
......
...@@ -56,8 +56,8 @@ TEST(fill_constant_batch_size_like_x86, run_test) { ...@@ -56,8 +56,8 @@ TEST(fill_constant_batch_size_like_x86, run_test) {
FillConstantBatchSizeLikeCompute<float> fill_constant_batch_size_like; FillConstantBatchSizeLikeCompute<float> fill_constant_batch_size_like;
operators::FillConstantBatchSizeLikeParam param; operators::FillConstantBatchSizeLikeParam param;
param.Input = &input; param.input = &input;
param.Out = &out; param.out = &out;
std::vector<int> shape{-1, 132, 7}; std::vector<int> shape{-1, 132, 7};
float value = 3.5; float value = 3.5;
param.shape = shape; param.shape = shape;
......
...@@ -20,6 +20,7 @@ add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS}) ...@@ -20,6 +20,7 @@ add_operator(box_coder_op_lite basic SRCS box_coder_op.cc DEPS ${op_DEPS})
add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS}) add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DEPS})
add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS}) add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS})
add_operator(shuffle_channel_op basic SRCS shuffle_channel_op.cc DEPS ${op_DEPS}) add_operator(shuffle_channel_op basic SRCS shuffle_channel_op.cc DEPS ${op_DEPS})
add_operator(yolo_box_op basic SRCS yolo_box_op.cc DEPS ${op_DEPS}) add_operator(yolo_box_op basic SRCS yolo_box_op.cc DEPS ${op_DEPS})
add_operator(interpolate_op basic SRCS interpolate_op.cc DEPS ${op_DEPS}) add_operator(interpolate_op basic SRCS interpolate_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fill_constant_batch_size_like_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool FillConstantBatchSizeLikeOp::CheckShape() const {
CHECK(param_.out);
CHECK(param_.input);
CHECK_GT(param_.shape.size(), 0);
CHECK_GE(param_.input_dim_idx, 0);
CHECK_GE(param_.output_dim_idx, 0);
return true;
}
bool FillConstantBatchSizeLikeOp::InferShape() const {
std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()};
if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) {
output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1;
} else {
output_dim[param_.output_dim_idx] =
param_.input->dims()[param_.input_dim_idx];
}
param_.out->Resize(output_dim);
return true;
}
bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto out_name = opdesc.Output("Out").front();
auto input_name = opdesc.Input("Input").front();
param_.out = GetMutableVar<lite::Tensor>(scope, out_name);
param_.input = GetMutableVar<lite::Tensor>(scope, input_name);
param_.dtype = opdesc.GetAttr<int>("dtype");
param_.shape = opdesc.GetAttr<std::vector<int>>("shape");
if (opdesc.HasAttr("value")) {
param_.value = opdesc.GetAttr<float>("value");
}
if (opdesc.HasAttr("input_dim_idx")) {
param_.input_dim_idx = opdesc.GetAttr<int>("input_dim_idx");
}
if (opdesc.HasAttr("output_dim_idx")) {
param_.output_dim_idx = opdesc.GetAttr<int>("output_dim_idx");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchSizeLikeOp);
...@@ -45,6 +45,6 @@ class FillConstantBatchSizeLikeOp : public OpLite { ...@@ -45,6 +45,6 @@ class FillConstantBatchSizeLikeOp : public OpLite {
mutable FillConstantBatchSizeLikeParam param_; mutable FillConstantBatchSizeLikeParam param_;
}; };
} /* namespace operators */ } // namespace operators
} /* namespace lite */ } // namespace lite
} /* namespace paddle */ } // namespace paddle
...@@ -12,129 +12,69 @@ ...@@ -12,129 +12,69 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/core/op_lite.h" #include "lite/operators/fill_constant_op.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class FillConstantOp : public OpLite { bool FillConstantOp::CheckShape() const {
public: CHECK(param_.out);
explicit FillConstantOp(const std::string& type) : OpLite(type) {} return true;
}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.Out); bool FillConstantOp::InferShape() const {
return true; std::vector<int64_t> out_shape;
} auto shape_tensor = param_.shape_tensor;
auto shape_tensor_list = param_.shape_tensor_list;
bool InferShape() const override { if (shape_tensor != nullptr) {
lite::Tensor* shape_tensor_ = param_.shape_tensor; auto shape_tensor_data = shape_tensor->data<int>();
if (param_.shape.empty() && shape_tensor_ != nullptr) { for (int i = 0; i < shape_tensor->numel(); i++) {
param_.Out->Resize(shape_tensor_->dims()); out_shape.push_back(shape_tensor_data[i]);
return true;
}
param_.Out->Resize(param_.shape);
return true;
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto Out_name = opdesc.Output("Out").front();
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.dtype = opdesc.GetAttr<int>("dtype");
param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
param_.value = opdesc.GetAttr<float>("value");
param_.force_cpu = opdesc.GetAttr<bool>("force_cpu");
param_.shape_tensor = nullptr;
param_.shape_tensor_list = {};
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (opdesc.HasInput("ShapeTensor") &&
!opdesc.Input("ShapeTensor").empty()) {
auto args = opdesc.Input("ShapeTensor");
auto* var = scope->FindVar(args.front());
param_.shape_tensor = var->GetMutable<lite::Tensor>();
} }
if (opdesc.HasAttr("ShapeTensorList")) { } else if (!shape_tensor_list.empty()) {
auto args = opdesc.Input("ShapeTensorList"); for (int i = 0; i < shape_tensor_list.size(); i++) {
auto* var = scope->FindVar(args.front()); out_shape.push_back(shape_tensor_list[i]->data<int>()[0]);
param_.shape_tensor_list =
*(var->GetMutable<std::vector<lite::Tensor*>>());
} }
return true; } else if (!param_.shape.empty()) {
out_shape = param_.shape;
} else {
LOG(FATAL) << "no valid out_shape. Must set one of shape_tensor, or "
"shape_tensor_list, or shape.";
} }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } param_.out->Resize(out_shape);
return true;
std::string DebugString() const override { return "fill_constant"; } }
private:
mutable operators::FillConstantParam param_;
};
class FillConstantBatchLikeOp : public OpLite { bool FillConstantOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
public: auto out_name = opdesc.Output("Out").front();
explicit FillConstantBatchLikeOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override { param_.out = GetMutableVar<lite::Tensor>(scope, out_name);
CHECK_OR_FALSE(param_.out); param_.dtype = opdesc.GetAttr<int>("dtype");
CHECK_OR_FALSE(param_.input); if (opdesc.HasAttr("shape")) {
CHECK_GT_OR_FALSE(param_.shape.size(), 0); param_.shape = opdesc.GetAttr<std::vector<int64_t>>("shape");
CHECK_GE_OR_FALSE(param_.input_dim_idx, 0);
CHECK_GE_OR_FALSE(param_.output_dim_idx, 0);
return true;
} }
param_.value = opdesc.GetAttr<float>("value");
param_.force_cpu = opdesc.GetAttr<bool>("force_cpu");
bool InferShape() const override { if (opdesc.HasInput("ShapeTensor") && !opdesc.Input("ShapeTensor").empty()) {
auto output_dim = param_.shape; auto shape_tensor_name = opdesc.Input("ShapeTensor").front();
output_dim[param_.output_dim_idx] = param_.shape_tensor = GetMutableVar<lite::Tensor>(scope, shape_tensor_name);
param_.input->dims()[param_.input_dim_idx];
param_.out->Resize(output_dim);
return true;
} }
if (opdesc.HasInput("ShapeTensorList") &&
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { !opdesc.Input("ShapeTensorList").empty()) {
auto Out_name = opdesc.Output("Out").front(); for (auto shape_tensor_name : opdesc.Input("ShapeTensorList")) {
auto In_name = opdesc.Input("Input").front(); param_.shape_tensor_list.push_back(
GetMutableVar<lite::Tensor>(scope, shape_tensor_name));
param_.out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.input = GetMutableVar<lite::Tensor>(scope, In_name);
param_.dtype = opdesc.GetAttr<int>("dtype");
auto shape = opdesc.GetAttr<std::vector<int>>("shape");
std::vector<int64_t> outshape;
for (auto i : shape) {
outshape.push_back(i);
}
param_.shape = outshape;
if (opdesc.HasAttr("value")) {
param_.value = opdesc.GetAttr<float>("value");
}
if (opdesc.HasAttr("input_dim_idx")) {
param_.input_dim_idx = opdesc.GetAttr<int>("input_dim_idx");
}
if (opdesc.HasAttr("output_dim_idx")) {
param_.output_dim_idx = opdesc.GetAttr<int>("output_dim_idx");
} }
return true;
} }
return true;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } }
std::string DebugString() const override {
return "fill_constant_batch_size_like";
}
private:
mutable operators::FillConstantBatchLikeParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp); REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp);
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchLikeOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FillConstantOp : public OpLite {
public:
FillConstantOp() {}
explicit FillConstantOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "fill_constant"; }
private:
mutable FillConstantParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -418,35 +418,26 @@ struct MeanGradParam { ...@@ -418,35 +418,26 @@ struct MeanGradParam {
struct FillConstantParam { struct FillConstantParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{}; std::vector<int64_t> shape{};
lite::Tensor* shape_tensor; lite::Tensor* shape_tensor{nullptr};
std::vector<lite::Tensor*> shape_tensor_list{}; std::vector<lite::Tensor*> shape_tensor_list{};
float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
lite::Tensor* Out{};
};
struct FillConstantBatchLikeParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{};
float value{0.0f}; float value{0.0f};
// useless for x86, keep it for compatibility // useless for x86, keep it for compatibility
bool force_cpu{false}; bool force_cpu{false};
lite::Tensor* out{}; lite::Tensor* out{};
const lite::Tensor* input{};
int input_dim_idx{0};
int output_dim_idx{0};
}; };
struct FillConstantBatchSizeLikeParam { struct FillConstantBatchSizeLikeParam {
lite::Tensor* Input; const lite::Tensor* input{nullptr};
lite::Tensor* Out; lite::Tensor* out{nullptr};
std::vector<int> shape; std::vector<int> shape{};
int input_dim_idx{0}; int input_dim_idx{0};
int output_dim_idx{0}; int output_dim_idx{0};
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)}; int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
float value{0.0f}; float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
}; };
// //
......
...@@ -34,6 +34,8 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -34,6 +34,8 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class FillConstantBatchSizeLikeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "input";
std::string out_ = "out";
DDim in_dims_{};
LoD in_lod_{};
std::vector<int> shape_{};
float value_{0.f};
int input_dim_idx_{0};
int output_dim_idx_{0};
int dtype_{static_cast<int>(VarDescAPI::VarDataType::FP32)};
public:
FillConstantBatchSizeLikeComputeTester(
const Place& place,
const std::string& alias,
DDim in_dims,
LoD in_lod,
std::vector<int> shape,
float value = 0.f,
int input_dim_idx = 0,
int output_dim_idx = 0,
int dtype = static_cast<int>(VarDescAPI::VarDataType::FP32))
: TestCase(place, alias),
in_dims_(in_dims),
in_lod_(in_lod),
shape_(shape),
value_(value),
input_dim_idx_(input_dim_idx),
output_dim_idx_(output_dim_idx),
dtype_(dtype) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_);
auto* input = scope->FindTensor(input_);
std::vector<int64_t> out_shape{shape_.begin(), shape_.end()};
if (input_dim_idx_ == 0 && !input->lod().empty()) {
out_shape[output_dim_idx_] = input->lod().back().size() - 1;
} else {
out_shape[output_dim_idx_] = input->dims()[input_dim_idx_];
}
out->Resize(out_shape);
auto* output_data = out->mutable_data<float>();
for (int i = 0; i < out->numel(); i++) {
output_data[i] = value_;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fill_constant_batch_size_like");
op_desc->SetInput("Input", {input_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("shape", shape_);
op_desc->SetAttr("value", value_);
op_desc->SetAttr("input_dim_idx", input_dim_idx_);
op_desc->SetAttr("output_dim_idx", output_dim_idx_);
op_desc->SetAttr("dtype", dtype_);
}
void PrepareData() override {
std::vector<float> din(in_dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, in_dims_.production());
SetCommonTensor(input_, in_dims_, din.data(), in_lod_);
SetPrecisionType(out_, PRECISION(kFloat));
}
};
void TestFillConstantBatchSizeLike(Place place, float abs_error) {
for (auto input_dim_idx : {0, 1, 2}) {
for (auto output_dim_idx : {0, 1, 2}) {
std::unique_ptr<arena::TestCase> tester(
new FillConstantBatchSizeLikeComputeTester(place,
"def",
DDim{{5, 4, 3}},
{},
{2, 3, 4},
0.f,
input_dim_idx,
output_dim_idx));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
void TestFillConstantBatchSizeLikeLod(Place place, float abs_error) {
for (auto lod : std::vector<LoD>{{{0, 1, 4, 5}}, {{0, 2, 4}, {0, 1, 4, 5}}}) {
std::unique_ptr<arena::TestCase> tester(
new FillConstantBatchSizeLikeComputeTester(
place, "def", DDim{{5, 4, 3}}, lod, {2, 3, 4}, 0.f));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
void TestFillConstantBatchSizeLikeValue(Place place, float abs_error) {
std::vector<float> values{-1., 3.5};
for (auto value : values) {
std::unique_ptr<arena::TestCase> tester(
new FillConstantBatchSizeLikeComputeTester(
place, "def", DDim{{5, 4, 3}}, {}, {2, 3}, value));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(fill_constant_batch_size_like, precision) {
LOG(INFO) << "test fill_constant_batch_size_like op";
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
#else
return;
#endif
TestFillConstantBatchSizeLike(place, abs_error);
TestFillConstantBatchSizeLikeLod(place, abs_error);
TestFillConstantBatchSizeLikeValue(place, abs_error);
}
} // namespace lite
} // namespace paddle
...@@ -24,60 +24,56 @@ class FillConstantComputeTester : public arena::TestCase { ...@@ -24,60 +24,56 @@ class FillConstantComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string out_ = "out"; std::string out_ = "out";
int dtype_{static_cast<int>(VarDescAPI::VarDataType::FP32)}; std::string shape_tensor_ = "shape_tensor";
std::vector<std::string> shape_tensor_list_{};
std::vector<int64_t> shape_{}; std::vector<int64_t> shape_{};
std::string shape_tensor_ = "ShapeTensor"; float value_{0.0f};
std::vector<std::string> shape_tensor_list_; int dtype_{static_cast<int>(VarDescAPI::VarDataType::FP32)};
bool is_use_shape_tensor_{false}; bool is_use_shape_tensor_{false};
bool is_use_shape_tensor_list_{false}; bool is_use_shape_tensor_list_{false};
float value_{0.0f};
// useless for x86, keep it for compatibility // useless for x86, keep it for compatibility
bool force_cpu_{false}; bool force_cpu_{false};
// DDim shape_tensor_data{{5, 3}};
std::vector<int32_t> shape_tensor_data;
DDim shape_test{{1, 2}};
public: public:
FillConstantComputeTester(const Place& place, FillConstantComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
std::vector<int64_t> shape, std::vector<int64_t> shape,
const bool is_use_shape_tensor,
const bool is_use_shape_tensor_list,
float value, float value,
bool force_cpu) int dtype,
: TestCase(place, alias) { const bool is_use_shape_tensor = false,
shape_ = shape; const bool is_use_shape_tensor_list = false)
value_ = value; : TestCase(place, alias),
force_cpu_ = force_cpu; shape_(shape),
is_use_shape_tensor_ = is_use_shape_tensor; value_(value),
is_use_shape_tensor_list_ = is_use_shape_tensor_list; dtype_(dtype),
is_use_shape_tensor_(is_use_shape_tensor),
for (int i = 0; i < shape_test.size(); i++) { is_use_shape_tensor_list_(is_use_shape_tensor_list) {
shape_tensor_data.push_back(i + 1); if (is_use_shape_tensor_list) {
for (int i = 0; i < shape.size(); i++) {
shape_tensor_list_.push_back(shape_tensor_ + std::to_string(i));
}
} }
} }
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_); auto* out = scope->NewTensor(out_);
DDim output_dims{shape_}; std::vector<int64_t> out_shape;
if (is_use_shape_tensor_) { if (is_use_shape_tensor_) {
auto* temp_shape = scope->FindTensor(shape_tensor_); auto* shape_tensor = scope->FindTensor(shape_tensor_);
auto* shape_data = temp_shape->data<int>(); auto* shape_tensor_data = shape_tensor->data<int>();
auto vec_shape = out_shape = std::vector<int64_t>(
std::vector<int64_t>(shape_data, shape_data + temp_shape->numel()); shape_tensor_data, shape_tensor_data + shape_tensor->numel());
output_dims.ConstructFrom(vec_shape); } else if (is_use_shape_tensor_list_) {
}
if (is_use_shape_tensor_list_) {
std::vector<int64_t> vec_shape;
for (int i = 0; i < shape_tensor_list_.size(); i++) { for (int i = 0; i < shape_tensor_list_.size(); i++) {
auto* temp_shape = scope->FindTensor(shape_tensor_list_[i]); auto* shape_tensor = scope->FindTensor(shape_tensor_list_[i]);
vec_shape.push_back(*temp_shape->data<int>()); out_shape.push_back(shape_tensor->data<int>()[0]);
} }
} else {
output_dims.ConstructFrom(vec_shape); out_shape = shape_;
} }
out->Resize(output_dims); out->Resize(out_shape);
auto* output_data = out->mutable_data<float>(); auto* output_data = out->mutable_data<float>();
for (int i = 0; i < out->numel(); i++) { for (int i = 0; i < out->numel(); i++) {
...@@ -86,92 +82,105 @@ class FillConstantComputeTester : public arena::TestCase { ...@@ -86,92 +82,105 @@ class FillConstantComputeTester : public arena::TestCase {
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
LOG(INFO) << "PrepareOpDesc";
op_desc->SetType("fill_constant"); op_desc->SetType("fill_constant");
op_desc->SetAttr("dtype", dtype_);
op_desc->SetAttr("shape", shape_);
op_desc->SetAttr("value", value_);
op_desc->SetAttr("force_cpu", force_cpu_);
if (is_use_shape_tensor_) { if (is_use_shape_tensor_) {
op_desc->SetInput("ShapeTensor", {shape_tensor_}); op_desc->SetInput("ShapeTensor", {shape_tensor_});
} } else if (is_use_shape_tensor_list_) {
if (is_use_shape_tensor_list_) { op_desc->SetInput("ShapeTensorList", shape_tensor_list_);
// std::vector<std::string> shape_tensor_list_; } else {
for (int i = 0; i < shape_test.size(); ++i) { op_desc->SetAttr("shape", shape_);
shape_tensor_list_.push_back("shape_tensor_list_" + std::to_string(i));
}
op_desc->SetInput("ShapeTensorList", {shape_tensor_list_});
} }
op_desc->SetOutput("Out", {out_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("dtype", dtype_);
op_desc->SetAttr("value", value_);
op_desc->SetAttr("force_cpu", force_cpu_);
} }
void PrepareData() override { void PrepareData() override {
if (is_use_shape_tensor_) { if (is_use_shape_tensor_) {
// std::vector<int64_t> temp = x_dims_.data(); std::vector<int> dshape_tensor(shape_.begin(), shape_.end());
// int64_t* data = temp.data(); SetCommonTensor(shape_tensor_,
SetCommonTensor(shape_tensor_, shape_test, shape_tensor_data.data()); DDim({static_cast<int64_t>(shape_.size())}),
dshape_tensor.data());
} }
if (is_use_shape_tensor_list_) { if (is_use_shape_tensor_list_) {
Scope& scope_ = this->scope(); for (int i = 0; i < shape_.size(); ++i) {
for (int i = 0; i < shape_test.size(); ++i) { std::vector<int> dshape_tensor{static_cast<int>(shape_[i])};
auto* tensor = SetCommonTensor(shape_tensor_list_[i], DDim({1}), dshape_tensor.data());
scope_.NewTensor("shape_tensor_list_" + std::to_string(i));
tensor->Resize(DDim({1}));
auto* d = tensor->mutable_data<int>();
d[0] = shape_tensor_data[i];
} }
} }
SetPrecisionType(out_, PRECISION(kFloat));
} }
}; };
TEST(fill_constant, precision) { void TestFillConstantShape(Place place, float abs_error) {
LOG(INFO) << "test fill_constant op, kARM"; std::vector<std::vector<int64_t>> out_shapes{
#ifdef LITE_WITH_ARM {2, 3, 4, 5}, {2, 3, 4}, {3, 4}, {4}};
Place place(TARGET(kARM)); for (auto out_shape : out_shapes) {
std::vector<int64_t> shape{1, 2}; std::unique_ptr<arena::TestCase> tester(new FillConstantComputeTester(
place,
for (int dtype : {static_cast<int>(VarDescAPI::VarDataType::INT32)}) { "def",
for (float value : {1, 2}) { out_shape,
for (bool is_use_shape_tensor_list : {false, true}) { 1.f,
for (bool is_use_shape_tensor : {false, true}) { static_cast<int>(VarDescAPI::VarDataType::FP32)));
if (is_use_shape_tensor && is_use_shape_tensor_list) break; arena::Arena arena(std::move(tester), place, abs_error);
LOG(INFO) << "value:" << value arena.TestPrecision();
<< ", is_use_shape_tensor:" << is_use_shape_tensor
<< ", is_use_shape_tensor_list:"
<< is_use_shape_tensor_list;
std::unique_ptr<arena::TestCase> tester(
new FillConstantComputeTester(place,
"def",
shape,
is_use_shape_tensor,
is_use_shape_tensor_list,
value,
false));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
#endif }
#ifdef LITE_WITH_X86 void TestFillConstantValue(Place place, float abs_error) {
Place place(TARGET(kX86)); std::vector<float> values{-1., 0., 3.5};
LOG(INFO) << "test concate op, x86"; for (auto value : values) {
for (int axis : {1, 2}) { std::unique_ptr<arena::TestCase> tester(new FillConstantComputeTester(
for (bool is_use_axis_tensor : {false, true}) { place,
LOG(INFO) << "axis:" << axis "def",
<< ", is_use_axis_tensor:" << is_use_axis_tensor; {2, 3},
std::unique_ptr<arena::TestCase> tester( value,
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor)); static_cast<int>(VarDescAPI::VarDataType::FP32)));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
}
} }
}
void TestFillConstantShapeTensor(Place place, float abs_error) {
std::unique_ptr<arena::TestCase> tester(new FillConstantComputeTester(
place,
"def",
{2, 3, 4},
1.f,
static_cast<int>(VarDescAPI::VarDataType::FP32),
true));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
void TestFillConstantShapeTensorList(Place place, float abs_error) {
std::unique_ptr<arena::TestCase> tester(new FillConstantComputeTester(
place,
"def",
{2, 3, 4},
1.f,
static_cast<int>(VarDescAPI::VarDataType::FP32),
false,
true));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
TEST(fill_constant, precision) {
LOG(INFO) << "test fill_constant op";
Place place;
float abs_error = 1e-5;
#ifdef LITE_WITH_ARM
place = TARGET(kARM);
#else
return;
#endif #endif
TestFillConstantShape(place, abs_error);
TestFillConstantValue(place, abs_error);
TestFillConstantShapeTensor(place, abs_error);
TestFillConstantShapeTensorList(place, abs_error);
} }
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册