transpose_op.cc 13.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
xzl 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
xzl 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
xzl 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
xzl 已提交
14

15
#include <memory>
16
#include <string>
17
#include <vector>
X
xzl 已提交
18

19 20 21
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
22
#include "paddle/fluid/framework/op_registry.h"
23

X
xzl 已提交
24 25 26 27 28 29 30
namespace paddle {
namespace operators {

class TransposeOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

31
  void InferShape(framework::InferShapeContext *ctx) const override {
32 33
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Transpose");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose");
Q
Qiao Longfei 已提交
34 35
    auto x_dims = ctx->GetInputDim("X");
    std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
36

X
xzl 已提交
37
    size_t x_rank = x_dims.size();
X
xzl 已提交
38
    size_t axis_size = axis.size();
X
xzl 已提交
39

40 41 42
    // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
    // axis_size
    PADDLE_ENFORCE_GE(x_rank,
43
                      axis_size,
44 45
                      platform::errors::InvalidArgument(
                          "The input tensor's dimension "
46
                          "should be equal to or greater than the axis's size. "
47 48
                          "But received input tensor's dimension is %d, "
                          "axis's size is %d",
49 50
                          x_rank,
                          axis_size));
51 52 53

    std::vector<int> count(axis_size, 0);
    for (size_t i = 0; i < axis_size; i++) {
54 55
      PADDLE_ENFORCE_GE(axis[i],
                        0,
56 57 58
                        platform::errors::InvalidArgument(
                            "The axis should be greater than or equal to 0."
                            "But received %d of axis[%d]",
59 60
                            axis[i],
                            i));
61

62
      PADDLE_ENFORCE_EQ(
63 64
          axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
          true,
65 66 67 68 69 70 71
          platform::errors::InvalidArgument(
              "Each element of Attribute axis should "
              "be a unique value range from 0 to (dims - 1), "
              "where the dims is the axis's size, "
              "unique value means this axis value can appear only once. "
              "But received axis[%d] is %d, axis_size is %d, "
              "count[axis[%d]] is %d",
72 73 74 75 76
              i,
              axis[i],
              axis_size,
              i,
              count[axis[i]]));
X
xzl 已提交
77
    }
X
xzl 已提交
78

X
xzl 已提交
79
    framework::DDim out_dims(x_dims);
J
Jacek Czaja 已提交
80 81 82
#ifdef PADDLE_WITH_MKLDNN
    // Here we need to match dims to paddle layout
    // as we are producing non-oneDNN result
83
    if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) &&
84 85
        (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
         phi::DataLayout::kNHWC)) {
86
      auto dims = phi::vectorize<int>(x_dims);
J
Jacek Czaja 已提交
87 88 89 90 91 92
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      x_dims = x_dims.reshape(dims);
      VLOG(3)
          << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
    }
#endif
93
    for (size_t i = 0; i < axis_size; i++) {
X
xzl 已提交
94
      out_dims[i] = x_dims[axis[i]];
X
xzl 已提交
95
    }
Q
Qiao Longfei 已提交
96
    ctx->SetOutputDim("Out", out_dims);
X
xzl 已提交
97
  }
98 99

 protected:
100
  phi::KernelKey GetExpectedKernelType(
101
      const framework::ExecutionContext &ctx) const override {
102
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
103
    auto &data_format = ctx.Attr<std::string>("data_format");
104
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
105 106
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
107
  }
X
xzl 已提交
108 109 110 111
};

