transpose_op.cc 13.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
X
xzl 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
xzl 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
xzl 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
xzl 已提交
14

15
#include <memory>
16
#include <string>
17
#include <vector>
X
xzl 已提交
18

19 20 21 22
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

23 24 25
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
26
#include "paddle/fluid/framework/op_registry.h"
27

X
xzl 已提交
28 29 30 31 32 33 34
namespace paddle {
namespace operators {

class TransposeOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

35
  void InferShape(framework::InferShapeContext *ctx) const override {
36 37
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Transpose");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose");
Q
Qiao Longfei 已提交
38 39
    auto x_dims = ctx->GetInputDim("X");
    std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
40

41 42
    int x_rank = x_dims.size();
    int axis_size = axis.size();
X
xzl 已提交
43

44 45 46
    // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
    // axis_size
    PADDLE_ENFORCE_GE(x_rank,
47
                      axis_size,
48 49
                      platform::errors::InvalidArgument(
                          "The input tensor's dimension "
50
                          "should be equal to or greater than the axis's size. "
51 52
                          "But received input tensor's dimension is %d, "
                          "axis's size is %d",
53 54
                          x_rank,
                          axis_size));
55

56
    std::vector<int> formated_axis = axis;
57
    std::vector<int> count(axis_size, 0);
58 59 60 61 62 63 64 65 66 67
    for (int i = 0; i < axis_size; i++) {
      PADDLE_ENFORCE_LT(axis[i],
                        axis_size,
                        platform::errors::InvalidArgument(
                            "The reduce dim index %d should be in the "
                            "range [ -dimension(X), dimension(X) ) "
                            "which dimesion = %d. But received dim index = %d.",
                            i,
                            axis_size,
                            axis[i]));
68
      PADDLE_ENFORCE_GE(axis[i],
69
                        -axis_size,
70
                        platform::errors::InvalidArgument(
71 72 73 74 75 76
                            "The reduce dim index %d should be in the "
                            "range [ -dimension(X), dimension(X) )  "
                            "which dimesion = %d. But received dim index = %d.",
                            i,
                            axis_size,
                            axis[i]));
77

78 79 80 81 82 83 84 85 86 87
      if (axis[i] < 0) {
        formated_axis[i] = axis[i] + axis_size;
      }
      PADDLE_ENFORCE_EQ(++count[formated_axis[i]],
                        1,
                        platform::errors::InvalidArgument(
                            "Each element of axis should be unique. but "
                            "axis[%d] is %d appear not only once",
                            i,
                            axis[i]));
X
xzl 已提交
88
    }
X
xzl 已提交
89

X
xzl 已提交
90
    framework::DDim out_dims(x_dims);
J
Jacek Czaja 已提交
91 92 93
#ifdef PADDLE_WITH_MKLDNN
    // Here we need to match dims to paddle layout
    // as we are producing non-oneDNN result
94
    if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) &&
95 96
        (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
         phi::DataLayout::kNHWC)) {
97
      auto dims = phi::vectorize<int>(x_dims);
J
Jacek Czaja 已提交
98 99 100 101 102 103
      std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
      x_dims = x_dims.reshape(dims);
      VLOG(3)
          << "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
    }
#endif
104 105
    for (int i = 0; i < axis_size; i++) {
      out_dims[i] = x_dims[formated_axis[i]];
X
xzl 已提交
106
    }
Q
Qiao Longfei 已提交
107
    ctx->SetOutputDim("Out", out_dims);
X
xzl 已提交
108
  }
109 110

 protected:
111
  phi::KernelKey GetExpectedKernelType(
112
      const framework::ExecutionContext &ctx) const override {
113
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
114
    auto &data_format = ctx.Attr<std::string>("data_format");
115
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
116 117
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
118
  }
X
xzl 已提交
119 120 121 122
};

