sum_op.h 8.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
13
#include <vector>
14

Y
Yi Wang 已提交
15 16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19
#include "paddle/phi/kernels/funcs/math_function.h"
20 21 22 23 24

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
25
using SelectedRows = phi::SelectedRows;
Q
QI JUN 已提交
26
using LoDTensor = framework::LoDTensor;
27 28 29 30
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

Z
zhaoyuchen2018 已提交
31 32 33 34 35 36 37 38 39 40
template <typename DeviceContext, typename T>
void SelectedRowsCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];

  if (in_place && in_vars.size() < 2) {
    return;
  }

41
  std::vector<const phi::SelectedRows *> inputs;
Z
zhaoyuchen2018 已提交
42 43 44
  SelectedRows temp_in0;

  if (in_place) {
45
    auto &in0 = in_vars[0]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
46 47 48 49 50 51
    temp_in0.set_height(in0.height());
    temp_in0.set_rows(in0.rows());
    framework::TensorCopy(in0.value(), in0.place(), context.device_context(),
                          temp_in0.mutable_value());
    inputs.push_back(&temp_in0);
    for (size_t i = 1; i < in_vars.size(); ++i) {
52
      auto &in = in_vars[i]->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
53 54 55 56 57 58
      if (in.rows().size() > 0) {
        inputs.push_back(&in);
      }
    }
  } else {
    for (auto &in_var : in_vars) {
59
      auto &in = in_var->Get<phi::SelectedRows>();
Z
zhaoyuchen2018 已提交
60
      if (in.rows().size() > 0) {
61
        inputs.push_back(&in_var->Get<phi::SelectedRows>());
Z
zhaoyuchen2018 已提交
62 63 64 65
      }
    }
  }

66
  auto *out = context.Output<phi::SelectedRows>("Out");
Z
zhaoyuchen2018 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
  out->mutable_rows()->clear();

  bool has_data = false;
  for (auto &in : inputs) {
    if (in->rows().size() > 0) {
      has_data = true;
      break;
    }
  }
  if (has_data) {
    math::scatter::MergeAdd<DeviceContext, T> merge_add;
    merge_add(context.template device_context<DeviceContext>(), inputs, out);

    out->SyncIndex();

  } else {
    // no data, just set a empty out tensor.
84
    out->mutable_value()->mutable_data<T>(phi::make_ddim({0}),
Z
zhaoyuchen2018 已提交
85 86 87 88 89 90 91 92 93 94 95
                                          context.GetPlace());
  }
}

template <typename DeviceContext, typename T>
void LodTensorArrayCompute(const framework::ExecutionContext &context) {
  auto in_vars = context.MultiInputVar("X");
  auto out_var = context.OutputVar("Out");
  bool in_place = out_var == in_vars[0];
  auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
  for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
96
    PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(), true,
97 98 99 100
                      platform::errors::InvalidArgument(
                          "Only support all inputs are TensorArray, "
                          "but inputs[%d] is not TensorArray.",
                          i));
Z
zhaoyuchen2018 已提交
101 102 103
    auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();

    for (size_t i = 0; i < in_array.size(); ++i) {
104
      if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
Z
zhaoyuchen2018 已提交
105 106 107
        if (i >= out_array.size()) {
          out_array.resize(i + 1);
        }
108
        if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
Z
zhaoyuchen2018 已提交
109 110 111 112
          framework::TensorCopy(in_array[i], in_array[i].place(),
                                context.device_context(), &out_array[i]);
          out_array[i].set_lod(in_array[i].lod());
        } else {
113 114 115 116 117 118
          PADDLE_ENFORCE_EQ(
              out_array[i].lod(), in_array[i].lod(),
              platform::errors::InvalidArgument(
                  "The lod message between inputs[%d] and"
                  " outputs[%d] must be same, but now is not same.",
                  i, i));
Z
zhaoyuchen2018 已提交
119 120 121 122 123 124 125 126 127 128
          auto in = EigenVector<T>::Flatten(in_array[i]);
          auto result = EigenVector<T>::Flatten(out_array[i]);
          result.device(*context.template device_context<DeviceContext>()
                             .eigen_device()) = result + in;
        }
      }
    }
  }
}

