未验证 提交 4a8708bb 编写于 作者: W Wilber 提交者: GitHub

[Inference] Add conv_fusion nhwc impl. (#49047)

上级 7875accb
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,6 +56,10 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker { ...@@ -55,6 +56,10 @@ class Conv2DFusionOpMaker : public Conv2DOpMaker {
"search_times", "search_times",
"The number of exhaustive search times for convolution algorithm.") "The number of exhaustive search times for convolution algorithm.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
} }
}; };
...@@ -67,31 +72,14 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -67,31 +72,14 @@ class Conv2DFusionOp : public operators::ConvOp {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion"); OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv2DFusion");
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion"); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "Conv2DFusion");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(
in_dims.size(),
4U,
platform::errors::InvalidArgument(
"The input's dimension of Operator(Conv2DFusion) is expected "
"to be 4. But received: input's dimension = %u, shape = [%s].",
in_dims.size(),
in_dims));
// In some case, attribute data_format is "AnyLayout". // In some case, attribute data_format is "AnyLayout".
std::string data_format = ctx->Attrs().Get<std::string>("data_format"); std::string data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format,
"NDHWC",
platform::errors::PermissionDenied(
"Operator(Conv2DFusion) supports data format of "
"channel first (NCHW,NCDHW) and data format of channel last(NHWC) "
"now. But received: data_format = '%s'.",
data_format));
// MKL-DNN Kernels are using NCHW order of dims description // MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel // so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) && const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC"); (data_format == "NHWC" || data_format == "NDHWC");
std::vector<int64_t> output_shape = ComputeOutputShape(ctx); std::vector<int64_t> output_shape =
ComputeOutputShape(ctx, data_format, channel_last);
ctx->SetOutputDim("Output", phi::make_ddim(output_shape)); ctx->SetOutputDim("Output", phi::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output"); ctx->ShareLoD("Input", "Output");
...@@ -145,8 +133,9 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -145,8 +133,9 @@ class Conv2DFusionOp : public operators::ConvOp {
} }
} }
std::vector<int64_t> ComputeOutputShape( std::vector<int64_t> ComputeOutputShape(framework::InferShapeContext* ctx,
framework::InferShapeContext* ctx) const { const std::string& data_format,
bool channel_last) const {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Conv");
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv"); OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter", "Conv");
...@@ -170,24 +159,6 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -170,24 +159,6 @@ class Conv2DFusionOp : public operators::ConvOp {
"dilation is %d.", "dilation is %d.",
dilations[i])); dilations[i]));
} }
const std::string data_format =
ctx->Attrs().Get<std::string>("data_format");
// if data_format is NHWC, we convert the weight dimension to the form of
// nchw to minimize program changes.
if (data_format == "NHWC") {
int kh = filter_dims[1];
int kw = filter_dims[2];
int ic = filter_dims[3];
filter_dims[1] = ic;
filter_dims[2] = kh;
filter_dims[3] = kw;
}
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, in_dims.size() == 4 || in_dims.size() == 5,
...@@ -223,7 +194,6 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -223,7 +194,6 @@ class Conv2DFusionOp : public operators::ConvOp {
strides[i])); strides[i]));
} }
int in_sub_stride_size = in_dims.size() - stride_size;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), in_dims.size(),
strides.size() + 2U, strides.size() + 2U,
...@@ -237,14 +207,15 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -237,14 +207,15 @@ class Conv2DFusionOp : public operators::ConvOp {
in_dims, in_dims,
strides.size(), strides.size(),
phi::make_ddim(strides), phi::make_ddim(strides),
in_sub_stride_size)); in_dims.size() - stride_size));
const auto input_channels = const auto input_channels =
channel_last ? in_dims[in_dims.size() - 1] : in_dims[1]; channel_last ? in_dims[in_dims.size() - 1] : in_dims[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input_channels, input_channels,
filter_dims[1] * groups, (channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1]) *
groups,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The number of input's channels should be equal to filter's " "The number of input's channels should be equal to filter's "
"channels " "channels "
...@@ -254,7 +225,7 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -254,7 +225,7 @@ class Conv2DFusionOp : public operators::ConvOp {
"The error may come from wrong data_format setting.", "The error may come from wrong data_format setting.",
input_channels, input_channels,
in_dims, in_dims,
filter_dims[1], channel_last ? filter_dims[filter_dims.size() - 1] : filter_dims[1],
filter_dims, filter_dims,
groups, groups,
data_format)); data_format));
...@@ -285,8 +256,13 @@ class Conv2DFusionOp : public operators::ConvOp { ...@@ -285,8 +256,13 @@ class Conv2DFusionOp : public operators::ConvOp {
in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
} }
framework::DDim filter_data_dims = framework::DDim filter_data_dims;
phi::slice_ddim(filter_dims, 2, filter_dims.size()); if (channel_last) {
filter_data_dims =
phi::slice_ddim(filter_dims, 1, filter_dims.size() - 1);
} else {
filter_data_dims = phi::slice_ddim(filter_dims, 2, filter_dims.size());
}
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims); std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation( UpdatePaddingAndDilation(
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ConvFusionOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_fusion",
{"Input", "Filter", "Bias", "ResidualData"},
{
"strides",
"paddings",
"padding_algorithm",
"dilations",
"groups",
"data_format",
"activation",
"exhaustive_search",
"split_channels",
"workspace_size_MB",
},
{"Output", "Outputs"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion, phi::ConvFusionOpArgumentMapping);
...@@ -43,6 +43,30 @@ def create_test_padding_VALID_class(parent): ...@@ -43,6 +43,30 @@ def create_test_padding_VALID_class(parent):
globals()[cls_name] = TestPaddingVALIDCase globals()[cls_name] = TestPaddingVALIDCase
def create_test_cudnn_channel_last_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestCudnnChannelLastCase(parent):
def init_test_case(self):
super().init_test_case()
self.data_format = "NHWC"
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
K1, K2, R, S = self.filter_size
self.filter_size = [K1, R, S, K2]
def test_check_output(self):
print(self.attrs)
if self.has_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLast")
TestCudnnChannelLastCase.__name__ = cls_name
globals()[cls_name] = TestCudnnChannelLastCase
class TestConv2DFusionOp(OpTest): class TestConv2DFusionOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "conv2d_fusion" self.op_type = "conv2d_fusion"
...@@ -73,9 +97,14 @@ class TestConv2DFusionOp(OpTest): ...@@ -73,9 +97,14 @@ class TestConv2DFusionOp(OpTest):
filter = np.random.random(self.filter_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype)
bias = np.random.random(self.filter_size[0]).astype(self.dtype) bias = np.random.random(self.filter_size[0]).astype(self.dtype)
if self.data_format == "NHWC":
filter_nchw = np.transpose(filter, [0, 3, 1, 2])
else:
filter_nchw = filter
self.output, _, _, _, _ = conv2d_forward_naive( self.output, _, _, _, _ = conv2d_forward_naive(
input, input,
filter, filter_nchw,
self.groups, self.groups,
conv2d_param, conv2d_param,
self.padding_algorithm, self.padding_algorithm,
...@@ -100,7 +129,10 @@ class TestConv2DFusionOp(OpTest): ...@@ -100,7 +129,10 @@ class TestConv2DFusionOp(OpTest):
self.output += residual_data self.output += residual_data
# Add bias # Add bias
if self.data_format == "NCHW":
self.output = self.output + bias.reshape((1, bias.size, 1, 1)) self.output = self.output + bias.reshape((1, bias.size, 1, 1))
else:
self.output = self.output + bias.reshape((1, 1, 1, bias.size))
assert self.activation in ['relu', 'identity'] assert self.activation in ['relu', 'identity']
if self.activation == 'relu': if self.activation == 'relu':
...@@ -359,6 +391,23 @@ class TestWithInput1x1Filter1x1_AsyPadding(TestConv2DFusionOp): ...@@ -359,6 +391,23 @@ class TestWithInput1x1Filter1x1_AsyPadding(TestConv2DFusionOp):
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
class TestSimpleNHWC(TestConv2DFusionOp):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [3, 5, 5, 2] # NHWC
self.data_format = "NHWC"
assert np.mod(self.input_size[3], self.groups) == 0
f_c = self.input_size[3] // self.groups
self.filter_size = [4, 3, 3, f_c]
def init_group(self):
self.groups = 1
def init_paddings(self):
self.pad = [1, 1]
self.padding_algorithm = "EXPLICIT"
create_test_padding_SAME_class(TestAsyPadding) create_test_padding_SAME_class(TestAsyPadding)
create_test_padding_SAME_class(TestWithPad_AsyPadding) create_test_padding_SAME_class(TestWithPad_AsyPadding)
create_test_padding_SAME_class(TestWithStride_AsyPadding) create_test_padding_SAME_class(TestWithStride_AsyPadding)
...@@ -371,5 +420,11 @@ create_test_padding_VALID_class(TestWithStride_AsyPadding) ...@@ -371,5 +420,11 @@ create_test_padding_VALID_class(TestWithStride_AsyPadding)
create_test_padding_VALID_class(TestWithGroup_AsyPadding) create_test_padding_VALID_class(TestWithGroup_AsyPadding)
create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding) create_test_padding_VALID_class(TestWithInput1x1Filter1x1_AsyPadding)
create_test_cudnn_channel_last_class(TestAsyPadding)
create_test_cudnn_channel_last_class(TestWithPad_AsyPadding)
create_test_cudnn_channel_last_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_class(TestWithInput1x1Filter1x1_AsyPadding)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册