expand_op.cc 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
yangyaming 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/expand_op.h"
S
sneaxiy 已提交
16
#include <memory>
17
#include <string>
18
#include <vector>
Y
yangyaming 已提交
19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace operators {

using framework::Tensor;

class ExpandOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
Y
yangyaming 已提交
30
  void InferShape(framework::InferShapeContext* ctx) const override {
31 32 33
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");

Y
yangyaming 已提交
34
    auto x_dims = ctx->GetInputDim("X");
35 36 37 38 39
    std::vector<int> expand_times(x_dims.size(), -1);

    if (!ctx->HasInputs("expand_times_tensor")) {
      expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
    }
Y
yangyaming 已提交
40 41

    PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
42
                      "The number of Attr(expand_times)'s value must be equal "
Y
yangyaming 已提交
43
                      "to the rank of Input(X).");
Y
yangyaming 已提交
44
    PADDLE_ENFORCE_LE(x_dims.size(), 6,
Y
yangyaming 已提交
45
                      "The rank of Input(X) must not be greater than 6.");
Y
yangyaming 已提交
46 47 48

    std::vector<int64_t> out_shape(x_dims.size());
    for (size_t i = 0; i < expand_times.size(); ++i) {
49 50 51 52 53
      if (x_dims[i] == -1 || expand_times[i] == -1) {
        out_shape[i] = -1;
      } else {
        out_shape[i] = x_dims[i] * expand_times[i];
      }
M
minqiyang 已提交
54 55
    }

Y
yangyaming 已提交
56
    ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
57 58 59
    if (out_shape[0] == x_dims[0]) {
      ctx->ShareLoD("X", "Out");
    }
Y
yangyaming 已提交
60
  }
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
                                   ctx.device_context());
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "expand_times_tensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
Y
yangyaming 已提交
78 79 80 81
};

class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
82
  void Make() override {
Y
yangyaming 已提交
83
    AddInput("X",
C
caoying03 已提交
84 85
             "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
             "X is the input to be expanded.");
86 87 88
    AddInput("expand_times_tensor", "(Tensor Tensor<int>), epxand times for X")
        .AsDuplicable()
        .AsDispensable();
Y
yangyaming 已提交
89
    AddOutput("Out",
C
caoying03 已提交
90 91 92 93 94
              "(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
              "The rank of Output(Out) have the same with Input(X). "
              "After expanding, size of each dimension of Output(Out) is equal "
              "to size of the corresponding dimension of Input(X) multiplying "
              "the corresponding value given by Attr(expand_times).");
95
    AddAttr<std::vector<int>>("expand_times",
96 97
                              "Expand times number for each dimension.")
        .SetDefault({});
Y
yangyaming 已提交
98
    AddComment(R"DOC(
Y
yangyaming 已提交
99
Expand operator tiles the input by given times number. You should set times
100
number for each dimension by providing attribute 'expand_times'. The rank of X
C
caoying03 已提交
101 102
should be in [1, 6]. Please note that size of 'expand_times' must be the same
with X's rank. Following is a using case:
Y
yangyaming 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

Input(X) is a 3-D tensor with shape [2, 3, 1]:

        [
           [[1], [2], [3]],
           [[4], [5], [6]]
        ]

Attr(expand_times):  [1, 2, 2]

Output(Out) is a 3-D tensor with shape [2, 6, 2]:

        [
            [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
            [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
        ]

Y
yangyaming 已提交
120 121 122 123 124 125 126 127 128
)DOC");
  }
};

class ExpandGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
Y
yangyaming 已提交
129 130 131 132
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
                   "Input(Out@GRAD) should not be null.");
133

Y
yangyaming 已提交
134 135
    auto x_dims = ctx->GetInputDim("X");
    std::vector<int> expand_times =
136
        ctx->Attrs().Get<std::vector<int>>("expand_times");
137

Y
yangyaming 已提交
138
    auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
Y
yangyaming 已提交
139

M
minqiyang 已提交
140
    size_t start_pos = 0u;
M
minqiyang 已提交
141
    if (!ctx->IsRuntime() && x_dims[0] < 0) {
M
minqiyang 已提交
142
      PADDLE_ENFORCE_EQ(
M
minqiyang 已提交
143
          x_dims[0], out_dims[0],
M
minqiyang 已提交
144 145
          "The first dimension size of Input(Out@GRAD) should be "
          "equal to the crroresponding dimension size of Input(X)");
M
minqiyang 已提交
146 147 148 149
      start_pos = 1u;
    }

    for (size_t i = start_pos; i < expand_times.size(); ++i) {
Y
yangyaming 已提交
150
      PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
Y
yangyaming 已提交
151 152
                        "Each dimension size of Input(Out@GRAD) should be "
                        "equal to multiplication of crroresponding dimension "
153
                        "size of Input(X) and Attr(expand_times) value.");
Y
yangyaming 已提交
154 155
    }

Y
yangyaming 已提交
156 157 158 159 160
    auto x_grad_name = framework::GradVarName("X");

    if (ctx->HasOutput(x_grad_name)) {
      ctx->SetOutputDim(x_grad_name, x_dims);
    }
Y
yangyaming 已提交
161
  }
162 163 164 165

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
166 167 168
    return framework::OpKernelType(
        ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
        ctx.device_context());
169 170 171 172 173 174 175 176 177 178 179
  }

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "expand_times_tensor") {
      return expected_kernel_type;
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
Y
yangyaming 已提交
180 181
};

S
sneaxiy 已提交
182 183 184 185 186 187 188 189 190 191 192
class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType("expand_grad");
    op->SetInput("X", Input("X"));
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
193
    op->SetInput("expand_times_tensor", Input("expand_times_tensor"));
S
sneaxiy 已提交
194 195 196 197 198
    op->SetAttrMap(Attrs());
    return op;
  }
};

199 200
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X");

Y
yangyaming 已提交
201 202 203 204
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
205
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
S
sneaxiy 已提交
206
                  ops::ExpandGradOpDescMaker);
207 208
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp,
                  ops::ExpandGradNoNeedBufVarsInferer);
Y
yangyaming 已提交
209
REGISTER_OP_CPU_KERNEL(
210 211 212 213
    expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
    ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>,
    ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>);
Q
QI JUN 已提交
214 215
REGISTER_OP_CPU_KERNEL(
    expand_grad,
216 217
    ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>);