average_accumulates_op.h 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext>
void getAccumulators(const framework::ExecutionContext& ctx,
W
wanghaoshuang 已提交
32 33
                     int64_t& num_updates, int64_t& num_accumulates,
                     int64_t& old_num_accumulates);
34 35 36

template <typename DeviceContext>
void setAccumulators(const framework::ExecutionContext& ctx,
W
wanghaoshuang 已提交
37 38
                     int64_t num_updates, int64_t num_accumulates,
                     int64_t old_num_accumulates);
39 40 41 42 43

template <typename DeviceContext, typename T>
class AverageAccumulatesKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
W
wanghaoshuang 已提交
44
    // It is used to avoid loss of precision
45
    static const int64_t kMaxNumAccumulates = 16384;
W
wanghaoshuang 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    // Get accumulators from input
    int64_t num_updates = 0;
    int64_t num_accumulates = 0;
    int64_t old_num_accumulates = 0;
    getAccumulators<DeviceContext>(ctx, num_updates, num_accumulates,
                                   old_num_accumulates);

    // Get attrs
    float average_window = ctx.Attr<float>("average_window");
    int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
    int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
    min_average_window =
        std::min<int64_t>(min_average_window, max_average_window);

    // Get inputs
    auto* param = ctx.Input<Tensor>("param");
    auto* in_sum_1 = ctx.Input<Tensor>("in_sum_1");
    auto* in_sum_2 = ctx.Input<Tensor>("in_sum_2");
    auto* in_sum_3 = ctx.Input<Tensor>("in_sum_3");
65 66 67 68
    auto param_tensor = EigenVector<T>::Flatten(*param);
    auto in_sum_1_tensor = EigenVector<T>::Flatten(*in_sum_1);
    auto in_sum_2_tensor = EigenVector<T>::Flatten(*in_sum_2);
    auto in_sum_3_tensor = EigenVector<T>::Flatten(*in_sum_3);
W
wanghaoshuang 已提交
69 70 71 72 73

    // Get outputs
    auto* out_sum_1 = ctx.Output<Tensor>("out_sum_1");
    auto* out_sum_2 = ctx.Output<Tensor>("out_sum_2");
    auto* out_sum_3 = ctx.Output<Tensor>("out_sum_3");
74 75 76 77
    auto out_sum_1_tensor = EigenVector<T>::Flatten(*out_sum_1);
    auto out_sum_2_tensor = EigenVector<T>::Flatten(*out_sum_2);
    auto out_sum_3_tensor = EigenVector<T>::Flatten(*out_sum_3);

W
wanghaoshuang 已提交
78
    // Compute
79 80
    auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
    math::SetConstant<DeviceContext, T> constant_functor;
W
wanghaoshuang 已提交
81 82
    ++num_updates;
    ++num_accumulates;
83 84 85
    out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor;
    out_sum_2_tensor.device(place) = in_sum_2_tensor;
    out_sum_3_tensor.device(place) = in_sum_3_tensor;
W
wanghaoshuang 已提交
86
    if (num_updates % kMaxNumAccumulates == 0) {
87 88 89 90
      out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
    }
W
wanghaoshuang 已提交
91 92 93
    if (num_accumulates >= min_average_window &&
        num_accumulates >= std::min<int64_t>(max_average_window,
                                             num_updates * average_window)) {
94 95 96 97 98
      out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor;
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_1,
                       0.0);
      constant_functor(ctx.template device_context<DeviceContext>(), out_sum_2,
                       0.0);
W
wanghaoshuang 已提交
99 100
      old_num_accumulates = num_accumulates;
      num_accumulates = 0;
101
    }
W
wanghaoshuang 已提交
102 103 104 105

    // Set accumulators to output
    setAccumulators<DeviceContext>(ctx, num_updates, num_accumulates,
                                   old_num_accumulates);
106 107 108 109 110
  }
};

}  // namespace operators
}  // namespace paddle