From e1b78a45fbc919cde5197309567033f7fad2ef60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Fri, 28 Sep 2018 20:00:32 +0800 Subject: [PATCH] Optimize winograd memory use --- mace/kernels/conv_2d.h | 107 ++++++++++++++------------------- mace/kernels/opencl/conv_2d.cc | 5 +- mace/ops/conv_2d.h | 4 +- 3 files changed, 47 insertions(+), 69 deletions(-) diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index e7b8e633..c96b70ef 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -73,8 +73,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const bool is_filter_transformed) + const float relux_max_limit) : Conv2dFunctorBase(context, strides, padding_type, @@ -82,8 +81,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { dilations, activation, relux_max_limit), - transformed_filter_(GetCPUAllocator(), DataType::DT_FLOAT), - is_filter_transformed_(is_filter_transformed) {} + is_filter_transformed_(false) {} void Conv2dGeneral(const float *input, const float *filter, @@ -270,14 +268,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { MACE_CHECK_NOTNULL(output); std::vector filter_shape(4); - if (is_filter_transformed_) { - // TOC -> OIHW - filter_shape[0] = filter->dim(1); - filter_shape[1] = filter->dim(2); - filter_shape[2] = filter_shape[3] = 3; - } else { - filter_shape = filter->shape(); - } + filter_shape = filter->shape(); std::vector output_shape(4); std::vector paddings(2); @@ -349,9 +340,9 @@ struct Conv2dFunctor : Conv2dFunctorBase { std::function conv_func; bool - use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3 + use_winograd = filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1 - && input_channels >= 8 && channels >= 8); + && input_channels >= 8 && channels >= 8; bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3 @@ -452,15 +443,15 @@ struct Conv2dFunctor : Conv2dFunctorBase { index_t padded_output_size = 0; if (use_winograd) { transformed_input_size = - std::accumulate(transformed_input_shape.begin(), - transformed_input_shape.end(), - 1, - std::multiplies()) * sizeof(float); + std::accumulate(transformed_input_shape.begin(), + transformed_input_shape.end(), + 1, + std::multiplies()) * sizeof(float); transformed_output_size = - std::accumulate(transformed_output_shape.begin(), - transformed_output_shape.end(), - 1, - std::multiplies()) * sizeof(float); + std::accumulate(transformed_output_shape.begin(), + transformed_output_shape.end(), + 1, + std::multiplies()) * sizeof(float); total_scratch_size += transformed_input_size + transformed_output_size; } if (extra_input_height != input_height @@ -477,14 +468,13 @@ struct Conv2dFunctor : Conv2dFunctorBase { * sizeof(float); total_scratch_size += padded_output_size; } - // scratch for sgemm + // scratch for sgemm, preoccupy enough buffer if (use_neon_1x1_s1) { - total_scratch_size += - (input_batch * input_height * input_width - * (input_channels + channels)) * sizeof(float); + total_scratch_size += (input_batch * input_height * input_width + * (input_channels + channels)) + * sizeof(float); } else if (use_winograd) { - total_scratch_size += - (transformed_input_size + transformed_output_size) * sizeof(float); + total_scratch_size += transformed_input_size + transformed_output_size; } // Init scratch buffer @@ -506,36 +496,34 @@ struct Conv2dFunctor : Conv2dFunctorBase { MACE_UNUSED(extra_input_shape); MACE_UNUSED(extra_output_shape); + Tensor transformed_filter; + // decide which convolution function to call if (use_winograd) { transformed_input.Reshape(transformed_input_shape); transformed_output.Reshape(transformed_output_shape); - const float *transformed_filter_ptr; - if (transformed_filter_.dim_size() == 0) { - if (is_filter_transformed_) { - transformed_filter_ptr = filter_data; - } else { - MACE_RETURN_IF_ERROR(transformed_filter_.Resize( - transformed_filter_shape)); - switch (winograd_out_tile_size) { - case 2: - TransformFilter4x4(filter_data, - filter_shape[1], - filter_shape[0], - transformed_filter_.mutable_data()); - break; - case 6: - TransformFilter8x8(filter_data, - filter_shape[1], - filter_shape[0], - transformed_filter_.mutable_data()); - break; - default:MACE_NOT_IMPLEMENTED; - } - transformed_filter_ptr = transformed_filter_.data(); + const float *transformed_filter_data = nullptr; + // filter only needs to be transformed once, set transformed_filter_data + // to null after the first run. + if (!is_filter_transformed_) { + transformed_filter.Resize(transformed_filter_shape); + switch (winograd_out_tile_size) { + case 2: + TransformFilter4x4(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter.mutable_data()); + break; + case 6: + TransformFilter8x8(filter_data, + filter_shape[1], + filter_shape[0], + transformed_filter.mutable_data()); + break; + default:MACE_NOT_IMPLEMENTED; } - } else { - transformed_filter_ptr = transformed_filter_.data(); + transformed_filter_data = transformed_filter.data(); + is_filter_transformed_ = true; } float *transformed_input_data = transformed_input.mutable_data(); @@ -543,7 +531,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { conv_func = [=](const float *pad_input, float *pad_output) { WinoGradConv3x3s1(pad_input, - transformed_filter_ptr, + transformed_filter_data, batch, extra_input_height, extra_input_width, @@ -728,7 +716,6 @@ struct Conv2dFunctor : Conv2dFunctorBase { return MACE_SUCCESS; } - Tensor transformed_filter_; bool is_filter_transformed_; SGemm sgemm_; }; @@ -741,17 +728,14 @@ struct Conv2dFunctor : Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const bool is_filter_transformed) + const float relux_max_limit) : Conv2dFunctorBase(context, strides, padding_type, paddings, dilations, activation, - relux_max_limit) { - MACE_UNUSED(is_filter_transformed); - } + relux_max_limit) {} template inline void Im2col( @@ -998,8 +982,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const bool is_filter_transformed); + const float relux_max_limit); MaceStatus operator()(const Tensor *input, const Tensor *filter, diff --git a/mace/kernels/opencl/conv_2d.cc b/mace/kernels/opencl/conv_2d.cc index e21c4744..38bb2e8f 100644 --- a/mace/kernels/opencl/conv_2d.cc +++ b/mace/kernels/opencl/conv_2d.cc @@ -27,8 +27,7 @@ Conv2dFunctor::Conv2dFunctor( const std::vector &paddings, const int *dilations, const ActivationType activation, - const float relux_max_limit, - const bool is_filter_transformed) + const float relux_max_limit) : Conv2dFunctorBase(context, strides, padding_type, @@ -36,8 +35,6 @@ Conv2dFunctor::Conv2dFunctor( dilations, activation, relux_max_limit) { - MACE_UNUSED(is_filter_transformed); - if (context->device()->opencl_runtime()->UseImageMemory()) { kernel_.reset(new opencl::image::Conv2dKernel); } else { diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 9d2c2426..9f731fa4 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -38,9 +38,7 @@ class Conv2dOp : public ConvPool2dOpBase { kernels::StringToActivationType( OperatorBase::GetOptionalArg("activation", "NOOP")), - OperatorBase::GetOptionalArg("max_limit", 0.0f), - static_cast(OperatorBase::GetOptionalArg( - "is_filter_transformed", false))) {} + OperatorBase::GetOptionalArg("max_limit", 0.0f)) {} MaceStatus Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); -- GitLab