未验证 提交 3e2dec5b 编写于 作者: Z Zhang Zheng 提交者: GitHub

Change the input param of fusion op interface from pointer to tensor (#36349)

上级 fba355fb
...@@ -536,32 +536,20 @@ class CudnnBNAddReluTester { ...@@ -536,32 +536,20 @@ class CudnnBNAddReluTester {
bn_bias->Resize({1, 1, 1, channels_}); bn_bias->Resize({1, 1, 1, channels_});
// input // input
float *sum_ptr = sum->data<float>();
float *sum_of_square_ptr = sum_of_square->data<float>();
float *bn_scale_ptr = bn_scale->data<float>();
float *bn_bias_ptr = bn_bias->data<float>();
mean->Resize({1, 1, 1, channels_}); mean->Resize({1, 1, 1, channels_});
var->Resize({1, 1, 1, channels_}); var->Resize({1, 1, 1, channels_});
// output // output
float *mean_ptr = mean->data<float>(); equiv_scale->Resize({1, 1, 1, channels_});
float *var_ptr = var->data<float>(); equiv_bias->Resize({1, 1, 1, channels_});
float *saved_mean_ptr = saved_mean->Resize({1, 1, 1, channels_});
saved_mean->mutable_data<float>({1, 1, 1, channels_}, place); saved_var->Resize({1, 1, 1, channels_});
float *saved_var_ptr =
saved_var->mutable_data<float>({1, 1, 1, channels_}, place);
T *equiv_scale_ptr =
equiv_scale->mutable_data<T>({1, 1, 1, channels_}, place);
T *equiv_bias_ptr =
equiv_bias->mutable_data<T>({1, 1, 1, channels_}, place);
auto param_shape = framework::vectorize<int>(bn_scale->dims()); auto param_shape = framework::vectorize<int>(bn_scale->dims());
op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape); op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape);
bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr, bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean,
saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr, saved_var, mean, var, equiv_scale, equiv_bias, eps_,
equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_, momentum_, ele_count_, true);
true);
} }
// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu // Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
...@@ -627,21 +615,13 @@ class CudnnBNAddReluTester { ...@@ -627,21 +615,13 @@ class CudnnBNAddReluTester {
&saved_var_z, &equiv_scale_z, &equiv_bias_z); &saved_var_z, &equiv_scale_z, &equiv_bias_z);
} }
T *x_ptr = x.data<T>(); y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
T *z_ptr = (fuse_add_ || has_shortcut_) ? z.data<T>() : nullptr;
T *equiv_scale_x_ptr = equiv_scale_x.data<T>();
T *equiv_bias_x_ptr = equiv_bias_x.data<T>();
T *equiv_scale_z_ptr = has_shortcut_ ? equiv_scale_z.data<T>() : nullptr;
T *equiv_bias_z_ptr = has_shortcut_ ? equiv_bias_z.data<T>() : nullptr;
T *y_ptr =
y.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
int c = channels_; int c = channels_;
int64_t nhw = ele_count_; int64_t nhw = ele_count_;
int32_t c_int32_elems = ((c + 63) & ~63) / 32; int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = (nhw + 31) & ~31; int32_t nhw_int32_elems = (nhw + 31) & ~31;
int32_t *bitmask_ptr = bitmask.mutable_data<int32_t>( bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1}));
{nhw_int32_elems, c_int32_elems, 1}, place);
auto data_shape = framework::vectorize<int>(x.dims()); auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale_x.dims()); auto param_shape = framework::vectorize<int>(bn_scale_x.dims());
...@@ -651,8 +631,8 @@ class CudnnBNAddReluTester { ...@@ -651,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_, op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape, has_shortcut_, data_shape, param_shape,
bitmask_shape); bitmask_shape);
sbar_op.Forward(ctx, x_ptr, equiv_scale_x_ptr, equiv_bias_x_ptr, y_ptr, sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
bitmask_ptr, z_ptr, equiv_scale_z_ptr, equiv_bias_z_ptr); equiv_bias_z, &y, &bitmask);
TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x); TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x); TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
...@@ -697,19 +677,10 @@ class CudnnBNAddReluTester { ...@@ -697,19 +677,10 @@ class CudnnBNAddReluTester {
saved_mean.Resize({1, 1, 1, channels_}); saved_mean.Resize({1, 1, 1, channels_});
saved_var.Resize({1, 1, 1, channels_}); saved_var.Resize({1, 1, 1, channels_});
T *dy_ptr = dy.data<T>(); dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
T *x_ptr = x.data<T>(); dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
float *bn_scale_ptr = bn_scale.data<float>(); dscale.Resize(framework::make_ddim({1, 1, 1, channels_}));
float *bn_bias_ptr = bn_bias.data<float>(); dbias.Resize(framework::make_ddim({1, 1, 1, channels_}));
float *saved_mean_ptr = saved_mean.data<float>();
float *saved_var_ptr = saved_var.data<float>();
int32_t *bitmask_ptr = bitmask.data<int32_t>();
T *dx_ptr =
dx.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
T *dz_ptr =
dz.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
float *dscale_ptr = dscale.mutable_data<float>({1, 1, 1, channels_}, place);
float *dbias_ptr = dbias.mutable_data<float>({1, 1, 1, channels_}, place);
auto data_shape = framework::vectorize<int>(x.dims()); auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale.dims()); auto param_shape = framework::vectorize<int>(bn_scale.dims());
...@@ -718,9 +689,8 @@ class CudnnBNAddReluTester { ...@@ -718,9 +689,8 @@ class CudnnBNAddReluTester {
std::string act_type = "relu"; std::string act_type = "relu";
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape, op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape); param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy_ptr, x_ptr, bn_scale_ptr, bn_bias_ptr, sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
saved_mean_ptr, saved_var_ptr, bitmask_ptr, dx_ptr, dz_ptr, bitmask, &dx, &dz, &dscale, &dbias, eps_);
dscale_ptr, dbias_ptr, eps_);
TensorCopySync(dx, platform::CPUPlace(), cpu_dx); TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz); TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
......
...@@ -68,12 +68,13 @@ class CudnnBNStatsFinalize { ...@@ -68,12 +68,13 @@ class CudnnBNStatsFinalize {
} }
~CudnnBNStatsFinalize() {} ~CudnnBNStatsFinalize() {}
void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr, void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum,
float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr, const Tensor &sum_of_squares, const Tensor &scale,
float *saved_mean_ptr, float *saved_invstd_ptr, const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd,
float *running_mean_ptr, float *running_var_ptr, Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale,
T *equiv_scale_ptr, T *equiv_bias_ptr, double eps, Tensor *equiv_bias, double eps, float momentum,
float momentum, int64_t ele_count, bool is_train) { int64_t ele_count, bool is_train) {
auto place = ctx.GetPlace();
if (is_train) { if (is_train) {
TrainInit(ctx); TrainInit(ctx);
} else { } else {
...@@ -82,6 +83,17 @@ class CudnnBNStatsFinalize { ...@@ -82,6 +83,17 @@ class CudnnBNStatsFinalize {
auto &op = is_train ? train_op_ : inference_op_; auto &op = is_train ? train_op_ : inference_op_;
// Set variant_param for both inference_op_ and train_op_ // Set variant_param for both inference_op_ and train_op_
float *sum_ptr = const_cast<float *>(sum.data<float>());
float *sum_of_squares_ptr =
const_cast<float *>(sum_of_squares.data<float>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = saved_mean->mutable_data<float>(place);
float *saved_invstd_ptr = saved_invstd->mutable_data<float>(place);
float *running_mean_ptr = running_mean->mutable_data<float>(place);
float *running_var_ptr = running_var->mutable_data<float>(place);
T *equiv_scale_ptr = equiv_scale->mutable_data<T>(place);
T *equiv_bias_ptr = equiv_bias->mutable_data<T>(place);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr); op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
......
...@@ -38,7 +38,8 @@ struct NormConvolutionArgs { ...@@ -38,7 +38,8 @@ struct NormConvolutionArgs {
compute_type = platform::CudnnDataType<float>::type; compute_type = platform::CudnnDataType<float>::type;
} }
void Set(const std::vector<int> &input_shape, void Set(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape, const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, int padding, int stride, const std::vector<int> &output_shape, int padding, int stride,
int dilation, int group) { int dilation, int group) {
...@@ -61,12 +62,33 @@ struct NormConvolutionArgs { ...@@ -61,12 +62,33 @@ struct NormConvolutionArgs {
"The filter_shape is expected to store as nhwc, and " "The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s].", "h = w = 1 or 3. But recieved filter_shape is [%s].",
framework::make_ddim(filter_shape))); framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
true,
platform::errors::InvalidArgument(
"The input channel is expected to be multiple of 8, "
"and the output channel is expected to be multiple "
"of 32. But recieved input channel is %d, output "
"channel is %d.",
filter_shape[3], filter_shape[0]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_shape.size(), 4U, output_shape.size(), 4U,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The size of output_shape is expected to 4. But recieved " "The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].", "filter_shape's size is %d, filter_shape is [%s].",
output_shape.size(), framework::make_ddim(output_shape))); output_shape.size(), framework::make_ddim(output_shape)));
is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
PADDLE_ENFORCE_EQ(
is_support, true,
platform::errors::InvalidArgument(
"Current test is only supported in the platforms with "
"compatiblity greater than or equal to 70 and the kernel size "
"must be equal to 1 or 3. When the kernel size is 1, "
"the stride must be 1 if the compatiblity is equal to 70. "
"Besides, the dilation and group must be equal to 1. But recieved "
"compatiblity is %d, kernel size is %d, stride is %d, "
"dilation is %d, group is %d",
ctx.GetComputeCapability(), filter_shape[1], stride, dilation,
group));
for (size_t i = 0; i < input_shape.size(); ++i) { for (size_t i = 0; i < input_shape.size(); ++i) {
in_dims.push_back(input_shape[i]); in_dims.push_back(input_shape[i]);
...@@ -89,6 +111,25 @@ struct NormConvolutionArgs { ...@@ -89,6 +111,25 @@ struct NormConvolutionArgs {
conv_desc.set(dtype, paddings, strides, dilations, false, group); conv_desc.set(dtype, paddings, strides, dilations, false, group);
} }
bool IsSupport(const platform::CUDADeviceContext &ctx,
const std::vector<int> &filter_shape, int stride, int dilation,
int group) {
int kernel_size = filter_shape[1];
if (dilation != 1 || group != 1) {
return false;
}
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size == 3) || (kernel_size == 1)) {
return true;
}
}
return false;
}
cudnnDataType_t dtype; cudnnDataType_t dtype;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
cudnnDataType_t compute_type; cudnnDataType_t compute_type;
...@@ -104,6 +145,8 @@ struct NormConvolutionArgs { ...@@ -104,6 +145,8 @@ struct NormConvolutionArgs {
platform::TensorDescriptor out_desc; platform::TensorDescriptor out_desc;
platform::TensorDescriptor out_stats_desc; platform::TensorDescriptor out_stats_desc;
platform::ConvolutionDescriptor conv_desc; platform::ConvolutionDescriptor conv_desc;
bool is_support;
}; };
template <typename T> template <typename T>
...@@ -115,15 +158,16 @@ class CudnnNormConvolution { ...@@ -115,15 +158,16 @@ class CudnnNormConvolution {
const std::vector<int> &output_shape, const int &padding, const std::vector<int> &output_shape, const int &padding,
const int &stride, const int &dilation, const int &stride, const int &dilation,
const int &group) { const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride, args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group); dilation, group);
} }
~CudnnNormConvolution() {} ~CudnnNormConvolution() {}
void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr, void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input,
T *filter_ptr, T *output_ptr, float *sum_ptr, const Tensor &filter, Tensor *output, Tensor *sum,
float *sum_of_squares_ptr) { Tensor *sum_of_squares) {
auto cudnn_handle = ctx.cudnn_handle(); auto cudnn_handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
CudnnFusionOp *fwd_op = GetForwardOp(ctx); CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp( size_t workspace_size = RoundUp(
...@@ -132,12 +176,17 @@ class CudnnNormConvolution { ...@@ -132,12 +176,17 @@ class CudnnNormConvolution {
// Set variant_param // Set variant_param
// input ptr // input ptr
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
fwd_op->SetOpVariantParamAttrPtr( fwd_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size); CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
// output ptr // output ptr
T *output_ptr = output->mutable_data<T>(place);
float *sum_ptr = sum->mutable_data<float>(place);
float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
...@@ -209,28 +258,34 @@ class CudnnNormConvolutionGrad { ...@@ -209,28 +258,34 @@ class CudnnNormConvolutionGrad {
const std::vector<int> &output_shape, const std::vector<int> &output_shape,
const int &padding, const int &stride, const int &padding, const int &stride,
const int &dilation, const int &group) { const int &dilation, const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride, args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group); dilation, group);
dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} }
~CudnnNormConvolutionGrad() {} ~CudnnNormConvolutionGrad() {}
void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr, void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr, const Tensor &filter, const Tensor &output_grad,
T *filter_grad_ptr, bool use_addto = false) { Tensor *input_grad, Tensor *filter_grad,
if (filter_grad_ptr) { bool use_addto = false) {
BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr, auto place = ctx.GetPlace();
filter_grad_ptr); T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());
if (filter_grad) {
T *filter_grad_ptr = filter_grad->mutable_data<T>(place);
BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
} }
if (input_grad_ptr) { if (input_grad) {
BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr, T *input_grad_ptr = input_grad->mutable_data<T>(place);
use_addto); BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
} }
} }
private: private:
void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr, void BackwardFilter(const platform::CUDADeviceContext &ctx,
T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) { T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) {
auto cudnn_handle = ctx.cudnn_handle(); auto cudnn_handle = ctx.cudnn_handle();
CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx); CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
...@@ -255,9 +310,8 @@ class CudnnNormConvolutionGrad { ...@@ -255,9 +310,8 @@ class CudnnNormConvolutionGrad {
workspace_size); workspace_size);
} }
void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr, void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr, T *filter_ptr, T *input_grad_ptr, bool use_addto = false) {
bool use_addto = false) {
auto cudnn_handle = ctx.cudnn_handle(); auto cudnn_handle = ctx.cudnn_handle();
size_t workspace_size = GetWorkspaceSizeBwdData(ctx); size_t workspace_size = GetWorkspaceSizeBwdData(ctx);
......
...@@ -229,15 +229,6 @@ class CudnnNormConvolutionTester { ...@@ -229,15 +229,6 @@ class CudnnNormConvolutionTester {
platform::DeviceContextPool::Instance().Get( platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0))); platform::CUDAPlace(0)));
if (!Support(*ctx)) {
LOG(INFO)
<< "Current test is only supported in the platforms with "
<< "compatiblity greater than or equal to 70 and the kernel size "
<< "must be equal to 1 or 3. Besides, when the kernel size is 1, "
<< "the stride must be 1 if the compatiblity is equal to 70.";
return;
}
framework::Tensor cpu_output_base; framework::Tensor cpu_output_base;
framework::Tensor cpu_sum_base; framework::Tensor cpu_sum_base;
framework::Tensor cpu_sum_of_square_base; framework::Tensor cpu_sum_of_square_base;
...@@ -325,14 +316,10 @@ class CudnnNormConvolutionTester { ...@@ -325,14 +316,10 @@ class CudnnNormConvolutionTester {
TensorCopySync(cpu_input_, place, &input); TensorCopySync(cpu_input_, place, &input);
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
T *input_ptr = input.data<T>(); output.Resize(framework::make_ddim(
T *filter_ptr = filter_nhwc.data<T>(); {batch_size_, out_height_, out_width_, output_channels_}));
T *output_ptr = output.mutable_data<T>( sum.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
{batch_size_, out_height_, out_width_, output_channels_}, place); sum_of_square.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
float *sum_ptr =
sum.mutable_data<float>({1, 1, 1, output_channels_}, place);
float *sum_of_square_ptr =
sum_of_square.mutable_data<float>({1, 1, 1, output_channels_}, place);
auto input_shape = framework::vectorize<int>(input.dims()); auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims()); auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
...@@ -340,8 +327,7 @@ class CudnnNormConvolutionTester { ...@@ -340,8 +327,7 @@ class CudnnNormConvolutionTester {
op::CudnnNormConvolution<T> conv_op(ctx, input_shape, filter_shape, op::CudnnNormConvolution<T> conv_op(ctx, input_shape, filter_shape,
output_shape, padding_, stride_, output_shape, padding_, stride_,
dilation_, group_); dilation_, group_);
conv_op.Forward(ctx, input_ptr, filter_ptr, output_ptr, sum_ptr, conv_op.Forward(ctx, input, filter_nhwc, &output, &sum, &sum_of_square);
sum_of_square_ptr);
TensorCopySync(output, platform::CPUPlace(), cpu_output); TensorCopySync(output, platform::CPUPlace(), cpu_output);
TensorCopySync(sum, platform::CPUPlace(), cpu_sum); TensorCopySync(sum, platform::CPUPlace(), cpu_sum);
...@@ -362,11 +348,8 @@ class CudnnNormConvolutionTester { ...@@ -362,11 +348,8 @@ class CudnnNormConvolutionTester {
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc); TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
TensorCopySync(cpu_output_grad_, place, &output_grad); TensorCopySync(cpu_output_grad_, place, &output_grad);
T *input_ptr = input.data<T>(); input_grad.Resize(input.dims());
T *filter_ptr = filter_nhwc.data<T>(); filter_grad.Resize(filter_nhwc.dims());
T *output_grad_ptr = output_grad.data<T>();
T *input_grad_ptr = input_grad.mutable_data<T>(input.dims(), place);
T *filter_grad_ptr = filter_grad.mutable_data<T>(filter_nhwc.dims(), place);
auto input_shape = framework::vectorize<int>(input.dims()); auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims()); auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
...@@ -374,26 +357,13 @@ class CudnnNormConvolutionTester { ...@@ -374,26 +357,13 @@ class CudnnNormConvolutionTester {
op::CudnnNormConvolutionGrad<T> conv_grad_op(ctx, input_shape, filter_shape, op::CudnnNormConvolutionGrad<T> conv_grad_op(ctx, input_shape, filter_shape,
output_shape, padding_, output_shape, padding_,
stride_, dilation_, group_); stride_, dilation_, group_);
conv_grad_op.Backward(ctx, input_ptr, output_grad_ptr, filter_ptr, conv_grad_op.Backward(ctx, input, filter_nhwc, output_grad, &input_grad,
input_grad_ptr, filter_grad_ptr); &filter_grad);
TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad); TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad); TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad);
} }
bool Support(const platform::CUDADeviceContext &ctx) {
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size_ == 3) || ((kernel_size_ == 1) && (stride_ == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size_ == 3) || (kernel_size_ == 1)) {
return true;
}
}
return false;
}
private: private:
int batch_size_; int batch_size_;
int height_; int height_;
...@@ -477,6 +447,15 @@ TEST(CudnnNormConvFp16, K1S2O4) { ...@@ -477,6 +447,15 @@ TEST(CudnnNormConvFp16, K1S2O4) {
CudnnNormConvolutionTester<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.CheckForward(1e-3, true); platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
test.CheckBackward(1e-3); platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3));
}
} }
...@@ -107,25 +107,33 @@ class CudnnScaleBiasAddRelu { ...@@ -107,25 +107,33 @@ class CudnnScaleBiasAddRelu {
~CudnnScaleBiasAddRelu() {} ~CudnnScaleBiasAddRelu() {}
void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr, void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr, const Tensor &x_scale, const Tensor &x_bias, const Tensor &z,
T *z_ptr = nullptr, T *z_scale_ptr = nullptr, const Tensor &z_scale, const Tensor &z_bias, Tensor *out,
T *z_bias_ptr = nullptr) { Tensor *bitmask) {
ForwardInit(ctx); ForwardInit(ctx);
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param // Set variant_param
// input ptr // input ptr
T *x_ptr = const_cast<T *>(x.data<T>());
T *x_scale_ptr = const_cast<T *>(x_scale.data<T>());
T *x_bias_ptr = const_cast<T *>(x_bias.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) { if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale.data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else { } else {
if (fused_add_) { if (fused_add_) {
T *z_ptr = const_cast<T *>(z.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
} }
} }
...@@ -134,6 +142,8 @@ class CudnnScaleBiasAddRelu { ...@@ -134,6 +142,8 @@ class CudnnScaleBiasAddRelu {
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_);
// output ptr // output ptr
T *out_ptr = out->mutable_data<T>(place);
int32_t *bitmask_ptr = bitmask->mutable_data<int32_t>(place);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
...@@ -147,16 +157,30 @@ class CudnnScaleBiasAddRelu { ...@@ -147,16 +157,30 @@ class CudnnScaleBiasAddRelu {
fwd_workspace_byte_); fwd_workspace_byte_);
} }
void Backward(const platform::CUDADeviceContext &ctx, T *dy_ptr, T *x_ptr, void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
float *scale_ptr, float *bias_ptr, float *saved_mean_ptr, const Tensor &x, const Tensor &scale, const Tensor &bias,
float *saved_invstd_ptr, int32_t *bitmask_ptr, T *dx_ptr, const Tensor &saved_mean, const Tensor &saved_invstd,
T *dz_ptr, float *dscale_ptr, float *dbias_ptr, double eps) { const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx); BackwardInit(ctx);
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle); bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param // Set variant_param
// input ptr // input ptr
T *dy_ptr = const_cast<T *>(dy.data<T>());
T *x_ptr = const_cast<T *>(x.data<T>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = const_cast<int32_t *>(bitmask.data<int32_t>());
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
float *dbias_ptr = dbias ? dbias->mutable_data<float>(place) : nullptr;
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册