未验证 提交 ed2a1852 编写于 作者: G gongweibao 提交者: GitHub

optimize nhwc for tensor core in ConvOp and ConvGradOp (#20597)

上级 c918788b
...@@ -40,6 +40,10 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; ...@@ -40,6 +40,10 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
static inline bool IsVoltaOrLater(const platform::CUDADeviceContext& dev_ctx) {
return dev_ctx.GetComputeCapability() >= 70;
}
template <typename T> template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> { class CUDNNConvOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -68,11 +72,27 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -68,11 +72,27 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
const std::string data_format = ctx.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
auto dtype = platform::CudnnDataType<T>::type;
// Tensor Core introduced from Volta GPUs supports more faster conv op
// with FP16 in NHWC data format.
const bool compute_in_nhwc =
dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx);
// We will only do data format conversion from NHWC to NCHW.
// cudnn will convert NCHW to NHWC automatically on Tensor Core.
auto compute_format =
compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW;
VLOG(3) << "Compute ConvOp with cuDNN:"
<< " data_format=" << data_format << " compute_format="
<< (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW");
// ------------ transformed tensor ----------- // ------------ transformed tensor -----------
Tensor transformed_input_channel(input->type()); Tensor transformed_input_channel(input->type());
Tensor transformed_output(output->type()); Tensor transformed_output(output->type());
Tensor transformed_filter_channel(filter->type());
T* output_data = nullptr; T* output_data = nullptr;
if (channel_last) { if (channel_last && compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>( ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel); ctx, input, &transformed_input_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>( TransToChannelFirst<platform::CUDADeviceContext, T>(
...@@ -82,19 +102,36 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -82,19 +102,36 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
&transformed_output); &transformed_output);
} else { } else {
transformed_input_channel = *input; transformed_input_channel.ShareDataWith(*input);
transformed_output = *output; transformed_output.ShareDataWith(*output);
}
if (compute_format == DataLayout::kNHWC) {
VLOG(3) << "Transform filter tensor from NCHW to NHWC.";
ResizeToChannelLast<platform::CUDADeviceContext, T>(
ctx, filter, &transformed_filter_channel);
TransToChannelLast<platform::CUDADeviceContext, T>(
ctx, filter, &transformed_filter_channel);
} else {
transformed_filter_channel.ShareDataWith(*filter);
} }
output_data = transformed_output.data<T>(); output_data = transformed_output.data<T>();
// update padding and dilation // update padding and dilation
auto in_dims = transformed_input_channel.dims(); auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims(); auto filter_dims = transformed_filter_channel.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims;
if (compute_format == DataLayout::kNCHW) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
filter_data_dims =
framework::slice_ddim(filter_dims, 1, filter_dims.size() - 1);
}
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
...@@ -108,17 +145,33 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -108,17 +145,33 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
std::vector<int> padding_diff(data_dim); std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2); std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0]; new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
if (compute_format == DataLayout::kNCHW) {
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
} else {
new_input_shape_vec[data_dim + 1] =
transformed_input_channel.dims()[data_dim + 1];
}
std::vector<int> input_pad(transformed_input_channel.dims().size() * 2, std::vector<int> input_pad(transformed_input_channel.dims().size() * 2,
0); 0);
for (size_t i = 0; i < data_dim; ++i) { for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] = if (compute_format == DataLayout::kNCHW) {
transformed_input_channel.dims()[i + 2] + padding_diff[i]; new_input_shape_vec[i + 2] =
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } else {
new_input_shape_vec[i + 1] =
transformed_input_channel.dims()[i + 1] + padding_diff[i];
}
if (compute_format == DataLayout::kNCHW) {
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
} else {
input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i];
}
} }
framework::DDim new_input_shape( framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec)); framework::make_ddim(new_input_shape_vec));
...@@ -147,7 +200,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -147,7 +200,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
} }
} else { } else {
transformed_input = transformed_input_channel; transformed_input.ShareDataWith(transformed_input_channel);
if (paddings.size() == data_dim) { if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) { for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i]; padding_common[i] = paddings[i];
...@@ -160,18 +213,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -160,18 +213,20 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
} }
const T* input_data = transformed_input.data<T>(); const T* input_data = transformed_input.data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = transformed_filter_channel.data<T>();
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_input, filter, &transformed_output, strides, ConvArgs args{&transformed_input, &transformed_filter_channel,
padding_common, dilations}; &transformed_output, strides,
padding_common, dilations};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto dtype = platform::CudnnDataType<T>::type; DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC
DataLayout layout = DataLayout::kNCHW; : DataLayout::kNCHW;
if (transformed_input_channel.dims().size() == 5) { if (transformed_input.dims().size() == 5) {
layout = DataLayout::kNCDHW; layout = compute_format == DataLayout::kNHWC ? DataLayout::kNDHWC
: DataLayout::kNCDHW;
} }
auto layout_format = GetCudnnTensorFormat(layout); auto layout_format = GetCudnnTensorFormat(layout);
...@@ -186,21 +241,27 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -186,21 +241,27 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
args.cdesc.desc(), groups)); args.cdesc.desc(), groups));
groups = 1; groups = 1;
#endif #endif
args.idesc.set(transformed_input, groups); args.idesc.set(transformed_input, layout_format);
args.wdesc.set(transformed_filter_channel, layout_format, groups);
args.wdesc.set(*filter, layout_format, groups); args.odesc.set(transformed_output, layout_format);
args.odesc.set(transformed_output, groups);
int i_n, i_c, i_d, i_h, i_w; int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w; int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(transformed_output.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d,
&o_h, &o_w); if (compute_format == DataLayout::kNHWC) {
GetNCDHW(transformed_input.dims(), DataLayout::kNHWC, &i_n, &i_c, &i_d,
&i_h, &i_w);
GetNCDHW(transformed_output.dims(), DataLayout::kNHWC, &o_n, &o_c, &o_d,
&o_h, &o_w);
} else {
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
GetNCDHW(transformed_output.dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d,
&o_h, &o_w);
}
int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups; int group_offset_filter = transformed_filter_channel.numel() / groups;
// ------------------- cudnn conv workspace --------------------- // ------------------- cudnn conv workspace ---------------------
size_t workspace_size = 0; // final workspace to allocate. size_t workspace_size = 0; // final workspace to allocate.
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
...@@ -225,7 +286,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -225,7 +286,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
workspace_size); workspace_size);
} }
if (channel_last) { if (channel_last && compute_format == DataLayout::kNCHW) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>( TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_output, output); ctx, &transformed_output, output);
} }
...@@ -245,7 +306,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -245,7 +306,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter")); auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
const T* filter_data = filter->data<T>();
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace()); input_grad->mutable_data<T>(ctx.GetPlace());
} }
...@@ -269,12 +329,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -269,12 +329,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
const std::string data_format = ctx.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
auto dtype = platform::CudnnDataType<T>::type;
const bool compute_in_nhwc =
dtype == CUDNN_DATA_HALF && IsVoltaOrLater(dev_ctx);
auto compute_format =
compute_in_nhwc && channel_last ? DataLayout::kNHWC : DataLayout::kNCHW;
VLOG(3) << "Compute ConvGradOp with cuDNN:"
<< " data_format=" << data_format << " compute_format="
<< (compute_format == DataLayout::kNHWC ? "NHWC" : "NCHW");
// transform Tensor // transform Tensor
Tensor transformed_input_channel(input->type()); Tensor transformed_input_channel(input->type());
Tensor transformed_output_grad_channel(output_grad->type()); Tensor transformed_output_grad_channel(output_grad->type());
Tensor transformed_input_grad_channel(input->type()); Tensor transformed_input_grad_channel(input->type());
Tensor transformed_filter_channel(filter->type());
Tensor transformed_filter_grad_channel(filter->type());
if (channel_last) { if (channel_last && compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input, output_grad, input_grad and tensor from "
"NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>( ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input, &transformed_input_channel); ctx, input, &transformed_input_channel);
TransToChannelFirst<platform::CUDADeviceContext, T>( TransToChannelFirst<platform::CUDADeviceContext, T>(
...@@ -289,22 +362,46 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -289,22 +362,46 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
ResizeToChannelFirst<platform::CUDADeviceContext, T>( ResizeToChannelFirst<platform::CUDADeviceContext, T>(
ctx, input_grad, &transformed_input_grad_channel); ctx, input_grad, &transformed_input_grad_channel);
} }
} else { } else {
transformed_input_channel = *input; transformed_input_channel.ShareDataWith(*input);
transformed_output_grad_channel = *output_grad; transformed_output_grad_channel.ShareDataWith(*output_grad);
if (input_grad) { if (input_grad) {
transformed_input_grad_channel.ShareDataWith(*input_grad); transformed_input_grad_channel.ShareDataWith(*input_grad);
} }
} }
if (compute_format == DataLayout::kNHWC) {
VLOG(3) << "Transform filter and filter_grad tensor from NCHW to NHWC.";
ResizeToChannelLast<platform::CUDADeviceContext, T>(
ctx, filter, &transformed_filter_channel);
TransToChannelLast<platform::CUDADeviceContext, T>(
ctx, filter, &transformed_filter_channel);
if (filter_grad) {
ResizeToChannelLast<platform::CUDADeviceContext, T>(
ctx, filter_grad, &transformed_filter_grad_channel);
}
} else {
transformed_filter_channel.ShareDataWith(*filter);
if (filter_grad) {
transformed_filter_grad_channel.ShareDataWith(*filter_grad);
}
}
// update paddings // update paddings
auto in_dims = transformed_input_channel.dims(); auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims(); auto filter_dims = transformed_filter_channel.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); framework::DDim filter_data_dims;
framework::DDim filter_data_dims = if (compute_format == DataLayout::kNCHW) {
framework::slice_ddim(filter_dims, 2, filter_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
filter_data_dims =
framework::slice_ddim(filter_dims, 1, filter_dims.size() - 1);
}
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
...@@ -323,15 +420,30 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -323,15 +420,30 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
std::vector<int> padding_diff(data_dim); std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2); std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0]; new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1]; if (compute_format == DataLayout::kNCHW) {
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
} else {
new_input_shape_vec[data_dim + 1] =
transformed_input_channel.dims()[data_dim + 1];
}
for (size_t i = 0; i < data_dim; ++i) { for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] = if (compute_format == DataLayout::kNCHW) {
transformed_input_channel.dims()[i + 2] + padding_diff[i]; new_input_shape_vec[i + 2] =
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; } else {
new_input_shape_vec[i + 1] =
transformed_input_channel.dims()[i + 1] + padding_diff[i];
}
if (compute_format == DataLayout::kNCHW) {
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
} else {
input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i];
}
} }
framework::DDim new_input_shape( framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec)); framework::make_ddim(new_input_shape_vec));
...@@ -384,42 +496,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -384,42 +496,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
const T* input_data = transformed_input.data<T>(); const T* input_data = transformed_input.data<T>();
const T* output_grad_data = transformed_output_grad_channel.data<T>(); const T* output_grad_data = transformed_output_grad_channel.data<T>();
const T* filter_data = transformed_filter_channel.data<T>();
T* filter_grad_data = nullptr; T* filter_grad_data = nullptr;
T* input_grad_data = nullptr; T* input_grad_data = nullptr;
T* transformed_input_grad_data = nullptr; T* transformed_input_grad_data = nullptr;
ConvArgs args1{&transformed_input_grad, ConvArgs args1{&transformed_input_grad,
filter, &transformed_filter_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations}; dilations};
ConvArgs args2{&transformed_input, ConvArgs args2{&transformed_input,
filter_grad, &transformed_filter_grad_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations}; dilations};
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto dtype = platform::CudnnDataType<T>::type; DataLayout layout = compute_format == DataLayout::kNHWC ? DataLayout::kNHWC
DataLayout layout = DataLayout::kNCHW; : DataLayout::kNCHW;
if (input->dims().size() == 5) { if (transformed_input.dims().size() == 5) {
layout = DataLayout::kNCDHW; layout = compute_format == DataLayout::kNHWC ? DataLayout::kNDHWC
: DataLayout::kNCDHW;
} }
auto layout_tensor = GetCudnnTensorFormat(layout); auto layout_tensor = GetCudnnTensorFormat(layout);
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
int i_n, i_c, i_d, i_h, i_w; int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w; int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNCHW, &o_n, if (compute_format == DataLayout::kNHWC) {
&o_c, &o_d, &o_h, &o_w); GetNCDHW(transformed_input.dims(), DataLayout::kNHWC, &i_n, &i_c, &i_d,
&i_h, &i_w);
GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNHWC, &o_n,
&o_c, &o_d, &o_h, &o_w);
} else {
GetNCDHW(transformed_input.dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d,
&i_h, &i_w);
GetNCDHW(transformed_output_grad_channel.dims(), DataLayout::kNCHW, &o_n,
&o_c, &o_d, &o_h, &o_w);
}
int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups; int group_offset_filter = transformed_filter_channel.numel() / groups;
// ------------------- cudnn backward algorithm --------------------- // ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t data_algo = cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0); static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
...@@ -439,9 +560,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -439,9 +560,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
input_grad_data = input_grad->data<T>(); input_grad_data = input_grad->data<T>();
transformed_input_grad_data = transformed_input_grad.data<T>(); transformed_input_grad_data = transformed_input_grad.data<T>();
args1.handle = handle; args1.handle = handle;
args1.idesc.set(transformed_input_grad, iwo_groups); args1.idesc.set(transformed_input_grad, layout_tensor);
args1.wdesc.set(*filter, layout_tensor, iwo_groups); args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
args1.odesc.set(transformed_output_grad_channel, iwo_groups); args1.odesc.set(transformed_output_grad_channel, layout_tensor);
args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args1.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
...@@ -453,11 +574,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -453,11 +574,12 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
filter_grad_data = filter_grad->data<T>(); filter_grad_data = transformed_filter_grad_channel.data<T>();
args2.handle = handle; args2.handle = handle;
args2.idesc.set(transformed_input, iwo_groups); args2.idesc.set(transformed_input, layout_tensor);
args2.wdesc.set(*filter_grad, layout_tensor, iwo_groups); args2.wdesc.set(transformed_filter_grad_channel, layout_tensor,
args2.odesc.set(transformed_output_grad_channel, iwo_groups); iwo_groups);
args2.odesc.set(transformed_output_grad_channel, layout_tensor);
args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups); args2.cdesc.set(dtype, padding_common, strides, dilations, c_groups);
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
...@@ -506,7 +628,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -506,7 +628,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} }
} }
if (channel_last) { if (channel_last && compute_format == DataLayout::kNCHW) {
TransToChannelLast<paddle::platform::CUDADeviceContext, T>( TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_input_grad_channel, input_grad); ctx, &transformed_input_grad_channel, input_grad);
} }
...@@ -527,6 +649,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -527,6 +649,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}, },
workspace_size); workspace_size);
} }
if (compute_format == DataLayout::kNHWC) {
TransToChannelFirst<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_filter_grad_channel, filter_grad);
}
} }
} }
}; };
......
...@@ -97,13 +97,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -97,13 +97,15 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
filter_dims[0], filter_dims, groups); filter_dims[0], filter_dims, groups);
framework::DDim in_data_dims; framework::DDim in_data_dims;
framework::DDim filter_data_dims;
if (channel_last) { if (channel_last) {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} }
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); filter_data_dims = framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims); std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
...@@ -117,9 +119,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,9 +119,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
(in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) { (in_data_dims[i] <= 0 || filter_dims[i + 2] <= 0)) {
output_shape.push_back(-1); output_shape.push_back(-1);
} else { } else {
output_shape.push_back(ConvOutputSize(in_data_dims[i], filter_dims[i + 2], output_shape.push_back(
dilations[i], paddings[2 * i], ConvOutputSize(in_data_dims[i], filter_data_dims[i], dilations[i],
paddings[2 * i + 1], strides[i])); paddings[2 * i], paddings[2 * i + 1], strides[i]));
} }
} }
if (channel_last) { if (channel_last) {
...@@ -335,7 +337,7 @@ parameters is checked in the infer-shape. ...@@ -335,7 +337,7 @@ parameters is checked in the infer-shape.
Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. the width of the feature.
Filters(Input) is MCHW format. Where M is the number of output image channels, C is Filters(Input) is MCHW format format. Where M is the number of output image channels, C is
the number of input image channels, H is the height of the filter, and W the number of input image channels, H is the height of the filter, and W
is the width of the filter. is the width of the filter.
Parameters(strides, paddings, dilations) are two elements. These two elements represent Parameters(strides, paddings, dilations) are two elements. These two elements represent
......
...@@ -154,6 +154,36 @@ inline void ResizeToChannelFirst(const framework::ExecutionContext& context, ...@@ -154,6 +154,36 @@ inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
} }
} }
template <typename DeviceContext, typename T>
inline void ResizeToChannelLast(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context, inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input, const Tensor* input,
......
...@@ -34,6 +34,29 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) { ...@@ -34,6 +34,29 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) {
return ToCudnnDataType(type); return ToCudnnDataType(type);
} }
inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
std::vector<int> transformed_dims(dims.begin(), dims.end());
int H, W, D, C;
if (dims.size() == 4) {
H = dims[1];
W = dims[2];
C = dims[3];
transformed_dims[1] = C;
transformed_dims[2] = H;
transformed_dims[3] = W;
} else {
D = dims[1];
H = dims[2];
W = dims[3];
C = dims[4];
transformed_dims[1] = C;
transformed_dims[2] = D;
transformed_dims[3] = H;
transformed_dims[4] = W;
}
return transformed_dims;
}
template <> template <>
inline cudnnDataType_t ToCudnnDataType( inline cudnnDataType_t ToCudnnDataType(
const framework::proto::VarType::Type& t) { const framework::proto::VarType::Type& t) {
...@@ -117,6 +140,19 @@ class TensorDescriptor { ...@@ -117,6 +140,19 @@ class TensorDescriptor {
dims_with_group.data(), strides.data())); dims_with_group.data(), strides.data()));
} }
void set(const Tensor& tensor, const cudnnTensorFormat_t format) {
auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
} else {
transformed_dims = dims;
}
CUDNN_ENFORCE(dynload::cudnnSetTensorNdDescriptorEx(
desc_.get(), format, ToCudnnDataType(tensor.type()),
transformed_dims.size(), transformed_dims.data()));
}
private: private:
std::unique_ptr<T, Deleter> desc_; std::unique_ptr<T, Deleter> desc_;
}; };
...@@ -143,12 +179,18 @@ class FilterDescriptor { ...@@ -143,12 +179,18 @@ class FilterDescriptor {
void set(const Tensor& tensor, const cudnnTensorFormat_t format, void set(const Tensor& tensor, const cudnnTensorFormat_t format,
const int groups = 1) { const int groups = 1) {
auto dims = framework::vectorize<int>(tensor.dims()); auto dims = framework::vectorize<int>(tensor.dims());
std::vector<int> transformed_dims;
if (format == CUDNN_TENSOR_NHWC) {
transformed_dims = TransformDimOrder(dims);
} else {
transformed_dims = dims;
}
if (groups > 1) { if (groups > 1) {
dims[1] = dims[1] / groups; transformed_dims[1] = transformed_dims[1] / groups;
} }
CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor( CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor(
desc_.get(), ToCudnnDataType(tensor.type()), format, dims.size(), desc_.get(), ToCudnnDataType(tensor.type()), format,
dims.data())); transformed_dims.size(), transformed_dims.data()));
} }
private: private:
......
...@@ -81,7 +81,6 @@ def conv2d_forward_naive(input, ...@@ -81,7 +81,6 @@ def conv2d_forward_naive(input,
if len(pad) == 4: if len(pad) == 4:
pad_h_0, pad_h_1 = pad[0], pad[1] pad_h_0, pad_h_1 = pad[0], pad[1]
pad_w_0, pad_w_1 = pad[2], pad[3] pad_w_0, pad_w_1 = pad[2], pad[3]
out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] * out_h = 1 + (in_h + pad_h_0 + pad_h_1 - (dilation[0] *
(f_h - 1) + 1)) // stride[0] (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] * out_w = 1 + (in_w + pad_w_0 + pad_w_1 - (dilation[1] *
...@@ -204,6 +203,50 @@ def create_test_cudnn_channel_last_class(parent): ...@@ -204,6 +203,50 @@ def create_test_cudnn_channel_last_class(parent):
globals()[cls_name] = TestCudnnChannelLastCase globals()[cls_name] = TestCudnnChannelLastCase
def create_test_cudnn_channel_last_fp16_class(parent, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCudnnChannelLastFp16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_grad_no_filter(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and grad_check:
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
cls_name = "{0}_{1}".format(parent.__name__, "CudnnChannelLastFp16")
TestCudnnChannelLastFp16.__name__ = cls_name
globals()[cls_name] = TestCudnnChannelLastFp16
def create_test_padding_SAME_class(parent): def create_test_padding_SAME_class(parent):
class TestPaddingSMAECase(parent): class TestPaddingSMAECase(parent):
def init_paddings(self): def init_paddings(self):
...@@ -699,7 +742,6 @@ class TestConv2dOp_v2(OpTest): ...@@ -699,7 +742,6 @@ class TestConv2dOp_v2(OpTest):
self.init_dilation() self.init_dilation()
self.init_data_format() self.init_data_format()
self.init_test_case() self.init_test_case()
self.init_paddings() self.init_paddings()
self.init_test_case_2() self.init_test_case_2()
...@@ -1195,6 +1237,17 @@ create_test_cudnn_channel_last_class(TestWithStride_AsyPadding) ...@@ -1195,6 +1237,17 @@ create_test_cudnn_channel_last_class(TestWithStride_AsyPadding)
create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding) create_test_cudnn_channel_last_class(TestWithGroup_AsyPadding)
create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding) create_test_cudnn_channel_last_class(TestWithDilation_AsyPadding)
create_test_cudnn_channel_last_fp16_class(
TestConv2dOp_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithPad_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithStride_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithGroup_AsyPadding, grad_check=False)
create_test_cudnn_channel_last_fp16_class(
TestWithDilation_AsyPadding, grad_check=False)
# --------- test python API --------------- # --------- test python API ---------------
class TestConv2dAPI(OpTest): class TestConv2dAPI(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册