expand_op.h 7.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
yangyaming 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
yangyaming 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
yangyaming 已提交
14 15 16

#pragma once

17 18
#include <vector>

Y
yangyaming 已提交
19 20 21 22 23 24
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
Y
Yi Wang 已提交
25 26 27
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
Y
yangyaming 已提交
28

29 30
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
31 32 33 34 35 36
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
37 38 39
#define COND(n)                                               \
  BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
                         BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
40 41 42 43 44
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
45
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
46
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
47
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
48 49 50 51

namespace paddle {
namespace operators {

Y
yangyaming 已提交
52
using Tensor = framework::Tensor;
Y
yangyaming 已提交
53 54 55 56 57 58 59
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
60
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
61
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
62 63
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
64
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
65
    switch (rank) {
66
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
67
      default:
Y
yangyaming 已提交
68 69
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
70
    }
Y
yangyaming 已提交
71 72 73 74 75
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
76
    auto* in0 = context.Input<Tensor>("X");
77
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
78
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
79 80 81 82 83 84 85 86
    Eigen::DSizes<int, Rank> bcast_dims;
    auto x_dims = in0->dims();
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
Q
QI JUN 已提交
87 88
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
89 90 91 92
    y.device(place) = x.broadcast(bcast_dims);
  }
};

Q
QI JUN 已提交
93
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
94
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
95 96
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
97
    auto* in0 = context.Input<Tensor>("X");
98
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
99
    auto x_dims = in0->dims();
100 101 102 103 104 105
    // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
    //    if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
    //    dimensions [expand_times[i], x_dims[i]].
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

123 124
    int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
               reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
Y
yangyaming 已提交
125 126
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
127 128
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
129
      out0->mutable_data<T>(context.GetPlace());
Y
Yi Wang 已提交
130 131
      framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
                            out0);
Y
yangyaming 已提交
132 133 134 135
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
136 137
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
138
      }
Y
yangyaming 已提交
139
    }
Y
yangyaming 已提交
140 141 142 143 144 145 146
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
147 148
    size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
    size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
Y
yangyaming 已提交
149
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
150
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
151 152
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
153
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
154
                      "reduce dimensions.");
Y
yangyaming 已提交
155 156 157
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
Y
yangyaming 已提交
158 159
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
160
    Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
Y
yangyaming 已提交
161 162 163
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
164
    Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
Y
yangyaming 已提交
165 166 167 168
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
Q
QI JUN 已提交
169 170
    x_grad.device(
        *context.template device_context<DeviceContext>().eigen_device()) =
Y
yangyaming 已提交
171 172 173 174
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

Y
yangyaming 已提交
175 176
}  // namespace operators
}  // namespace paddle