expand_op.h 6.9 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   You may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

28 29
#define MAX_RANK_SUPPORTED 6

Y
yangyaming 已提交
30 31 32 33 34 35
#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
36 37 38
#define COND(n)                                               \
  BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
                         BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
Y
yangyaming 已提交
39 40 41 42 43
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
44
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
45
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
46
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
47 48 49 50

namespace paddle {
namespace operators {

Y
yangyaming 已提交
51
using Tensor = framework::Tensor;
Y
yangyaming 已提交
52 53 54 55 56 57 58 59
template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

template <typename Place, typename T>
Y
yangyaming 已提交
60
class ExpandKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
61 62
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
63
    auto rank = context.Input<Tensor>("X")->dims().size();
Y
yangyaming 已提交
64
    switch (rank) {
65
      REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
Y
yangyaming 已提交
66
      default:
Y
yangyaming 已提交
67 68
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
69
    }
Y
yangyaming 已提交
70 71 72 73 74
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
75
    auto* in0 = context.Input<Tensor>("X");
76
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
77
    auto* out0 = context.Output<Tensor>("Out");
Y
yangyaming 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91
    Eigen::DSizes<int, Rank> bcast_dims;
    auto x_dims = in0->dims();
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
    auto place = context.GetEigenDevice<Place>();
    y.device(place) = x.broadcast(bcast_dims);
  }
};

template <typename Place, typename T>
Y
yangyaming 已提交
92
class ExpandGradKernel : public framework::OpKernel<T> {
Y
yangyaming 已提交
93 94
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
95
    auto* in0 = context.Input<Tensor>("X");
96
    auto& expand_times = context.Attr<std::vector<int>>("expand_times");
Y
yangyaming 已提交
97
    auto x_dims = in0->dims();
98 99 100 101 102 103
    // 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
    //    if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
    //    dimensions [expand_times[i], x_dims[i]].
    // 2. reduce_dims_vec is the dimension parameter to compute gradients. For
    //    each dimension expanded, the gradients should be summed to original
    //    size.
Y
yangyaming 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

121 122
    int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
               reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
Y
yangyaming 已提交
123 124
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
125 126
      auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
      auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
127
      out0->mutable_data<T>(context.GetPlace());
Y
yangyaming 已提交
128
      out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
Y
yangyaming 已提交
129 130 131 132
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
133 134
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
135
      }
Y
yangyaming 已提交
136
    }
Y
yangyaming 已提交
137 138 139 140 141 142 143
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
144 145
    size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
    size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
Y
yangyaming 已提交
146
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
147
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
148 149
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
150
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
151
                      "reduce dimensions.");
Y
yangyaming 已提交
152 153 154
    auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
Y
yangyaming 已提交
155 156
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
157
    Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
Y
yangyaming 已提交
158 159 160
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
161
    Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
Y
yangyaming 已提交
162 163 164 165 166 167 168 169 170
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
    x_grad.device(context.GetEigenDevice<Place>()) =
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

Y
yangyaming 已提交
171 172
}  // namespace operators
}  // namespace paddle