From b14e21c30a75f9c3679867f244ed96bbc567d833 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Thu, 16 Jan 2020 13:37:56 +0800 Subject: [PATCH] [NPU] enhance conv_transpose and ut (#2773) --- lite/kernels/npu/bridges/conv_transpose_op.cc | 72 +++- .../npu/bridges/conv_transpose_op_test.cc | 372 ------------------ lite/operators/conv_transpose_op.cc | 31 +- lite/tests/kernels/CMakeLists.txt | 2 + .../kernels/conv_transpose_compute_test.cc | 341 ++++++++++++++++ lite/tests/utils/naive_math_impl.h | 1 - 6 files changed, 403 insertions(+), 416 deletions(-) delete mode 100644 lite/kernels/npu/bridges/conv_transpose_op_test.cc create mode 100644 lite/tests/kernels/conv_transpose_compute_test.cc diff --git a/lite/kernels/npu/bridges/conv_transpose_op.cc b/lite/kernels/npu/bridges/conv_transpose_op.cc index aa27be234c..adade8844b 100644 --- a/lite/kernels/npu/bridges/conv_transpose_op.cc +++ b/lite/kernels/npu/bridges/conv_transpose_op.cc @@ -15,6 +15,7 @@ #include "lite/kernels/npu/bridges/graph.h" #include "lite/kernels/npu/bridges/registry.h" #include "lite/kernels/npu/bridges/utility.h" +#include "lite/operators/conv_op.h" namespace paddle { namespace lite { @@ -38,6 +39,7 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto input = scope->FindMutableTensor(input_name); auto input_dims = input->dims(); CHECK_EQ(input_dims.size(), 4); + auto filter_name = op_info->Input("Filter").front(); auto filter_type = kernel->GetInputDeclType("Filter"); CHECK(filter_type->precision() == PRECISION(kFloat)); @@ -45,18 +47,53 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto filter = scope->FindMutableTensor(filter_name); auto filter_dims = filter->dims(); CHECK_EQ(filter_dims.size(), 4); + auto output_name = op_info->Output("Output").front(); auto output_type = kernel->GetOutputDeclType("Output"); CHECK(output_type->precision() == PRECISION(kFloat)); CHECK(output_type->layout() == DATALAYOUT(kNCHW)); + auto strides = op_info->GetAttr>("strides"); - auto paddings = op_info->GetAttr>("paddings"); + CHECK_EQ(strides.size(), 2L); auto groups = op_info->GetAttr("groups"); - auto dilations = op_info->GetAttr>("dilations"); + if (groups > 1) { + LOG(WARNING) << "[NPU] only support groups == 1"; + return FAILED; + } + auto fuse_relu = op_info->HasAttr("fuse_relu") && op_info->GetAttr("fuse_relu"); - CHECK_EQ(strides.size(), 2L); + std::vector output_size; + if (op_info->HasAttr("output_size")) { + output_size = op_info->GetAttr>("output_size"); + } + + auto paddings = op_info->GetAttr>("paddings"); + auto dilations = op_info->GetAttr>("dilations"); CHECK_EQ(dilations.size(), 2L); + std::string padding_algorithm = + op_info->HasAttr("padding_algorithm") + ? op_info->GetAttr("padding_algorithm") + : ""; + if (paddings.size() == 2L) { + for (size_t i = 0; i < 2L; ++i) { + int copy_pad = *(paddings.begin() + 2 * i); + paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); + } + } + CHECK_EQ(paddings.size(), 4L) + << "[NPU] Paddings size should be the same or twice as the input size."; + operators::UpdatePaddingAndDilation(&paddings, + &dilations, + strides, + padding_algorithm, + input_dims, + filter_dims); + if (paddings[0] != paddings[1] || paddings[2] != paddings[3]) { + LOG(WARNING) << "[NPU] only support \"pad_top == pad_bottom && pad_left == " + "pad_right\" ."; + return FAILED; + } // Input node std::shared_ptr input_node = nullptr; @@ -67,23 +104,23 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // Create input sizes node to describe the dimensions of input tensor - if (paddings.size() == 2L) { - for (size_t i = 0; i < 2L; ++i) { - int copy_pad = *(paddings.begin() + 2 * i); - paddings.insert(paddings.begin() + 2 * i + 1, copy_pad); - } - } - CHECK_EQ(paddings.size(), 4L) - << "[NPU] Paddings size should be the same or twice as the input size."; std::vector input_sizes; input_sizes.push_back(input_dims[0]); input_sizes.push_back(filter_dims[1] * groups); for (int i = 0; i < strides.size(); i++) { int kernel_ext = dilations[i] * (filter_dims[i + 2] - 1) + 1; - int output_size = - (input_dims[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i]; + int output_size = (input_dims[i + 2] - 1) * strides[i] + kernel_ext - + paddings[i * 2] - paddings[i * 2 + 1]; input_sizes.push_back(output_size); } + if (!output_size.empty()) { + CHECK_EQ(output_size.size(), 2L); + if (output_size[0] != input_sizes[2] || output_size[1] != input_sizes[3]) { + LOG(WARNING) << "[NPU] not support output_size: " << output_size[0] + << ", " << output_size[1]; + return FAILED; + } + } auto input_sizes_node = graph->Add(output_name + "/input_sizes", input_sizes); // Filter node @@ -96,8 +133,13 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { conv_transpose_op->set_input_filter(*filter_node->data()); conv_transpose_op->set_input_x(*input_node->data()); // Set attributes - conv_transpose_op->set_attr_format(0); // NCHW - conv_transpose_op->set_attr_pad_mode(0); // NOTSET + conv_transpose_op->set_attr_format(0); // NCHW + // "SAME" is different from paddle + if (padding_algorithm == "VALID") { + conv_transpose_op->set_attr_pad_mode(5); + } else { + conv_transpose_op->set_attr_pad_mode(0); // NOTSET + } conv_transpose_op->set_attr_group(groups); conv_transpose_op->set_attr_pad(ge::AttrValue::LIST_INT( {paddings[0], paddings[1], paddings[2], paddings[3]})); diff --git a/lite/kernels/npu/bridges/conv_transpose_op_test.cc b/lite/kernels/npu/bridges/conv_transpose_op_test.cc deleted file mode 100644 index f96e57c06f..0000000000 --- a/lite/kernels/npu/bridges/conv_transpose_op_test.cc +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/operators/conv_transpose_op.h" -#include -#include -#include "lite/core/op_registry.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/test_helper.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace npu { -namespace bridges { - -template -void add_bias_with_relu(DType* data, - const DType* bias, - int channel_size, - int inner_size, - bool has_relu) { - for (int c = 0; c < channel_size; ++c) { - DType bias_val = bias != nullptr ? bias[c] : 0; - for (int i = 0; i < inner_size; i++) { - DType data_val = data[i]; - data_val += bias_val; - if (has_relu) { - data_val = data_val > 0 ? data_val : 0.f; - } - data[i] = data_val; - } - data += inner_size; - } -} - -template -void col2im(const DType* data_col, - const int channel_size, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - DType* data_im) { - memset(data_im, 0, height * width * channel_size * sizeof(DType)); - const int output_h = - (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int output_w = - (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - const int inner_size = height * width; - for (int c = channel_size; c--; data_im += inner_size) { - for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { - for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_row = -pad_h + kernel_row * dilation_h; - for (int output_rows = output_h; output_rows; output_rows--) { - if (input_row < 0 || input_row >= height) { - data_col += output_w; - } else { - int input_col = -pad_w + kernel_col * dilation_w; - for (int output_col = output_w; output_col; output_col--) { - if (input_col >= 0 && input_col < width) { - data_im[input_row * width + input_col] += *data_col; - } - data_col++; - input_col += stride_w; - } - } - input_row += stride_h; - } - } - } - } -} - -template -void gemm(int M, - int N, - int K, - const IType* A, - const IType* B, - OType* C, - OType alpha, - OType beta, - bool is_trans_A = false, - bool is_trans_B = false) { - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - OType sum = static_cast(0); - for (int k = 0; k < K; ++k) { - IType a; - IType b; - if (is_trans_A) { - a = A[k * M + m]; - } else { - a = A[m * K + k]; - } - if (is_trans_B) { - b = B[n * K + k]; - } else { - b = B[k * N + n]; - } - sum += a * b; - } - C[m * N + n] = alpha * sum + beta * C[m * N + n]; - } - } -} - -template -void conv_transpose_ref( - const std::shared_ptr op) { - Scope* scope = op->scope(); - const OpInfo* op_info = op->op_info(); - auto input = - scope->FindVar(op_info->Input("Input").front())->GetMutable(); - auto filter = - scope->FindVar(op_info->Input("Filter").front())->GetMutable(); - auto output = - scope->FindVar(op_info->Output("Output").front())->GetMutable(); - std::vector strides = - op_info->GetAttr>("strides"); - std::vector paddings = - op_info->GetAttr>("paddings"); - int32_t groups = op_info->GetAttr("groups"); - std::vector dilations = - op_info->GetAttr>("dilations"); - bool fuse_relu = op_info->GetAttr("fuse_relu"); - Tensor* bias = nullptr; - OType* bias_data = nullptr; - if (op_info->HasInput("Bias")) { - auto bias_var_names = op_info->Input("Bias"); - if (bias_var_names.size() > 0) { - auto bias_var_name = bias_var_names.front(); - bias = scope->FindVar(bias_var_name)->GetMutable(); - bias_data = bias->mutable_data(); - } - } - auto input_dims = input->dims(); - auto filter_dims = filter->dims(); - auto output_dims = output->dims(); - auto input_data = input->mutable_data(); - auto filter_data = filter->mutable_data(); - auto output_data = output->mutable_data(); - int kernel_w = filter_dims[3]; - int kernel_h = filter_dims[2]; - int stride_w = strides[1]; - int stride_h = strides[0]; - int dila_w = dilations[1]; - int dila_h = dilations[0]; - int pad_w = paddings[1]; - int pad_h = paddings[0]; - int batch_size = input_dims[0]; - int in_ch_size = input_dims[1]; - int in_h = input_dims[2]; - int in_w = input_dims[3]; - int out_ch_size = output_dims[1]; - int out_h = output_dims[2]; - int out_w = output_dims[3]; - - int M = out_ch_size * kernel_w * kernel_h / groups; - int N = in_h * in_w; - int K = in_ch_size / groups; - - if (in_ch_size != out_ch_size || groups != in_ch_size) { - CHECK_EQ(in_ch_size % groups, 0); - CHECK_EQ(out_ch_size % groups, 0); - } - - auto workspace = std::vector(groups * M * N); - int group_input_size = in_w * in_h * in_ch_size / groups; - int group_output_size = out_w * out_h * out_ch_size / groups; - int group_col_size = M * N; - int group_filter_size = - in_ch_size * out_ch_size * kernel_w * kernel_h / (groups * groups); - bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && - (stride_w == 1) && (pad_w == 1) && (pad_h == 1) && - (dila_w == 1) && (dila_h == 1); - for (int n = 0; n < batch_size; ++n) { - input_data += n * in_ch_size * in_h * in_w; - output_data += n * out_ch_size * out_h * out_w; - auto col_data = workspace.data(); - if (flag_1x1s1p1) { - col_data = output_data; - } - memset(col_data, 0, sizeof(OType) * group_col_size); - for (int g = 0; g < groups; ++g) { - auto input_group_data = input_data + g * group_input_size; - auto filter_group_data = filter_data + g * group_filter_size; - auto col_group_data = col_data + g * group_col_size; - gemm(M, - N, - K, - filter_group_data, - input_group_data, - col_group_data, - static_cast(1), - static_cast(0), - true, - false); - } - if (!flag_1x1s1p1) { - col2im(col_data, - out_ch_size, - out_h, - out_w, - kernel_h, - kernel_w, - pad_h, - pad_w, - stride_h, - stride_w, - dila_h, - dila_w, - output_data); - } - add_bias_with_relu( - output_data, bias_data, out_ch_size, out_w * out_h, fuse_relu); - } -} - -void test_conv_transpose(int bs, - int ic, - int ih, - int iw, - bool has_bias, - bool fuse_relu, - int filters, - int groups, - int dilation, - int stride, - int padding, - int kernel) { - // prepare input&output variables - Scope scope; - std::string input_var_name("input"); - std::string filter_var_name("filter"); - std::string bias_var_name("bias"); - std::string output_var_name("output"); - std::string output_ref_var_name("output_ref"); - auto* input = scope.Var(input_var_name)->GetMutable(); - auto* filter = scope.Var(filter_var_name)->GetMutable(); - auto* bias = scope.Var(bias_var_name)->GetMutable(); - auto* output = scope.Var(output_var_name)->GetMutable(); - auto* output_ref = scope.Var(output_ref_var_name)->GetMutable(); - - // get group size and input&filter shape - std::vector input_shape = {bs, ic, ih, iw}; - std::vector filter_shape = {ic, filters, kernel, kernel}; - input->Resize(input_shape); - filter->Resize(filter_shape); - - // initialize input&output data - FillTensor(input); - FillTensor(filter); - - // initialize op desc - cpp::OpDesc opdesc; - opdesc.SetType("conv2d_transpose"); - opdesc.SetInput("Input", {input_var_name}); - opdesc.SetInput("Filter", {filter_var_name}); - opdesc.SetOutput("Output", {output_var_name}); - opdesc.SetAttr("dilations", std::vector({dilation, dilation})); - opdesc.SetAttr("strides", std::vector({stride, stride})); - opdesc.SetAttr("paddings", - std::vector({padding, padding, padding, padding})); - opdesc.SetAttr("groups", groups); - opdesc.SetAttr("fuse_relu", static_cast(fuse_relu)); - if (has_bias) { - bias->Resize({1, filters * groups, 1, 1}); - FillTensor(bias); - opdesc.SetInput("Bias", {bias_var_name}); - } - - // create and convert op to NPU model, then run it on NPU - auto op = CreateOp(opdesc, &scope); - LauchOp(op, {input_var_name}, {output_var_name}); - output_ref->CopyDataFrom(*output); - - // execute reference implementation and save to output tensor('out') - conv_transpose_ref(op); - - // compare results - auto* output_data = output->mutable_data(); - auto* output_ref_data = output_ref->mutable_data(); - for (int i = 0; i < output->dims().production(); i++) { - VLOG(5) << i; - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5); - } -} - -TEST(NPUBridges, conv_transpose) { -#if 1 - for (auto bs : {1, 2}) { - for (auto ic : {3, 6}) { - for (auto ih : {14, 28}) { - for (auto iw : {14, 28}) { - for (auto has_bias : {false, true}) { - for (auto fuse_relu : {false, true}) { - for (auto filters : {1, 2, 5}) { - for (auto groups : {1 /* , 2, 5*/}) { - for (auto dilation : {1, 2}) { - for (auto stride : {1, 2}) { - for (auto kernel : {1, 3, 5}) { - std::vector paddings = {kernel / 2}; - if (kernel / 2 != 0) { - paddings.push_back(0); - } - for (auto padding : paddings) { - VLOG(3) << "bs: " << bs << " ic: " << ic - << " ih: " << ih << " iw: " << iw - << " has_bias: " << has_bias - << " fuse_relu: " << fuse_relu - << " filters: " << filters - << " groups: " << groups - << " dilation: " << dilation - << " stride: " << stride - << " padding: " << padding - << " kernel: " << kernel; - test_conv_transpose(bs, - ic, - ih, - iw, - has_bias, - fuse_relu, - filters, - groups, - dilation, - stride, - padding, - kernel); - } - } - } - } - } - } - } - } - } - } - } - } -#else - test_conv_transpose(1, 6, 8, 8, false, false, 5, 2, 1, 1, 1, 3); -#endif -} - -} // namespace bridges -} // namespace npu -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_OP(conv2d_transpose); -USE_NPU_BRIDGE(conv2d_transpose); diff --git a/lite/operators/conv_transpose_op.cc b/lite/operators/conv_transpose_op.cc index 94a621491f..a84b975492 100644 --- a/lite/operators/conv_transpose_op.cc +++ b/lite/operators/conv_transpose_op.cc @@ -11,10 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "lite/operators/conv_transpose_op.h" #include #include "lite/core/op_lite.h" #include "lite/core/op_registry.h" +#include "lite/operators/conv_op.h" namespace paddle { namespace lite { @@ -50,34 +52,6 @@ inline int ConvTransposeOutputSize(int input_size, return output_size; } -inline void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilations, - const std::vector& strides, - const std::string padding_algorithm, - const lite::DDim data_dims, - const lite::DDim& ksize) { - // when padding_desc is "VALID" or "SAME" - if (padding_algorithm == "SAME") { - for (size_t i = 0; i < strides.size(); ++i) { - int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i]; - int pad_sum = std::max( - (out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2], - (int64_t)0); - int pad_0 = pad_sum / 2; - int pad_1 = pad_sum - pad_0; - // pad - *(paddings->begin() + i * 2) = pad_0; - *(paddings->begin() + i * 2 + 1) = pad_1; - // dilation - *(dilations->begin() + i) = 1; - } - } else if (padding_algorithm == "VALID") { - for (auto& it : *paddings) { - it = 0; - } - } -} - bool ConvTransposeOpLite::InferShape() const { const auto in_dims = param_.x->dims(); const auto filter_dims = param_.filter->dims(); @@ -169,6 +143,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc, } if (op_desc.HasAttr("fuse_relu")) { param_.fuse_relu = op_desc.GetAttr("fuse_relu"); + param_.activation_param.active_type = lite_api::ActivationType::kRelu; } if (op_desc.HasAttr("output_size")) { param_.output_size = op_desc.GetAttr>("output_size"); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 113f6a8b33..b6acfb45c1 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -1,4 +1,6 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) + lite_cc_test(test_kernel_conv_compute SRCS conv_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_conv_transpose_compute SRCS conv_transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/conv_transpose_compute_test.cc b/lite/tests/kernels/conv_transpose_compute_test.cc new file mode 100644 index 0000000000..584212e2cd --- /dev/null +++ b/lite/tests/kernels/conv_transpose_compute_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" +#include "lite/tests/utils/naive_math_impl.h" + +namespace paddle { +namespace lite { + +class ConvTransposeComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string op_type_ = "conv2d_transpose"; + std::string input_ = "input"; + std::string filter_ = "filter"; + std::string output_ = "output"; + DDim dims_; + + int filter_channels_ = 1; + std::vector ksize_{3, 3}; + std::vector strides_{1, 1}; + std::vector paddings_{0, 0}; + int groups_ = 1; + std::vector dilations_{1, 1}; + std::string padding_algorithm_ = ""; + std::vector output_size_{}; + std::string bias_ = ""; + bool fuse_relu_ = false; + + public: + ConvTransposeComputeTester(const Place& place, + const std::string& alias, + DDim dims, + int filter_channels = 1, + std::vector ksize = {3, 3}, + std::vector strides = {1, 1}, + std::vector paddings = {0, 0}, + int groups = 1, + std::vector dilations = {1, 1}, + std::string padding_algorithm = "", + std::vector output_size = {}, + std::string bias = "", + bool fuse_relu = false) + : TestCase(place, alias), + dims_(dims), + filter_channels_(filter_channels), + ksize_(ksize), + strides_(strides), + paddings_(paddings), + groups_(groups), + dilations_(dilations), + padding_algorithm_(padding_algorithm), + output_size_(output_size), + bias_(bias), + fuse_relu_(fuse_relu) {} + + void RunBaseline(Scope* scope) override { + if (paddings_.size() == 2L) { + paddings_.insert(paddings_.begin(), paddings_[0]); + paddings_.insert(paddings_.begin() + 2, paddings_[2]); + } + CHECK_EQ(paddings_.size(), 4); + + if (padding_algorithm_ == "SAME") { + for (size_t i = 0; i < strides_.size(); ++i) { + int out_size = (dims_[i + 2] + strides_[i] - 1) / strides_[i]; + int pad_sum = + std::max((out_size - 1) * strides_[i] + ksize_[i] - dims_[i + 2], + (int64_t)0); + int pad_0 = pad_sum / 2; + int pad_1 = pad_sum - pad_0; + // pad + paddings_[i * 2] = pad_0; + paddings_[i * 2 + 1] = pad_1; + // dilation + dilations_[i] = 1; + } + } else if (padding_algorithm_ == "VALID") { + for (auto& it : paddings_) { + it = 0; + } + } + + std::vector output_shape{dims_[0], filter_channels_ * groups_}; + for (size_t i = 0; i < strides_.size(); ++i) { + const int dkernel = dilations_[i] * (ksize_[i] - 1) + 1; + int output_size = (dims_[i + 2] - 1) * strides_[i] - paddings_[i * 2] - + paddings_[i * 2 + 1] + dkernel; + output_shape.push_back(output_size); + } + + if (!output_size_.empty()) { + for (size_t i = 0; i < output_size_.size(); ++i) { + output_shape[i + 2] = output_size_[i]; + } + } + auto output = scope->NewTensor(output_); + output->Resize(output_shape); + + const Tensor* input = scope->FindTensor(input_); + const Tensor* filter = scope->FindTensor(filter_); + const Tensor* bias = scope->FindTensor(bias_); + auto input_dims = input->dims(); + auto filter_dims = filter->dims(); + auto output_dims = output->dims(); + auto input_data = input->data(); + auto filter_data = filter->data(); + auto output_data = output->mutable_data(); + + bool flag_bias = bias != nullptr; + const float* bias_data = flag_bias ? bias->data() : nullptr; + deconv_basic(input_data, + output_data, + input_dims[0], + output_dims[1], + output_dims[2], + output_dims[3], + input_dims[1], + input_dims[2], + input_dims[3], + filter_data, + bias_data, + groups_, + filter_dims[3], + filter_dims[2], + strides_[1], + strides_[0], + dilations_[1], + dilations_[0], + paddings_[2], + paddings_[3], + paddings_[0], + paddings_[1], + flag_bias, + fuse_relu_); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(op_type_); + op_desc->SetInput("Input", {input_}); + op_desc->SetInput("Filter", {filter_}); + if (!bias_.empty()) { + op_desc->SetInput("Bias", {bias_}); + } + op_desc->SetOutput("Output", {output_}); + op_desc->SetAttr("strides", strides_); + op_desc->SetAttr("paddings", paddings_); + op_desc->SetAttr("groups", groups_); + op_desc->SetAttr("dilations", dilations_); + if (!padding_algorithm_.empty()) { + op_desc->SetAttr("padding_algorithm", padding_algorithm_); + } + if (!output_size_.empty()) { + op_desc->SetAttr("output_size", output_size_); + } + op_desc->SetAttr("fuse_relu", fuse_relu_); + } + + void PrepareData() override { + std::vector din(dims_.production()); + fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); + SetCommonTensor(input_, dims_, din.data()); + + DDim filter_dims( + std::vector{dims_[1], filter_channels_, ksize_[0], ksize_[1]}); + std::vector dfilter(filter_dims.production()); + fill_data_rand(dfilter.data(), -1.f, 1.f, filter_dims.production()); + SetCommonTensor(filter_, filter_dims, dfilter.data(), {}, true); + + if (!bias_.empty()) { + DDim bias_dims(std::vector{filter_channels_ * groups_}); + std::vector dbias(bias_dims.production()); + fill_data_rand(din.data(), -1.f, 1.f, bias_dims.production()); + SetCommonTensor(bias_, bias_dims, dbias.data(), {}, true); + } + } +}; + +void TestConvTransposeKsize(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto filter_channels : {1, 3}) { + for (auto ksize : + std::vector>{{1, 1}, {2, 2}, {3, 3}, {2, 3}}) { + std::unique_ptr tester(new ConvTransposeComputeTester( + place, "def", DDim(dims), filter_channels, ksize)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } + } +} + +void TestConvTransposeStrides(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto strides : std::vector>{{2, 2}, {3, 3}, {1, 2}}) { + std::unique_ptr tester(new ConvTransposeComputeTester( + place, "def", DDim(dims), 3, {3, 3}, strides)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposePaddings(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto paddings : std::vector>{ + {1, 1}, {2, 2}, {0, 1}, {1, 0, 0, 1}, {1, 2, 0, 1}}) { + std::unique_ptr tester(new ConvTransposeComputeTester( + place, "def", DDim(dims), 3, {3, 3}, {1, 1}, paddings)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposeGroups(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto groups : {2, 3, 6}) { + std::unique_ptr tester(new ConvTransposeComputeTester( + place, "def", DDim(dims), 12, {3, 3}, {1, 1}, {0, 0}, groups)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposeDilations(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto dilations : std::vector>{{2, 2}, {1, 2}}) { + std::unique_ptr tester(new ConvTransposeComputeTester( + place, "def", DDim(dims), 3, {3, 3}, {1, 1}, {0, 0}, 1, dilations)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposePaddingAlgorithm(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto padding_algorithm : std::vector{"SAME", "VALID"}) { + std::unique_ptr tester( + new ConvTransposeComputeTester(place, + "def", + DDim(dims), + 3, + {3, 3}, + {2, 2}, + {0, 0}, + 1, + {1, 1}, + padding_algorithm)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposeOutputSize(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 12, 12}}) { + for (auto output_size : std::vector>{{25, 26}, {26, 26}}) { + std::unique_ptr tester( + new ConvTransposeComputeTester(place, + "def", + DDim(dims), + 3, + {3, 3}, + {2, 2}, + {0, 0}, + 1, + {1, 1}, + "", + output_size)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } +} + +void TestConvTransposeBiasRelu(Place place, float abs_error = 2e-5) { + for (auto dims : std::vector>{{5, 6, 11, 12}}) { + for (auto bias : std::vector{"", "bias"}) { + for (bool fuse_relu : {true, false}) { + if (bias.empty() && fuse_relu) continue; + std::unique_ptr tester( + new ConvTransposeComputeTester(place, + "def", + DDim(dims), + 3, + {3, 3}, + {1, 1}, + {0, 0}, + 1, + {1, 1}, + "", + {}, + bias, + fuse_relu)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } + } +} + +TEST(Conv_transpose, precision) { + float abs_error = 2e-5; + Place place; +#if defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 5e-2; // Using fp16 in NPU +#else + return; +#endif + + TestConvTransposeKsize(place, abs_error); + TestConvTransposeStrides(place, abs_error); + TestConvTransposePaddings(place, abs_error); + TestConvTransposeGroups(place, abs_error); + TestConvTransposeDilations(place, abs_error); + TestConvTransposePaddingAlgorithm(place, abs_error); + TestConvTransposeOutputSize(place, abs_error); + TestConvTransposeBiasRelu(place, abs_error); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index fd868e85ac..91e398c5a9 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -407,7 +407,6 @@ void deconv_basic(const Dtype1* din, int k = chin / group; int group_size_in = win * hin * chin / group; - int group_size_out = wout * hout * chout / group; int group_size_coldata = m * n; int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && -- GitLab