提交 e1b78a45 编写于 作者: 李寅

Optimize winograd memory use

上级 e2a40a03
...@@ -73,8 +73,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -73,8 +73,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const bool is_filter_transformed)
: Conv2dFunctorBase(context, : Conv2dFunctorBase(context,
strides, strides,
padding_type, padding_type,
...@@ -82,8 +81,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -82,8 +81,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
dilations, dilations,
activation, activation,
relux_max_limit), relux_max_limit),
transformed_filter_(GetCPUAllocator(), DataType::DT_FLOAT), is_filter_transformed_(false) {}
is_filter_transformed_(is_filter_transformed) {}
void Conv2dGeneral(const float *input, void Conv2dGeneral(const float *input,
const float *filter, const float *filter,
...@@ -270,14 +268,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -270,14 +268,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
std::vector<index_t> filter_shape(4); std::vector<index_t> filter_shape(4);
if (is_filter_transformed_) { filter_shape = filter->shape();
// TOC -> OIHW
filter_shape[0] = filter->dim(1);
filter_shape[1] = filter->dim(2);
filter_shape[2] = filter_shape[3] = 3;
} else {
filter_shape = filter->shape();
}
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
...@@ -349,9 +340,9 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -349,9 +340,9 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
std::function<void(const float *input, float *output)> conv_func; std::function<void(const float *input, float *output)> conv_func;
bool bool
use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3 use_winograd = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8); && input_channels >= 8 && channels >= 8;
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3 bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1; && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3 bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
...@@ -452,15 +443,15 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -452,15 +443,15 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
index_t padded_output_size = 0; index_t padded_output_size = 0;
if (use_winograd) { if (use_winograd) {
transformed_input_size = transformed_input_size =
std::accumulate(transformed_input_shape.begin(), std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(), transformed_input_shape.end(),
1, 1,
std::multiplies<index_t>()) * sizeof(float); std::multiplies<index_t>()) * sizeof(float);
transformed_output_size = transformed_output_size =
std::accumulate(transformed_output_shape.begin(), std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(), transformed_output_shape.end(),
1, 1,
std::multiplies<index_t>()) * sizeof(float); std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size; total_scratch_size += transformed_input_size + transformed_output_size;
} }
if (extra_input_height != input_height if (extra_input_height != input_height
...@@ -477,14 +468,13 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -477,14 +468,13 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
* sizeof(float); * sizeof(float);
total_scratch_size += padded_output_size; total_scratch_size += padded_output_size;
} }
// scratch for sgemm // scratch for sgemm, preoccupy enough buffer
if (use_neon_1x1_s1) { if (use_neon_1x1_s1) {
total_scratch_size += total_scratch_size += (input_batch * input_height * input_width
(input_batch * input_height * input_width * (input_channels + channels))
* (input_channels + channels)) * sizeof(float); * sizeof(float);
} else if (use_winograd) { } else if (use_winograd) {
total_scratch_size += total_scratch_size += transformed_input_size + transformed_output_size;
(transformed_input_size + transformed_output_size) * sizeof(float);
} }
// Init scratch buffer // Init scratch buffer
...@@ -506,36 +496,34 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -506,36 +496,34 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
MACE_UNUSED(extra_input_shape); MACE_UNUSED(extra_input_shape);
MACE_UNUSED(extra_output_shape); MACE_UNUSED(extra_output_shape);
Tensor transformed_filter;
// decide which convolution function to call // decide which convolution function to call
if (use_winograd) { if (use_winograd) {
transformed_input.Reshape(transformed_input_shape); transformed_input.Reshape(transformed_input_shape);
transformed_output.Reshape(transformed_output_shape); transformed_output.Reshape(transformed_output_shape);
const float *transformed_filter_ptr; const float *transformed_filter_data = nullptr;
if (transformed_filter_.dim_size() == 0) { // filter only needs to be transformed once, set transformed_filter_data
if (is_filter_transformed_) { // to null after the first run.
transformed_filter_ptr = filter_data; if (!is_filter_transformed_) {
} else { transformed_filter.Resize(transformed_filter_shape);
MACE_RETURN_IF_ERROR(transformed_filter_.Resize( switch (winograd_out_tile_size) {
transformed_filter_shape)); case 2:
switch (winograd_out_tile_size) { TransformFilter4x4(filter_data,
case 2: filter_shape[1],
TransformFilter4x4(filter_data, filter_shape[0],
filter_shape[1], transformed_filter.mutable_data<float>());
filter_shape[0], break;
transformed_filter_.mutable_data<float>()); case 6:
break; TransformFilter8x8(filter_data,
case 6: filter_shape[1],
TransformFilter8x8(filter_data, filter_shape[0],
filter_shape[1], transformed_filter.mutable_data<float>());
filter_shape[0], break;
transformed_filter_.mutable_data<float>()); default:MACE_NOT_IMPLEMENTED;
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_ptr = transformed_filter_.data<float>();
} }
} else { transformed_filter_data = transformed_filter.data<float>();
transformed_filter_ptr = transformed_filter_.data<float>(); is_filter_transformed_ = true;
} }
float *transformed_input_data = transformed_input.mutable_data<float>(); float *transformed_input_data = transformed_input.mutable_data<float>();
...@@ -543,7 +531,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -543,7 +531,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
conv_func = [=](const float *pad_input, float *pad_output) { conv_func = [=](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input, WinoGradConv3x3s1(pad_input,
transformed_filter_ptr, transformed_filter_data,
batch, batch,
extra_input_height, extra_input_height,
extra_input_width, extra_input_width,
...@@ -728,7 +716,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase { ...@@ -728,7 +716,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
return MACE_SUCCESS; return MACE_SUCCESS;
} }
Tensor transformed_filter_;
bool is_filter_transformed_; bool is_filter_transformed_;
SGemm sgemm_; SGemm sgemm_;
}; };
...@@ -741,17 +728,14 @@ struct Conv2dFunctor<DeviceType::CPU, uint8_t> : Conv2dFunctorBase { ...@@ -741,17 +728,14 @@ struct Conv2dFunctor<DeviceType::CPU, uint8_t> : Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const bool is_filter_transformed)
: Conv2dFunctorBase(context, : Conv2dFunctorBase(context,
strides, strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit) { relux_max_limit) {}
MACE_UNUSED(is_filter_transformed);
}
template <typename T> template <typename T>
inline void Im2col( inline void Im2col(
...@@ -998,8 +982,7 @@ struct Conv2dFunctor<DeviceType::GPU, T> : Conv2dFunctorBase { ...@@ -998,8 +982,7 @@ struct Conv2dFunctor<DeviceType::GPU, T> : Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit);
const bool is_filter_transformed);
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
......
...@@ -27,8 +27,7 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor( ...@@ -27,8 +27,7 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor(
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const bool is_filter_transformed)
: Conv2dFunctorBase(context, : Conv2dFunctorBase(context,
strides, strides,
padding_type, padding_type,
...@@ -36,8 +35,6 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor( ...@@ -36,8 +35,6 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor(
dilations, dilations,
activation, activation,
relux_max_limit) { relux_max_limit) {
MACE_UNUSED(is_filter_transformed);
if (context->device()->opencl_runtime()->UseImageMemory()) { if (context->device()->opencl_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::Conv2dKernel<T>); kernel_.reset(new opencl::image::Conv2dKernel<T>);
} else { } else {
......
...@@ -38,9 +38,7 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -38,9 +38,7 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetOptionalArg<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f), OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
static_cast<bool>(OperatorBase::GetOptionalArg<int>(
"is_filter_transformed", false))) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册