expand_op.h 7.0 KB
Newer Older
Y
yangyaming 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
yangyaming 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
yangyaming 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
yangyaming 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27

#pragma once

#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

28 29
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
30 31 32 33 34 35
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
36 37 38
#define COND(n)                                               \
  BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
                         BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
39 40 41 42 43
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
44
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
45
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
46
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
47 48 49 50

namespace paddle {
namespace operators {

Y
yangyaming 已提交
51
using Tensor = framework::Tensor;
Y
yangyaming 已提交
52 53 54 55 56 57 58
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
59
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
60
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
61 62
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
63
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
64
    switch (rank) {
65
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
66
      default:
Y
yangyaming 已提交
67 68
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
69
    }
Y
yangyaming 已提交
70 71 72 73 74
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
75
    auto* in0 = context.Input<Tensor>("X");
76
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
77
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
78 79 80 81 82 83 84 85
    Eigen::DSizes<int, Rank> bcast_dims;
    auto x_dims = in0->dims();
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
Q
QI JUN 已提交
86 87
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
Y
yangyaming 已提交
88 89 90 91
    y.device(place) = x.broadcast(bcast_dims);
  }
};

Q
QI JUN 已提交
92
template <typename DeviceContext, typename T>
Y
yangyaming 已提交
93
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
94 95
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
96
    auto* in0 = context.Input<Tensor>("X");
97
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
98
    auto x_dims = in0->dims();
99 100 101 102 103 104
    // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
    //    if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
    //    dimensions [expand_times[i], x_dims[i]].
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

122 123
    int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
               reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
Y
yangyaming 已提交
124 125
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
126 127
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
128
      out0->mutable_data<T>(context.GetPlace());
129
      framework::Copy(*in0, context.GetPlace(), context.device_context(), out0);
Y
yangyaming 已提交
130 131 132 133
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
134 135
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
136
      }
Y
yangyaming 已提交
137
    }
Y
yangyaming 已提交
138 139 140 141 142 143 144
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
145 146
    size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
    size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
Y
yangyaming 已提交
147
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
148
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
149 150
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
151
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
152
                      "reduce dimensions.");
Y
yangyaming 已提交
153 154 155
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
Y
yangyaming 已提交
156 157
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
158
    Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
Y
yangyaming 已提交
159 160 161
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
162
    Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
Y
yangyaming 已提交
163 164 165 166
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
Q
QI JUN 已提交
167 168
    x_grad.device(
        *context.template device_context<DeviceContext>().eigen_device()) =
Y
yangyaming 已提交
169 170 171 172
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

Y
yangyaming 已提交
173 174
}  // namespace operators
}  // namespace paddle