提交 e1b78a45 编写于 作者: 李寅

Optimize winograd memory use

上级 e2a40a03
......@@ -73,8 +73,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed)
const float relux_max_limit)
: Conv2dFunctorBase(context,
strides,
padding_type,
......@@ -82,8 +81,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
dilations,
activation,
relux_max_limit),
transformed_filter_(GetCPUAllocator(), DataType::DT_FLOAT),
is_filter_transformed_(is_filter_transformed) {}
is_filter_transformed_(false) {}
void Conv2dGeneral(const float *input,
const float *filter,
......@@ -270,14 +268,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
MACE_CHECK_NOTNULL(output);
std::vector<index_t> filter_shape(4);
if (is_filter_transformed_) {
// TOC -> OIHW
filter_shape[0] = filter->dim(1);
filter_shape[1] = filter->dim(2);
filter_shape[2] = filter_shape[3] = 3;
} else {
filter_shape = filter->shape();
}
filter_shape = filter->shape();
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
......@@ -349,9 +340,9 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
std::function<void(const float *input, float *output)> conv_func;
bool
use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3
use_winograd = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8);
&& input_channels >= 8 && channels >= 8;
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
......@@ -452,15 +443,15 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
index_t padded_output_size = 0;
if (use_winograd) {
transformed_input_size =
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
transformed_output_size =
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
}
if (extra_input_height != input_height
......@@ -477,14 +468,13 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
* sizeof(float);
total_scratch_size += padded_output_size;
}
// scratch for sgemm
// scratch for sgemm, preoccupy enough buffer
if (use_neon_1x1_s1) {
total_scratch_size +=
(input_batch * input_height * input_width
* (input_channels + channels)) * sizeof(float);
total_scratch_size += (input_batch * input_height * input_width
* (input_channels + channels))
* sizeof(float);
} else if (use_winograd) {
total_scratch_size +=
(transformed_input_size + transformed_output_size) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
}
// Init scratch buffer
......@@ -506,36 +496,34 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
MACE_UNUSED(extra_input_shape);
MACE_UNUSED(extra_output_shape);
Tensor transformed_filter;
// decide which convolution function to call
if (use_winograd) {
transformed_input.Reshape(transformed_input_shape);
transformed_output.Reshape(transformed_output_shape);
const float *transformed_filter_ptr;
if (transformed_filter_.dim_size() == 0) {
if (is_filter_transformed_) {
transformed_filter_ptr = filter_data;
} else {
MACE_RETURN_IF_ERROR(transformed_filter_.Resize(
transformed_filter_shape));
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_ptr = transformed_filter_.data<float>();
const float *transformed_filter_data = nullptr;
// filter only needs to be transformed once, set transformed_filter_data
// to null after the first run.
if (!is_filter_transformed_) {
transformed_filter.Resize(transformed_filter_shape);
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
} else {
transformed_filter_ptr = transformed_filter_.data<float>();
transformed_filter_data = transformed_filter.data<float>();
is_filter_transformed_ = true;
}
float *transformed_input_data = transformed_input.mutable_data<float>();
......@@ -543,7 +531,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
conv_func = [=](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
transformed_filter_ptr,
transformed_filter_data,
batch,
extra_input_height,
extra_input_width,
......@@ -728,7 +716,6 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
return MACE_SUCCESS;
}
Tensor transformed_filter_;
bool is_filter_transformed_;
SGemm sgemm_;
};
......@@ -741,17 +728,14 @@ struct Conv2dFunctor<DeviceType::CPU, uint8_t> : Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed)
const float relux_max_limit)
: Conv2dFunctorBase(context,
strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit) {
MACE_UNUSED(is_filter_transformed);
}
relux_max_limit) {}
template <typename T>
inline void Im2col(
......@@ -998,8 +982,7 @@ struct Conv2dFunctor<DeviceType::GPU, T> : Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed);
const float relux_max_limit);
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
......
......@@ -27,8 +27,7 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor(
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed)
const float relux_max_limit)
: Conv2dFunctorBase(context,
strides,
padding_type,
......@@ -36,8 +35,6 @@ Conv2dFunctor<DeviceType::GPU, T>::Conv2dFunctor(
dilations,
activation,
relux_max_limit) {
MACE_UNUSED(is_filter_transformed);
if (context->device()->opencl_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::Conv2dKernel<T>);
} else {
......
......@@ -38,9 +38,7 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType(
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f),
static_cast<bool>(OperatorBase::GetOptionalArg<int>(
"is_filter_transformed", false))) {}
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册