sum_op.h 8.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
13
#include <vector>
14

Y
Yi Wang 已提交
15 16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/phi/kernels/funcs/math_function.h"
20 21 22 23 24

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
25
using SelectedRows = phi::SelectedRows;
Q
QI JUN 已提交
26
using LoDTensor = framework::LoDTensor;
27 28
template <typename T,
          int MajorType = Eigen::RowMajor,
29 30 31
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

Z
zhaoyuchen2018 已提交
32 33 34 35 36 37 38 39 40 41
template <typename DeviceContext, typename T>
void SelectedRowsCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];

  if (in_place && in_vars.size() < 2) {
    return;
  }

42
  std::vector<const phi::SelectedRows *> inputs;
Z
zhaoyuchen2018 已提交
43 44 45
  SelectedRows temp_in0;

  if (in_place) {
46
    auto &in0 = in_vars[0]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
47 48
    temp_in0.set_height(in0.height());
    temp_in0.set_rows(in0.rows());
49 50 51
    framework::TensorCopy(in0.value(),
                          in0.place(),
                          context.device_context(),
Z
zhaoyuchen2018 已提交
52 53 54
                          temp_in0.mutable_value());
    inputs.push_back(&temp_in0);
    for (size_t i = 1; i < in_vars.size(); ++i) {
55
      auto &in = in_vars[i]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
56 57 58 59 60 61
      if (in.rows().size() > 0) {
        inputs.push_back(&in);
      }
    }
  } else {
    for (auto &in_var : in_vars) {
62
      auto &in = in_var->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
63
      if (in.rows().size() > 0) {
64
        inputs.push_back(&in_var->Get<phi::SelectedRows>());
Z
zhaoyuchen2018 已提交
65 66 67 68
      }
    }
  }

69
  auto *out = context.Output<phi::SelectedRows>("Out");
Z
zhaoyuchen2018 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
  out->mutable_rows()->clear();

  bool has_data = false;
  for (auto &in : inputs) {
    if (in->rows().size() > 0) {
      has_data = true;
      break;
    }
  }
  if (has_data) {
    math::scatter::MergeAdd<DeviceContext, T> merge_add;
    merge_add(context.template device_context<DeviceContext>(), inputs, out);

    out->SyncIndex();

  } else {
    // no data, just set a empty out tensor.
87
    out->mutable_value()->mutable_data<T>(phi::make_ddim({0}),
Z
zhaoyuchen2018 已提交
88 89 90 91 92 93 94 95 96 97 98
                                          context.GetPlace());
  }
}

template <typename DeviceContext, typename T>
void LodTensorArrayCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];
  auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
  for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
99 100
    PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(),
                      true,
101 102 103 104
                      platform::errors::InvalidArgument(
                          "Only support all inputs are TensorArray, "
                          "but inputs[%d] is not TensorArray.",
                          i));
Z
zhaoyuchen2018 已提交
105 106 107
    auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

    for (size_t i = 0; i < in_array.size(); ++i) {
108
      if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
Z
zhaoyuchen2018 已提交
109 110 111
        if (i >= out_array.size()) {
          out_array.resize(i + 1);
        }
112
        if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
113 114 115 116
          framework::TensorCopy(in_array[i],
                                in_array[i].place(),
                                context.device_context(),
                                &out_array[i]);
Z
zhaoyuchen2018 已提交
117 118
          out_array[i].set_lod(in_array[i].lod());
        } else {
119
          PADDLE_ENFORCE_EQ(
120 121
              out_array[i].lod(),
              in_array[i].lod(),
122 123 124
              platform::errors::InvalidArgument(
                  "The lod message between inputs[%d] and"
                  " outputs[%d] must be same, but now is not same.",
125 126
                  i,
                  i));
Z
zhaoyuchen2018 已提交
127 128 129 130 131 132 133 134 135 136
          auto in = EigenVector<T>::Flatten(in_array[i]);
          auto result = EigenVector<T>::Flatten(out_array[i]);
          result.device(*context.template device_context<DeviceContext>()
                             .eigen_device()) = result + in;
        }
      }
    }
  }
}