Q
QI JUN 已提交
129
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
130
class SumKernel : public framework::OpKernel<T> {
131
 public:
132
  void Compute(const framework::ExecutionContext &context) const override {
X
xiongkun 已提交
133
    VLOG(10) << "start sum kernel";
Y
Yu Yang 已提交
134
    auto in_vars = context.MultiInputVar("X");
135
    size_t in_num = in_vars.size();
Q
QI JUN 已提交
136 137
    auto out_var = context.OutputVar("Out");

Y
Yu Yang 已提交
138 139
    bool in_place = out_var == in_vars[0];

Q
QI JUN 已提交
140
    if (out_var->IsType<framework::LoDTensor>()) {
141 142
      auto *out = out_var->GetMutable<framework::LoDTensor>();
      auto *out_ptr = out->mutable_data<T>(context.GetPlace());
X
xiongkun 已提交
143 144
      if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>() &&
          in_vars[0]->Get<framework::LoDTensor>().IsInitialized()) {
145 146 147 148
        auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
        if (in_0_tensor.numel() > 0) {
          in_place = (in_0_tensor.data<T>() == out_ptr);
        }
Y
Update  
Yang Yu 已提交
149
      }
150

Y
Update  
Yang Yu 已提交
151
      auto result = EigenVector<T>::Flatten(*out);
152 153 154
      auto &place =
          *context.template device_context<DeviceContext>().eigen_device();
      int start = in_place ? 1 : 0;
Y
Update  
Yang Yu 已提交
155
      if (!in_place) {
156
        if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
X
xiongkun 已提交
157 158 159
            in_vars[1]->IsType<framework::LoDTensor>() &&
            in_vars[0]->Get<framework::LoDTensor>().IsInitialized() &&
            in_vars[1]->Get<framework::LoDTensor>().IsInitialized()) {
160 161 162 163 164 165 166 167 168 169
          auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
          auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
          if (in_0.numel() && in_1.numel()) {
            auto in_0_e = EigenVector<T>::Flatten(in_0);
            auto in_1_e = EigenVector<T>::Flatten(in_1);
            result.device(place) = in_0_e + in_1_e;
            start = 2;
          }
        }
        if (start != 2) {
X
xiongkun 已提交
170
          VLOG(10) << "Fill with constant = 0 in sum kernel.";
171
          phi::funcs::SetConstant<DeviceContext, T> constant_functor;
172
          constant_functor(context.template device_context<DeviceContext>(),
C
chengduo 已提交
173
                           out, static_cast<T>(0));
174
        }
Y
Yu Yang 已提交
175
      }
Q
QI JUN 已提交
176

Q
QI JUN 已提交
177
      math::SelectedRowsAddToTensor<DeviceContext, T> functor;
Y
Yu Yang 已提交
178
      // If in_place, just skip the first tensor
179
      for (size_t i = start; i < in_num; i++) {
Q
QI JUN 已提交
180
        if (in_vars[i]->IsType<framework::LoDTensor>()) {
181
          auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
X
xiongkun 已提交
182
          if (!in_t.IsInitialized() || in_t.numel() == 0) {
183 184
            continue;
          }
Q
QI JUN 已提交
185 186
          auto in = EigenVector<T>::Flatten(in_t);
          result.device(place) = result + in;
187 188
        } else if (in_vars[i]->IsType<phi::SelectedRows>()) {
          auto &in_t = in_vars[i]->Get<phi::SelectedRows>();
Q
QI JUN 已提交
189
          functor(context.template device_context<DeviceContext>(), in_t, out);
Q
QI JUN 已提交
190
        } else {
191 192 193 194 195
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Expected type of Input(X) of %d-th must be Tensor, "
              "SelectedRows. But got "
              "unsupport type: %s.",
              framework::ToTypeName(in_vars[i]->Type())));
Q
QI JUN 已提交
196 197
        }
      }
198
    } else if (out_var->IsType<phi::SelectedRows>()) {
Z
zhaoyuchen2018 已提交
199
      SelectedRowsCompute<DeviceContext, T>(context);
200
    } else if (out_var->IsType<framework::LoDTensorArray>()) {
Z
zhaoyuchen2018 已提交
201
      LodTensorArrayCompute<DeviceContext, T>(context);
202
    } else {
203 204 205 206 207
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Expected type of Output(out) must be Tensor, SelectedRows, "
          "LoDTensorArray. But got "
          "unsupport type: %s.",
          framework::ToTypeName(out_var->Type())));
208
    }
X
xiongkun 已提交
209
    VLOG(10) << "end sum kernel";
210 211 212 213
  }
};
}  // namespace operators
}  // namespace paddle