提交 5e813b53 编写于 作者: J Jie Fang 提交者: gongweibao

nhwc optimization for batchnorm (#21090)

上级 fce24315
......@@ -141,6 +141,10 @@ class GradOpDescMakerBase {
return (fwd_op_.Inputs().count(name) > 0);
}
bool HasOutput(const std::string& name) const {
return (fwd_op_.Outputs().count(name) > 0);
}
private:
const OpDesc& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
......
......@@ -107,6 +107,12 @@ class GradOpBaseMakerBase {
return it != var_base_map_in_.end();
}
bool HasOutput(const std::string name) const {
auto it = var_base_map_out_.find(name);
return it != var_base_map_out_.end();
}
private:
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name,
bool is_grad,
......
......@@ -25,27 +25,42 @@ namespace paddle {
namespace operators {
void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true,
platform::errors::InvalidArgument(
"Input(Scale) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Mean"), true,
platform::errors::InvalidArgument(
"Input(Mean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Variance"), true,
platform::errors::InvalidArgument(
"Input(Variance) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
platform::errors::InvalidArgument(
"Output(Y) of BatchNormOp should not be null."));
bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) {
PADDLE_ENFORCE(ctx->HasOutput("MeanOut"),
"Output(MeanOut) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"),
"Output(VarianceOut) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"),
"Output(SavedMean) of ConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"),
"Output(SavedVariance) of ConvOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasOutput("MeanOut"), true,
platform::errors::InvalidArgument(
"Output(MeanOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("VarianceOut"), true,
platform::errors::InvalidArgument(
"Output(VarianceOut) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedMean"), true,
platform::errors::InvalidArgument(
"Output(SavedMean) of BatchNormOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SavedVariance"), true,
platform::errors::InvalidArgument(
"Output(SavedVariance) of BatchNormOp should not be null."));
}
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
......@@ -200,6 +215,10 @@ void BatchNormOpMaker::Make() {
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel")
.AsDispensable();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
......@@ -643,6 +662,9 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
op->SetInput("Bias", this->Input("Bias"));
op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance"));
if (this->HasOutput("ReserveSpace")) {
op->SetInput("ReserveSpace", this->Output("ReserveSpace"));
}
// used when setting use_global_stats True during training
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
......
......@@ -56,12 +56,39 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto dtype = platform::CudnnDataType<T>::type;
const bool fast_nhwc_batch_norm =
is_test ||
(dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);
auto compute_format =
fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
? DataLayout::kNHWC
: DataLayout::kNCHW;
Tensor transformed_x(x->type());
Tensor transformed_y(y->type());
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, y,
&transformed_y);
} else {
transformed_x.ShareDataWith(*x);
transformed_y.ShareDataWith(*y);
}
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
......@@ -90,7 +117,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
VLOG(3) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) {
if (compute_format == DataLayout::kNCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
......@@ -126,8 +153,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle,
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_y.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(),
......@@ -167,23 +195,102 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} else {
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
if (compute_format == DataLayout::kNHWC) {
called = true;
size_t workspace_size = 0;
size_t reserve_space_size = 0;
void *reserve_space_ptr = nullptr;
void *workspace_ptr = nullptr;
Tensor workspace_tensor;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
PADDLE_ENFORCE_NOT_NULL(
reserve_space,
platform::errors::NotFound(
"The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace ---------------
CUDNN_ENFORCE(
platform::dynload::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*zDesc=*/nullptr,
/*yDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space --------------
CUDNN_ENFORCE(
platform::dynload::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*activationDesc=*/nullptr,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data(
ctx.GetPlace(), transformed_x.type(), reserve_space_size);
workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size);
CUDNN_ENFORCE(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, mode_, CUDNN_BATCHNORM_OPS_BN,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, transformed_x.template data<T>(), nullptr,
nullptr, data_desc_, transformed_y.template data<T>(),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
nullptr, workspace_ptr, workspace_size, reserve_space_ptr,
reserve_space_size));
}
#endif
if (!called) {
CUDNN_ENFORCE(
platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_y.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
}
}
}
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_y, y);
}
// clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
CUDNN_ENFORCE(
......@@ -337,9 +444,41 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
auto dtype = platform::CudnnDataType<T>::type;
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
const bool fast_nhwc_batch_norm =
dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
reserve_space != nullptr;
auto compute_format =
fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
? DataLayout::kNHWC
: DataLayout::kNCHW;
Tensor transformed_x(x->type());
Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x(d_x->type());
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
&transformed_d_y);
TransToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_y,
&transformed_d_y);
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, d_x,
&transformed_d_x);
} else {
transformed_x.ShareDataWith(*x);
transformed_d_y.ShareDataWith(*d_y);
transformed_d_x.ShareDataWith(*d_x);
}
std::vector<int> dims;
std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) {
if (compute_format == DataLayout::kNCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
......@@ -348,7 +487,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const int num = x->numel();
const int num = transformed_x.numel();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
......@@ -404,20 +543,95 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_var->template data<BatchNormParamType<T>>();
if (d_scale && d_bias) {
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
if (compute_format == DataLayout::kNHWC) {
called = true;
size_t workspace_size = 0;
void *workspace_ptr = nullptr;
Tensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size();
// --------------- cudnn batchnorm workspace ---------------
CUDNN_ENFORCE(platform::dynload::
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/dev_ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*yDesc=*/data_desc_,
/*dyDesc=*/data_desc_,
/*dzDesc=*/nullptr,
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data(
ctx.GetPlace(), transformed_x.type(), workspace_size);
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/dev_ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*xDesc=*/data_desc_,
/*xData=*/transformed_x.template data<T>(),
/*yDesc=*/nullptr,
/*yData=*/nullptr,
/*dyDesc=*/data_desc_,
/*dyData=*/transformed_d_y.template data<T>(),
/*dzDesc=*/nullptr,
/*dzData=*/nullptr,
/*dxDesc=*/data_desc_,
/*dxData=*/transformed_d_x.template mutable_data<T>(
ctx.GetPlace()),
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr,
/*dBnScaleData=*/d_scale
->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
/*dBnBiasData=*/d_bias
->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
/*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/const_cast<T *>(
reserve_space->template data<T>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size));
}
#endif
if (!called) {
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
}
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NCHW to NHWC";
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_d_x, d_x);
}
} else {
if (data_layout == framework::DataLayout::kNCHW) {
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
......@@ -450,7 +664,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
if (data_layout == framework::DataLayout::kNCHW) {
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
grid1, block, 0, dev_ctx.stream()>>>(
......
......@@ -16,8 +16,10 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace paddle {
......@@ -39,24 +41,109 @@ template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
inline void ResizeToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 2) {
// input
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[3];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
} else if (dim == 1) {
transformed_input->Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(framework::make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace());
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelFirst(const framework::ExecutionContext& context,
const Tensor* input,
Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 3, 1, 2};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
template <typename DeviceContext, typename T>
inline void TransToChannelLast(const framework::ExecutionContext& context,
const Tensor* input, Tensor* transformed_input) {
int dim = input->dims().size() - 2;
if (dim == 3) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<DeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, transformed_input, axis);
} else if (dim == 2) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 3, 1};
math::Transpose<DeviceContext, T, 4> trans4;
trans4(dev_ctx, *input, transformed_input, axis);
} else if (dim == 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis{0, 2, 1};
math::Transpose<DeviceContext, T, 3> trans3;
trans3(dev_ctx, *input, transformed_input, axis);
}
}
class BatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
const framework::ExecutionContext& ctx) const override;
};
class BatchNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
const framework::ExecutionContext& ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -85,13 +172,13 @@ class BatchNormOpInferVarType
template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
......
......@@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP);
#endif
#ifdef PADDLE_USE_DSO
bool HasCUDNN() {
std::call_once(cudnn_dso_flag,
......
......@@ -189,6 +189,15 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#if CUDNN_VERSION >= 7401
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \
__macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \
__macro(cudnnBatchNormalizationForwardTrainingEx); \
__macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \
__macro(cudnnBatchNormalizationBackwardEx); \
__macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -2523,6 +2523,13 @@ def batch_norm(input,
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'batch_norm')
dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
......@@ -2577,6 +2584,11 @@ def batch_norm(input,
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
reserve_space = None
if has_reserve_space:
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
......@@ -2599,17 +2611,19 @@ def batch_norm(input,
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
outputs = {
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
}
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
helper.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs=attrs)
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return helper.append_activation(batch_norm_out)
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid.core as core
......@@ -413,16 +414,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
inputs['MomentumTensor'] = block.var('momentum_var')
else:
attrs['momentum'] = momentum
outputs = {
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory
"VarianceOut": block.var('variance'), # share memory
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
}
has_reserve_space = False
if data_format == 'NHWC':
flag = os.environ.get(
'FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
if has_reserve_space:
block.create_var(name="reserve_space", dtype='float16')
outputs["ReserveSpace"] = block.var('reserve_space')
del os.environ['FLAGS_cudnn_batchnorm_spatial_persistent']
bn_op = block.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory
"VarianceOut": block.var('variance'), # share memory
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
},
outputs=outputs,
attrs=attrs)
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
......@@ -479,6 +492,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1"
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self):
self.use_momentum_variable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册