Q
QI JUN 已提交
137
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
138
class SumKernel : public framework::OpKernel<T> {
139
 public:
140
  void Compute(const framework::ExecutionContext &context) const override {
X
xiongkun 已提交
141
    VLOG(10) << "start sum kernel";
Y
Yu Yang 已提交
142
    auto in_vars = context.MultiInputVar("X");
143
    size_t in_num = in_vars.size();
Q
QI JUN 已提交
144 145
    auto out_var = context.OutputVar("Out");

Y
Yu Yang 已提交
146 147
    bool in_place = out_var == in_vars[0];

Q
QI JUN 已提交
148
    if (out_var->IsType<framework::LoDTensor>()) {
149 150
      auto *out = out_var->GetMutable<framework::LoDTensor>();
      auto *out_ptr = out->mutable_data<T>(context.GetPlace());
X
xiongkun 已提交
151 152
      if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>() &&
          in_vars[0]->Get<framework::LoDTensor>().IsInitialized()) {
153 154 155 156
        auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
        if (in_0_tensor.numel() > 0) {
          in_place = (in_0_tensor.data<T>() == out_ptr);
        }
Y
Update  
Yang Yu 已提交
157
      }
158

Y
Update  
Yang Yu 已提交
159
      auto result = EigenVector<T>::Flatten(*out);
160 161 162
      auto &place =
          *context.template device_context<DeviceContext>().eigen_device();
      int start = in_place ? 1 : 0;
Y
Update  
Yang Yu 已提交
163
      if (!in_place) {
164
        if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
X
xiongkun 已提交
165 166 167
            in_vars[1]->IsType<framework::LoDTensor>() &&
            in_vars[0]->Get<framework::LoDTensor>().IsInitialized() &&
            in_vars[1]->Get<framework::LoDTensor>().IsInitialized()) {
168 169 170 171 172 173 174 175 176 177
          auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
          auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
          if (in_0.numel() && in_1.numel()) {
            auto in_0_e = EigenVector<T>::Flatten(in_0);
            auto in_1_e = EigenVector<T>::Flatten(in_1);
            result.device(place) = in_0_e + in_1_e;
            start = 2;
          }
        }
        if (start != 2) {
X
xiongkun 已提交
178
          VLOG(10) << "Fill with constant = 0 in sum kernel.";
179
          phi::funcs::SetConstant<DeviceContext, T> constant_functor;
180
          constant_functor(context.template device_context<DeviceContext>(),
181 182
                           out,
                           static_cast<T>(0));
183
        }
Y
Yu Yang 已提交
184
      }
Q
QI JUN 已提交
185

Q
QI JUN 已提交
186
      math::SelectedRowsAddToTensor<DeviceContext, T> functor;
Y
Yu Yang 已提交
187
      // If in_place, just skip the first tensor
188
      for (size_t i = start; i < in_num; i++) {
Q
QI JUN 已提交
189
        if (in_vars[i]->IsType<framework::LoDTensor>()) {
190
          auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
X
xiongkun 已提交
191
          if (!in_t.IsInitialized() || in_t.numel() == 0) {
192 193
            continue;
          }
Q
QI JUN 已提交
194 195
          auto in = EigenVector<T>::Flatten(in_t);
          result.device(place) = result + in;
196 197
        } else if (in_vars[i]->IsType<phi::SelectedRows>()) {
          auto &in_t = in_vars[i]->Get<phi::SelectedRows>();
Q
QI JUN 已提交
198
          functor(context.template device_context<DeviceContext>(), in_t, out);
Q
QI JUN 已提交
199
        } else {
200 201 202 203 204
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Expected type of Input(X) of %d-th must be Tensor, "
              "SelectedRows. But got "
              "unsupport type: %s.",
              framework::ToTypeName(in_vars[i]->Type())));
Q
QI JUN 已提交
205 206
        }
      }
207
    } else if (out_var->IsType<phi::SelectedRows>()) {
Z
zhaoyuchen2018 已提交
208
      SelectedRowsCompute<DeviceContext, T>(context);
209
    } else if (out_var->IsType<framework::LoDTensorArray>()) {
Z
zhaoyuchen2018 已提交
210
      LodTensorArrayCompute<DeviceContext, T>(context);
211
    } else {
212 213 214 215 216
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Expected type of Output(out) must be Tensor, SelectedRows, "
          "LoDTensorArray. But got "
          "unsupport type: %s.",
          framework::ToTypeName(out_var->Type())));
217
    }
X
xiongkun 已提交
218
    VLOG(10) << "end sum kernel";
219 220 221 222
  }
};
}  // namespace operators
}  // namespace paddle