expand_op.h 6.5 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   You may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include <boost/preprocessor/arithmetic/div.hpp>
#include <boost/preprocessor/arithmetic/mod.hpp>
#include <boost/preprocessor/comparison/greater.hpp>
#include <boost/preprocessor/comparison/greater_equal.hpp>
#include <boost/preprocessor/control/if.hpp>
#include <boost/preprocessor/repetition/repeat.hpp>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"

#define EXPAND_TEMPLATE(z, n, data) \
  case n + 1: {                     \
    Expand<n + 1>(context);         \
    break;                          \
  }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)

#define COND(n) BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, 6), BOOST_PP_MOD(n, 6))
#define EXPAND_GRAD_CASE(n)                                        \
  case n: {                                                        \
    ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
    break;                                                         \
  }
Y
yangyaming 已提交
41
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
Y
yangyaming 已提交
42
  BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
Y
yangyaming 已提交
43
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
Y
yangyaming 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

namespace paddle {
namespace operators {

template <typename T, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

template <typename Place, typename T>
class ExpandKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
59
    auto rank = context.Input<framework::Tensor>("X")->dims().size();
Y
yangyaming 已提交
60 61 62
    switch (rank) {
      REP_EXPAND_TEMPLATE(6)
      default:
Y
yangyaming 已提交
63 64
        PADDLE_ENFORCE(false,
                       "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
65 66 67 68 69 70
    };
  }

 protected:
  template <int Rank>
  void Expand(const framework::ExecutionContext& context) const {
Y
yangyaming 已提交
71 72 73
    auto* in0 = context.Input<framework::Tensor>("X");
    auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
    auto* out0 = context.Output<framework::LoDTensor>("Out");
Y
yangyaming 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    Eigen::DSizes<int, Rank> bcast_dims;
    auto x_dims = in0->dims();
    for (size_t i = 0; i < expand_times.size(); ++i) {
      bcast_dims[i] = expand_times[i];
    }
    auto x = EigenTensor<T, Rank>::From(*in0);
    out0->mutable_data<T>(context.GetPlace());
    auto y = EigenTensor<T, Rank>::From(*out0);
    auto place = context.GetEigenDevice<Place>();
    y.device(place) = x.broadcast(bcast_dims);
  }
};

template <typename Place, typename T>
class ExpandGradKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
Y
yangyaming 已提交
91 92
    auto* in0 = context.Input<framework::Tensor>("X");
    auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
Y
yangyaming 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    auto x_dims = in0->dims();
    std::vector<int> reshape_dims_vec;
    std::vector<int> reduce_dims_vec;
    for (size_t i = 0; i < expand_times.size(); ++i) {
      if (expand_times[i] == 1) {
        reshape_dims_vec.push_back(x_dims[i]);
      } else {
        if (x_dims[i] == 1) {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
        } else {
          reduce_dims_vec.push_back(reshape_dims_vec.size());
          reshape_dims_vec.push_back(expand_times[i]);
          reshape_dims_vec.push_back(x_dims[i]);
        }
      }
    }

    int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7;
Y
yangyaming 已提交
112 113
    // no need reduce, just copy
    if (reduce_dims_vec.size() == 0) {
Y
yangyaming 已提交
114 115 116 117
      auto* in0 =
          context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
      auto* out0 =
          context.Output<framework::LoDTensor>(framework::GradVarName("X"));
Y
yangyaming 已提交
118 119 120 121 122 123 124 125 126 127
      out0->mutable_data<T>(context.GetPlace());
      if (platform::is_cpu_place(context.GetPlace())) {
        out0->CopyFrom<T>(*in0, platform::CPUPlace());
      } else {
        out0->CopyFrom<T>(*in0, platform::GPUPlace());
      }
    } else {
      switch (dims) {
        REP_EXPAND_GRAD_TEMPLATE(72)
        default:
Y
yangyaming 已提交
128 129
          PADDLE_ENFORCE(
              false, "Only support tensor with rank being between 1 and 6.");
Y
yangyaming 已提交
130 131
      };
    }
Y
yangyaming 已提交
132 133 134 135 136 137 138 139 140 141
  }

 protected:
  template <int Dims>
  void ExpandBackward(const framework::ExecutionContext& context,
                      const std::vector<int>& reshape_dims_vec,
                      const std::vector<int>& reduce_dims_vec) const {
    size_t reshape_size = Dims / 6 + 1;
    size_t reduce_size = Dims % 6 + 1;
    PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
Y
yangyaming 已提交
142
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
143 144
                      "reshape dimensions.");
    PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
Y
yangyaming 已提交
145
                      "Inconsistent size between template Dims and "
Y
yangyaming 已提交
146
                      "reduce dimensions.");
Y
yangyaming 已提交
147 148 149 150 151
    auto* in0 =
        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto* out0 =
        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
    auto x = EigenVector<T>::Flatten(*(context.Input<framework::Tensor>("X")));
Y
yangyaming 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    out0->mutable_data<T>(context.GetPlace());
    auto x_grad = EigenVector<T>::Flatten(*out0);
    Eigen::DSizes<int, Dims / 6 + 1> reshape_dims;
    for (size_t i = 0; i < reshape_size; ++i) {
      reshape_dims[i] = reshape_dims_vec[i];
    }
    Eigen::DSizes<int, Dims % 6 + 1> reduce_dims;
    for (size_t i = 0; i < reduce_size; ++i) {
      reduce_dims[i] = reduce_dims_vec[i];
    }
    auto out_grad = EigenVector<T>::Flatten(*in0);
    x_grad.device(context.GetEigenDevice<Place>()) =
        out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
  }
};

}  // operators
}  // paddle