expand_op.h 7.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
yangyaming 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
yangyaming 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
yangyaming 已提交
14 15 16

#pragma once

17 18
#include <vector>

Y
yangyaming 已提交
19 20 21 22 23 24
#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
Y
Yi Wang 已提交
25 26 27
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
Y
yangyaming 已提交
28

29 30
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
31 32 33 34 35 36
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
37 38 39
#define COND(n)                                               \
  BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
                         BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
40 41 42 43 44
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
45
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
46
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
47
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
48 49 50 51

namespace paddle {
namespace operators {

Y
yangyaming 已提交
52
using Tensor = framework::Tensor;
Y
yangyaming 已提交
53 54 55 56 57 58 59
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
60
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
61
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
62 63
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
64
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
65
    switch (rank) {
66
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
67
      default:
Y
yangyaming 已提交
68 69
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
70
    }
Y
yangyaming 已提交
71 72 73 74 75
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
76
    auto* in0 = context.Input<Tensor>("X");
77
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
78
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
79 80 81 82 83 84 85
    Eigen::DSizes<int, Rank> bcast_dims;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
Q
QI JUN 已提交
86 87
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
88 89 90 91
    y.device(place) = x.broadcast(bcast_dims);
  }
};

Q
QI JUN 已提交
92
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
93
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
94 95
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
96
    auto* in0 = context.Input<Tensor>("X");
97
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
98
    auto x_dims = in0->dims();
99 100 101 102 103 104
    // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
    //    if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
    //    dimensions [expand_times[i], x_dims[i]].
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

122 123
    int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
               reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
Y
yangyaming 已提交
124 125
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
126 127
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
128
      out0->mutable_data<T>(context.GetPlace());
Y
Yi Wang 已提交
129 130
      framework::TensorCopy(*in0, context.GetPlace(), context.device_context(),
                            out0);
Y
yangyaming 已提交
131 132 133 134
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
135 136
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
137
      }
Y
yangyaming 已提交
138
    }
Y
yangyaming 已提交
139 140 141 142 143 144 145
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
146 147
    size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
    size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
Y
yangyaming 已提交
148
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
149
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
150 151
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
152
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
153
                      "reduce dimensions.");
Y
yangyaming 已提交
154 155 156
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
Y
yangyaming 已提交
157 158
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
159
    Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
Y
yangyaming 已提交
160 161 162
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
163
    Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
Y
yangyaming 已提交
164 165 166 167
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
Q
QI JUN 已提交
168 169
    x_grad.device(
        *context.template device_context<DeviceContext>().eigen_device()) =
Y
yangyaming 已提交
170 171 172 173
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

Y
yangyaming 已提交
174 175
}  // namespace operators
}  // namespace paddle