未验证 提交 7e6c0cee 编写于 作者: Z Zhang Zheng 提交者: GitHub

Implement Fused BN + Add + Relu with cudnnFusedOps API. (#35955)

上级 91119271
......@@ -80,5 +80,6 @@ if (WITH_GPU OR WITH_ROCM)
endif()
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif()
endif()
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent);
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor;
USE_OP(batch_norm);
template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
std::default_random_engine random(0);
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = static_cast<T>(dis(random));
}
}
template <typename T>
void InitConstantTensor(const std::vector<int64_t> &dims, T value,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = value;
}
}
template <typename T>
void CheckOutput(std::string name, const framework::Tensor &cpu_res,
const framework::Tensor &cpu_base, float diff,
bool is_relative_atol = false) {
if (cpu_res.dims().size() == cpu_base.dims().size()) {
EXPECT_EQ(cpu_res.dims(), cpu_base.dims());
} else {
EXPECT_EQ(cpu_res.numel(), cpu_base.numel());
}
const T *cpu_res_ptr = cpu_res.data<T>();
const T *cpu_base_ptr = cpu_base.data<T>();
float max_diff = 0;
int index = 0;
for (int i = 0; i < cpu_res.numel(); ++i) {
float cur_diff;
if (is_relative_atol) {
cur_diff = static_cast<float>(
std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) / cpu_base_ptr[i]));
EXPECT_LT(static_cast<float>(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) /
cpu_base_ptr[i])),
diff);
} else {
cur_diff = static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i]));
EXPECT_LT(static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])),
diff);
}
if (cur_diff > max_diff) {
max_diff = cur_diff;
index = i;
}
}
std::string error_type = is_relative_atol ? "relative" : "absolute";
LOG(INFO) << "[" << name << "], The dims is [" << cpu_res.dims()
<< "], maximum " << error_type << " error is " << max_diff << ": "
<< cpu_res_ptr[index] << " vs " << cpu_base_ptr[index];
}
template <typename T>
void ComputeSumAndSquareSum(const framework::Tensor &cpu_x,
framework::Tensor *cpu_sum,
framework::Tensor *cpu_sum_of_square) {
// x is in NHWC format.
auto dims = cpu_x.dims();
int64_t c = dims[3];
const T *cpu_x_ptr = cpu_x.data<T>();
float *cpu_sum_ptr =
cpu_sum->mutable_data<float>({1, 1, 1, c}, platform::CPUPlace());
float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data<float>(
{1, 1, 1, c}, platform::CPUPlace());
for (int j = 0; j < c; ++j) {
float tmp_sum = 0.0f;
float tmp_sum_of_squares = 0.0f;
for (int i = 0; i < cpu_x.numel() / c; ++i) {
float tmp_x = static_cast<float>(cpu_x_ptr[i * c + j]);
tmp_sum += tmp_x;
tmp_sum_of_squares += tmp_x * tmp_x;
}
cpu_sum_ptr[j] = tmp_sum;
cpu_sum_square_ptr[j] = tmp_sum_of_squares;
}
}
// get paddle batchnorm op results as baseline
void ComputeBatchNormForward(const platform::CUDADeviceContext &ctx,
const Tensor &cpu_x, const Tensor &cpu_scale,
const Tensor &cpu_bias, Tensor *cpu_mean,
Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y,
Tensor *cpu_reserve_space) {
framework::Scope scope;
auto *x = scope.Var("X")->GetMutable<framework::LoDTensor>();
auto *scale = scope.Var("Scale")->GetMutable<framework::LoDTensor>();
auto *bias = scope.Var("Bias")->GetMutable<framework::LoDTensor>();
auto *mean = scope.Var("Mean")->GetMutable<framework::LoDTensor>();
auto *var = scope.Var("Variance")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("Y")->GetMutable<framework::LoDTensor>();
auto *saved_mean = scope.Var("SavedMean")->GetMutable<framework::LoDTensor>();
auto *saved_var =
scope.Var("SavedVariance")->GetMutable<framework::LoDTensor>();
auto *reserve_space =
scope.Var("ReserveSpace")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_x, place, x);
TensorCopySync(cpu_scale, place, scale);
TensorCopySync(cpu_bias, place, bias);
TensorCopySync(*cpu_mean, place, mean);
TensorCopySync(*cpu_var, place, var);
int64_t channels = x->dims()[3];
scale->Resize({channels});
bias->Resize({channels});
mean->Resize({channels});
var->Resize({channels});
framework::AttributeMap attrs;
std::string data_layout = "NHWC";
attrs.insert({"data_layout", data_layout});
auto op = framework::OpRegistry::CreateOp(
"batch_norm", {{"X", {"X"}},
{"Scale", {"Scale"}},
{"Bias", {"Bias"}},
{"Mean", {"Mean"}},
{"Variance", {"Variance"}}},
{{"Y", {"Y"}},
{"MeanOut", {"Mean"}},
{"VarianceOut", {"Variance"}},
{"SavedMean", {"SavedMean"}},
{"SavedVariance", {"SavedVariance"}},
{"ReserveSpace", {"ReserveSpace"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*y, platform::CPUPlace(), cpu_y);
TensorCopySync(*mean, platform::CPUPlace(), cpu_mean);
TensorCopySync(*var, platform::CPUPlace(), cpu_var);
TensorCopySync(*saved_mean, platform::CPUPlace(), cpu_saved_mean);
TensorCopySync(*saved_var, platform::CPUPlace(), cpu_saved_var);
TensorCopySync(*reserve_space, platform::CPUPlace(), cpu_reserve_space);
}
template <typename T>
class CudnnBNAddReluTester {
public:
CudnnBNAddReluTester(int batch_size, int height, int width, int channels) {
batch_size_ = batch_size;
height_ = height;
width_ = width;
channels_ = channels;
ele_count_ = batch_size_ * height_ * width_;
SetUp();
}
~CudnnBNAddReluTester() {}
void CheckForward(float diff, bool is_relative_atol = false) {
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
framework::Tensor cpu_mean_base;
framework::Tensor cpu_var_base;
framework::Tensor cpu_saved_mean_base;
framework::Tensor cpu_saved_var_base;
framework::Tensor cpu_y_base;
framework::Tensor cpu_reserve_space_base;
BaselineForward(*ctx, &cpu_mean_base, &cpu_var_base, &cpu_saved_mean_base,
&cpu_saved_var_base, &cpu_y_base, &cpu_reserve_space_base);
framework::Tensor cpu_mean;
framework::Tensor cpu_var;
framework::Tensor cpu_saved_mean;
framework::Tensor cpu_saved_var;
framework::Tensor cpu_y;
framework::Tensor cpu_bitmask;
FusedForward(*ctx, &cpu_mean, &cpu_var, &cpu_saved_mean, &cpu_saved_var,
&cpu_y, &cpu_bitmask);
CheckOutput<float>("Mean", cpu_mean, cpu_mean_base, diff, is_relative_atol);
CheckOutput<float>("Variance", cpu_var, cpu_var_base, diff,
is_relative_atol);
CheckOutput<float>("SavedMean", cpu_saved_mean, cpu_saved_mean_base, diff,
is_relative_atol);
CheckOutput<float>("SavedVariance", cpu_saved_var, cpu_saved_var_base, diff,
is_relative_atol);
CheckOutput<T>("Y", cpu_y, cpu_y_base, diff, is_relative_atol);
}
private:
void SetUp() {
// Initialize input data
InitRandomTensor<T>({batch_size_, height_, width_, channels_}, &cpu_x_);
ComputeSumAndSquareSum<T>(cpu_x_, &cpu_sum_, &cpu_sum_of_square_);
// scale and bias should be initialized randomly.
InitConstantTensor<float>({channels_}, static_cast<float>(1.0f),
&cpu_bn_scale_);
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f),
&cpu_bn_bias_);
}
void InitMeanVar(Tensor *cpu_mean, Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var) {
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f), cpu_mean);
InitConstantTensor<float>({channels_}, static_cast<float>(1.0f), cpu_var);
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f),
cpu_saved_mean);
InitConstantTensor<float>({channels_}, static_cast<float>(0.0f),
cpu_saved_var);
}
void BaselineForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean,
Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y,
Tensor *cpu_reserve_space) {
InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var);
ComputeBatchNormForward(ctx, cpu_x_, cpu_bn_scale_, cpu_bn_bias_, cpu_mean,
cpu_var, cpu_saved_mean, cpu_saved_var, cpu_y,
cpu_reserve_space);
}
// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
void FusedForward(const platform::CUDADeviceContext &ctx, Tensor *cpu_mean,
Tensor *cpu_var, Tensor *cpu_saved_mean,
Tensor *cpu_saved_var, Tensor *cpu_y, Tensor *cpu_bitmask) {
framework::Tensor x;
framework::Tensor sum;
framework::Tensor sum_of_square;
framework::Tensor bn_scale;
framework::Tensor bn_bias;
auto place = ctx.GetPlace();
TensorCopySync(cpu_x_, place, &x);
TensorCopySync(cpu_sum_, place, &sum);
TensorCopySync(cpu_sum_of_square_, place, &sum_of_square);
TensorCopySync(cpu_bn_scale_, place, &bn_scale);
TensorCopySync(cpu_bn_bias_, place, &bn_bias);
bn_scale.Resize({1, 1, 1, channels_});
bn_bias.Resize({1, 1, 1, channels_});
T *x_ptr = x.data<T>();
float *sum_ptr = sum.data<float>();
float *sum_of_square_ptr = sum_of_square.data<float>();
float *bn_scale_ptr = bn_scale.data<float>();
float *bn_bias_ptr = bn_bias.data<float>();
framework::Tensor mean;
framework::Tensor var;
framework::Tensor saved_mean;
framework::Tensor saved_var;
framework::Tensor equiv_scale;
framework::Tensor equiv_bias;
framework::Tensor y;
framework::Tensor bitmask;
InitMeanVar(cpu_mean, cpu_var, cpu_saved_mean, cpu_saved_var);
TensorCopySync(*cpu_mean, place, &mean);
TensorCopySync(*cpu_var, place, &var);
mean.Resize({1, 1, 1, channels_});
var.Resize({1, 1, 1, channels_});
float *mean_ptr = mean.data<float>();
float *var_ptr = var.data<float>();
float *saved_mean_ptr =
saved_mean.mutable_data<float>({1, 1, 1, channels_}, place);
float *saved_var_ptr =
saved_var.mutable_data<float>({1, 1, 1, channels_}, place);
T *equiv_scale_ptr =
equiv_scale.mutable_data<T>({1, 1, 1, channels_}, place);
T *equiv_bias_ptr = equiv_bias.mutable_data<T>({1, 1, 1, channels_}, place);
T *y_ptr =
y.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
// bitmask
int c = channels_;
int64_t nhw = ele_count_;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = (nhw + 31) & ~31;
int32_t *bitmask_ptr = bitmask.mutable_data<int32_t>(
{nhw_int32_elems, c_int32_elems, 1}, place);
auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale.dims());
auto bitmask_shape = framework::vectorize<int>(bitmask.dims());
// 1. BN Stats Finalize
op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape);
bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr,
saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr,
equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_,
true);
// 2. Scale Bias + Relu (not fused add)
std::string act_type = "";
op::CudnnScaleBiasAddRelu<T> sbar_op(
ctx, act_type, false, false, data_shape, param_shape, bitmask_shape);
sbar_op.Forward(ctx, x_ptr, equiv_scale_ptr, equiv_bias_ptr, y_ptr,
bitmask_ptr);
TensorCopySync(mean, platform::CPUPlace(), cpu_mean);
TensorCopySync(var, platform::CPUPlace(), cpu_var);
TensorCopySync(saved_mean, platform::CPUPlace(), cpu_saved_mean);
TensorCopySync(saved_var, platform::CPUPlace(), cpu_saved_var);
TensorCopySync(y, platform::CPUPlace(), cpu_y);
TensorCopySync(bitmask, platform::CPUPlace(), cpu_bitmask);
}
private:
int batch_size_;
int height_;
int width_;
int channels_;
int ele_count_;
// Forward input
framework::Tensor cpu_x_;
framework::Tensor cpu_sum_;
framework::Tensor cpu_sum_of_square_;
framework::Tensor cpu_bn_scale_;
framework::Tensor cpu_bn_bias_;
double eps_ = 1e-5;
float momentum_ = 0.9;
};
TEST(CudnnBNAddReluForward, GPUCudnnBNAddReluForwardFp16) {
int batch_size = 4;
int height = 8;
int width = 8;
int channels = 64;
FLAGS_cudnn_batchnorm_spatial_persistent = true;
CudnnBNAddReluTester<paddle::platform::float16> test(batch_size, height,
width, channels);
test.CheckForward(2e-3);
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;
template <typename T>
using BatchNormParamType =
typename platform::CudnnDataType<T>::BatchNormParamType;
#if CUDNN_VERSION >= 8000
template <typename T>
struct BNStatsFinalizeArgs {
BNStatsFinalizeArgs() {
dtype = platform::CudnnDataType<T>::type;
param_dtype = platform::CudnnDataType<BatchNormParamType<T>>::type;
format = CUDNN_TENSOR_NHWC;
}
void Set(const std::vector<int> &param_shape) {
PADDLE_ENFORCE_EQ(
param_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s].",
param_shape.size(), framework::make_ddim(param_shape)));
in_desc.set(param_shape, format, param_dtype);
out_desc.set(param_shape, format, dtype);
}
cudnnDataType_t dtype;
cudnnDataType_t param_dtype;
cudnnTensorFormat_t format;
platform::TensorDescriptor in_desc;
platform::TensorDescriptor out_desc;
};
template <typename T>
class CudnnBNStatsFinalize {
public:
CudnnBNStatsFinalize(const platform::CUDADeviceContext &ctx,
const std::vector<int> &param_shape)
: train_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_TRAINING),
inference_op_(CUDNN_FUSED_BN_FINALIZE_STATISTICS_INFERENCE) {
args_.Set(param_shape);
}
~CudnnBNStatsFinalize() {}
void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr,
float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr,
float *saved_mean_ptr, float *saved_invstd_ptr,
float *running_mean_ptr, float *running_var_ptr,
T *equiv_scale_ptr, T *equiv_bias_ptr, double eps,
float momentum, int64_t ele_count, bool is_train) {
if (is_train) {
TrainInit(ctx);
} else {
InferenceInit(ctx);
}
auto &op = is_train ? train_op_ : inference_op_;
// Set variant_param for both inference_op_ and train_op_
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_VAR, running_var_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, equiv_scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, equiv_bias_ptr);
op.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps);
// Set extra variant_param only for train_op_:
if (is_train) {
op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD, saved_invstd_ptr);
double avg_factor = 1.0 - momentum;
op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_INT64_T_BN_ACCUMULATION_COUNT,
&ele_count);
op.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EXP_AVG_FACTOR,
&avg_factor);
}
// fused op execute
auto handle = ctx.cudnn_handle();
op.Execute(handle);
}
private:
void TrainInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param for train op
train_op_.SetOpConstParamAttr(
{CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER,
CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// Set input and output desc for train op
train_op_.SetOpConstParamDesc(
{CUDNN_PARAM_YSTATS_DESC, CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC},
args_.in_desc.desc());
train_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.out_desc.desc());
// Get workspace
auto handle = ctx.cudnn_handle();
train_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Check workspace size, also creates plan.
size_t workspace_size_bytes = train_op_.GetWorkspaceSizeInBytes(handle);
PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U,
platform::errors::InvalidArgument(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."));
train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
static_cast<void *>(nullptr));
train_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
&workspace_size_bytes);
}
void InferenceInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param for inference op
inference_op_.SetOpConstParamAttr(
{CUDNN_PARAM_BN_SCALE_PLACEHOLDER, CUDNN_PARAM_BN_BIAS_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_RUNNING_VAR_PLACEHOLDER,
CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// Set input and output desc for inference op
inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC,
args_.in_desc.desc());
inference_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.out_desc.desc());
// Get workspace
auto handle = ctx.cudnn_handle();
inference_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Check workspace size, also creates plan.
size_t workspace_size_bytes = inference_op_.GetWorkspaceSizeInBytes(handle);
PADDLE_ENFORCE_EQ(workspace_size_bytes, 0U,
platform::errors::InvalidArgument(
"Unexpected non-zero workspace size for "
"CudnnBNStatsFinalize."));
inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
static_cast<void *>(nullptr));
inference_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
&workspace_size_bytes);
}
BNStatsFinalizeArgs<T> args_;
CudnnFusionOp train_op_;
CudnnFusionOp inference_op_;
};
#endif
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
namespace dynload = platform::dynload;
template <typename T>
using BatchNormParamType =
typename platform::CudnnDataType<T>::BatchNormParamType;
#if CUDNN_VERSION >= 8000
template <typename T>
struct ScaleBiasAddReluArgs {
ScaleBiasAddReluArgs() {
dtype = platform::CudnnDataType<T>::type;
param_dtype = platform::CudnnDataType<BatchNormParamType<T>>::type;
format = CUDNN_TENSOR_NHWC;
}
void Set(const std::string &act_type, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape) {
PADDLE_ENFORCE_EQ(
data_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of data_shape is expected to 4. But recieved "
"data_shape's size is %d, data_shape is [%s].",
data_shape.size(), framework::make_ddim(data_shape)));
PADDLE_ENFORCE_EQ(
param_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of param_shape is expected to 4. But recieved "
"param_shape's size is %d, param_shape is [%s].",
param_shape.size(), framework::make_ddim(param_shape)));
PADDLE_ENFORCE_EQ(
bitmask_shape.size(), 3U,
platform::errors::InvalidArgument(
"The size of bitmask_shape is expected to 3. But recieved "
"bitmask_shape's size is %d, bitmask_shape is [%s].",
bitmask_shape.size(), framework::make_ddim(bitmask_shape)));
in_desc.set(data_shape, format, dtype);
out_desc.set(data_shape, format, dtype);
equiv_scale_bias_desc.set(param_shape, format, dtype);
scale_bias_mean_var_desc.set(param_shape, format, param_dtype);
bitmask_desc.set(bitmask_shape, format, CUDNN_DATA_INT32);
// set activation desc
cudnnActivationMode_t mode = CUDNN_ACTIVATION_IDENTITY;
if (act_type != "") {
PADDLE_ENFORCE_EQ(
act_type, "relu",
platform::errors::InvalidArgument(
"Only relu activation supported in normalized convolution."));
mode = CUDNN_ACTIVATION_RELU;
}
double dummy_clip = 0.0;
activation_desc.set(mode, dummy_clip);
}
cudnnDataType_t dtype;
cudnnDataType_t param_dtype;
cudnnTensorFormat_t format;
platform::TensorDescriptor in_desc;
platform::TensorDescriptor out_desc;
platform::TensorDescriptor equiv_scale_bias_desc;
platform::TensorDescriptor scale_bias_mean_var_desc;
platform::TensorDescriptor bitmask_desc;
platform::ActivationDescriptor activation_desc;
};
template <typename T>
class CudnnScaleBiasAddRelu {
public:
CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx,
const std::string &act_type, bool fused_add,
bool has_shortcut, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape)
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK),
bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) {
fused_add_ = fused_add;
has_shortcut_ = has_shortcut;
args_.Set(act_type, data_shape, param_shape, bitmask_shape);
}
~CudnnScaleBiasAddRelu() {}
void Forward(const platform::CUDADeviceContext &ctx, T *x_ptr, T *x_scale_ptr,
T *x_bias_ptr, T *out_ptr, int32_t *bitmask_ptr,
T *z_ptr = nullptr, T *z_scale_ptr = nullptr,
T *z_bias_ptr = nullptr) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fused_add_) {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
fwd_op_.SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_);
// output ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, out_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
workspace_handle.RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// workspace ptr
fwd_op_.Execute(handle);
},
fwd_workspace_byte_);
}
void Backward(const platform::CUDADeviceContext &ctx, T *dy_ptr, T *x_ptr,
float *scale_ptr, float *bias_ptr, float *saved_mean_ptr,
float *saved_invstd_ptr, int32_t *bitmask_ptr, T *dx_ptr,
T *dz_ptr, float *dscale_ptr, float *dbias_ptr, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
bwd_workspace_byte_ = bwd_op_.GetWorkspaceSizeInBytes(handle);
// Set variant_param
// input ptr
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, x_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, dy_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_MEAN, saved_mean_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SAVED_INVSTD,
saved_invstd_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ACTIVATION_BITMASK, bitmask_ptr);
bwd_op_.SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &bwd_workspace_byte_);
// output ptr
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DXDATA, dx_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DSCALE, dscale_ptr);
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr);
bwd_op_.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON,
&eps);
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr);
}
workspace_handle.RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// workspace ptr
bwd_op_.Execute(handle);
},
bwd_workspace_byte_);
}
private:
void ForwardInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param
fwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_BN_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_EQBIAS_PLACEHOLDER, CUDNN_PARAM_YDATA_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_) {
fwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
} else if (fused_add_) {
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
// input desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc());
}
// equiv scale/bias desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_EQSCALEBIAS_DESC,
args_.equiv_scale_bias_desc.desc());
if (has_shortcut_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_Z_EQSCALEBIAS_DESC,
args_.equiv_scale_bias_desc.desc());
}
// output desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
// bitmask desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC,
args_.bitmask_desc.desc());
// activation desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC,
args_.activation_desc.desc());
// others
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}
void BackwardInit(const platform::CUDADeviceContext &ctx) {
// Set constant_param
bwd_op_.SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_DYDATA_PLACEHOLDER,
CUDNN_PARAM_DXDATA_PLACEHOLDER, CUDNN_PARAM_BN_SCALE_PLACEHOLDER,
CUDNN_PARAM_BN_BIAS_PLACEHOLDER, CUDNN_PARAM_BN_SAVED_MEAN_PLACEHOLDER,
CUDNN_PARAM_BN_SAVED_INVSTD_PLACEHOLDER,
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
// input desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc());
}
// scale/bias/mean/var desc for backward
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_BN_SCALEBIAS_MEANVAR_DESC,
args_.scale_bias_mean_var_desc.desc());
// output desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DYDESC, args_.out_desc.desc());
// bitmask desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_BITMASK_DESC,
args_.bitmask_desc.desc());
// activation desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ACTIVATION_DESC,
args_.activation_desc.desc());
// others
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}
bool fused_add_ = false;
bool has_shortcut_ = false;
size_t fwd_workspace_byte_;
size_t bwd_workspace_byte_;
ScaleBiasAddReluArgs<T> args_;
CudnnFusionOp fwd_op_;
CudnnFusionOp bwd_op_;
};
#endif
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册