sync_batch_norm_op.cu 4.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

K
Kaipeng Deng 已提交
15
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
Q
qingqing01 已提交
16 17 18 19 20

namespace paddle {
namespace operators {

template <typename T>
K
Kaipeng Deng 已提交
21 22
class SyncBatchNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
qingqing01 已提交
23 24 25 26 27 28 29 30
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const float momentum = ctx.Attr<float>("momentum");
    const bool is_test = ctx.Attr<bool>("is_test");
    const std::string layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout layout = framework::StringToDataLayout(layout_str);
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
K
Kaipeng Deng 已提交
31 32 33 34 35
    PADDLE_ENFORCE_EQ(use_global_stats, false,
                      platform::errors::InvalidArgument(
                          "sync_batch_norm doesn't support "
                          "to set use_global_stats True. Please use batch_norm "
                          "in this case."));
Q
qingqing01 已提交
36 37 38 39

    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Y");

K
Kaipeng Deng 已提交
40 41
    const auto *est_mean = ctx.Input<Tensor>("Mean");
    const auto *est_var = ctx.Input<Tensor>("Variance");
Q
qingqing01 已提交
42

K
Kaipeng Deng 已提交
43 44 45
    // moving mean/variance
    auto *mean_out = ctx.Output<Tensor>("MeanOut");
    auto *variance_out = ctx.Output<Tensor>("VarianceOut");
Q
qingqing01 已提交
46

K
Kaipeng Deng 已提交
47 48
    auto *saved_mean = ctx.Output<Tensor>("SavedMean");
    auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
Q
qingqing01 已提交
49

K
Kaipeng Deng 已提交
50 51 52 53
    SyncBatchNormFunctor<platform::CUDADeviceContext, T>(
        ctx, layout, x, y, est_mean, est_var, mean_out, variance_out,
        saved_mean, saved_inv_variance, epsilon, momentum, is_test,
        use_global_stats);
Q
qingqing01 已提交
54 55 56
  }
};

K
Kaipeng Deng 已提交
57 58 59
template <typename T>
class SyncBatchNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
qingqing01 已提交
60 61
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
62 63 64
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
qingqing01 已提交
65 66 67 68 69 70
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const std::string layout_str = ctx.Attr<std::string>("data_layout");

    const DataLayout layout = framework::StringToDataLayout(layout_str);
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
K
Kaipeng Deng 已提交
71
    const auto *bias = ctx.Input<Tensor>("Bias");
Q
qingqing01 已提交
72 73 74 75 76 77

    // init output
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

K
Kaipeng Deng 已提交
78 79
    const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *saved_inv_var = ctx.Input<Tensor>("SavedVariance");
Q
qingqing01 已提交
80

K
Kaipeng Deng 已提交
81 82 83
    SyncBatchNormGradFunctor<platform::CUDADeviceContext, T>(
        ctx, layout, scale, bias, d_x, d_y, d_scale, d_bias, saved_mean,
        saved_inv_var, epsilon);
Q
qingqing01 已提交
84 85 86 87 88 89 90 91 92 93
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
    sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
94 95
    ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>,
    ops::SyncBatchNormKernel<plat::CUDADeviceContext, plat::float16>);
Q
qingqing01 已提交
96 97 98
REGISTER_OP_CUDA_KERNEL(
    sync_batch_norm_grad,
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
99 100 101 102
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>,
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);

// clang-format on