未验证 提交 3e2dec5b 编写于 作者: Z Zhang Zheng 提交者: GitHub

Change the input param of fusion op interface from pointer to tensor (#36349)

上级 fba355fb
......@@ -536,32 +536,20 @@ class CudnnBNAddReluTester {
bn_bias->Resize({1, 1, 1, channels_});
// input
float *sum_ptr = sum->data<float>();
float *sum_of_square_ptr = sum_of_square->data<float>();
float *bn_scale_ptr = bn_scale->data<float>();
float *bn_bias_ptr = bn_bias->data<float>();
mean->Resize({1, 1, 1, channels_});
var->Resize({1, 1, 1, channels_});
// output
float *mean_ptr = mean->data<float>();
float *var_ptr = var->data<float>();
float *saved_mean_ptr =
saved_mean->mutable_data<float>({1, 1, 1, channels_}, place);
float *saved_var_ptr =
saved_var->mutable_data<float>({1, 1, 1, channels_}, place);
T *equiv_scale_ptr =
equiv_scale->mutable_data<T>({1, 1, 1, channels_}, place);
T *equiv_bias_ptr =
equiv_bias->mutable_data<T>({1, 1, 1, channels_}, place);
equiv_scale->Resize({1, 1, 1, channels_});
equiv_bias->Resize({1, 1, 1, channels_});
saved_mean->Resize({1, 1, 1, channels_});
saved_var->Resize({1, 1, 1, channels_});
auto param_shape = framework::vectorize<int>(bn_scale->dims());
op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape);
bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr,
saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr,
equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_,
true);
bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean,
saved_var, mean, var, equiv_scale, equiv_bias, eps_,
momentum_, ele_count_, true);
}
// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
......@@ -627,21 +615,13 @@ class CudnnBNAddReluTester {
&saved_var_z, &equiv_scale_z, &equiv_bias_z);
}
T *x_ptr = x.data<T>();
T *z_ptr = (fuse_add_ || has_shortcut_) ? z.data<T>() : nullptr;
T *equiv_scale_x_ptr = equiv_scale_x.data<T>();
T *equiv_bias_x_ptr = equiv_bias_x.data<T>();
T *equiv_scale_z_ptr = has_shortcut_ ? equiv_scale_z.data<T>() : nullptr;
T *equiv_bias_z_ptr = has_shortcut_ ? equiv_bias_z.data<T>() : nullptr;
T *y_ptr =
y.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
int c = channels_;
int64_t nhw = ele_count_;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = (nhw + 31) & ~31;
int32_t *bitmask_ptr = bitmask.mutable_data<int32_t>(
{nhw_int32_elems, c_int32_elems, 1}, place);
bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1}));
auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale_x.dims());
......@@ -651,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x_ptr, equiv_scale_x_ptr, equiv_bias_x_ptr, y_ptr,
bitmask_ptr, z_ptr, equiv_scale_z_ptr, equiv_bias_z_ptr);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
equiv_bias_z, &y, &bitmask);
TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
......@@ -697,19 +677,10 @@ class CudnnBNAddReluTester {
saved_mean.Resize({1, 1, 1, channels_});
saved_var.Resize({1, 1, 1, channels_});
T *dy_ptr = dy.data<T>();
T *x_ptr = x.data<T>();
float *bn_scale_ptr = bn_scale.data<float>();
float *bn_bias_ptr = bn_bias.data<float>();
float *saved_mean_ptr = saved_mean.data<float>();
float *saved_var_ptr = saved_var.data<float>();
int32_t *bitmask_ptr = bitmask.data<int32_t>();
T *dx_ptr =
dx.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
T *dz_ptr =
dz.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
float *dscale_ptr = dscale.mutable_data<float>({1, 1, 1, channels_}, place);
float *dbias_ptr = dbias.mutable_data<float>({1, 1, 1, channels_}, place);
dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dscale.Resize(framework::make_ddim({1, 1, 1, channels_}));
dbias.Resize(framework::make_ddim({1, 1, 1, channels_}));
auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale.dims());
......@@ -718,9 +689,8 @@ class CudnnBNAddReluTester {
std::string act_type = "relu";
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy_ptr, x_ptr, bn_scale_ptr, bn_bias_ptr,
saved_mean_ptr, saved_var_ptr, bitmask_ptr, dx_ptr, dz_ptr,
dscale_ptr, dbias_ptr, eps_);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
bitmask, &dx, &dz, &dscale, &dbias, eps_);
TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
......
......@@ -68,12 +68,13 @@ class CudnnBNStatsFinalize {
}
~CudnnBNStatsFinalize() {}
void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr,
float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr,
float *saved_mean_ptr, float *saved_invstd_ptr,
float *running_mean_ptr, float *running_var_ptr,
T *equiv_scale_ptr, T *equiv_bias_ptr, double eps,
float momentum, int64_t ele_count, bool is_train) {
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum,
const Tensor &sum_of_squares, const Tensor &scale,
const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd,
Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale,
Tensor *equiv_bias, double eps, float momentum,
int64_t ele_count, bool is_train) {
auto place = ctx.GetPlace();
if (is_train) {
TrainInit(ctx);
} else {
......@@ -82,6 +83,17 @@ class CudnnBNStatsFinalize {
auto &op = is_train ? train_op_ : inference_op_;
// Set variant_param for both inference_op_ and train_op_
float *sum_ptr = const_cast<float *>(sum.data<float>());
float *sum_of_squares_ptr =
const_cast<float *>(sum_of_squares.data<float>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = saved_mean->mutable_data<float>(place);
float *saved_invstd_ptr = saved_invstd->mutable_data<float>(place);
float *running_mean_ptr = running_mean->mutable_data<float>(place);
float *running_var_ptr = running_var->mutable_data<float>(place);
T *equiv_scale_ptr = equiv_scale->mutable_data<T>(place);
T *equiv_bias_ptr = equiv_bias->mutable_data<T>(place);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
......
......@@ -38,7 +38,8 @@ struct NormConvolutionArgs {
compute_type = platform::CudnnDataType<float>::type;
}
void Set(const std::vector<int> &input_shape,
void Set(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, int padding, int stride,
int dilation, int group) {
......@@ -61,12 +62,33 @@ struct NormConvolutionArgs {
"The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s].",
framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
true,
platform::errors::InvalidArgument(
"The input channel is expected to be multiple of 8, "
"and the output channel is expected to be multiple "
"of 32. But recieved input channel is %d, output "
"channel is %d.",
filter_shape[3], filter_shape[0]));
PADDLE_ENFORCE_EQ(
output_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
output_shape.size(), framework::make_ddim(output_shape)));
is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
PADDLE_ENFORCE_EQ(
is_support, true,
platform::errors::InvalidArgument(
"Current test is only supported in the platforms with "
"compatiblity greater than or equal to 70 and the kernel size "
"must be equal to 1 or 3. When the kernel size is 1, "
"the stride must be 1 if the compatiblity is equal to 70. "
"Besides, the dilation and group must be equal to 1. But recieved "
"compatiblity is %d, kernel size is %d, stride is %d, "
"dilation is %d, group is %d",
ctx.GetComputeCapability(), filter_shape[1], stride, dilation,
group));
for (size_t i = 0; i < input_shape.size(); ++i) {
in_dims.push_back(input_shape[i]);
......@@ -89,6 +111,25 @@ struct NormConvolutionArgs {
conv_desc.set(dtype, paddings, strides, dilations, false, group);
}
bool IsSupport(const platform::CUDADeviceContext &ctx,
const std::vector<int> &filter_shape, int stride, int dilation,
int group) {
int kernel_size = filter_shape[1];
if (dilation != 1 || group != 1) {
return false;
}
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size == 3) || (kernel_size == 1)) {
return true;
}
}
return false;
}
cudnnDataType_t dtype;
cudnnTensorFormat_t format;
cudnnDataType_t compute_type;
......@@ -104,6 +145,8 @@ struct NormConvolutionArgs {
platform::TensorDescriptor out_desc;
platform::TensorDescriptor out_stats_desc;
platform::ConvolutionDescriptor conv_desc;
bool is_support;
};
template <typename T>
......@@ -115,15 +158,16 @@ class CudnnNormConvolution {
const std::vector<int> &output_shape, const int &padding,
const int &stride, const int &dilation,
const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride,
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
}
~CudnnNormConvolution() {}
void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *filter_ptr, T *output_ptr, float *sum_ptr,
float *sum_of_squares_ptr) {
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, Tensor *output, Tensor *sum,
Tensor *sum_of_squares) {
auto cudnn_handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp(
......@@ -132,12 +176,17 @@ class CudnnNormConvolution {
// Set variant_param
// input ptr
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
fwd_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
// output ptr
T *output_ptr = output->mutable_data<T>(place);
float *sum_ptr = sum->mutable_data<float>(place);
float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
......@@ -209,28 +258,34 @@ class CudnnNormConvolutionGrad {
const std::vector<int> &output_shape,
const int &padding, const int &stride,
const int &dilation, const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride,
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
~CudnnNormConvolutionGrad() {}
void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
T *filter_grad_ptr, bool use_addto = false) {
if (filter_grad_ptr) {
BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr,
filter_grad_ptr);
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, const Tensor &output_grad,
Tensor *input_grad, Tensor *filter_grad,
bool use_addto = false) {
auto place = ctx.GetPlace();
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());
if (filter_grad) {
T *filter_grad_ptr = filter_grad->mutable_data<T>(place);
BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
}
if (input_grad_ptr) {
BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr,
use_addto);
if (input_grad) {
T *input_grad_ptr = input_grad->mutable_data<T>(place);
BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
}
}
private:
void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) {
void BackwardFilter(const platform::CUDADeviceContext &ctx,
T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) {
auto cudnn_handle = ctx.cudnn_handle();
CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
......@@ -255,9 +310,8 @@ class CudnnNormConvolutionGrad {
workspace_size);
}
void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
bool use_addto = false) {
void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr,
T *filter_ptr, T *input_grad_ptr, bool use_addto = false) {
auto cudnn_handle = ctx.cudnn_handle();
size_t workspace_size = GetWorkspaceSizeBwdData(ctx);
......
......@@ -229,15 +229,6 @@ class CudnnNormConvolutionTester {
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
if (!Support(*ctx)) {
LOG(INFO)
<< "Current test is only supported in the platforms with "
<< "compatiblity greater than or equal to 70 and the kernel size "
<< "must be equal to 1 or 3. Besides, when the kernel size is 1, "
<< "the stride must be 1 if the compatiblity is equal to 70.";
return;
}
framework::Tensor cpu_output_base;
framework::Tensor cpu_sum_base;
framework::Tensor cpu_sum_of_square_base;
......@@ -325,14 +316,10 @@ class CudnnNormConvolutionTester {
TensorCopySync(cpu_input_, place, &input);
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
T *input_ptr = input.data<T>();
T *filter_ptr = filter_nhwc.data<T>();
T *output_ptr = output.mutable_data<T>(
{batch_size_, out_height_, out_width_, output_channels_}, place);
float *sum_ptr =
sum.mutable_data<float>({1, 1, 1, output_channels_}, place);
float *sum_of_square_ptr =
sum_of_square.mutable_data<float>({1, 1, 1, output_channels_}, place);
output.Resize(framework::make_ddim(
{batch_size_, out_height_, out_width_, output_channels_}));
sum.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
sum_of_square.Resize(framework::make_ddim({1, 1, 1, output_channels_}));
auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
......@@ -340,8 +327,7 @@ class CudnnNormConvolutionTester {
op::CudnnNormConvolution<T> conv_op(ctx, input_shape, filter_shape,
output_shape, padding_, stride_,
dilation_, group_);
conv_op.Forward(ctx, input_ptr, filter_ptr, output_ptr, sum_ptr,
sum_of_square_ptr);
conv_op.Forward(ctx, input, filter_nhwc, &output, &sum, &sum_of_square);
TensorCopySync(output, platform::CPUPlace(), cpu_output);
TensorCopySync(sum, platform::CPUPlace(), cpu_sum);
......@@ -362,11 +348,8 @@ class CudnnNormConvolutionTester {
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
TensorCopySync(cpu_output_grad_, place, &output_grad);
T *input_ptr = input.data<T>();
T *filter_ptr = filter_nhwc.data<T>();
T *output_grad_ptr = output_grad.data<T>();
T *input_grad_ptr = input_grad.mutable_data<T>(input.dims(), place);
T *filter_grad_ptr = filter_grad.mutable_data<T>(filter_nhwc.dims(), place);
input_grad.Resize(input.dims());
filter_grad.Resize(filter_nhwc.dims());
auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
......@@ -374,26 +357,13 @@ class CudnnNormConvolutionTester {
op::CudnnNormConvolutionGrad<T> conv_grad_op(ctx, input_shape, filter_shape,
output_shape, padding_,
stride_, dilation_, group_);
conv_grad_op.Backward(ctx, input_ptr, output_grad_ptr, filter_ptr,
input_grad_ptr, filter_grad_ptr);
conv_grad_op.Backward(ctx, input, filter_nhwc, output_grad, &input_grad,
&filter_grad);
TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad);
}
bool Support(const platform::CUDADeviceContext &ctx) {
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size_ == 3) || ((kernel_size_ == 1) && (stride_ == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size_ == 3) || (kernel_size_ == 1)) {
return true;
}
}
return false;
}
private:
int batch_size_;
int height_;
......@@ -477,6 +447,15 @@ TEST(CudnnNormConvFp16, K1S2O4) {
CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size,
stride);
test.CheckForward(1e-3, true);
test.CheckBackward(1e-3);
platform::CUDADeviceContext *ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(platform::CUDAPlace(0)));
if (ctx->GetComputeCapability() <= 70) {
ASSERT_THROW(test.CheckForward(1e-3, true),
paddle::platform::EnforceNotMet);
ASSERT_THROW(test.CheckBackward(1e-3), paddle::platform::EnforceNotMet);
} else {
ASSERT_NO_THROW(test.CheckForward(1e-3, true));
ASSERT_NO_THROW(test.CheckBackward(1e-3));
}
}
......@@ -107,25 +107,33 @@ class CudnnScaleBiasAddRelu {
~CudnnScaleBiasAddRelu() {}
void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr,
T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr,
T *z_ptr = nullptr, T *z_scale_ptr = nullptr,
T *z_bias_ptr = nullptr) {
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
const Tensor &x_scale, const Tensor &x_bias, const Tensor &z,
const Tensor &z_scale, const Tensor &z_bias, Tensor *out,
Tensor *bitmask) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle();
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
T *x_ptr = const_cast<T *>(x.data<T>());
T *x_scale_ptr = const_cast<T *>(x_scale.data<T>());
T *x_bias_ptr = const_cast<T *>(x_bias.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale.data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fused_add_) {
T *z_ptr = const_cast<T *>(z.data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
......@@ -134,6 +142,8 @@ class CudnnScaleBiasAddRelu {
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_);
// output ptr
T *out_ptr = out->mutable_data<T>(place);
int32_t *bitmask_ptr = bitmask->mutable_data<int32_t>(place);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
......@@ -147,16 +157,30 @@ class CudnnScaleBiasAddRelu {
fwd_workspace_byte_);
}
void Backward(const platform::CUDADeviceContext &ctx, T *dy_ptr, T *x_ptr,
float *scale_ptr, float *bias_ptr, float *saved_mean_ptr,
float *saved_invstd_ptr, int32_t *bitmask_ptr, T *dx_ptr,
T *dz_ptr, float *dscale_ptr, float *dbias_ptr, double eps) {
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
const Tensor &x, const Tensor &scale, const Tensor &bias,
const Tensor &saved_mean, const Tensor &saved_invstd,
const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();
auto workspace_handle = ctx.cudnn_workspace_handle();
bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
T *dy_ptr = const_cast<T *>(dy.data<T>());
T *x_ptr = const_cast<T *>(x.data<T>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = const_cast<int32_t *>(bitmask.data<int32_t>());
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
float *dbias_ptr = dbias ? dbias->mutable_data<float>(place) : nullptr;
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册