未验证 提交 ed2a1852 编写于 作者: G gongweibao 提交者: GitHub

optimize nhwc for tensor core in ConvOp and ConvGradOp (#20597)

上级 c918788b
...@@ -97,13 +97,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -97,13 +97,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
filter_dims[0], filter_dims, groups); filter_dims[0], filter_dims, groups);
framework::DDim in_data_dims; framework::DDim in_data_dims;
framework::DDim filter_data_dims;
if (channel_last) { if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} }
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
...@@ -117,9 +119,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,9 +119,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) { (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1); output_shape.push_back(-1);
} else { } else {
output_shape.push_back(ConvOutputSize(in_data_dims[i], filter_dims[i + 2], output_shape.push_back(
dilations[i], paddings[2 * i], ConvOutputSize(in_data_dims[i], filter_data_dims[i], dilations[i],
paddings[2 * i + 1], strides[i])); paddings[2 * i], paddings[2 * i + 1], strides[i]));
} }
} }
if (channel_last) { if (channel_last) {
...@@ -335,7 +337,7 @@ parameters is checked in the infer-shape. ...@@ -335,7 +337,7 @@ parameters is checked in the infer-shape.
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. the width of the feature.
Filters(Input) is MCHW format. Where M is the number of output image channels, C is Filters(Input) is MCHW format format. Where M is the number of output image channels, C is
the number of input image channels, H is the height of the filter, and W the number of input image channels, H is the height of the filter, and W
is the width of the filter. is the width of the filter.
Parameters(strides, paddings, dilations) are two elements. These two elements represent Parameters(strides, paddings, dilations) are two elements. These two elements represent
......
...@@ -154,6 +154,36 @@ inline void ResizeToChannelFirst(const framework::ExecutionContext& context, ...@@ -154,6 +154,36 @@ inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
} }
} }
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context, inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input, const Tensor* input,
......
...@@ -34,6 +34,29 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) { ...@@ -34,6 +34,29 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) {
return ToCudnnDataType(type); return ToCudnnDataType(type);
} }
inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
std::vector<int> transformed_dims(dims.begin(), dims.end());
int H, W, D, C;
if (dims.size() == 4) {
H = dims[1];
W = dims[2];
C = dims[3];
transformed_dims[1] = C;
transformed_dims[2] = H;
transformed_dims[3] = W;
} else {
D = dims[1];
H = dims[2];
W = dims[3];
C = dims[4];
transformed_dims[1] = C;
transformed_dims[2] = D;
transformed_dims[3] = H;
transformed_dims[4] = W;
}
return transformed_dims;
}
template <> template <>
inline cudnnDataType_t ToCudnnDataType( inline cudnnDataType_t ToCudnnDataType(
const framework::proto::VarType::Type& t) { const framework::proto::VarType::Type& t) {
...@@ -117,6 +140,19 @@ class TensorDescriptor { ...@@ -117,6 +140,19 @@ class TensorDescriptor {
dims_with_group.data(), strides.data())); dims_with_group.data(), strides.data()));
} }
void set(const Tensor& tensor, const cudnnTensorFormat_t format) {
auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
} else {
transformed_dims = dims;
}
CUDNN_ENFORCE(dynload::cudnnSetTensorNdDescriptorEx(
desc_.get(), format, ToCudnnDataType(tensor.type()),
transformed_dims.size(), transformed_dims.data()));
}
private: private:
std::unique_ptr<T, Deleter> desc_; std::unique_ptr<T, Deleter> desc_;
}; };
...@@ -143,12 +179,18 @@ class FilterDescriptor { ...@@ -143,12 +179,18 @@ class FilterDescriptor {
void set(const Tensor& tensor, const cudnnTensorFormat_t format, void set(const Tensor& tensor, const cudnnTensorFormat_t format,
const int groups = 1) { const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims()); auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
} else {
transformed_dims = dims;
}
if (groups > 1) { if (groups > 1) {
dims[1] = dims[1] / groups; transformed_dims[1] = transformed_dims[1] / groups;
} }
CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor( CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor(
desc_.get(), ToCudnnDataType(tensor.type()), format, dims.size(), desc_.get(), ToCudnnDataType(tensor.type()), format,
dims.data())); transformed_dims.size(), transformed_dims.data()));
} }
private: private:
......
...@@ -81,7 +81,6 @@ def conv2d_forward_naive(input, ...@@ -81,7 +81,6 @@ def conv2d_forward_naive(input,
if len(pad) == 4: if len(pad) == 4:
pad_h_0, pad_h_1 = pad[0], pad[1] pad_h_0, pad_h_1 = pad[0], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[3] pad_w_0, pad_w_1 = pad[2], pad[3]
out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] * out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] *
(f_h - 1) + 1)) // stride[0] (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] * out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] *
...@@ -204,6 +203,50 @@ def create_test_cudnn_channel_last_class(parent): ...@@ -204,6 +203,50 @@ def create_test_cudnn_channel_last_class(parent):
globals()[cls_name] = TestCudnnChannelLastCase globals()[cls_name] = TestCudnnChannelLastCase
def create_test_cudnn_channel_last_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCudnnChannelLastFp16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLastFp16")
TestCudnnChannelLastFp16.__name__ = cls_name
globals()[cls_name] = TestCudnnChannelLastFp16
def create_test_padding_SAME_class(parent): def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent): class TestPaddingSMAECase(parent):
def init_paddings(self): def init_paddings(self):
...@@ -699,7 +742,6 @@ class TestConv2dOp_v2(OpTest): ...@@ -699,7 +742,6 @@ class TestConv2dOp_v2(OpTest):
self.init_dilation() self.init_dilation()
self.init_data_format() self.init_data_format()
self.init_test_case() self.init_test_case()
self.init_paddings() self.init_paddings()
self.init_test_case_2() self.init_test_case_2()
...@@ -1195,6 +1237,17 @@ create_test_cudnn_channel_last_class(TestWithStride_AsyPadding) ...@@ -1195,6 +1237,17 @@ create_test_cudnn_channel_last_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding) create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding) create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding)
create_test_cudnn_channel_last_fp16_class(
TestConv2dOp_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithPad_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithStride_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithGroup_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithDilation_AsyPadding, grad_check=False)
# --------- test python API --------------- # --------- test python API ---------------
class TestConv2dAPI(OpTest): class TestConv2dAPI(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册