expand_op.h 7.0 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   You may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

28 29
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
30 31 32 33 34 35
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
36 37 38
#define COND(n)                                               \
  BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
                         BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
39 40 41 42 43
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
44
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
45
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
46
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
47 48 49 50

namespace paddle {
namespace operators {

Y
yangyaming 已提交
51
using Tensor = framework::Tensor;
Y
yangyaming 已提交
52 53 54 55 56 57 58
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
59
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
60
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
61 62
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
63
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
64
    switch (rank) {
65
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
66
      default:
Y
yangyaming 已提交
67 68
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
69
    }
Y
yangyaming 已提交
70 71 72 73 74
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
75
    auto* in0 = context.Input<Tensor>("X");
76
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
77
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
78 79 80 81 82 83 84 85
    Eigen::DSizes<int, Rank> bcast_dims;
    auto x_dims = in0->dims();
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
Q
QI JUN 已提交
86 87
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
88 89 90 91
    y.device(place) = x.broadcast(bcast_dims);
  }
};

Q
QI JUN 已提交
92
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
93
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
94 95
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
96
    auto* in0 = context.Input<Tensor>("X");
97
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
98
    auto x_dims = in0->dims();
99 100 101 102 103 104
    // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
    //    if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
    //    dimensions [expand_times[i], x_dims[i]].
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

122 123
    int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
               reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
Y
yangyaming 已提交
124 125
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
126 127
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
128
      out0->mutable_data<T>(context.GetPlace());
D
dzhwinter 已提交
129 130
      framework::CopyFrom(*in0, context.GetPlace(), context.device_context(),
                          out0);
Y
yangyaming 已提交
131 132 133 134
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
135 136
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
137
      }
Y
yangyaming 已提交
138
    }
Y
yangyaming 已提交
139 140 141 142 143 144 145
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
146 147
    size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
    size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
Y
yangyaming 已提交
148
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
149
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
150 151
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
152
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
153
                      "reduce dimensions.");
Y
yangyaming 已提交
154 155 156
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
Y
yangyaming 已提交
157 158
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
159
    Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
Y
yangyaming 已提交
160 161 162
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
163
    Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
Y
yangyaming 已提交
164 165 166 167
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
Q
QI JUN 已提交
168 169
    x_grad.device(
        *context.template device_context<DeviceContext>().eigen_device()) =
Y
yangyaming 已提交
170 171 172 173
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

Y
yangyaming 已提交
174 175
}  // namespace operators
}  // namespace paddle