提交 155328a4 编写于 作者: Y Yihua Xu

Clean Code

test=develop
上级 65dbc7cc
...@@ -28,6 +28,46 @@ using mkldnn::stream; ...@@ -28,6 +28,46 @@ using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
bool is_conv3d) {
if (groups > 1) {
if (is_conv3d) {
int output = weights_tz[0];
int input = weights_tz[1];
int dimension = weights_tz[2];
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
}
}
inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
int groups, bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? format : mkldnn::memory::format::goidhw;
} else {
return (groups == 1) ? format : mkldnn::memory::format::goihw;
}
}
template <typename T> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
filter->format() != memory::format::format_undef, filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); "Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW"); "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5, PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"); "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) { if (bias) {
...@@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
if (g > 1) { GetWeightsTz(weights_tz, g, is_conv3d);
if (is_conv3d) {
int o = weights_tz[0];
int i = weights_tz[1];
int d = weights_tz[2];
int h = weights_tz[3];
int w = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = d;
weights_tz[4] = h;
weights_tz[5] = w;
} else {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
...@@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = mkldnn::memory::format weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goihw; GetWeightsFormat(filter->format(), g, is_conv3d);
if (is_conv3d) {
weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
}
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format); {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
...@@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
if (is_conv3d) { if (is_conv3d) {
chosen_memory_format = chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
if (g > 1) { GetWeightsTz(weights_tz, g, is_conv3d);
if (is_conv3d) {
int o = weights_tz[0];
int i = weights_tz[1];
int d = weights_tz[2];
int h = weights_tz[3];
int w = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = d;
weights_tz[4] = h;
weights_tz[5] = w;
} else {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = mkldnn::memory::format weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goihw; GetWeightsFormat(filter->format(), g, is_conv3d);
if (is_conv3d) {
weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
}
// Get an unique name from "argument" name of "Output" variable // Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
...@@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
if (is_conv3d) { if (is_conv3d) {
chosen_memory_format = chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
...@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() { ...@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal" "The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.") "to the number of output channels. Only used with MKL-DNN.")
.AsDispensable(); .AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddInput("ResidualData", AddInput("ResidualData",
"(Tensor) Tensor with residual data " "(Tensor) Tensor with residual data "
"to which convolution output will be added." "to which convolution output will be added."
"Used with fuse_residual_connection fusion.") "Used with fuse_residual_connection fusion.")
.AsDispensable(); .AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
...@@ -251,14 +251,14 @@ void Conv3DOpMaker::Make() { ...@@ -251,14 +251,14 @@ void Conv3DOpMaker::Make() {
"is the width of the filter." "is the width of the filter."
"If the groups attribute is greater than 1, C equals the number of " "If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
AddInput("ResidualData", AddInput("ResidualData",
"(Tensor) Tensor with residual data " "(Tensor) Tensor with residual data "
"to which convolution output will be added." "to which convolution output will be added."
"Used with fuse_residual_connection fusion.") "Used with fuse_residual_connection fusion.")
.AsDispensable(); .AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the " "(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of " "strides(d_stride, h_stride, w_stride) of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册