average_accumulates_op.h 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext>
W
wanghaoshuang 已提交
31
void GetAccumulators(const framework::ExecutionContext& ctx,
W
wanghaoshuang 已提交
32 33
                     int64_t& num_updates, int64_t& num_accumulates,
                     int64_t& old_num_accumulates);
34 35

template <typename DeviceContext>
W
wanghaoshuang 已提交
36
void SetAccumulators(const framework::ExecutionContext& ctx,
W
wanghaoshuang 已提交
37 38
                     int64_t num_updates, int64_t num_accumulates,
                     int64_t old_num_accumulates);
39 40 41 42 43

template <typename DeviceContext, typename T>
class AverageAccumulatesKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
W
wanghaoshuang 已提交
44
    // It is used to avoid loss of precision
45
    static const int64_t kMaxNumAccumulates = 16384;
W
wanghaoshuang 已提交
46 47 48 49
    // Get accumulators from input
    int64_t num_updates = 0;
    int64_t num_accumulates = 0;
    int64_t old_num_accumulates = 0;
W
wanghaoshuang 已提交
50
    GetAccumulators<DeviceContext>(ctx, num_updates, num_accumulates,
W
wanghaoshuang 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64
                                   old_num_accumulates);

    // Get attrs
    float average_window = ctx.Attr<float>("average_window");
    int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
    int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
    min_average_window =
        std::min<int64_t>(min_average_window, max_average_window);

    // Get inputs
    auto* param = ctx.Input<Tensor>("param");
    auto* in_sum_1 = ctx.Input<Tensor>("in_sum_1");
    auto* in_sum_2 = ctx.Input<Tensor>("in_sum_2");
    auto* in_sum_3 = ctx.Input<Tensor>("in_sum_3");
65 66 67 68
    auto param_tensor = EigenVector<T>::Flatten(*param);
    auto in_sum_1_tensor = EigenVector<T>::Flatten(*in_sum_1);
    auto in_sum_2_tensor = EigenVector<T>::Flatten(*in_sum_2);
    auto in_sum_3_tensor = EigenVector<T>::Flatten(*in_sum_3);
W
wanghaoshuang 已提交
69 70 71 72 73

    // Get outputs
    auto* out_sum_1 = ctx.Output<Tensor>("out_sum_1");
    auto* out_sum_2 = ctx.Output<Tensor>("out_sum_2");
    auto* out_sum_3 = ctx.Output<Tensor>("out_sum_3");
74 75 76 77
    auto out_sum_1_tensor = EigenVector<T>::Flatten(*out_sum_1);
    auto out_sum_2_tensor = EigenVector<T>::Flatten(*out_sum_2);
    auto out_sum_3_tensor = EigenVector<T>::Flatten(*out_sum_3);

W
wanghaoshuang 已提交
78
    // Compute
79 80
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::SetConstant<DeviceContext, T> constant_functor;
W
wanghaoshuang 已提交
81 82
    ++num_updates;
    ++num_accumulates;
83 84 85
    out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor;
    out_sum_2_tensor.device(place) = in_sum_2_tensor;
    out_sum_3_tensor.device(place) = in_sum_3_tensor;
W
wanghaoshuang 已提交
86
    if (num_updates % kMaxNumAccumulates == 0) {
W
wanghaoshuang 已提交
87 88
      // Move the sum to a different buffer to avoid loss of precision due to
      // too many sums.
89 90 91 92
      out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
    }
W
wanghaoshuang 已提交
93 94 95
    if (num_accumulates >= min_average_window &&
        num_accumulates >= std::min<int64_t>(max_average_window,
                                             num_updates * average_window)) {
W
wanghaoshuang 已提交
96
      //  Now the average window is too long, discard the old sum.
97 98 99 100 101
      out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_2,
                       0.0);
W
wanghaoshuang 已提交
102 103
      old_num_accumulates = num_accumulates;
      num_accumulates = 0;
104
    }
W
wanghaoshuang 已提交
105 106

    // Set accumulators to output
W
wanghaoshuang 已提交
107
    SetAccumulators<DeviceContext>(ctx, num_updates, num_accumulates,
W
wanghaoshuang 已提交
108
                                   old_num_accumulates);
109 110 111 112 113
  }
};

}  // namespace operators
}  // namespace paddle