sync_batch_norm_op.cu 4.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

K
Kaipeng Deng 已提交
15
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
Q
qingqing01 已提交
16 17 18 19 20

namespace paddle {
namespace operators {

template <typename T>
K
Kaipeng Deng 已提交
21 22
class SyncBatchNormKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
qingqing01 已提交
23 24 25 26 27 28 29 30
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const float momentum = ctx.Attr<float>("momentum");
    const bool is_test = ctx.Attr<bool>("is_test");
    const std::string layout_str = ctx.Attr<std::string>("data_layout");
    const DataLayout layout = framework::StringToDataLayout(layout_str);
    const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
31
    const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
K
Kaipeng Deng 已提交
32 33 34 35 36
    PADDLE_ENFORCE_EQ(use_global_stats, false,
                      platform::errors::InvalidArgument(
                          "sync_batch_norm doesn't support "
                          "to set use_global_stats True. Please use batch_norm "
                          "in this case."));
Q
qingqing01 已提交
37 38 39 40

    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Y");

K
Kaipeng Deng 已提交
41 42
    const auto *est_mean = ctx.Input<Tensor>("Mean");
    const auto *est_var = ctx.Input<Tensor>("Variance");
Q
qingqing01 已提交
43

K
Kaipeng Deng 已提交
44 45 46
    // moving mean/variance
    auto *mean_out = ctx.Output<Tensor>("MeanOut");
    auto *variance_out = ctx.Output<Tensor>("VarianceOut");
Q
qingqing01 已提交
47

K
Kaipeng Deng 已提交
48 49
    auto *saved_mean = ctx.Output<Tensor>("SavedMean");
    auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
Q
qingqing01 已提交
50

51
    bool test_mode = is_test && (!trainable_stats);
K
Kaipeng Deng 已提交
52 53
    SyncBatchNormFunctor<platform::CUDADeviceContext, T>(
        ctx, layout, x, y, est_mean, est_var, mean_out, variance_out,
54
        saved_mean, saved_inv_variance, epsilon, momentum, test_mode,
K
Kaipeng Deng 已提交
55
        use_global_stats);
Q
qingqing01 已提交
56 57 58
  }
};

K
Kaipeng Deng 已提交
59 60 61
template <typename T>
class SyncBatchNormGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
Q
qingqing01 已提交
62 63
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
K
Kaipeng Deng 已提交
64 65 66
    PADDLE_ENFORCE_EQ(
        platform::is_gpu_place(ctx.GetPlace()), true,
        platform::errors::InvalidArgument("It must use CUDAPlace."));
Q
qingqing01 已提交
67 68 69 70 71 72
    double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
    const std::string layout_str = ctx.Attr<std::string>("data_layout");

    const DataLayout layout = framework::StringToDataLayout(layout_str);
    const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    const auto *scale = ctx.Input<Tensor>("Scale");
K
Kaipeng Deng 已提交
73
    const auto *bias = ctx.Input<Tensor>("Bias");
Q
qingqing01 已提交
74 75 76 77 78 79

    // init output
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
    auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));

K
Kaipeng Deng 已提交
80 81
    const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
    const auto *saved_inv_var = ctx.Input<Tensor>("SavedVariance");
Q
qingqing01 已提交
82

K
Kaipeng Deng 已提交
83 84 85
    SyncBatchNormGradFunctor<platform::CUDADeviceContext, T>(
        ctx, layout, scale, bias, d_x, d_y, d_scale, d_bias, saved_mean,
        saved_inv_var, epsilon);
Q
qingqing01 已提交
86 87 88 89 90 91 92 93 94 95
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
    sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
96 97
    ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>,
    ops::SyncBatchNormKernel<plat::CUDADeviceContext, plat::float16>);
Q
qingqing01 已提交
98 99 100
REGISTER_OP_CUDA_KERNEL(
    sync_batch_norm_grad,
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
101 102 103 104
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>,
    ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);

// clang-format on