From 79618219010cf5f6bb92a8bc126c887dd317e511 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Mon, 4 Mar 2019 00:05:12 +0800 Subject: [PATCH] Enable using optimization implementation for conv_add_relu op --- .../arm/convolution/conv_add_relu_kernel.cpp | 41 ++++- .../central-arm-func/conv_add_relu_arm_func.h | 153 ------------------ .../kernel/central-arm-func/conv_arm_func.h | 94 +++++++++++ src/operators/math/conv_func.h | 52 ++++++ src/operators/math/gemm/executor.h | 30 ++-- 5 files changed, 201 insertions(+), 169 deletions(-) delete mode 100644 src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index e318a866a3..054cfd4c45 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -15,21 +15,58 @@ limitations under the License. */ #ifdef FUSION_CONVADDRELU_OP #include "operators/kernel/conv_add_relu_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { template <> bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { + InitBaseConvKernel(param); return true; } template <> void ConvAddReluKernel::Compute( const FusionConvAddReluParam ¶m) { - ConvAddReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + param.Bias(), true, true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.Bias(), true, true); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + param.Bias(), true, true); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); + break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); + break; +#endif // __aarch64__ + case ConvParam::EXEC_GEMM_FLOAT: + ConvAddReluBasic>(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } } + template class ConvAddReluKernel; } // namespace operators diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h deleted file mode 100644 index 04a84fc976..0000000000 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDRELU_OP - -#pragma once -#include -#include -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void ConvAddReluBasic(const FusionConvAddReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - int32_t axis = param.Axis(); - Otype *bias_data = bias.data(); - Tensor *output = param.Output(); - output->mutable_data(); - - float alpha = 1.0f; - float beta = 1.0f; - int32_t groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int32_t batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int32_t in_step = static_cast(input->dims()[1]) / groups; - int32_t out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int32_t i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int32_t g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - - math::MatMul(filter_slice, false, col_matrix, false, alpha, - &out_slice, beta, true, bias_data); - } - } -} - -template -void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { - param.Output()->mutable_data(); - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConv3x3(param.Input(), param.Strides(), - // param.Paddings(), - // param.Filter(), param.Bias(), - // param.Output(), false); - if (param.Paddings()[0] == 0) { - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - param.Bias(), true, true); - } else { - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), param.Bias(), true, true); - } - } else { - ConvAddReluBasic(param); - } -} -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 6f37a0b711..7170c3ff4d 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -212,6 +212,100 @@ inline void DepthwiseConv5x5(const ConvParam ¶m) { } #endif // __aarch64__ +template +void ConvAddReluBasic(const ParamType ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor bias = *param.Bias(); + + Tensor *output = param.Output(); + output->mutable_data(); + + float alpha = 1.0f; + float beta = 1.0f; + int32_t groups = param.Groups(); + int32_t axis = param.Axis(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int32_t batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int32_t in_step = static_cast(input->dims()[1]) / groups; + int32_t out_step = static_cast(output->dims()[1]) / groups; + + float *bias_data = bias.data(); + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int32_t i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int32_t g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col_matrix = in_slice; + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + math::MatMul(filter_slice, false, col_matrix, false, alpha, + &out_slice, beta, true, bias_data); + } + } +} + template void ConvBNReluBasic(const ParamType ¶m) { const Tensor *input = param.Input(); diff --git a/src/operators/math/conv_func.h b/src/operators/math/conv_func.h index 40320dedac..4debd2e105 100644 --- a/src/operators/math/conv_func.h +++ b/src/operators/math/conv_func.h @@ -99,6 +99,58 @@ inline bool IsExpand(const std::vector &filter_dim, return !(filter_1 && strides_1 && padding_0 && dilation_1); } +template +void AddChannelWise(const framework::Tensor *input, + const framework::Tensor *bias, framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + size_t spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vaddq_f32(__bias, in0); + in1 = vaddq_f32(__bias, in1); + in2 = vaddq_f32(__bias, in2); + in3 = vaddq_f32(__bias, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vaddq_f32(__bias, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active((*x) + beta); + } + } + } +} + template void ScaleAddChannelWise(const framework::Tensor *input, const framework::Tensor *scale, diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h index 9dcf808019..cf7c5687c8 100644 --- a/src/operators/math/gemm/executor.h +++ b/src/operators/math/gemm/executor.h @@ -61,7 +61,7 @@ class GemmExecutor : public Executor { K_(K) { unsigned int L1_size = info->L1_cache; unsigned int L2_size = info->L2_cache; - // if (N_ > 10000) L1_size *= 2; + if (N_ > 30000 && K_ > 100) L1_size *= 2; if (num_threads_ >= 2) L1_size /= 2; rhs_tile_num_ = L1_size / (K * sizeof(Itype)); @@ -74,8 +74,8 @@ class GemmExecutor : public Executor { rhs_tile_num_ *= Strategy::out_width(); } - // lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) * - // Strategy::out_height(); + // lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) * + // Strategy::out_height(); lhs_tile_num_ = L2_size / (K * sizeof(Itype)); if (lhs_tile_num_ == 0) { lhs_tile_num_ = Strategy::out_height(); @@ -90,8 +90,8 @@ class GemmExecutor : public Executor { void operator()(const float alpha, const Itype *A, const int lda, const Itype *B, const int ldb, const float beta, Otype *C, const int ldc) { - // struct timeval tv_begin, tv_end; - // gettimeofday(&tv_begin,NULL); + // struct timeval tv_begin, tv_end; + // gettimeofday(&tv_begin,NULL); int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); lhs_worksize_ = sizeof(Itype) * mblock * K_; @@ -107,9 +107,10 @@ class GemmExecutor : public Executor { strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); - // std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ << - // std::endl; std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) << - // std::endl; + // std::cout << "M: " << M_ << ", N: " << N_ + // << ", K: " << K_ << std::endl; + // std::cout << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) + // << std::endl; #pragma omp parallel for if (N_ > 128) for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { @@ -145,11 +146,12 @@ class GemmExecutor : public Executor { paddle_mobile::memory::Free(rhs_workspace_); paddle_mobile::memory::Free(out_workspace_); - // gettimeofday(&tv_end,NULL); - // float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f + - // (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; std::cout << "elapsed: " - // << elapsed << "ms, speed: " << (M_ * N_ * K_ / 1000.f / 1000.f) / - // elapsed << " gflops" << std::endl; + // gettimeofday(&tv_end,NULL); + // float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f + + // (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; + // std::cout << "elapsed: " << elapsed << "ms, speed: " + // << (M_ * N_ * K_ / 1000.f / 1000.f) / elapsed + // << " gflops" << std::endl; } virtual ~GemmExecutor() {} @@ -189,7 +191,7 @@ class GemvExecutor : public Executor { void operator()(const float alpha, const Itype *A, const int lda, const Itype *B, const float beta, Otype *C) { - // strategy_.kernel(); + // strategy_.kernel(); } virtual ~GemvExecutor() {} -- GitLab