pad_op.h 4.3 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
          typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

Q
QI JUN 已提交
29
template <typename DeviceContext, typename T, size_t D>
W
wanghaoshuang 已提交
30
void PadFunction(const framework::ExecutionContext& context) {
W
wanghaoshuang 已提交
31
  auto pads = context.Attr<std::vector<int>>("paddings");
W
wanghaoshuang 已提交
32
  Eigen::array<std::pair<int, int>, D> paddings;
W
wanghaoshuang 已提交
33
  for (size_t i = 0; i < paddings.size(); ++i) {
W
wanghaoshuang 已提交
34 35
    paddings[i].first = pads[i * 2];
    paddings[i].second = pads[i * 2 + 1];
W
wanghaoshuang 已提交
36
  }
W
wanghaoshuang 已提交
37
  T pad_value = context.Attr<T>("pad_value");
W
wanghaoshuang 已提交
38

W
wanghaoshuang 已提交
39 40 41
  auto* x = context.Input<Tensor>("X");
  auto* out = context.Output<Tensor>("Out");
  out->mutable_data<T>(context.GetPlace());
W
wanghaoshuang 已提交
42

W
wanghaoshuang 已提交
43 44
  auto x_tensor = EigenTensor<T, D>::From(*x);
  auto out_tensor = EigenTensor<T, D>::From(*out);
Q
QI JUN 已提交
45 46
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
W
wanghaoshuang 已提交
47
  out_tensor.device(place) = x_tensor.pad(paddings, pad_value);
W
wanghaoshuang 已提交
48 49
}

Q
QI JUN 已提交
50
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
51
class PadKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
52 53
 public:
  void Compute(const framework::ExecutionContext& context) const override {
W
wanghaoshuang 已提交
54 55
    int rank = context.Input<Tensor>("X")->dims().size();
    switch (rank) {
W
wanghaoshuang 已提交
56
      case 1:
Q
QI JUN 已提交
57
        PadFunction<DeviceContext, T, 1>(context);
W
wanghaoshuang 已提交
58 59
        break;
      case 2:
Q
QI JUN 已提交
60
        PadFunction<DeviceContext, T, 2>(context);
W
wanghaoshuang 已提交
61 62
        break;
      case 3:
Q
QI JUN 已提交
63
        PadFunction<DeviceContext, T, 3>(context);
W
wanghaoshuang 已提交
64 65
        break;
      case 4:
Q
QI JUN 已提交
66
        PadFunction<DeviceContext, T, 4>(context);
W
wanghaoshuang 已提交
67 68
        break;
      case 5:
Q
QI JUN 已提交
69
        PadFunction<DeviceContext, T, 5>(context);
W
wanghaoshuang 已提交
70 71
        break;
      case 6:
Q
QI JUN 已提交
72
        PadFunction<DeviceContext, T, 6>(context);
W
wanghaoshuang 已提交
73 74
        break;
      default:
W
wanghaoshuang 已提交
75 76
        PADDLE_THROW(
            "PadOp only support tensors with no more than 6 dimensions.");
W
wanghaoshuang 已提交
77
    }
W
wanghaoshuang 已提交
78 79 80
  }
};

Q
QI JUN 已提交
81
template <typename DeviceContext, typename T, size_t D>
W
wanghaoshuang 已提交
82
void PadGradFunction(const framework::ExecutionContext& context) {
W
wanghaoshuang 已提交
83
  auto pads = context.Attr<std::vector<int>>("paddings");
W
wanghaoshuang 已提交
84
  Eigen::array<std::pair<int, int>, D> paddings;
W
wanghaoshuang 已提交
85
  for (size_t i = 0; i < paddings.size(); ++i) {
W
wanghaoshuang 已提交
86 87
    paddings[i].first = -pads[i * 2];
    paddings[i].second = -pads[i * 2 + 1];
W
wanghaoshuang 已提交
88
  }
W
wanghaoshuang 已提交
89 90
  auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
  auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
W
wanghaoshuang 已提交
91 92 93 94
  if (d_x != nullptr) {
    d_x->mutable_data<T>(context.GetPlace());
    auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
    auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
Q
QI JUN 已提交
95 96
    auto& place =
        *context.template device_context<DeviceContext>().eigen_device();
W
wanghaoshuang 已提交
97 98
    d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0);
  }
W
wanghaoshuang 已提交
99 100
}

Q
QI JUN 已提交
101
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
102
class PadGradKernel : public framework::OpKernel<T> {
W
wanghaoshuang 已提交
103
 public:
W
wanghaoshuang 已提交
104
  void Compute(const framework::ExecutionContext& context) const override {
W
wanghaoshuang 已提交
105
    size_t rank =
W
wanghaoshuang 已提交
106
        context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
W
wanghaoshuang 已提交
107
    switch (rank) {
W
wanghaoshuang 已提交
108
      case 1:
Q
QI JUN 已提交
109
        PadGradFunction<DeviceContext, T, 1>(context);
W
wanghaoshuang 已提交
110 111
        break;
      case 2:
Q
QI JUN 已提交
112
        PadGradFunction<DeviceContext, T, 2>(context);
W
wanghaoshuang 已提交
113 114
        break;
      case 3:
Q
QI JUN 已提交
115
        PadGradFunction<DeviceContext, T, 3>(context);
W
wanghaoshuang 已提交
116 117
        break;
      case 4:
Q
QI JUN 已提交
118
        PadGradFunction<DeviceContext, T, 4>(context);
W
wanghaoshuang 已提交
119 120
        break;
      case 5:
Q
QI JUN 已提交
121
        PadGradFunction<DeviceContext, T, 5>(context);
W
wanghaoshuang 已提交
122 123
        break;
      case 6:
Q
QI JUN 已提交
124
        PadGradFunction<DeviceContext, T, 6>(context);
W
wanghaoshuang 已提交
125 126
        break;
      default:
W
wanghaoshuang 已提交
127 128
        PADDLE_THROW(
            "PadOp only support tensors with no more than 6 dimensions.");
W
wanghaoshuang 已提交
129 130 131 132 133 134
    }
  }
};

}  // namespace operators
}  // namespace paddle