From a6b1e4fa1246dcb31515b128bb121a3b61964b5d Mon Sep 17 00:00:00 2001 From: Xiaoyang LI Date: Sat, 12 Oct 2019 10:33:30 +0800 Subject: [PATCH] fix conv_transpose error (#2165) * fix conv_transpose error * fix build error, enable basic test of conv_transpose, test=develop --- cmake/generic.cmake | 2 +- lite/kernels/arm/conv_transpose_compute.cc | 7 +- lite/tests/math/CMakeLists.txt | 1 + .../tests/math/conv_transpose_compute_test.cc | 333 ++++++++++++++++++ lite/tests/utils/naive_math_impl.h | 173 +++++++++ 5 files changed, 512 insertions(+), 4 deletions(-) create mode 100644 lite/tests/math/conv_transpose_compute_test.cc diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 3d05bad64b..f1b73f1d88 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -306,7 +306,7 @@ function(cc_library TARGET_NAME) if(${source_file} MATCHES "__generated_code__.cc") list(APPEND full_path_src ${source_file}) else() - if(NOT ${source_file} MATCHES "framework.pb.cc") + if(NOT ${source_file} MATCHES "framework.pb.cc" AND NOT ${source_file} MATCHES "__generated_code__.cc") list(APPEND full_path_src ${CMAKE_CURRENT_SOURCE_DIR}/${source_file}) endif() endif() diff --git a/lite/kernels/arm/conv_transpose_compute.cc b/lite/kernels/arm/conv_transpose_compute.cc index 06392ae1e2..9fca00ad6b 100644 --- a/lite/kernels/arm/conv_transpose_compute.cc +++ b/lite/kernels/arm/conv_transpose_compute.cc @@ -49,10 +49,11 @@ void Conv2DTransposeCompute::PrepareForRun() { lite::Tensor tmp_weights; lite::arm::math::prepackA( - &tmp_weights, *(param.filter), 1., m, k, group, true, &ctx); + &tmp_weights, *(param.filter), 1.f, m, k, group, true, &ctx); param.filter->Resize(tmp_weights.dims()); param.filter->CopyDataFrom(tmp_weights); param.filter->Resize(w_dims); + is_first_epoch_ = false; } void Conv2DTransposeCompute::Run() { @@ -96,7 +97,7 @@ void Conv2DTransposeCompute::Run() { const float* din_batch = din + i * chin * hin * win; float* dout_batch = dout + i * chout * hout * wout; float* col_data = static_cast(ctx.workspace_data()) + - ctx.l2_cache_size() / sizeof(float); + ctx.llc_size() / sizeof(float); if (flag_1x1s1p1) { col_data = dout_batch; } @@ -112,7 +113,7 @@ void Conv2DTransposeCompute::Run() { weights_group, din_group, n, - 0., + 0.f, coldata_group, n, nullptr, diff --git a/lite/tests/math/CMakeLists.txt b/lite/tests/math/CMakeLists.txt index 67cb9576bd..55fcfc005c 100644 --- a/lite/tests/math/CMakeLists.txt +++ b/lite/tests/math/CMakeLists.txt @@ -2,5 +2,6 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(conv_transpose_compute_test SRCS conv_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_int8_compute_test SRCS conv_int8_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/math/conv_transpose_compute_test.cc b/lite/tests/math/conv_transpose_compute_test.cc new file mode 100644 index 0000000000..3a1bbac04b --- /dev/null +++ b/lite/tests/math/conv_transpose_compute_test.cc @@ -0,0 +1,333 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/core/context.h" +#include "lite/operators/op_params.h" +#include "lite/tests/utils/naive_math_impl.h" +#include "lite/tests/utils/tensor_utils.h" +#include "lite/tests/utils/timer.h" + +#ifdef LITE_WITH_ARM +#include "lite/kernels/arm/conv_transpose_compute.h" +#endif // LITE_WITH_ARM + +DEFINE_int32(cluster, 3, "cluster id"); +DEFINE_int32(threads, 1, "threads num"); +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); +DEFINE_bool(basic_test, false, "do all tests"); +DEFINE_bool(check_result, true, "check the result"); + +DEFINE_int32(batch, 1, "batch size"); +DEFINE_int32(in_channel, 32, "input channel"); +DEFINE_int32(in_height, 32, "input height"); +DEFINE_int32(in_width, 32, "input width"); + +DEFINE_int32(out_channel, 64, "output channel"); +DEFINE_int32(group, 1, "group"); +DEFINE_int32(kernel_h, 2, "kernel height"); +DEFINE_int32(kernel_w, 2, "kernel width"); +DEFINE_int32(pad_h, 0, "pad height"); +DEFINE_int32(pad_w, 0, "pad width"); +DEFINE_int32(stride_h, 2, "stride height"); +DEFINE_int32(stride_w, 2, "stride width"); +DEFINE_int32(dila_h, 1, "dilation height"); +DEFINE_int32(dila_w, 1, "dilation width"); + +DEFINE_bool(flag_relu, false, "do relu"); +DEFINE_bool(flag_bias, false, "with bias"); + +typedef paddle::lite::DDim DDim; +typedef paddle::lite::Tensor Tensor; +typedef paddle::lite::operators::ConvParam ConvParam; + +DDim compute_out_dim(const DDim& dim_in, + const paddle::lite::operators::ConvParam& param) { + auto filter_dims = param.filter->dims(); + DDim output_shape = dim_in; + output_shape[1] = filter_dims[1] * param.groups; + for (int i = 0; i < 2; i++) { + int kernel_extent = param.dilations[i] * (filter_dims[i + 2] - 1) + 1; + int output_len = (dim_in[i + 2] - 1) * param.strides[i] + kernel_extent - + 2 * param.paddings[i]; + output_shape[i + 2] = output_len; + } + return output_shape; +} + +#ifdef LITE_WITH_ARM +void test_conv_transpose_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& cluster_id) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + ConvParam param; + param.x = new Tensor; + param.x->set_precision(PRECISION(kFloat)); + param.filter = new Tensor; + param.filter->Resize(weight_dim); + param.filter->set_precision(PRECISION(kFloat)); + if (flag_bias) { + param.bias = new Tensor; + param.bias->Resize({weight_dim[0]}); + param.bias->set_precision(PRECISION(kFloat)); + } + param.strides = strides; + param.paddings = pads; + param.dilations = dilas; + param.fuse_relu = flag_relu; + param.groups = group; + + param.output = new Tensor; + param.output->set_precision(PRECISION(kFloat)); + + // paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.filter, 1.f); + if (flag_bias) { + // paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.bias, 1.f); + } + Tensor tmp_weights; + tmp_weights.Resize(weight_dim); + tmp_weights.CopyDataFrom(*param.filter); + auto wptr = tmp_weights.data(); + auto bias_ptr = flag_bias ? param.bias->data() : nullptr; + + for (auto& cls : cluster_id) { + for (auto& th : thread_num) { + paddle::lite::kernels::arm::Conv2DTransposeCompute conv_t; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), th); + /// set param and context + for (auto& dim_in : input_dims) { + param.x->Resize(dim_in); + DDim out_tmp_dims = compute_out_dim(dim_in, param); + if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { + continue; + } + param.output->Resize(out_tmp_dims); + break; + } + conv_t.SetParam(param); + conv_t.SetContext(std::move(ctx1)); + /// prepare for run + conv_t.PrepareForRun(); + + for (auto& dim_in : input_dims) { + CHECK_EQ(weight_dim[0], dim_in[1]) + << "input channel must equal to weights channel"; + DDim dim_out = compute_out_dim(dim_in, param); + if (dim_out[2] < 1 || dim_out[3] < 1) { + continue; + } + param.x->Resize(dim_in); + param.output->Resize(dim_out); + + // paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); + paddle::lite::fill_tensor_const(*param.x, 1.f); + auto din = param.x->data(); + + Tensor tout_basic; + if (FLAGS_check_result) { + tout_basic.set_precision(PRECISION(kFloat)); + tout_basic.Resize(dim_out); + fill_tensor_const(tout_basic, 0.f); + auto dout_basic = tout_basic.mutable_data(); + + deconv_basic(din, + dout_basic, + dim_in[0], + dim_out[1], + dim_out[2], + dim_out[3], + dim_in[1], + dim_in[2], + dim_in[3], + wptr, + bias_ptr, + group, + weight_dim[3], + weight_dim[2], + strides[1], + strides[0], + dilas[1], + dilas[0], + pads[1], + pads[0], + flag_bias, + flag_relu); + } + /// warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + conv_t.Launch(); + } + /// compute + lite::test::Timer t0; + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.start(); + conv_t.Launch(); + t0.end(); + } + + float gops = + 2.f * tmp_weights.numel() * dim_in[0] * dim_in[2] * dim_in[3]; + LOG(INFO) << "conv fp32: input shape: " << dim_in << ", output shape" + << dim_out << ",running time, avg: " << t0.get_average_ms() + << ", min time: " << t0.get_min_time() + << ", total GOPS: " << 1e-9 * gops + << " GOPS, avg GOPs: " << 1e-6 * gops / t0.get_average_ms() + << " GOPs, max GOPs: " << 1e-6 * gops / t0.get_min_time(); + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tout_basic, *param.output, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-3f) { + if (max_diff > 5e-4f) { + LOG(WARNING) << "basic result"; + print_tensor(tout_basic); + LOG(WARNING) << "saber result"; + print_tensor(*param.output); + Tensor tdiff; + tdiff.Resize(tout_basic.dims()); + tdiff.set_precision(PRECISION(kFloat)); + tensor_diff(tout_basic, *param.output, tdiff); + print_tensor(tdiff); + LOG(FATAL) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out + << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", cluster: " << cls + << " failed!!\n"; + } + } + } + LOG(INFO) << "test fp32 conv: input: " << dim_in + << ", output: " << dim_out << ", weight dim: " << weight_dim + << ", pad: " << pads[0] << ", " << pads[1] + << ", stride: " << strides[0] << ", " << strides[1] + << ", dila_: " << dilas[0] << ", " << dilas[1] + << ", bias: " << (flag_bias ? "true" : "false") + << ", relu: " << (flag_relu ? "true" : "false") + << ", threads: " << th << ", cluster: " << cls + << " successed!!\n"; + } + } + } + + delete param.x; + delete param.filter; + delete param.output; + delete param.bias; +} +#else +void test_conv_transpose_fp32(const std::vector& input_dims, + const DDim& weight_dim, + int group, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilas, + bool flag_bias, + bool flag_relu, + const std::vector& thread_num, + const std::vector& cluster_id) {} +#endif // LITE_WITH_ARM + +#if 1 /// random param conv +TEST(TestConvRand, test_conv_transpose_rand) { + if (FLAGS_basic_test) { + for (auto& cin : {1, 3, 8, 16}) { + for (auto& cout : {1, 5, 8, 16}) { + for (auto& g : {1, 2}) { + for (auto& kw : {1, 2, 3}) { + for (auto& kh : {1, 2, 3}) { + for (auto& stride : {1, 2}) { + for (auto& pad : {0, 1, 2}) { + for (auto& dila : {1, 2}) { + for (auto& flag_bias : {false, true}) { + for (auto& flag_relu : {false, true}) { + if (cin % g != 0 || cout % g != 0) { + continue; + } + std::vector dims; + DDim weights_dim({cin, cout / g, kh, kw}); + for (auto& batch : {1, 2}) { + for (auto& h : {1, 3, 19, 32, 28}) { + dims.push_back(DDim({batch, cin, h, h})); + } + } + test_conv_transpose_fp32(dims, + weights_dim, + g, + {stride, stride}, + {pad, pad}, + {dila, dila}, + flag_bias, + flag_relu, + {1, 2, 4}, + {FLAGS_cluster}); + } + } + } + } + } + } + } + } + } + } + } +} +#endif /// random param conv + +#if 1 /// custom +TEST(TestConvCustom, test_conv_transpose_fp32_custom_size) { + CHECK_EQ(FLAGS_in_channel % FLAGS_group, 0) + << "input channel must be divided by group"; + CHECK_EQ(FLAGS_out_channel % FLAGS_group, 0) + << "num_output must be divided by group"; + test_conv_transpose_fp32( + {DDim({FLAGS_batch, FLAGS_in_channel, FLAGS_in_height, FLAGS_in_width})}, + DDim({FLAGS_in_channel, + FLAGS_out_channel / FLAGS_group, + FLAGS_kernel_h, + FLAGS_kernel_w}), + FLAGS_group, + {FLAGS_stride_h, FLAGS_stride_w}, + {FLAGS_pad_h, FLAGS_pad_w}, + {FLAGS_dila_h, FLAGS_dila_w}, + FLAGS_flag_bias, + FLAGS_flag_relu, + {FLAGS_threads}, + {FLAGS_cluster}); +} +#endif // custom diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index accbb0eead..846126ac24 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -189,3 +189,176 @@ static void conv_basic(const Dtype1* din, } } } + +template +static void fill_bias_relu(Dtype* tensor, + const Dtype* bias, + int channel, + int channel_size, + bool flag_bias, + bool flag_relu) { + Dtype* data = tensor; + for (int j = 0; j < channel; ++j) { + Dtype bias_c = flag_bias ? bias[j] : 0; + for (int i = 0; i < channel_size; i++) { + data[i] += bias_c; + if (flag_relu) { + data[i] = data[i] > 0 ? data[i] : 0.f; + } + } + data += channel_size; + } +} + +template +static void do_relu(Dtype* tensor, int size) { + for (int j = 0; j < size; ++j) { + tensor[j] = tensor[j] > 0 ? tensor[j] : (Dtype)0; + } +} + +inline bool is_a_ge_zero_and_a_lt_b(int a, int b) { + return static_cast(a) < static_cast(b); +} + +template +static void col2im(const Dtype* data_col, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + Dtype* data_im) { + memset(data_im, 0, height * width * channels * sizeof(Dtype)); + const int output_h = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int output_w = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int channel_size = height * width; + + for (int channel = channels; channel--; data_im += channel_size) { + for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_row = -pad_h + kernel_row * dilation_h; + + for (int output_rows = output_h; output_rows; output_rows--) { + if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { + data_col += output_w; + } else { + int input_col = -pad_w + kernel_col * dilation_w; + + for (int output_col = output_w; output_col; output_col--) { + if (is_a_ge_zero_and_a_lt_b(input_col, width)) { + data_im[input_row * width + input_col] += *data_col; + } + data_col++; + input_col += stride_w; + } + } + input_row += stride_h; + } + } + } + } +} + +//! for float, dtype1 and type2 is float +//! for int8, dytpe1 is char, dtype2 is int +template +void deconv_basic(const Dtype1* din, + Dtype2* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const Dtype1* weights, + const Dtype2* bias, + int group, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int dila_w, + int dila_h, + int pad_w, + int pad_h, + bool flag_bias, + bool flag_relu) { + int m = chout * kernel_w * kernel_h / group; + int n = hin * win; + int k = chin / group; + + int group_size_in = win * hin * chin / group; + int group_size_out = wout * hout * chout / group; + int group_size_coldata = m * n; + int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); + bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && + (stride_w == 1) && (pad_w == 1) && (pad_h == 1) && + (dila_w == 1) && (dila_h == 1); + + Dtype2* workspace_ptr = + static_cast(malloc(sizeof(float) * m * n * group)); + + for (int i = 0; i < num; ++i) { + const Dtype1* din_batch = din + i * chin * hin * win; + Dtype2* dout_batch = dout + i * chout * hout * wout; + + Dtype2* col_data = workspace_ptr; + if (flag_1x1s1p1) { + col_data = dout_batch; + } + memset(col_data, 0, sizeof(Dtype2) * group_size_coldata); + for (int g = 0; g < group; ++g) { + const Dtype1* din_group = din_batch + g * group_size_in; + const Dtype1* weights_group = weights + g * group_size_weights; + Dtype2* coldata_group = col_data + g * group_size_coldata; + basic_gemm(true, + false, + m, + n, + k, + 1, + weights_group, + m, + din_group, + n, + 0, + coldata_group, + n, + nullptr, + false, + (!flag_bias && flag_relu)); + } + + if (!flag_1x1s1p1) { + col2im(col_data, + chout, + hout, + wout, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dila_h, + dila_w, + dout_batch); + } + //! add bias + if (flag_bias) { + fill_bias_relu( + dout_batch, bias, chout, wout * hout, flag_bias, flag_relu); + } + } + free(workspace_ptr); +} -- GitLab