class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
123
  void Make() override {
124
    AddInput(
X
xzl 已提交
125
        "X",
126 127
        "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
    AddOutput("Out", "(Tensor)The output tensor.");
X
xzl 已提交
128 129
    AddAttr<std::vector<int>>(
        "axis",
130 131 132
        "(vector<int>) A list of values, and the size of the list should be "
        "the same with the input tensor rank. This operator permutes the input "
        "tensor's axes according to the values given.");
133 134
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
135 136
        .SetDefault(false)
        .AsExtra();
137 138 139 140 141 142
    AddAttr<std::string>(
        "data_format",
        "(string, default NCHW) Only used in "
        "An optional string from: \"NHWC\", \"NCHW\". "
        "Defaults to \"NHWC\". Specify the data format of the output data, "
        "the input will be transformed automatically. ")
143 144
        .SetDefault("AnyLayout")
        .AsExtra();
145 146 147 148
    AddAttr<bool>(
        "use_quantizer",
        "(bool, default false) "
        "This parameter is no longer used. Use 'mkldnn_data_type' instead.")
149 150
        .SetDefault(false)
        .AsExtra();
151 152 153 154
    AddAttr<std::string>(
        "mkldnn_data_type",
        "(string, default \"float32\"). Data type of mkldnn kernel")
        .SetDefault("float32")
155 156
        .InEnum({"float32", "int8", "bfloat16"})
        .AsExtra();
157
    /* int8 parameters */
X
xzl 已提交
158
    AddComment(R"DOC(
159 160
Transpose Operator.

161 162
The input tensor will be permuted according to the axes given.
The behavior of this operator is similar to how `numpy.transpose` works.
Y
ying 已提交
163

164 165 166 167 168 169
- suppose the input `X` is a 2-D tensor:
    $$
    X = \begin{pmatrix}
    0 &1 &2 \\
    3 &4 &5
    \end{pmatrix}$$
W
wanghaoshuang 已提交
170

171
    the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)
W
wanghaoshuang 已提交
172

173
    then the output $Y$ is:
W
wanghaoshuang 已提交
174

175 176 177 178 179 180
    $$
    Y = \begin{pmatrix}
         0 &3 \\
         1 &4  \\
         2 &5
    \end{pmatrix}$$
W
wanghaoshuang 已提交
181

182
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
183
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
184

X
xzl 已提交
185 186 187 188 189 190 191 192
)DOC");
  }
};

class TransposeOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

193
 protected:
194
  phi::KernelKey GetExpectedKernelType(
195
      const framework::ExecutionContext &ctx) const override {
196 197
    auto data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
J
jiahongyu 已提交
198
    std::string data_format = ctx.Attr<std::string>("data_format");
199
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
200 201
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
202
  }
X
xzl 已提交
203 204
};

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
 public:
  Transpose2Op(const std::string &type,
               const framework::VariableNameMap &inputs,
               const framework::VariableNameMap &outputs,
               const framework::AttributeMap &attrs)
      : TransposeOp(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext *ctx) const override {
    TransposeOp::InferShape(ctx);
220
    if (!ctx->HasOutput("XShape")) return;
221 222 223 224 225 226
    const auto &in_dims = ctx->GetInputDim("X");
    std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
    x_shape_dim[0] = 0;
    for (int i = 0; i < in_dims.size(); ++i) {
      x_shape_dim[i + 1] = in_dims[i];
    }
227
    ctx->SetOutputDim("XShape", phi::make_ddim(x_shape_dim));
228 229 230 231
    ctx->ShareLoD("X", /*->*/ "XShape");
  }

 protected:
232
  phi::KernelKey GetExpectedKernelType(
233
      const framework::ExecutionContext &ctx) const override {
234 235
    framework::proto::VarType::Type data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "X");
J
jiahongyu 已提交
236
    std::string data_format = ctx.Attr<std::string>("data_format");
237
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
238 239
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
240 241 242
  }
};