class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
112
  void Make() override {
113
    AddInput(
X
xzl 已提交
114
        "X",
115 116
        "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
    AddOutput("Out", "(Tensor)The output tensor.");
X
xzl 已提交
117 118
    AddAttr<std::vector<int>>(
        "axis",
119 120 121
        "(vector<int>) A list of values, and the size of the list should be "
        "the same with the input tensor rank. This operator permutes the input "
        "tensor's axes according to the values given.");
122 123
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
124 125
        .SetDefault(false)
        .AsExtra();
126 127 128 129 130 131
    AddAttr<std::string>(
        "data_format",
        "(string, default NCHW) Only used in "
        "An optional string from: \"NHWC\", \"NCHW\". "
        "Defaults to \"NHWC\". Specify the data format of the output data, "
        "the input will be transformed automatically. ")
132 133
        .SetDefault("AnyLayout")
        .AsExtra();
134 135 136 137
    AddAttr<bool>(
        "use_quantizer",
        "(bool, default false) "
        "This parameter is no longer used. Use 'mkldnn_data_type' instead.")
138 139
        .SetDefault(false)
        .AsExtra();
140 141 142 143
    AddAttr<std::string>(
        "mkldnn_data_type",
        "(string, default \"float32\"). Data type of mkldnn kernel")
        .SetDefault("float32")
144 145
        .InEnum({"float32", "int8", "bfloat16"})
        .AsExtra();
146
    /* int8 parameters */
X
xzl 已提交
147
    AddComment(R"DOC(
148 149
Transpose Operator.

150 151
The input tensor will be permuted according to the axes given.
The behavior of this operator is similar to how `numpy.transpose` works.
Y
ying 已提交
152

153 154 155 156 157 158
- suppose the input `X` is a 2-D tensor:
    $$
    X = \begin{pmatrix}
    0 &1 &2 \\
    3 &4 &5
    \end{pmatrix}$$
W
wanghaoshuang 已提交
159

160
    the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)
W
wanghaoshuang 已提交
161

162
    then the output $Y$ is:
W
wanghaoshuang 已提交
163

164 165 166 167 168 169
    $$
    Y = \begin{pmatrix}
         0 &3 \\
         1 &4  \\
         2 &5
    \end{pmatrix}$$
W
wanghaoshuang 已提交
170

171
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
172
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
173

X
xzl 已提交
174 175 176 177 178 179 180 181
)DOC");
  }
};

class TransposeOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

182
  void InferShape(framework::InferShapeContext *ctx) const override {
183
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TransposeOpGrad");
184 185 186 187
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
                   "TransposeOpGrad");
Q
Qiao Longfei 已提交
188 189 190 191 192
    auto x_dims = ctx->GetInputDim("X");
    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
    }
X
xzl 已提交
193
  }
194 195

 protected:
196
  phi::KernelKey GetExpectedKernelType(
197
      const framework::ExecutionContext &ctx) const override {
198 199
    auto data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
J
jiahongyu 已提交
200
    std::string data_format = ctx.Attr<std::string>("data_format");
201
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
202 203
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
204
  }
X
xzl 已提交
205 206
};

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
 public:
  Transpose2Op(const std::string &type,
               const framework::VariableNameMap &inputs,
               const framework::VariableNameMap &outputs,
               const framework::AttributeMap &attrs)
      : TransposeOp(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext *ctx) const override {
    TransposeOp::InferShape(ctx);
222
    if (!ctx->HasOutput("XShape")) return;
223 224 225 226 227 228
    const auto &in_dims = ctx->GetInputDim("X");
    std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
    x_shape_dim[0] = 0;
    for (int i = 0; i < in_dims.size(); ++i) {
      x_shape_dim[i + 1] = in_dims[i];
    }
229
    ctx->SetOutputDim("XShape", phi::make_ddim(x_shape_dim));
230 231 232 233
    ctx->ShareLoD("X", /*->*/ "XShape");
  }

 protected:
234
  phi::KernelKey GetExpectedKernelType(
235
      const framework::ExecutionContext &ctx) const override {
236 237
    framework::proto::VarType::Type data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
238
    std::string data_format = ctx.Attr<std::string>("data_format");
239
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
240 241
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
242 243 244
  }
};

