sum_op.h 8.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
13
#include <vector>
Y
Yi Wang 已提交
14 15 16 17
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
18
#include "paddle/phi/kernels/funcs/math_function.h"
19 20 21 22 23

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
24
using SelectedRows = phi::SelectedRows;
Q
QI JUN 已提交
25
using LoDTensor = framework::LoDTensor;
26 27 28 29
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

Z
zhaoyuchen2018 已提交
30 31 32 33 34 35 36 37 38 39
template <typename DeviceContext, typename T>
void SelectedRowsCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];

  if (in_place && in_vars.size() < 2) {
    return;
  }

40
  std::vector<const phi::SelectedRows *> inputs;
Z
zhaoyuchen2018 已提交
41 42 43
  SelectedRows temp_in0;

  if (in_place) {
44
    auto &in0 = in_vars[0]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
45 46 47 48 49 50
    temp_in0.set_height(in0.height());
    temp_in0.set_rows(in0.rows());
    framework::TensorCopy(in0.value(), in0.place(), context.device_context(),
                          temp_in0.mutable_value());
    inputs.push_back(&temp_in0);
    for (size_t i = 1; i < in_vars.size(); ++i) {
51
      auto &in = in_vars[i]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
52 53 54 55 56 57
      if (in.rows().size() > 0) {
        inputs.push_back(&in);
      }
    }
  } else {
    for (auto &in_var : in_vars) {
58
      auto &in = in_var->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
59
      if (in.rows().size() > 0) {
60
        inputs.push_back(&in_var->Get<phi::SelectedRows>());
Z
zhaoyuchen2018 已提交
61 62 63 64
      }
    }
  }

65
  auto *out = context.Output<phi::SelectedRows>("Out");
Z
zhaoyuchen2018 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  out->mutable_rows()->clear();

  bool has_data = false;
  for (auto &in : inputs) {
    if (in->rows().size() > 0) {
      has_data = true;
      break;
    }
  }
  if (has_data) {
    math::scatter::MergeAdd<DeviceContext, T> merge_add;
    merge_add(context.template device_context<DeviceContext>(), inputs, out);

    out->SyncIndex();

  } else {
    // no data, just set a empty out tensor.
83
    out->mutable_value()->mutable_data<T>(phi::make_ddim({0}),
Z
zhaoyuchen2018 已提交
84 85 86 87 88 89 90 91 92 93 94
                                          context.GetPlace());
  }
}

template <typename DeviceContext, typename T>
void LodTensorArrayCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];
  auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
  for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
95
    PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(), true,
96 97 98 99
                      platform::errors::InvalidArgument(
                          "Only support all inputs are TensorArray, "
                          "but inputs[%d] is not TensorArray.",
                          i));
Z
zhaoyuchen2018 已提交
100 101 102
    auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

    for (size_t i = 0; i < in_array.size(); ++i) {
103
      if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
Z
zhaoyuchen2018 已提交
104 105 106
        if (i >= out_array.size()) {
          out_array.resize(i + 1);
        }
107
        if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
Z
zhaoyuchen2018 已提交
108 109 110 111
          framework::TensorCopy(in_array[i], in_array[i].place(),
                                context.device_context(), &out_array[i]);
          out_array[i].set_lod(in_array[i].lod());
        } else {
112 113 114 115 116 117
          PADDLE_ENFORCE_EQ(
              out_array[i].lod(), in_array[i].lod(),
              platform::errors::InvalidArgument(
                  "The lod message between inputs[%d] and"
                  " outputs[%d] must be same, but now is not same.",
                  i, i));
Z
zhaoyuchen2018 已提交
118 119 120 121 122 123 124 125 126 127
          auto in = EigenVector<T>::Flatten(in_array[i]);
          auto result = EigenVector<T>::Flatten(out_array[i]);
          result.device(*context.template device_context<DeviceContext>()
                             .eigen_device()) = result + in;
        }
      }
    }
  }
}

