未验证 提交 1816f57f 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] enhance conv2d ut (#2753)

上级 91f0ef0b
......@@ -107,7 +107,8 @@ class TestCase {
void SetCommonTensor(const std::string& var_name,
const DDim& ddim,
const T* data,
const LoD& lod = {}) {
const LoD& lod = {},
bool is_persistable = false) {
auto* tensor = scope_->NewTensor(var_name);
tensor->Resize(ddim);
auto* d = tensor->mutable_data<T>();
......@@ -115,6 +116,8 @@ class TestCase {
// set lod
if (!lod.empty()) *tensor->mutable_lod() = lod;
// set persistable
tensor->set_persistable(is_persistable);
}
// Prepare for the operator.
......
......@@ -38,18 +38,21 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(input_type->layout() == DATALAYOUT(kNCHW));
auto input = scope->FindMutableTensor(input_name);
auto input_dims = input->dims();
auto filter_name = op_info->Input("Filter").front();
auto filter_type = kernel->GetInputDeclType("Filter");
CHECK(filter_type->precision() == PRECISION(kFloat));
CHECK(filter_type->layout() == DATALAYOUT(kNCHW));
auto filter = scope->FindMutableTensor(filter_name);
auto filter_dims = filter->dims();
auto output_name = op_info->Output("Output").front();
auto output_type = kernel->GetOutputDeclType("Output");
CHECK(output_type->precision() == PRECISION(kFloat));
CHECK(output_type->layout() == DATALAYOUT(kNCHW));
auto output = scope->FindMutableTensor(output_name);
auto output_dims = output->dims();
auto bs = input_dims[0];
auto ic = input_dims[1];
auto oc = filter_dims[0];
......@@ -62,8 +65,13 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu =
op_info->HasAttr("fuse_relu") && op_info->GetAttr<bool>("fuse_relu");
bool with_act =
op_info->HasAttr("with_act") && op_info->GetAttr<bool>("with_act");
std::string act_type =
with_act ? op_info->GetAttr<std::string>("act_type") : "";
float leaky_relu_alpha = act_type == "leaky_relu"
? op_info->GetAttr<float>("leaky_relu_alpha")
: 0.f;
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
......@@ -187,10 +195,15 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
conv_op->set_input_x(*input_node->data());
conv_op->set_input_w(*filter_node->data());
conv_op->set_attr_mode(1);
conv_op->set_attr_pad_mode(0); // NOTSET
// when padding_algorithm=="SAME", NPU is different from lite
if (padding_algorithm == "VALID") {
conv_op->set_attr_pad_mode(5);
} else {
conv_op->set_attr_pad_mode(0);
}
conv_op->set_attr_group(groups);
conv_op->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[2], paddings[2]}));
{paddings[0], paddings[1], paddings[2], paddings[3]}));
conv_op->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
conv_op->set_attr_stride(ge::AttrValue::LIST_INT({strides[0], strides[1]}));
......@@ -212,13 +225,16 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
CHECK(conv_node);
if (fuse_relu) {
// Append relu node if fuse_relu is true
auto relu_node = graph->Add<ge::op::Activation>(output_name);
auto relu_op = relu_node->data<ge::op::Activation>();
relu_op->set_input_x(*conv_node->data());
relu_op->set_attr_mode(CvtActMode("relu"));
if (!act_type.empty()) {
auto act_node = graph->Add<ge::op::Activation>(output_name);
auto act_op = act_node->data<ge::op::Activation>();
act_op->set_input_x(*conv_node->data());
act_op->set_attr_mode(CvtActMode(act_type));
if (act_type == "leaky_relu") {
act_op->set_attr_negative_slope(leaky_relu_alpha);
}
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
void conv_ref(const std::shared_ptr<operators::ConvOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("Input").front())->GetMutable<Tensor>();
auto filter =
scope->FindVar(op_info->Input("Filter").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Output").front())->GetMutable<Tensor>();
std::vector<int32_t> strides =
op_info->GetAttr<std::vector<int32_t>>("strides");
std::vector<int32_t> paddings =
op_info->GetAttr<std::vector<int32_t>>("paddings");
int32_t groups = op_info->GetAttr<int32_t>("groups");
std::vector<int32_t> dilations =
op_info->GetAttr<std::vector<int32_t>>("dilations");
bool fuse_relu = op_info->GetAttr<bool>("fuse_relu");
auto input_dims = input->dims();
auto filter_dims = filter->dims();
auto output_dims = output->dims();
auto input_data = input->mutable_data<float>();
auto filter_data = filter->mutable_data<float>();
auto output_data = output->mutable_data<float>();
int kernel_w = filter_dims[3];
int kernel_h = filter_dims[2];
int stride_w = strides[1];
int stride_h = strides[0];
int dila_w = dilations[1];
int dila_h = dilations[0];
int pad_w = paddings[2];
int pad_h = paddings[0];
int batch_size = input_dims[0];
int in_ch_size = input_dims[1];
int in_h = input_dims[2];
int in_w = input_dims[3];
int out_ch_size = output_dims[1];
int out_h = output_dims[2];
int out_w = output_dims[3];
int out_c_group = out_ch_size / groups;
int in_c_group = in_ch_size / groups;
Tensor* bias = nullptr;
float* bias_data = nullptr;
bool is_channel_bias = false;
if (op_info->HasInput("Bias")) {
auto bias_var_names = op_info->Input("Bias");
if (bias_var_names.size() > 0) {
auto bias_var_name = bias_var_names.front();
bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
is_channel_bias = bias_dims.production() == out_ch_size;
bias_data = bias->mutable_data<float>();
}
}
for (int n = 0; n < batch_size; ++n) {
for (int g = 0; g < groups; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * groups * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w + oc * out_h * out_w +
oh * out_w + ow;
float out_value =
bias_data != nullptr
? (is_channel_bias ? bias_data[g * out_c_group + oc]
: bias_data[out_idx])
: 0;
// + out_value *= beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int in_idx = n * in_ch_size * in_h * in_w +
g * in_c_group * in_h * in_w + ic * in_h * in_w +
ih * in_w + iw;
int filter_idx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
out_value += input_data[in_idx] * filter_data[filter_idx];
}
}
}
if (fuse_relu) {
out_value = out_value > 0 ? out_value : 0;
}
output_data[out_idx] = out_value;
}
}
}
}
}
}
void test_conv(int bs,
int ic,
int oc,
int ih,
int iw,
bool has_bias,
bool is_channel_bias,
bool fuse_relu,
bool depthwise,
int dilation,
int stride,
int padding,
int kernel) {
// prepare input&output variables
Scope scope;
std::string input_var_name("input");
std::string filter_var_name("filter");
std::string bias_var_name("bias");
std::string output_var_name("output");
std::string output_ref_var_name("output_ref");
auto* input = scope.Var(input_var_name)->GetMutable<Tensor>();
auto* filter = scope.Var(filter_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* output = scope.Var(output_var_name)->GetMutable<Tensor>();
auto* output_ref = scope.Var(output_ref_var_name)->GetMutable<Tensor>();
// get group size and input&filter shape
int groups = 1;
if (depthwise) { // depthwise convolution ?
groups = oc = ic;
}
std::vector<int64_t> input_shape = {bs, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / groups, kernel, kernel};
std::vector<int64_t> output_shape({bs, oc});
for (size_t i = 0; i < 2; i++) {
const int dkernel = dilation * (kernel - 1) + 1;
int output_size = (input_shape[i + 2] + 2 * padding - dkernel) / stride + 1;
output_shape.push_back(output_size);
}
input->Resize(input_shape);
filter->Resize(filter_shape);
// initialize input&output data
FillTensor<float, int>(input);
FillTensor<float, int>(filter);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType(depthwise ? "depthwise_conv2d" : "conv2d");
opdesc.SetInput("Input", {input_var_name});
opdesc.SetInput("Filter", {filter_var_name});
opdesc.SetOutput("Output", {output_var_name});
opdesc.SetAttr("dilations", std::vector<int32_t>({dilation, dilation}));
opdesc.SetAttr("strides", std::vector<int32_t>({stride, stride}));
opdesc.SetAttr("paddings",
std::vector<int32_t>({padding, padding, padding, padding}));
opdesc.SetAttr("groups", groups);
opdesc.SetAttr("fuse_relu", static_cast<bool>(fuse_relu));
if (has_bias) {
if (is_channel_bias) {
bias->Resize({1, oc, 1, 1});
} else {
bias->Resize({output_shape});
}
FillTensor<float, int>(bias);
opdesc.SetInput("Bias", {bias_var_name});
}
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ConvOpLite>(opdesc, &scope);
LauchOp(op, {input_var_name}, {output_var_name});
output_ref->CopyDataFrom(*output);
// execute reference implementation and save to output tensor('out')
conv_ref(op);
// compare results
auto* output_data = output->mutable_data<float>();
auto* output_ref_data = output_ref->mutable_data<float>();
for (int i = 0; i < output->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, conv) {
#if 1
for (auto bs : {1, 2}) {
for (auto ic : {3, 6}) {
for (auto oc : {6, 9}) {
for (auto ih : {14, 28}) {
for (auto iw : {14, 28}) {
for (auto has_bias : {false, true}) {
for (auto is_channel_bias : {false, true}) {
for (auto fuse_relu : {false, true}) {
for (auto depthwise : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto kernel : {1, 3, 5}) {
std::vector<int> paddings = {kernel / 2};
if (kernel / 2 != 0) {
paddings.push_back(0);
}
for (auto padding : paddings) {
VLOG(3) << "bs: " << bs << " ic: " << ic
<< " oc: " << oc << " ih: " << ih
<< " iw: " << iw
<< " has_bias: " << has_bias
<< " is_channel_bias: " << is_channel_bias
<< " fuse_relu: " << fuse_relu
<< " depthwise: " << depthwise
<< " dilation: " << dilation
<< " stride: " << stride
<< " padding: " << padding
<< " kernel: " << kernel;
test_conv(bs,
ic,
oc,
ih,
iw,
has_bias,
is_channel_bias,
fuse_relu,
depthwise,
dilation,
stride,
padding,
kernel);
}
}
}
}
}
}
}
}
}
}
}
}
}
#else
test_conv(1, 3, 6, 14, 14, false, false, false, true, 2, 1, 1, 3);
test_conv(1, 3, 6, 14, 14, false, false, false, true, 2, 1, 0, 3);
test_conv(1, 3, 6, 14, 14, false, false, false, true, 2, 1, 2, 5);
test_conv(1, 3, 6, 14, 14, false, false, false, true, 2, 1, 0, 5);
#endif
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d);
USE_NPU_BRIDGE(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_NPU_BRIDGE(depthwise_conv2d);
......@@ -11,6 +11,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_activation_compute SRCS activation_compute_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv_compute SRCS conv_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class ConvComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "conv2d";
std::string input_ = "input";
std::string filter_ = "filter";
std::string output_ = "output";
DDim dims_;
int out_channels_ = 1;
int ksize_ = 3;
std::vector<int> strides_{1, 1};
std::vector<int> paddings_{0, 0};
int groups_ = 1;
std::vector<int> dilations_{1, 1};
std::string padding_algorithm_;
bool with_bias_ = false;
std::string bias_ = "bias";
bool with_act_ = false;
std::string act_type_;
float leaky_relu_alpha_ = 0.1;
public:
ConvComputeTester(const Place& place,
const std::string& alias,
DDim dims,
int out_channels = 1,
int ksize = 3,
std::vector<int> strides = {1, 1},
std::vector<int> paddings = {0, 0},
int groups = 1,
std::vector<int> dilations = {1, 1},
std::string padding_algorithm = "",
bool with_bias = false,
bool with_act = false,
std::string act_type = "",
float leaky_relu_alpha = 0.1)
: TestCase(place, alias),
dims_(dims),
out_channels_(out_channels),
ksize_(ksize),
strides_(strides),
paddings_(paddings),
groups_(groups),
dilations_(dilations),
padding_algorithm_(padding_algorithm),
with_bias_(with_bias),
with_act_(with_act),
act_type_(act_type),
leaky_relu_alpha_(leaky_relu_alpha) {}
void RunBaseline(Scope* scope) override {
auto* input = scope->FindTensor(input_);
auto* filter = scope->FindTensor(filter_);
auto input_dims = input->dims();
auto filter_dims = filter->dims();
auto* output = scope->NewTensor(output_);
CHECK(output);
if (paddings_.size() == 2L) {
paddings_.insert(paddings_.begin(), paddings_[0]);
paddings_.insert(paddings_.begin() + 2, paddings_[2]);
}
if (padding_algorithm_ == "SAME") {
for (size_t i = 0; i < strides_.size(); ++i) {
int out_size = (input_dims[i + 2] + strides_[i] - 1) / strides_[i];
int pad_sum =
std::max((out_size - 1) * strides_[i] + ksize_ - input_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings_.begin() + i * 2) = pad_0;
*(paddings_.begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations_.begin() + i) = 1;
}
} else if (padding_algorithm_ == "VALID") {
for (auto& it : paddings_) {
it = 0;
}
}
std::vector<int64_t> output_shape({input_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides_.size(); ++i) {
const int dkernel = dilations_[i] * (filter_dims[i + 2] - 1) + 1;
int output_size = (input_dims[i + 2] +
(paddings_[i * 2] + paddings_[i * 2 + 1]) - dkernel) /
strides_[i] +
1;
output_shape.push_back(output_size);
}
output->Resize(DDim(output_shape));
auto output_dims = output->dims();
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto output_data = output->mutable_data<float>();
int kernel_w = filter_dims[3];
int kernel_h = filter_dims[2];
int stride_w = strides_[1];
int stride_h = strides_[0];
int dila_w = dilations_[1];
int dila_h = dilations_[0];
int pad_w = paddings_[2];
int pad_h = paddings_[0];
int batch_size = input_dims[0];
int in_ch_size = input_dims[1];
int in_h = input_dims[2];
int in_w = input_dims[3];
int out_ch_size = output_dims[1];
int out_h = output_dims[2];
int out_w = output_dims[3];
int out_c_group = out_ch_size / groups_;
int in_c_group = in_ch_size / groups_;
const float* bias_data = nullptr;
bool is_channel_bias = true;
if (with_bias_) {
auto bias = scope->FindTensor(bias_);
bias_data = bias->data<float>();
}
for (int n = 0; n < batch_size; ++n) {
for (int g = 0; g < groups_; ++g) {
for (int oc = 0; oc < out_c_group; ++oc) {
for (int oh = 0; oh < out_h; ++oh) {
for (int ow = 0; ow < out_w; ++ow) {
int out_idx = n * groups_ * out_c_group * out_h * out_w +
g * out_c_group * out_h * out_w +
oc * out_h * out_w + oh * out_w + ow;
float out_value =
bias_data != nullptr
? (is_channel_bias ? bias_data[g * out_c_group + oc]
: bias_data[out_idx])
: 0;
// + out_value *= beta;
for (int ic = 0; ic < in_c_group; ++ic) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int iw = ow * stride_w - pad_w + kw * (dila_w);
int ih = oh * stride_h - pad_h + kh * (dila_h);
if (iw < 0 || iw >= in_w) continue;
if (ih < 0 || ih >= in_h) continue;
int in_idx = n * in_ch_size * in_h * in_w +
g * in_c_group * in_h * in_w +
ic * in_h * in_w + ih * in_w + iw;
int filter_idx =
g * out_c_group * in_c_group * kernel_h * kernel_w +
oc * in_c_group * kernel_h * kernel_w +
ic * kernel_h * kernel_w + kh * kernel_w + kw;
out_value += input_data[in_idx] * filter_data[filter_idx];
}
}
}
if (with_act_) {
if (act_type_ == "relu") {
out_value = out_value > 0 ? out_value : 0;
} else if (act_type_ == "leaky_relu") {
out_value =
std::max(out_value, out_value * leaky_relu_alpha_);
} else {
LOG(FATAL) << "unsupported";
}
}
output_data[out_idx] = out_value;
}
}
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("Input", {input_});
op_desc->SetInput("Filter", {filter_});
if (with_bias_) {
op_desc->SetInput("Bias", {bias_});
}
op_desc->SetOutput("Output", {output_});
op_desc->SetAttr("strides", strides_);
op_desc->SetAttr("paddings", paddings_);
op_desc->SetAttr("groups", groups_);
op_desc->SetAttr("dilations", dilations_);
if (!padding_algorithm_.empty()) {
op_desc->SetAttr("padding_algorithm", padding_algorithm_);
}
if (with_act_) {
op_desc->SetAttr("with_act", with_act_);
op_desc->SetAttr("act_type", act_type_);
if (act_type_ == "leaky_relu") {
op_desc->SetAttr("leaky_relu_alpha", leaky_relu_alpha_);
}
}
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
DDim filter_dims(std::vector<int64_t>{
out_channels_, dims_[1] / groups_, ksize_, ksize_});
std::vector<float> dfilter(filter_dims.production());
fill_data_rand(dfilter.data(), -1.f, 1.f, filter_dims.production());
SetCommonTensor(filter_, filter_dims, dfilter.data(), {}, true);
if (with_bias_) {
DDim bias_dims(std::vector<int64_t>{out_channels_});
std::vector<float> dbias(bias_dims.production());
fill_data_rand(din.data(), -1.f, 1.f, bias_dims.production());
SetCommonTensor(bias_, bias_dims, dbias.data(), {}, true);
}
}
};
void TestConvKsize(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 7, 8}, {5, 6, 17, 18}}) {
for (auto out_channels : {1, 3}) {
for (auto ksize : {1, 3, 5, 7}) {
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, ksize));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvGroups(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 6, 3, 4}, {5, 12, 7, 8}}) {
for (auto out_channels : {2, 3, 6}) {
for (auto groups : {2, 3, 6}) {
#ifdef LITE_WITH_NPU
if (out_channels % groups != 0) continue;
#endif
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, 3, {1, 1}, {0, 0}, groups));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvDilations(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 5, 6}, {5, 6, 9, 10}}) {
for (auto out_channels : {1, 3}) {
for (auto dilations : std::vector<std::vector<int>>{{2, 2}, {1, 2}}) {
std::unique_ptr<arena::TestCase> tester(
new ConvComputeTester(place,
"def",
DDim(dims),
out_channels,
3,
{1, 1},
{0, 0},
1,
dilations));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvStrides(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto out_channels : {1, 3}) {
for (auto strides :
std::vector<std::vector<int>>{{2, 2}, {3, 3}, {1, 2}, {3, 1}}) {
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, 3, strides));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvPaddings(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto out_channels : {1, 3}) {
for (auto paddings : std::vector<std::vector<int>>{
{1, 1}, {2, 2}, {1, 0, 0, 1}, {1, 2, 0, 1}}) {
std::unique_ptr<arena::TestCase> tester(new ConvComputeTester(
place, "def", DDim(dims), out_channels, 3, {1, 1}, paddings));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvPaddingAlgorithm(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto out_channels : {1, 3}) {
for (auto padding_algorithm : std::vector<std::string>{"VALID", "SAME"}) {
std::unique_ptr<arena::TestCase> tester(
new ConvComputeTester(place,
"def",
DDim(dims),
out_channels,
3,
{1, 1},
{0, 0},
1,
{1, 1},
padding_algorithm));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
}
void TestConvBias(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto out_channels : {1, 3}) {
std::unique_ptr<arena::TestCase> tester(
new ConvComputeTester(place,
"def",
DDim(dims),
out_channels,
3,
{1, 1},
{0, 0},
1,
{1, 1},
"",
true));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
void TestConvAct(Place place, float abs_error = 2e-5) {
for (auto dims :
std::vector<std::vector<int64_t>>{{1, 2, 3, 4}, {5, 6, 7, 8}}) {
for (auto out_channels : {1, 3}) {
std::unique_ptr<arena::TestCase> tester0(
new ConvComputeTester(place,
"def",
DDim(dims),
out_channels,
3,
{1, 1},
{0, 0},
1,
{1, 1},
"",
false,
true,
"relu"));
arena::Arena arena0(std::move(tester0), place, abs_error);
arena0.TestPrecision();
std::unique_ptr<arena::TestCase> tester1(
new ConvComputeTester(place,
"def",
DDim(dims),
out_channels,
3,
{1, 1},
{0, 0},
1,
{1, 1},
"",
false,
true,
"leaky_relu",
0.1));
arena::Arena arena1(std::move(tester1), place, abs_error);
arena1.TestPrecision();
}
}
}
TEST(Conv2d, precision) {
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 5e-2; // Using fp16 in NPU
#else
return;
#endif
TestConvKsize(place, abs_error);
TestConvGroups(place, abs_error);
TestConvDilations(place, abs_error);
TestConvStrides(place, abs_error);
TestConvPaddings(place, abs_error);
TestConvPaddingAlgorithm(place, abs_error);
TestConvBias(place, abs_error);
TestConvAct(place, abs_error);
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册