245
class Transpose2OpMaker : public framework::OpProtoAndCheckerMaker {
246 247
 public:
  void Make() override {
248 249 250 251 252 253 254 255 256
    AddInput(
        "X",
        "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
    AddOutput("Out", "(Tensor)The output tensor.");
    AddAttr<std::vector<int>>(
        "axis",
        "(vector<int>) A list of values, and the size of the list should be "
        "the same with the input tensor rank. This operator permutes the input "
        "tensor's axes according to the values given.");
257 258 259
    AddOutput("XShape", "(Tensor)The output tensor.")
        .AsIntermediate()
        .AsExtra();
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    AddComment(R"DOC(
Transpose Operator.

The input tensor will be permuted according to the axes given.
The behavior of this operator is similar to how `numpy.transpose` works.

- suppose the input `X` is a 2-D tensor:
    $$
    X = \begin{pmatrix}
    0 &1 &2 \\
    3 &4 &5
    \end{pmatrix}$$

    the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)

    then the output $Y$ is:

    $$
    Y = \begin{pmatrix}
         0 &3 \\
         1 &4  \\
         2 &5
    \end{pmatrix}$$

- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.

)DOC");
288 289 290
  }
};

H
hong 已提交
291 292
template <typename T>
class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
293
 public:
H
hong 已提交
294
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
295

296
  void Apply(GradOpPtr<T> grad_op) const override {
297
    grad_op->SetType("transpose2_grad");
H
hong 已提交
298 299 300 301
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
302 303 304
  }
};

305 306 307 308 309 310 311 312 313 314 315 316 317 318
template <typename T>
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("transpose2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

319 320 321 322 323
class Transpose2OpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
324 325 326 327 328
    OP_INOUT_CHECK(
        ctx->HasInput("XShape"), "Input", "XShape", "Transpose2OpGrad");
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
                   "Input",
                   framework::GradVarName("Out"),
329
                   "Transpose2OpGrad");
330 331
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      auto xshape_dim = ctx->GetInputDim("XShape");
332
      auto x_shape_dim = phi::slice_ddim(xshape_dim, 1, xshape_dim.size());
333 334 335 336 337 338
      ctx->SetOutputDim(framework::GradVarName("X"), x_shape_dim);
      ctx->ShareLoD("XShape", framework::GradVarName("X"));
    }
  }

 protected:
339
  phi::KernelKey GetExpectedKernelType(
340
      const framework::ExecutionContext &ctx) const override {
341 342 343
    framework::proto::VarType::Type data_type =
        OperatorWithKernel::IndicateVarDataType(ctx,
                                                framework::GradVarName("Out"));
J
jiahongyu 已提交
344
    std::string data_format = ctx.Attr<std::string>("data_format");
345
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
346 347
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
348 349 350
  }
};

H
hong 已提交
351 352 353 354 355 356 357 358
class TransposeGradInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SyncTypeAndDataType(framework::GradVarName("Out"),
                             framework::GradVarName("X"));
  }
};

X
xzl 已提交
359 360 361 362
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
363
REGISTER_OPERATOR(
364 365 366
    transpose,
    ops::TransposeOp,
    ops::TransposeOpMaker,
H
hong 已提交
367 368
    paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
    paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
369 370
REGISTER_OPERATOR(transpose_grad,
                  ops::TransposeOpGrad,
H
hong 已提交
371
                  ops::TransposeGradInferVarType);
372

373 374 375
REGISTER_OPERATOR(transpose2,
                  ops::Transpose2Op,
                  ops::Transpose2OpMaker,
H
hong 已提交
376 377
                  ops::Transpose2GradMaker<paddle::framework::OpDesc>,
                  ops::Transpose2GradMaker<paddle::imperative::OpBase>);
378 379
REGISTER_OPERATOR(transpose2_grad,
                  ops::Transpose2OpGrad,
H
hong 已提交
380
                  ops::TransposeGradInferVarType,
381 382
                  ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
                  ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);