243
class Transpose2OpMaker : public framework::OpProtoAndCheckerMaker {
244 245
 public:
  void Make() override {
246 247 248 249 250 251 252 253 254
    AddInput(
        "X",
        "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
    AddOutput("Out", "(Tensor)The output tensor.");
    AddAttr<std::vector<int>>(
        "axis",
        "(vector<int>) A list of values, and the size of the list should be "
        "the same with the input tensor rank. This operator permutes the input "
        "tensor's axes according to the values given.");
255 256 257
    AddOutput("XShape", "(Tensor)The output tensor.")
        .AsIntermediate()
        .AsExtra();
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    AddComment(R"DOC(
Transpose Operator.

The input tensor will be permuted according to the axes given.
The behavior of this operator is similar to how `numpy.transpose` works.

- suppose the input `X` is a 2-D tensor:
    $$
    X = \begin{pmatrix}
    0 &1 &2 \\
    3 &4 &5
    \end{pmatrix}$$

    the given `axes` is: $[1, 0]$, and $Y$ = transpose($X$, axis)

    then the output $Y$ is:

    $$
    Y = \begin{pmatrix}
         0 &3 \\
         1 &4  \\
         2 &5
    \end{pmatrix}$$

- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.

)DOC");
286 287 288
  }
};

H
hong 已提交
289 290
template <typename T>
class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
291
 public:
H
hong 已提交
292
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
293

294
  void Apply(GradOpPtr<T> grad_op) const override {
295
    grad_op->SetType("transpose2_grad");
H
hong 已提交
296 297 298 299
    grad_op->SetInput("XShape", this->Output("XShape"));
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
300 301 302
  }
};

303 304 305 306 307 308 309 310 311 312 313 314 315 316
template <typename T>
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("transpose2");
    grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
    grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
    grad_op->SetOutput("XShape", this->Input("XShape"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

317 318 319 320 321
class Transpose2OpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
322
  phi::KernelKey GetExpectedKernelType(
323
      const framework::ExecutionContext &ctx) const override {
324 325 326
    framework::proto::VarType::Type data_type =
        OperatorWithKernel::IndicateVarDataType(ctx,
                                                framework::GradVarName("Out"));
J
jiahongyu 已提交
327
    std::string data_format = ctx.Attr<std::string>("data_format");
328
    phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
329 330
    return phi::KernelKey(
        ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
331 332 333
  }
};

H
hong 已提交
334 335 336 337 338 339 340 341
class TransposeGradInferVarType : public framework::VarTypeInference {
 public:
  void operator()(framework::InferVarTypeContext *ctx) const override {
    ctx->SyncTypeAndDataType(framework::GradVarName("Out"),
                             framework::GradVarName("X"));
  }
};

X
xzl 已提交
342 343 344
}  // namespace operators
}  // namespace paddle

345 346 347 348 349 350 351
DECLARE_INFER_SHAPE_FUNCTOR(transpose_grad,
                            TransposeGradInferShapeFunctor,
                            PD_INFER_META(phi::TransposeGradInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(transpose2_grad,
                            Transpose2GradInferShapeFunctor,
                            PD_INFER_META(phi::TransposeGradInferMeta));
X
xzl 已提交
352
namespace ops = paddle::operators;
H
hong 已提交
353
REGISTER_OPERATOR(
354 355 356
    transpose,
    ops::TransposeOp,
    ops::TransposeOpMaker,
H
hong 已提交
357 358
    paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
    paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
359 360
REGISTER_OPERATOR(transpose_grad,
                  ops::TransposeOpGrad,
361 362
                  ops::TransposeGradInferVarType,
                  TransposeGradInferShapeFunctor);
363

364 365 366
REGISTER_OPERATOR(transpose2,
                  ops::Transpose2Op,
                  ops::Transpose2OpMaker,
H
hong 已提交
367 368
                  ops::Transpose2GradMaker<paddle::framework::OpDesc>,
                  ops::Transpose2GradMaker<paddle::imperative::OpBase>);
369 370
REGISTER_OPERATOR(transpose2_grad,
                  ops::Transpose2OpGrad,
H
hong 已提交
371
                  ops::TransposeGradInferVarType,
372
                  ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
373 374
                  ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>,
                  Transpose2GradInferShapeFunctor);