sum_op.h 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
13 14 15 16 17
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
18 19 20 21 22

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
Q
QI JUN 已提交
23 24
using SelectedRows = framework::SelectedRows;
using LoDTensor = framework::LoDTensor;
25 26 27 28
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

Q
QI JUN 已提交
29
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
30
class SumKernel : public framework::OpKernel<T> {
31
 public:
32
  void Compute(const framework::ExecutionContext &context) const override {
Y
Yu Yang 已提交
33
    auto in_vars = context.MultiInputVar("X");
Q
QI JUN 已提交
34 35 36
    int N = in_vars.size();
    auto out_var = context.OutputVar("Out");

Y
Yu Yang 已提交
37 38
    bool in_place = out_var == in_vars[0];

Q
QI JUN 已提交
39
    if (out_var->IsType<framework::LoDTensor>()) {
Y
Update  
Yang Yu 已提交
40
      auto *out = context.Output<LoDTensor>("Out");
Y
Yu Yang 已提交
41
      if (!in_place) {
Y
Refine  
Yang Yu 已提交
42
        out->mutable_data<T>(context.GetPlace());
Y
Update  
Yang Yu 已提交
43 44 45
      }
      auto result = EigenVector<T>::Flatten(*out);
      if (!in_place) {
Q
QI JUN 已提交
46 47 48
        math::SetConstant<DeviceContext, T> constant_functor;
        constant_functor(context.template device_context<DeviceContext>(), out,
                         0.0);
Y
Yu Yang 已提交
49
      }
Q
QI JUN 已提交
50

Q
QI JUN 已提交
51 52 53
      math::SelectedRowsAddToTensor<DeviceContext, T> functor;
      auto &place =
          *context.template device_context<DeviceContext>().eigen_device();
Y
Yu Yang 已提交
54 55
      // If in_place, just skip the first tensor
      for (int i = in_place ? 1 : 0; i < N; i++) {
Q
QI JUN 已提交
56
        if (in_vars[i]->IsType<framework::LoDTensor>()) {
57
          auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
58 59 60
          if (in_t.numel() == 0) {
            continue;
          }
Q
QI JUN 已提交
61 62 63
          auto in = EigenVector<T>::Flatten(in_t);
          result.device(place) = result + in;
        } else if (in_vars[i]->IsType<framework::SelectedRows>()) {
64
          auto &in_t = in_vars[i]->Get<framework::SelectedRows>();
Q
QI JUN 已提交
65
          functor(context.template device_context<DeviceContext>(), in_t, out);
Q
QI JUN 已提交
66 67 68 69 70
        } else {
          PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
        }
      }
    } else if (out_var->IsType<framework::SelectedRows>()) {
Y
Yang Yu 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
      std::unique_ptr<framework::SelectedRows> in0;
      if (in_place) {
        // If is in_place, we store the input[0] to in0
        auto &in_sel0 = in_vars[0]->Get<SelectedRows>();
        auto &rows = in_sel0.rows();
#ifdef PADDLE_WITH_CUDA
        std::vector<int64_t> rows_in_cpu;
        rows_in_cpu.reserve(rows.size());
        for (auto item : rows) {
          rows_in_cpu.push_back(item);
        }
        in0.reset(new framework::SelectedRows(rows_in_cpu, in_sel0.height()));
#else
        in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
#endif
        in0->mutable_value()->ShareDataWith(in_sel0.value());
      }

      auto get_selected_row = [&](size_t i) -> const SelectedRows & {
        if (i == 0 && in0) {
          return *in0.get();
        } else {
          return in_vars[i]->Get<SelectedRows>();
        }
      };

97
      auto *out = context.Output<SelectedRows>("Out");
Y
Yancey 已提交
98
      out->mutable_rows()->clear();
99
      auto *out_value = out->mutable_value();
Q
QI JUN 已提交
100 101 102 103

      // Runtime InferShape
      size_t first_dim = 0;
      for (int i = 0; i < N; i++) {
Y
Yang Yu 已提交
104 105
        auto &sel_row = get_selected_row(i);
        first_dim += sel_row.rows().size();
Q
QI JUN 已提交
106
      }
Y
Yang Yu 已提交
107 108 109
      auto in_dim =
          framework::vectorize(get_selected_row(N - 1).value().dims());
      in_dim[0] = static_cast<int64_t>(first_dim);
Q
QI JUN 已提交
110

Y
Yang Yu 已提交
111
      out_value->Resize(framework::make_ddim(in_dim));
Q
QI JUN 已提交
112 113
      out_value->mutable_data<T>(context.GetPlace());

Q
QI JUN 已提交
114
      math::SelectedRowsAddTo<DeviceContext, T> functor;
Q
QI JUN 已提交
115 116 117

      int64_t offset = 0;
      for (int i = 0; i < N; i++) {
Y
Yang Yu 已提交
118
        auto &sel_row = get_selected_row(i);
119 120 121
        if (!sel_row.value().IsInitialized() || sel_row.rows().size() == 0) {
          continue;
        }
Y
Yang Yu 已提交
122 123 124 125
        PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
        functor(context.template device_context<DeviceContext>(), sel_row,
                offset, out);
        offset += sel_row.value().numel();
Q
QI JUN 已提交
126
      }
127 128 129 130 131 132 133 134 135 136 137 138 139
    } else if (out_var->IsType<framework::LoDTensorArray>()) {
      auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
      for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
        PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
                       "Only support all inputs are TensorArray");
        auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

        for (size_t i = 0; i < in_array.size(); ++i) {
          if (in_array[i].numel() != 0) {
            if (i >= out_array.size()) {
              out_array.resize(i + 1);
            }
            if (out_array[i].numel() == 0) {
140 141
              framework::Copy(in_array[i], in_array[i].place(),
                              context.device_context(), &out_array[i]);
142 143 144 145 146
              out_array[i].set_lod(in_array[i].lod());
            } else {
              PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
              auto in = EigenVector<T>::Flatten(in_array[i]);
              auto result = EigenVector<T>::Flatten(out_array[i]);
Q
QI JUN 已提交
147 148
              result.device(*context.template device_context<DeviceContext>()
                                 .eigen_device()) = result + in;
149 150 151 152 153 154 155
            }
          }
        }
      }
    } else {
      PADDLE_THROW("Unexpected branch, output variable type is %s",
                   out_var->Type().name());
156 157 158 159 160
    }
  }
};
}  // namespace operators
}  // namespace paddle