average_accumulates_op.h 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext>
W
wanghaoshuang 已提交
31
void GetAccumulators(const framework::ExecutionContext& ctx,
32 33
                     int64_t* num_updates, int64_t* num_accumulates,
                     int64_t* old_num_accumulates);
34 35

template <typename DeviceContext>
W
wanghaoshuang 已提交
36
void SetAccumulators(const framework::ExecutionContext& ctx,
W
wanghaoshuang 已提交
37 38
                     int64_t num_updates, int64_t num_accumulates,
                     int64_t old_num_accumulates);
39 40 41 42 43

template <typename DeviceContext, typename T>
class AverageAccumulatesKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
W
wanghaoshuang 已提交
44
    // It is used to avoid loss of precision
45
    static const int64_t kMaxNumAccumulates = 16384;
W
wanghaoshuang 已提交
46 47 48 49
    // Get accumulators from input
    int64_t num_updates = 0;
    int64_t num_accumulates = 0;
    int64_t old_num_accumulates = 0;
50 51
    GetAccumulators<DeviceContext>(ctx, &num_updates, &num_accumulates,
                                   &old_num_accumulates);
W
wanghaoshuang 已提交
52 53 54 55 56

    // Get attrs
    float average_window = ctx.Attr<float>("average_window");
    int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
    int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
57 58 59
    PADDLE_ENFORCE_LE(min_average_window, max_average_window,
                      "min_average_window shouldn't be larger than "
                      "max_average_window");
W
wanghaoshuang 已提交
60 61 62 63 64 65

    // Get inputs
    auto* param = ctx.Input<Tensor>("param");
    auto* in_sum_1 = ctx.Input<Tensor>("in_sum_1");
    auto* in_sum_2 = ctx.Input<Tensor>("in_sum_2");
    auto* in_sum_3 = ctx.Input<Tensor>("in_sum_3");
66 67 68 69
    auto param_tensor = EigenVector<T>::Flatten(*param);
    auto in_sum_1_tensor = EigenVector<T>::Flatten(*in_sum_1);
    auto in_sum_2_tensor = EigenVector<T>::Flatten(*in_sum_2);
    auto in_sum_3_tensor = EigenVector<T>::Flatten(*in_sum_3);
W
wanghaoshuang 已提交
70 71 72 73 74

    // Get outputs
    auto* out_sum_1 = ctx.Output<Tensor>("out_sum_1");
    auto* out_sum_2 = ctx.Output<Tensor>("out_sum_2");
    auto* out_sum_3 = ctx.Output<Tensor>("out_sum_3");
75 76 77 78
    auto out_sum_1_tensor = EigenVector<T>::Flatten(*out_sum_1);
    auto out_sum_2_tensor = EigenVector<T>::Flatten(*out_sum_2);
    auto out_sum_3_tensor = EigenVector<T>::Flatten(*out_sum_3);

W
wanghaoshuang 已提交
79
    // Compute
80 81
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::SetConstant<DeviceContext, T> constant_functor;
W
wanghaoshuang 已提交
82 83
    ++num_updates;
    ++num_accumulates;
84 85 86
    out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor;
    out_sum_2_tensor.device(place) = in_sum_2_tensor;
    out_sum_3_tensor.device(place) = in_sum_3_tensor;
W
wanghaoshuang 已提交
87
    if (num_updates % kMaxNumAccumulates == 0) {
W
wanghaoshuang 已提交
88 89
      // Move the sum to a different buffer to avoid loss of precision due to
      // too many sums.
90 91 92 93
      out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
    }
W
wanghaoshuang 已提交
94 95 96
    if (num_accumulates >= min_average_window &&
        num_accumulates >= std::min<int64_t>(max_average_window,
                                             num_updates * average_window)) {
W
wanghaoshuang 已提交
97
      //  Now the average window is too long, discard the old sum.
98 99 100 101 102
      out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_2,
                       0.0);
W
wanghaoshuang 已提交
103 104
      old_num_accumulates = num_accumulates;
      num_accumulates = 0;
105
    }
W
wanghaoshuang 已提交
106 107

    // Set accumulators to output
W
wanghaoshuang 已提交
108
    SetAccumulators<DeviceContext>(ctx, num_updates, num_accumulates,
W
wanghaoshuang 已提交
109
                                   old_num_accumulates);
110 111 112 113 114
  }
};

}  // namespace operators
}  // namespace paddle