未验证 提交 ebff68fa 编写于 作者: Y Yang Zhang 提交者: GitHub

Add float16 support to `sync_batch_norm_op` (#19681)

* Add float16 support to `sync_batch_norm_op`

test=develop

* Add test for sync_bn with FP16 input

test=develop
上级 039b9710
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// clang-format off
#include <algorithm> #include <algorithm>
#include <cfloat> #include <cfloat>
#include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -32,24 +34,27 @@ using Tensor = framework::Tensor; ...@@ -32,24 +34,27 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) { __global__ void KeLocalStats(const T *x, int N, int M, int C,
typedef cub::BlockReduce<T, BlockDim> BlockReduce; BatchNormParamType<T> *mean_var) {
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) { for (int k = blockIdx.x; k < C; k += gridDim.x) {
T x_sum = 0; BatchNormParamType<T> x_sum = 0.;
T x2_sum = 0; BatchNormParamType<T> x2_sum = 0.;
for (int i = threadIdx.x; i < N * M; i += BlockDim) { for (int i = threadIdx.x; i < N * M; i += BlockDim) {
int id = layout == framework::DataLayout::kNCHW int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M ? (i / M) * C * M + k * M + i % M
: i * C + k; : i * C + k;
T x_in = x[id]; auto x_in = static_cast<BatchNormParamType<T>>(x[id]);
x_sum += x_in; x_sum += x_in;
x2_sum += x_in * x_in; x2_sum += x_in * x_in;
} }
__syncthreads(); __syncthreads();
T out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum());
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
mean_var[k] = out / (N * M); mean_var[k] = out / (N * M);
...@@ -61,22 +66,24 @@ __global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) { ...@@ -61,22 +66,24 @@ __global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) {
} }
} }
if (blockIdx.x == 0 && threadIdx.x == 0) { if (blockIdx.x == 0 && threadIdx.x == 0) {
mean_var[2 * C] = static_cast<T>(1.0); mean_var[2 * C] = static_cast<BatchNormParamType<T>>(1.0);
} }
} }
template <typename T> template <typename T>
__global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev, __global__ void KeSyncAndMovingStats(
const int C, const T momentum, BatchNormParamType<T> *means, BatchNormParamType<T> *variances,
const double epsilon, T *sv_mean_data, BatchNormParamType<T> *num_dev, const int C,
T *sv_inv_var_data, T *moving_means, const BatchNormParamType<T> momentum, const double epsilon,
T *moving_variances) { BatchNormParamType<T> *sv_mean_data, BatchNormParamType<T> *sv_inv_var_data,
BatchNormParamType<T> *moving_means,
BatchNormParamType<T> *moving_variances) {
// sync stats across multi-devices // sync stats across multi-devices
int gid = blockIdx.x * blockDim.x + threadIdx.x; int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
for (int i = gid; i < C; i += stride) { for (int i = gid; i < C; i += stride) {
T mean = means[i] / (*num_dev); auto mean = means[i] / (*num_dev);
T var = variances[i] / (*num_dev); auto var = variances[i] / (*num_dev);
var = var - mean * mean; var = var - mean * mean;
// sync stats // sync stats
...@@ -92,15 +99,21 @@ __global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev, ...@@ -92,15 +99,21 @@ __global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev,
} }
template <typename T, framework::DataLayout layout> template <typename T, framework::DataLayout layout>
static __global__ void KeNormAffine(const T *x, const T *scale, const T *bias, static __global__ void KeNormAffine(const T *x,
const T *mean, const T *variance, const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance,
const double epsilon, const int C, const double epsilon, const int C,
const int M, const int num, T *y) { const int M, const int num, T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) { for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
y[i] = (x[i] - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
auto y_i =
(x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c];
y[i] = static_cast<T>(y_i);
} }
} }
...@@ -128,14 +141,14 @@ class SyncBatchNormKernel : public framework::OpKernel<T> { ...@@ -128,14 +141,14 @@ class SyncBatchNormKernel : public framework::OpKernel<T> {
int x_numel = x->numel(); int x_numel = x->numel();
const T *x_d = x->data<T>(); const T *x_d = x->data<T>();
const T *s_d = ctx.Input<Tensor>("Scale")->data<T>(); const auto *s_d = ctx.Input<Tensor>("Scale")->data<BatchNormParamType<T>>();
const T *b_d = ctx.Input<Tensor>("Bias")->data<T>(); const auto *b_d = ctx.Input<Tensor>("Bias")->data<BatchNormParamType<T>>();
auto *y = ctx.Output<Tensor>("Y"); auto *y = ctx.Output<Tensor>("Y");
T *y_d = y->mutable_data<T>(ctx.GetPlace()); T *y_d = y->mutable_data<T>(ctx.GetPlace());
const T *mean_data = nullptr; const BatchNormParamType<T> *mean_data = nullptr;
const T *var_data = nullptr; const BatchNormParamType<T> *var_data = nullptr;
auto &dev_ctx = ctx.cuda_device_context(); auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
...@@ -148,51 +161,53 @@ class SyncBatchNormKernel : public framework::OpKernel<T> { ...@@ -148,51 +161,53 @@ class SyncBatchNormKernel : public framework::OpKernel<T> {
if (is_test) { if (is_test) {
const auto *est_mean = ctx.Input<Tensor>("Mean"); const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance"); const auto *est_var = ctx.Input<Tensor>("Variance");
mean_data = est_mean->data<T>(); mean_data = est_mean->data<BatchNormParamType<T>>();
var_data = est_var->data<T>(); var_data = est_var->data<BatchNormParamType<T>>();
} else { } else {
// x, x^2, 1, here 1 is used to calc device num // x, x^2, 1, here 1 is used to calc device num
// device num also can be got from platform::DeviceContextPool // device num also can be got from platform::DeviceContextPool
const int bytes = (C * 2 + 1) * sizeof(T); const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
alloc_ptr = memory::Alloc(dev_ctx, bytes); alloc_ptr = memory::Alloc(dev_ctx, bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr()); auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
const int threads = 256; const int threads = 256;
int grid = std::min(C, (max_threads + threads - 1) / threads); int grid = std::min(C, (max_threads + threads - 1) / threads);
if (layout == framework::DataLayout::kNCHW) { if (layout == framework::DataLayout::kNCHW) {
KeLocalStats< KeLocalStats<T, threads, framework::DataLayout::kNCHW>
T, threads, <<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
x_d, N, H * W * D, C, stats);
} else { } else {
KeLocalStats< KeLocalStats<T, threads, framework::DataLayout::kNHWC>
T, threads, <<<grid, threads, 0, stream>>>(x_d, N, H * W * D, C, stats);
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
x_d, N, H * W * D, C, stats);
} }
// moving mean/variance
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *est_mean_data =
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *est_var_data =
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
auto *sv_mean_data =
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto *sv_inv_var_data =
saved_inv_variance->mutable_data<BatchNormParamType<T>>(
ctx.GetPlace());
Tensor c_g_st; Tensor c_g_st;
T *c_g_st_d = c_g_st.mutable_data<T>({2 * C + 1}, platform::CPUPlace()); auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace());
auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
int dtype = platform::ToNCCLDataType(x->type()); int dtype = platform::ToNCCLDataType(mean_out->type());
// In-place operation // In-place operation
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum, stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream)); comm, stream));
// moving mean/variance
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
T *est_mean_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *est_var_data = variance_out->mutable_data<T>(ctx.GetPlace());
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
T *sv_mean_data = saved_mean->mutable_data<T>(ctx.GetPlace());
T *sv_inv_var_data = saved_inv_variance->mutable_data<T>(ctx.GetPlace());
// Note, Input('Mean')/Input('Variance') share variable with // Note, Input('Mean')/Input('Variance') share variable with
// Output('MeanOut')/Output('VarianceOut') // Output('MeanOut')/Output('VarianceOut')
KeSyncAndMovingStats<T><<<(C + block - 1) / block, block, 0, stream>>>( KeSyncAndMovingStats<T><<<(C + block - 1) / block, block, 0, stream>>>(
...@@ -205,39 +220,40 @@ class SyncBatchNormKernel : public framework::OpKernel<T> { ...@@ -205,39 +220,40 @@ class SyncBatchNormKernel : public framework::OpKernel<T> {
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) { if (layout == framework::DataLayout::kNCHW) {
KeNormAffine<T, KeNormAffine<T, framework::DataLayout::kNCHW>
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>( <<<grid2, block, 0, stream>>>(x_d, s_d, b_d, mean_data, var_data,
x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, epsilon, C, H * W * D, x_numel, y_d);
y_d);
} else { } else {
KeNormAffine<T, KeNormAffine<T, framework::DataLayout::kNHWC>
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>( <<<grid2, block, 0, stream>>>(x_d, s_d, b_d, mean_data, var_data,
x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, epsilon, C, H * W * D, x_numel, y_d);
y_d);
} }
} }
}; };
template <typename T, const int BlockDim, framework::DataLayout layout> template <typename T, const int BlockDim, framework::DataLayout layout>
__global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means, __global__ void KeBackwardLocalStats(const T *dy, const T *x,
int N, int M, int C, T *sum_dy_prod) { const BatchNormParamType<T> *means, int N,
typedef cub::BlockReduce<double, BlockDim> BlockReduce; int M, int C,
BatchNormParamType<T> *sum_dy_prod) {
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) { for (int k = blockIdx.x; k < C; k += gridDim.x) {
T sum1 = 0; BatchNormParamType<T> sum1 = 0.;
T sum2 = 0; BatchNormParamType<T> sum2 = 0.;
T mean = means[k]; auto mean = means[k];
for (int i = threadIdx.x; i < N * M; i += blockDim.x) { for (int i = threadIdx.x; i < N * M; i += blockDim.x) {
int id = layout == framework::DataLayout::kNCHW int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M ? (i / M) * C * M + k * M + i % M
: i * C + k; : i * C + k;
T g = dy[id]; auto g = static_cast<BatchNormParamType<T>>(dy[id]);
sum1 += g; sum1 += g;
sum2 += g * (x[id] - mean); auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
sum2 += g * (x_i - mean);
} }
__syncthreads(); __syncthreads();
T out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum());
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
sum_dy_prod[k] = out; sum_dy_prod[k] = out;
...@@ -249,72 +265,75 @@ __global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means, ...@@ -249,72 +265,75 @@ __global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means,
} }
} }
if (blockIdx.x == 0 && threadIdx.x == 0) { if (blockIdx.x == 0 && threadIdx.x == 0) {
sum_dy_prod[2 * C] = static_cast<T>(1.0); sum_dy_prod[2 * C] = 1.0;
} }
} }
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(const T *dy, const T *x, static __global__ void KeBNBackwardScaleBias(
const T *mean, const T *dy, const T *x, const BatchNormParamType<T> *mean,
const T *inv_variance, const BatchNormParamType<T> *inv_variance, const double epsilon,
const double epsilon, const int N, const int N, const int C, const int HxW, BatchNormParamType<T> *dscale,
const int C, const int HxW, BatchNormParamType<T> *dbias) {
T *dscale, T *dbias) {
const int outer_size = C; const int outer_size = C;
const int inner_size = N * HxW; const int inner_size = N * HxW;
typedef cub::BlockReduce<double, BlockDim> BlockReduce; typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T ds_sum = static_cast<T>(0); BatchNormParamType<T> ds_sum = 0.;
T db_sum = static_cast<T>(0); BatchNormParamType<T> db_sum = 0.;
T inv_var_i = inv_variance[i]; auto inv_var_i = inv_variance[i];
T mean_i = mean[i]; auto mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int id = layout == framework::DataLayout::kNCHW const int id = layout == framework::DataLayout::kNCHW
? ((j / HxW) * C + i) * HxW + (j % HxW) ? ((j / HxW) * C + i) * HxW + (j % HxW)
: j * outer_size + i; : j * outer_size + i;
ds_sum += dy[id] * (x[id] - mean_i); auto x_i = static_cast<BatchNormParamType<T>>(x[id]);
db_sum += dy[id]; auto dy_i = static_cast<BatchNormParamType<T>>(dy[id]);
ds_sum += dy_i * (x_i - mean_i);
db_sum += dy_i;
} }
__syncthreads(); __syncthreads();
double os = BlockReduce(temp_storage) auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum());
.Reduce(static_cast<double>(ds_sum), cub::Sum());
__syncthreads(); __syncthreads();
double ob = BlockReduce(temp_storage) auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum());
.Reduce(static_cast<double>(db_sum), cub::Sum());
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dscale[i] = static_cast<T>(os * inv_var_i); dscale[i] = os * inv_var_i;
dbias[i] = static_cast<T>(ob); dbias[i] = ob;
} }
__syncthreads(); __syncthreads();
} }
} }
template <typename T, framework::DataLayout layout> template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy, const T *x, const T *beta, static __global__ void KeBNBackwardData(
const T *mean, const T *inv_variance, const T *dy, const T *x, const BatchNormParamType<T> *gamma,
const T *g_sum_dy, const BatchNormParamType<T> *mean,
const T *g_sum_dy_prod, const BatchNormParamType<T> *inv_variance,
const T *num_dev, const double epsilon, const BatchNormParamType<T> *g_sum_dy,
const int C, const int HxW, const BatchNormParamType<T> *g_sum_dy_prod,
const int num, T *dx) { const BatchNormParamType<T> *num_dev, const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x; int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
T scale = static_cast<T>(C) / num; auto scale = static_cast<BatchNormParamType<T>>(C) / num;
T dev_num = num_dev[0]; auto dev_num = num_dev[0];
for (int i = gid; i < num; i += stride) { for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
T inv_var = inv_variance[c]; auto inv_var = inv_variance[c];
T s_d = beta[c]; auto s_d = gamma[c];
T gvar = -1.0 * (g_sum_dy_prod[c] / dev_num) * s_d * inv_var * auto gvar =
(inv_var * inv_var); -((g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var));
T gmean = -1.0 * (g_sum_dy[c] / dev_num) * s_d * inv_var; auto gmean = -((g_sum_dy[c] / dev_num) * s_d * inv_var);
dx[i] = auto x_i = static_cast<BatchNormParamType<T>>(x[i]);
dy[i] * s_d * inv_var + gmean * scale + gvar * scale * (x[i] - mean[c]); auto dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
auto dx_i =
dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]);
dx[i] = static_cast<T>(dx_i);
} }
} }
...@@ -348,8 +367,8 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> { ...@@ -348,8 +367,8 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
if (d_scale && d_bias) { if (d_scale && d_bias) {
d_scale->mutable_data<T>(ctx.GetPlace()); d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
} }
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C); PADDLE_ENFORCE_EQ(scale->dims()[0], C);
...@@ -371,11 +390,13 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> { ...@@ -371,11 +390,13 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto *comm = dev_ctx.nccl_comm(); auto *comm = dev_ctx.nccl_comm();
const T *saved_mean = ctx.Input<Tensor>("SavedMean")->data<T>(); const auto *saved_mean =
const T *saved_inv_var = ctx.Input<Tensor>("SavedVariance")->data<T>(); ctx.Input<Tensor>("SavedMean")->data<BatchNormParamType<T>>();
const int bytes = (C * 2 + 1) * sizeof(T); const auto *saved_inv_var =
ctx.Input<Tensor>("SavedVariance")->data<BatchNormParamType<T>>();
const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType<T>);
auto alloc_ptr = memory::Alloc(dev_ctx, bytes); auto alloc_ptr = memory::Alloc(dev_ctx, bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr()); auto *stats = reinterpret_cast<BatchNormParamType<T> *>(alloc_ptr->ptr());
const int threads = 256; const int threads = 256;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
...@@ -384,17 +405,15 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> { ...@@ -384,17 +405,15 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> {
int fsize = H * W * D; int fsize = H * W * D;
if (layout == framework::DataLayout::kNCHW) { if (layout == framework::DataLayout::kNCHW) {
KeBackwardLocalStats< KeBackwardLocalStats<T, threads, framework::DataLayout::kNCHW>
T, threads, <<<grid, threads, 0, stream>>>(dy_d, x_d, saved_mean, N, fsize, C,
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>( stats);
dy_d, x_d, saved_mean, N, fsize, C, stats);
} else { } else {
KeBackwardLocalStats< KeBackwardLocalStats<T, threads, framework::DataLayout::kNHWC>
T, threads, <<<grid, threads, 0, stream>>>(dy_d, x_d, saved_mean, N, fsize, C,
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>( stats);
dy_d, x_d, saved_mean, N, fsize, C, stats);
} }
int dtype = platform::ToNCCLDataType(x->type()); int dtype = platform::ToNCCLDataType(scale->type());
// In-place operation // In-place operation
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum, stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
...@@ -404,33 +423,33 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> { ...@@ -404,33 +423,33 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> {
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) { if (layout == framework::DataLayout::kNCHW) {
if (d_scale && d_bias) { if (d_scale && d_bias) {
KeBNBackwardScaleBias< KeBNBackwardScaleBias<T, threads, framework::DataLayout::kNCHW>
T, threads, <<<grid, threads, 0, stream>>>(
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<T>(), d_bias->data<T>()); d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
} }
if (d_x) { if (d_x) {
KeBNBackwardData< KeBNBackwardData<T, framework::DataLayout::kNCHW>
T, framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>( <<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<T>(), saved_mean, saved_inv_var, stats, dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C,
d_x->data<T>()); fsize, x->numel(), d_x->data<T>());
} }
} else { } else {
if (d_scale && d_bias) { if (d_scale && d_bias) {
KeBNBackwardScaleBias< KeBNBackwardScaleBias<T, threads, framework::DataLayout::kNHWC>
T, threads, <<<grid, threads, 0, stream>>>(
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<T>(), d_bias->data<T>()); d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
} }
if (d_x) { if (d_x) {
KeBNBackwardData< KeBNBackwardData<T, framework::DataLayout::kNHWC>
T, framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>( <<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<T>(), saved_mean, saved_inv_var, stats, dy_d, x_d, scale->data<BatchNormParamType<T>>(), saved_mean,
stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C,
d_x->data<T>()); fsize, x->numel(), d_x->data<T>());
} }
} }
} }
...@@ -443,8 +462,12 @@ namespace ops = paddle::operators; ...@@ -443,8 +462,12 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>, sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>); ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>,
ops::SyncBatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sync_batch_norm_grad, sync_batch_norm_grad,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>, ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>); ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
// clang-format on
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
test for sync bachnorm op.
for both FP64 and FP16 input.
"""
from __future__ import print_function from __future__ import print_function
...@@ -22,9 +26,24 @@ import paddle.fluid.core as core ...@@ -22,9 +26,24 @@ import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler from paddle.fluid import compiler
from op_test import OpTest
def create_or_get_tensor(scope, var_name, var, place):
"""Get tensor, if not found, create a new one."""
tensor = scope.var(var_name).get_tensor()
if var is not None:
assert isinstance(var, np.ndarray)
tensor.set_recursive_sequence_lengths([])
tensor.set(var, place)
return tensor
class TestSyncBatchNormOpTraining(unittest.TestCase): class TestSyncBatchNormOpTraining(unittest.TestCase):
"""sync_batch_norm op test."""
def setUp(self): def setUp(self):
"""Setup."""
#self.dtype = np.float32 #self.dtype = np.float32
self.dtype = np.float64 self.dtype = np.float64
self.N = 32 self.N = 32
...@@ -32,17 +51,20 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -32,17 +51,20 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
self.H = 64 self.H = 64
self.W = 32 self.W = 32
self.dshape = [self.N, self.C, self.H, self.W] self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-3
def build_program(self, def _build_program(self,
place, place,
layout, layout,
seed, seed,
sync_bn=False, sync_bn=False,
only_forward=False): only_forward=False):
"""Build program."""
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
main.random_seed = seed main.random_seed = seed
startup.random_seed = seed startup.random_seed = seed
use_cudnn = self.dtype == np.float16
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
data = fluid.layers.data( data = fluid.layers.data(
...@@ -56,7 +78,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -56,7 +78,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
filter_size=1, filter_size=1,
param_attr=fluid.ParamAttr(name='conv2d_weight'), param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False, bias_attr=False,
use_cudnn=False) use_cudnn=use_cudnn)
bn = fluid.layers.batch_norm( bn = fluid.layers.batch_norm(
conv, conv,
param_attr=fluid.ParamAttr(name='bn_scale'), param_attr=fluid.ParamAttr(name='bn_scale'),
...@@ -65,6 +87,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -65,6 +87,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
moving_variance_name='bn_moving_variance', moving_variance_name='bn_moving_variance',
data_layout=layout, data_layout=layout,
is_test=only_forward) is_test=only_forward)
bn = fluid.layers.cast(bn, 'float64')
sigmoid = fluid.layers.sigmoid(bn) sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid) out = fluid.layers.reduce_sum(sigmoid)
if not sync_bn: if not sync_bn:
...@@ -74,12 +97,17 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -74,12 +97,17 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
sgd_opt.backward(out) sgd_opt.backward(out)
return main, startup, [out, conv, bn] return main, startup, [out, conv, bn]
def compare(self, place, layout, only_forward): def _compare(self, place, layout, only_forward):
"""Compare results."""
seed = 10 seed = 10
os.environ['FLAGS_cudnn_deterministic'] = "1" os.environ['FLAGS_cudnn_deterministic'] = "1"
scope = core.Scope()
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
data = create_or_get_tensor(scope, "input",
OpTest.np_dtype_to_fluid_dtype(data), place)
# Single-GPU, N = 32 per GPU # Single-GPU, N = 32 per GPU
main, startup, outs = self.build_program(place, layout, seed, False, main, startup, outs = self._build_program(place, layout, seed, False,
only_forward) only_forward)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
...@@ -99,7 +127,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -99,7 +127,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
##################################################################### #####################################################################
# Multi-GPUs, self.N / core.get_cuda_device_count() per GPU # Multi-GPUs, self.N / core.get_cuda_device_count() per GPU
assert core.get_cuda_device_count() > 1 assert core.get_cuda_device_count() > 1
main, startup, outs = self.build_program(place, layout, seed, True, main, startup, outs = self._build_program(place, layout, seed, True,
only_forward) only_forward)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
...@@ -133,27 +161,43 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -133,27 +161,43 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
sync_bn_val = sync_bn_val[:bn_val.shape[0]] sync_bn_val = sync_bn_val[:bn_val.shape[0]]
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
bn_val, sync_bn_val, atol=1e-3), bn_val, sync_bn_val, atol=self.atol),
"Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " + "Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " +
str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val)) str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val))
def test_train(self): def test_train(self):
"""Test training."""
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
return return
places = [core.CUDAPlace(0)] places = [core.CUDAPlace(0)]
for place in places: for place in places:
for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
self.compare(place, layout, False) self._compare(place, layout, False)
def test_infer(self): def test_infer(self):
"""Test inference."""
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
return return
places = [core.CUDAPlace(0)] places = [core.CUDAPlace(0)]
for place in places: for place in places:
for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
self.compare(place, layout, True) self._compare(place, layout, True)
class TestFP16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining):
"""sync_batch_norm op test for FP16 input."""
def setUp(self):
"""Setup."""
self.dtype = np.float16
self.N = 32
self.C = 16
self.H = 64
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-2
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册