Q
QI JUN 已提交
128
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
129
class SumKernel : public framework::OpKernel<T> {
130
 public:
131
  void Compute(const framework::ExecutionContext &context) const override {
X
xiongkun 已提交
132
    VLOG(10) << "start sum kernel";
Y
Yu Yang 已提交
133
    auto in_vars = context.MultiInputVar("X");
134
    size_t in_num = in_vars.size();
Q
QI JUN 已提交
135 136
    auto out_var = context.OutputVar("Out");

Y
Yu Yang 已提交
137 138
    bool in_place = out_var == in_vars[0];

Q
QI JUN 已提交
139
    if (out_var->IsType<framework::LoDTensor>()) {
140 141
      auto *out = out_var->GetMutable<framework::LoDTensor>();
      auto *out_ptr = out->mutable_data<T>(context.GetPlace());
X
xiongkun 已提交
142 143
      if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>() &&
          in_vars[0]->Get<framework::LoDTensor>().IsInitialized()) {
144 145 146 147
        auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
        if (in_0_tensor.numel() > 0) {
          in_place = (in_0_tensor.data<T>() == out_ptr);
        }
Y
Update  
Yang Yu 已提交
148
      }
149

Y
Update  
Yang Yu 已提交
150
      auto result = EigenVector<T>::Flatten(*out);
151 152 153
      auto &place =
          *context.template device_context<DeviceContext>().eigen_device();
      int start = in_place ? 1 : 0;
Y
Update  
Yang Yu 已提交
154
      if (!in_place) {
155
        if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
X
xiongkun 已提交
156 157 158
            in_vars[1]->IsType<framework::LoDTensor>() &&
            in_vars[0]->Get<framework::LoDTensor>().IsInitialized() &&
            in_vars[1]->Get<framework::LoDTensor>().IsInitialized()) {
159 160 161 162 163 164 165 166 167 168
          auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
          auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
          if (in_0.numel() && in_1.numel()) {
            auto in_0_e = EigenVector<T>::Flatten(in_0);
            auto in_1_e = EigenVector<T>::Flatten(in_1);
            result.device(place) = in_0_e + in_1_e;
            start = 2;
          }
        }
        if (start != 2) {
X
xiongkun 已提交
169
          VLOG(10) << "Fill with constant = 0 in sum kernel.";
170
          phi::funcs::SetConstant<DeviceContext, T> constant_functor;
171
          constant_functor(context.template device_context<DeviceContext>(),
C
chengduo 已提交
172
                           out, static_cast<T>(0));
173
        }
Y
Yu Yang 已提交
174
      }
Q
QI JUN 已提交
175

Q
QI JUN 已提交
176
      math::SelectedRowsAddToTensor<DeviceContext, T> functor;
Y
Yu Yang 已提交
177
      // If in_place, just skip the first tensor
178
      for (size_t i = start; i < in_num; i++) {
Q
QI JUN 已提交
179
        if (in_vars[i]->IsType<framework::LoDTensor>()) {
180
          auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
X
xiongkun 已提交
181
          if (!in_t.IsInitialized() || in_t.numel() == 0) {
182 183
            continue;
          }
Q
QI JUN 已提交
184 185
          auto in = EigenVector<T>::Flatten(in_t);
          result.device(place) = result + in;
186 187
        } else if (in_vars[i]->IsType<phi::SelectedRows>()) {
          auto &in_t = in_vars[i]->Get<phi::SelectedRows>();
Q
QI JUN 已提交
188
          functor(context.template device_context<DeviceContext>(), in_t, out);
Q
QI JUN 已提交
189
        } else {
190 191 192 193 194
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Expected type of Input(X) of %d-th must be Tensor, "
              "SelectedRows. But got "
              "unsupport type: %s.",
              framework::ToTypeName(in_vars[i]->Type())));
Q
QI JUN 已提交
195 196
        }
      }
197
    } else if (out_var->IsType<phi::SelectedRows>()) {
Z
zhaoyuchen2018 已提交
198
      SelectedRowsCompute<DeviceContext, T>(context);
199
    } else if (out_var->IsType<framework::LoDTensorArray>()) {
Z
zhaoyuchen2018 已提交
200
      LodTensorArrayCompute<DeviceContext, T>(context);
201
    } else {
202 203 204 205 206
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Expected type of Output(out) must be Tensor, SelectedRows, "
          "LoDTensorArray. But got "
          "unsupport type: %s.",
          framework::ToTypeName(out_var->Type())));
207
    }
X
xiongkun 已提交
208
    VLOG(10) << "end sum kernel";
209 210 211 212
  }
};
}  // namespace operators
}  // namespace paddle