expand_op.h 8.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
yangyaming 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
yangyaming 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
yangyaming 已提交
14 15 16

#pragma once

17 18
#include <vector>

Y
yangyaming 已提交
19 20 21 22 23 24
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
Y
Yi Wang 已提交
25 26 27
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
Y
yangyaming 已提交
28

29 30
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
31 32 33 34 35 36
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
W
wangchaochaohu 已提交
37
#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
38 39 40 41 42
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
43
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
44
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
45
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
46 47 48

namespace paddle {
namespace operators {
49 50
inline std::vector<int> get_expand_times(
    const framework::ExecutionContext& ctx) {
L
liym27 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63
  if (ctx.HasInput("ExpandTimes")) {
    auto* expand_tensor = ctx.Input<framework::LoDTensor>("ExpandTimes");
    auto* expand_data = expand_tensor->data<int>();
    framework::Tensor cpu_expand_tensor;
    if (platform::is_gpu_place(expand_tensor->place())) {
      TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
      expand_data = cpu_expand_tensor.data<int>();
    }
    auto vec_epxand_times =
        std::vector<int>(expand_data, expand_data + expand_tensor->numel());
    return vec_epxand_times;
  }

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  auto list_expand_times_tensor =
      ctx.MultiInput<framework::Tensor>("expand_times_tensor");
  if (list_expand_times_tensor.size() > 0) {
    // get tensor from
    std::vector<int> vec_epxand_times;
    for (size_t i = 0; i < list_expand_times_tensor.size(); ++i) {
      auto tensor = list_expand_times_tensor[i];
      if (platform::is_gpu_place(tensor->place())) {
        framework::Tensor temp;
        TensorCopySync(*tensor, platform::CPUPlace(), &temp);
        vec_epxand_times.push_back(*temp.data<int32_t>());
      } else {
        vec_epxand_times.push_back(*tensor->data<int32_t>());
      }
    }

    return vec_epxand_times;
  } else {
    return ctx.Attr<std::vector<int>>("expand_times");
  }
}
Y
yangyaming 已提交
85

Y
yangyaming 已提交
86
using Tensor = framework::Tensor;
Y
yangyaming 已提交
87 88 89 90 91 92 93
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
94
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
95
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
96 97
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
98
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
99
    switch (rank) {
100
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
101
      default:
Y
yangyaming 已提交
102 103
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
104
    }
Y
yangyaming 已提交
105 106 107 108 109
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
110
    auto* in0 = context.Input<Tensor>("X");
111 112 113

    auto in_dims = in0->dims();
    auto expand_times = get_expand_times(context);
L
liym27 已提交
114 115 116
    PADDLE_ENFORCE_EQ(static_cast<size_t>(in_dims.size()), expand_times.size(),
                      "The number of Attr(expand_times)'s value must be equal "
                      "to the rank of Input(X).");
Y
yangyaming 已提交
117
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
118 119 120 121
    Eigen::DSizes<int, Rank> bcast_dims;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
122 123 124 125 126 127 128

    framework::DDim out_dims(in_dims);
    for (size_t i = 0; i < expand_times.size(); ++i) {
      out_dims[i] *= expand_times[i];
    }

    out0->Resize(out_dims);
Y
yangyaming 已提交
129 130 131
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
Q
QI JUN 已提交
132 133
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
134 135 136 137
    y.device(place) = x.broadcast(bcast_dims);
  }
};

Q
QI JUN 已提交
138
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
139
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
140 141
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
142
    auto* in0 = context.Input<Tensor>("X");
143 144
    // auto& expand_times = context.Attr<std::vector<int>>("expand_times");
    auto expand_times = get_expand_times(context);
Y
yangyaming 已提交
145
    auto x_dims = in0->dims();
W
wangchaochaohu 已提交
146
    // 1. reshape_dims_vec is the broadcast parameter.
147 148 149
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
150 151 152
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
W
wangchaochaohu 已提交
153 154 155
      reduce_dims_vec.push_back(reshape_dims_vec.size());
      reshape_dims_vec.push_back(expand_times[i]);
      reshape_dims_vec.push_back(x_dims[i]);
Y
yangyaming 已提交
156 157
    }

W
wangchaochaohu 已提交
158 159 160 161 162 163 164 165 166
    int dims = reduce_dims_vec.size();

    bool just_copy = true;
    for (size_t i = 0; i < expand_times.size(); i++) {
      if (expand_times[i] != 1) {
        just_copy = false;
        break;
      }
    }
Y
yangyaming 已提交
167
    // no need reduce, just copy
W
wangchaochaohu 已提交
168
    if (just_copy) {
Y
yangyaming 已提交
169 170
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
171
      out0->mutable_data<T>(context.GetPlace());
Y
Yi Wang 已提交
172 173
      framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
                            out0);
Y
yangyaming 已提交
174 175
    } else {
      switch (dims) {
W
wangchaochaohu 已提交
176
        REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
177
        default:
Y
yangyaming 已提交
178 179
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
180
      }
Y
yangyaming 已提交
181
    }
Y
yangyaming 已提交
182 183 184 185 186 187 188
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
W
wangchaochaohu 已提交
189 190
    size_t reshape_size = reshape_dims_vec.size();
    size_t reduce_size = reduce_dims_vec.size();
Y
yangyaming 已提交
191
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
192
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
193 194
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
195
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
196
                      "reduce dimensions.");
Y
yangyaming 已提交
197 198
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
199 200
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
W
wangchaochaohu 已提交
201
    Eigen::DSizes<int, Dims * 2> reshape_dims;
Y
yangyaming 已提交
202 203 204
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
W
wangchaochaohu 已提交
205
    Eigen::DSizes<int, Dims> reduce_dims;
Y
yangyaming 已提交
206 207 208 209
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
Q
QI JUN 已提交
210 211
    x_grad.device(
        *context.template device_context<DeviceContext>().eigen_device()) =
212 213 214
        out_grad.reshape(reshape_dims)
            .sum(reduce_dims)
            .reshape(x_grad.dimensions());
Y
yangyaming 已提交
215 216 217
  }
};

Y
yangyaming 已提交
218 219
}  // namespace operators
}  // namespace paddle