reshape_op.h 3.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14 15 16

#pragma once

Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Y
Yibing Liu 已提交
19 20 21 22

namespace paddle {
namespace operators {

Q
QI JUN 已提交
23
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
24
class ReshapeKernel : public framework::OpKernel<T> {
Y
Yibing Liu 已提交
25 26
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
Y
Yibing Liu 已提交
27 28
    auto* out = ctx.Output<framework::Tensor>("Out");
    auto* in = ctx.Input<framework::Tensor>("X");
Y
ying 已提交
29

C
caoying03 已提交
30 31
    auto out_dims =
        ValidateShape(ctx.Attr<std::vector<int>>("shape"), in->dims());
Y
Yan Chunwei 已提交
32 33 34 35 36 37 38 39 40
    bool inplace = ctx.Attr<bool>("inplace");
    if (!inplace) {
      out->mutable_data<T>(ctx.GetPlace());
      framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
      out->Resize(out_dims);
    } else {
      out->ShareDataWith(*in);
      out->Resize(out_dims);
    }
Y
Yibing Liu 已提交
41
  }
Y
ying 已提交
42 43

 private:
C
caoying03 已提交
44 45 46 47 48 49 50 51
  framework::DDim ValidateShape(const std::vector<int> shape_attr,
                                const framework::DDim& in_dims) const {
    const int64_t in_size = framework::product(in_dims);
    // only one dimension canbe set to -1, whose size will be automatically
    // infered.
    const int64_t unknown_index = -1;

    std::vector<int64_t> output_shape(shape_attr.size(), 0);
Y
ying 已提交
52
    int64_t capacity = 1;
C
caoying03 已提交
53 54 55 56 57 58
    int neg_dim_idx = -1;
    for (size_t i = 0; i < shape_attr.size(); ++i) {
      if (shape_attr[i] == unknown_index) neg_dim_idx = i;
      capacity *= (shape_attr[i] ? shape_attr[i] : in_dims[i]);
      output_shape[i] =
          (shape_attr[i] ? static_cast<int64_t>(shape_attr[i]) : in_dims[i]);
Y
ying 已提交
59 60
    }

C
caoying03 已提交
61 62 63 64 65 66 67 68
    if (neg_dim_idx != -1) {
      output_shape[neg_dim_idx] = -in_size / capacity;
      PADDLE_ENFORCE_EQ(output_shape[neg_dim_idx] * capacity, -in_size,
                        "Invalid shape is given.");
    } else {
      PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
    }
    return framework::make_ddim(output_shape);
Y
ying 已提交
69
  }
Y
Yibing Liu 已提交
70 71
};

Q
QI JUN 已提交
72
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
73
class ReshapeGradKernel : public framework::OpKernel<T> {
Y
Yibing Liu 已提交
74 75
 public:
  void Compute(const framework::ExecutionContext& ctx) const {
Y
Yibing Liu 已提交
76 77
    auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
C
caoying03 已提交
78

Y
Yibing Liu 已提交
79
    d_x->mutable_data<T>(ctx.GetPlace());
C
caoying03 已提交
80
    bool inplace = ctx.Attr<bool>("inplace");
Y
Yibing Liu 已提交
81 82

    auto in_dims = d_x->dims();
C
caoying03 已提交
83 84 85 86 87 88 89
    if (!inplace) {
      framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
      d_x->Resize(in_dims);
    } else {
      d_x->ShareDataWith(*d_out);
      d_x->Resize(in_dims);
    }
Y
Yibing Liu 已提交
90 91
  }
};
H
Helin Wang 已提交
92 93
}  // namespace operators
}  // namespace paddle