elementwise_op.h 23.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14 15

#pragma once
C
chengduo 已提交
16

17
#include <algorithm>  // for max
L
liuwei1031 已提交
18
#include <memory>
19
#include <string>
L
liuwei1031 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/data_layout.h"
24
#include "paddle/fluid/framework/op_version_registry.h"
25
#include "paddle/fluid/operators/common_infer_shape_functions.h"
26
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
C
chengduo 已提交
27

28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
G
gongweibao 已提交
31 32 33 34 35 36 37 38 39

namespace paddle {
namespace operators {

class ElementwiseOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  using Tensor = framework::Tensor;
C
chengduo 已提交
40 41

  void InferShape(framework::InferShapeContext *ctx) const override {
42 43 44 45
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ElementwiseOp");

46 47 48 49 50 51 52
    PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Y").front(),
                      framework::proto::VarType::LOD_TENSOR,
                      platform::errors::InvalidArgument(
                          "The input var's type should be LoDTensor, but the "
                          "received is %s [%s].",
                          ctx->GetInputsVarType("Y").front(),
                          ctx->Inputs("Y").front()));
C
chengduo 已提交
53 54

    if (ctx->GetInputsVarType("X").front() ==
55
        framework::proto::VarType::SELECTED_ROWS) {
56
      PADDLE_ENFORCE_EQ(
57 58
          ctx->GetInputDim("Y").size(),
          1u,
59 60 61 62 63
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the size of Y should be 1. "
              "But reveived the size of Y = %s.",
              ctx->GetInputDim("Y").size()));
64
      PADDLE_ENFORCE_EQ(
65 66
          ctx->GetInputDim("Y")[0],
          1,
67 68 69 70 71
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the first dimension of Y should be 1. "
              "But reveived the first dimension of Y = %s.",
              ctx->GetInputDim("Y")[0]));
72 73
    } else if (ctx->GetInputsVarType("X").front() !=
               framework::proto::VarType::LOD_TENSOR) {
74 75 76 77
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Input X's type[%s] is not supported by elementwise_op. Please set "
          "its type to LOD_TENSOR.",
          ctx->GetInputsVarType("X").front()));
C
chengduo 已提交
78
    }
79

80 81 82 83 84 85 86 87
    if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
      ctx->ShareDim("X", /*->*/ "Out");
      ctx->ShareLoD("X", /*->*/ "Out");
    } else {
      auto x_dims = ctx->GetInputDim("X");
      auto y_dims = ctx->GetInputDim("Y");
      int max_dim = std::max(x_dims.size(), y_dims.size());
      int axis = ctx->Attrs().Get<int>("axis");
88
      if (x_dims.size() == y_dims.size()) {
89 90
        PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
                          true,
91 92 93 94
                          platform::errors::InvalidArgument(
                              "axis should be -1 or 0 while the dimension of "
                              "tensor X (%s) is equal to the dimension of "
                              "tensor Y (%s), but received axis: %s",
95 96 97
                              x_dims.size(),
                              y_dims.size(),
                              axis));
98
      }
99 100
      PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
                        true,
101 102 103
                        platform::errors::InvalidArgument(
                            "The axis range must be [%s, %s), but axis is %s. "
                            "Please set the axis again.",
104 105 106
                            -1 * max_dim,
                            max_dim,
                            axis));
107 108
      axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
                       : axis);
109 110 111
      std::vector<int> x_dims_array(max_dim);
      std::vector<int> y_dims_array(max_dim);
      std::vector<int> out_dims_array(max_dim);
J
Jacek Czaja 已提交
112 113
#ifdef PADDLE_WITH_MKLDNN
      // (jczaja): Broadcasting of dims has to be done on Paddle shapes (NHWC)
114
      // if model is using NHWC and any of shapes in at least 3D
J
Jacek Czaja 已提交
115 116 117
      bool should_rotate =
          ctx->IsRunMKLDNNKernel() &&
          (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
118 119
           framework::DataLayout::kNHWC) &&
          (x_dims.size() >= 3 || y_dims.size() >= 3);
J
Jacek Czaja 已提交
120 121 122
      if (should_rotate) {
        // Pick bigger shape and rotate this one
        bool x_over_y = (x_dims.size() > y_dims.size());
123 124
        auto vdims = x_over_y ? phi::vectorize<int>(x_dims)
                              : phi::vectorize<int>(y_dims);
J
Jacek Czaja 已提交
125 126
        std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
        if (x_over_y) {
127
          x_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
128
        } else {
129
          y_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
130 131 132 133
        }
      }
#endif

134 135 136 137 138 139 140
      GetBroadcastDimsArrays(x_dims,
                             y_dims,
                             x_dims_array.data(),
                             y_dims_array.data(),
                             out_dims_array.data(),
                             max_dim,
                             axis);
J
Jacek Czaja 已提交
141 142 143
#ifdef PADDLE_WITH_MKLDNN
      // Now rotate shape back if needed (NHWC -> NCHW)
      if (should_rotate) {
144 145
        std::rotate(out_dims_array.begin() + 1,
                    out_dims_array.end() - 1,
J
Jacek Czaja 已提交
146 147 148
                    out_dims_array.end());
      }
#endif
149
      ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
150 151 152
      // to do
      ctx->ShareLoD("X", /*->*/ "Out");
    }
G
gongweibao 已提交
153
  }
154 155

  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
156
      const framework::ExecutionContext &ctx) const override {
157 158
    auto input_data_type =
        OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
159 160

#ifdef PADDLE_WITH_MKLDNN
161
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
162 163
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
164 165 166 167 168 169
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
170 171

  framework::OpKernelType GetKernelTypeForVar(
172 173
      const std::string &var_name,
      const framework::Tensor &tensor,
174
      const framework::OpKernelType &expected_kernel_type) const override {
175 176
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
177
      return framework::OpKernelType(
178 179
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
180
          tensor.layout());
181
    } else {
J
Jacek Czaja 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
#ifdef PADDLE_WITH_MKLDNN
      // When elementwise is first oneDNN op (there was some non oneDNN op
      // previously)
      // then we also need to rotate shape NHWC -> NCWH
      if ((expected_kernel_type.data_layout_ ==
           framework::DataLayout::kMKLDNN) &&
          (tensor.layout() != framework::DataLayout::kMKLDNN) &&
          paddle::platform::MKLDNNDeviceContext::tls()
                  .get_cur_paddle_data_layout() ==
              framework::DataLayout::kNHWC) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(),
                                       framework::DataLayout::kNHWC);
      }
#endif
197 198
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
199 200
    }
  }
G
gongweibao 已提交
201 202
};

C
chengduo 已提交
203 204 205
class ElementwiseOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
206
  std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
C
chengduo 已提交
207
      const override {
208 209
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
210 211 212
  }
};

G
gongweibao 已提交
213 214
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
215
  void Make() final {
216 217 218 219
    AddInputX();
    AddInputY();
    AddOpOutput();

G
gongweibao 已提交
220
    AddAttr<int>("axis",
221 222 223 224
                 "(int, default -1). If X.dimension != Y.dimension,"
                 "Y.dimension must be a subsequence of x.dimension. And axis "
                 "is the start dimension index "
                 "for broadcasting Y onto X. ")
225
        .SetDefault(-1);
226
    AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
227 228
        .SetDefault(false)
        .AsExtra();
229
    AddAttr<std::string>("x_data_format", "This parameter is no longer used.")
230 231
        .SetDefault("")
        .AsExtra();
232
    AddAttr<std::string>("y_data_format", "This parameter is no longer used.")
233 234
        .SetDefault("")
        .AsExtra();
235 236 237 238
    AddAttr<bool>(
        "use_quantizer",
        "(bool, default false) "
        "This parameter is no longer used. Use 'mkldnn_data_type' instead.")
239 240
        .SetDefault(false)
        .AsExtra();
241 242 243 244
    AddAttr<std::string>(
        "mkldnn_data_type",
        "(string, default \"float32\"). Data type of mkldnn kernel")
        .SetDefault("float32")
245 246
        .InEnum({"float32", "int8", "bfloat16"})
        .AsExtra();
247
    /* int8 parameters */
248 249
    AddAttr<float>("Scale_x",
                   "(float, default 1.0f), The quantize scale of X tensor")
250 251
        .SetDefault(1.0f)
        .AsExtra();
252 253
    AddAttr<float>("Scale_y",
                   "(float, default 1.0f), The quantize scale of Y tensor")
254 255
        .SetDefault(1.0f)
        .AsExtra();
256 257
    AddAttr<float>("Scale_out",
                   "(float, default 1.0f), The quantize scale of output data")
258 259
        .SetDefault(1.0f)
        .AsExtra();
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    AddOpComment();
  }

 protected:
  virtual void AddInputX() {
    AddInput("X", "(Tensor), The first input tensor of elementwise op.");
  }
  virtual void AddInputY() {
    AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
  }
  virtual void AddOpOutput() {
    AddOutput("Out",
              "N-dimension tensor. A location into which the result is stored. "
              "It's dimension "
              "equals with x");
  }
  virtual void AddOpComment() { AddComment(GetCommentExamples()); }

  virtual std::string GetOpFuntionality() const { return ""; }

  virtual std::string GetName() const = 0;
  virtual std::string GetEquation() const = 0;

  std::string GetCommentExamples() const {
    return string::Sprintf(R"DOC(
Elementwise %s Operator.

%s
K
kexinzhao 已提交
288 289 290

The equation is:

Y
Yu Yang 已提交
291
$$%s$$
K
kexinzhao 已提交
292

293
- $X$: a tensor of any dimension.
L
Luo Tao 已提交
294
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
K
kexinzhao 已提交
295 296

There are two cases for this operator:
297

L
Luo Tao 已提交
298 299
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
K
kexinzhao 已提交
300 301

For case 2:
302

303 304
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
   for broadcasting $Y$ onto $X$.
L
Luo Tao 已提交
305
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
306
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
L
Luo Tao 已提交
307
   subsequence, such as shape(Y) = (2, 1) => (2).
K
kexinzhao 已提交
308

L
Luo Tao 已提交
309
For example:
310

G
gongweibao 已提交
311
  .. code-block:: text
G
gongweibao 已提交
312

313 314
    shape(X) = (2, 3, 4, 5), shape(Y) = (,)
    shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
L
Luo Tao 已提交
315
    shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
316 317
    shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
    shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
318
    shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
319

Y
Yu Yang 已提交
320
)DOC",
321 322 323
                           GetName(),
                           GetOpFuntionality(),
                           GetEquation());
G
gongweibao 已提交
324 325 326 327 328 329 330 331
  }
};

class ElementwiseOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

C
chengduo 已提交
332
  void InferShape(framework::InferShapeContext *ctx) const override {
333
    auto out_grad_name = framework::GradVarName("Out");
334
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOpGrad");
335 336 337
    OP_INOUT_CHECK(ctx->HasInput(out_grad_name),
                   "Input",
                   out_grad_name,
338
                   "ElementwiseOpGrad");
Q
Qiao Longfei 已提交
339 340 341
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
342 343
      ctx->ShareDim("X", /*->*/ x_grad_name);
      ctx->ShareLoD("X", /*->*/ x_grad_name);
G
gongweibao 已提交
344
    }
Q
Qiao Longfei 已提交
345
    if (ctx->HasOutput(y_grad_name)) {
346 347
      ctx->ShareDim("Y", /*->*/ y_grad_name);
      ctx->ShareLoD("Y", /*->*/ y_grad_name);
G
gongweibao 已提交
348 349
    }
  }
350 351

  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
352
      const framework::ExecutionContext &ctx) const override {
353 354
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
355 356

#ifdef PADDLE_WITH_MKLDNN
357
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
358 359
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
360 361 362 363 364 365
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
C
chentianyu03 已提交
366 367

  framework::OpKernelType GetKernelTypeForVar(
368 369
      const std::string &var_name,
      const framework::Tensor &tensor,
C
chentianyu03 已提交
370 371 372
      const framework::OpKernelType &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
373
      return framework::OpKernelType(
374 375
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
376
          tensor.layout());
C
chentianyu03 已提交
377
    } else {
378 379
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
C
chentianyu03 已提交
380 381
    }
  }
G
gongweibao 已提交
382
};
383

384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->ShareDim("X", x_grad_name);
      ctx->ShareLoD("X", x_grad_name);
    }
    if (ctx->HasOutput(y_grad_name)) {
      ctx->ShareDim("Y", y_grad_name);
      ctx->ShareLoD("Y", y_grad_name);
    }
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
408
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
409 410

#ifdef PADDLE_WITH_MKLDNN
411
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
412 413
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
414 415 416 417 418 419
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
C
chentianyu03 已提交
420 421

  framework::OpKernelType GetKernelTypeForVar(
422 423
      const std::string &var_name,
      const framework::Tensor &tensor,
C
chentianyu03 已提交
424 425 426
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
427
      return framework::OpKernelType(
428 429
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
430
          tensor.layout());
C
chentianyu03 已提交
431
    } else {
432 433
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
C
chentianyu03 已提交
434 435
    }
  }
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
};

class ElementwiseOpDoubleGradWithoutDXDY
    : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
453 454
    framework::proto::VarType::Type input_data_type;
    if (ctx.HasInput("DDX") == false) {
455 456 457
      OP_INOUT_CHECK(ctx.HasInput("DDY"),
                     "Input",
                     "DDY",
458
                     "ElementwiseOpDoubleGradWithoutDXDY");
459
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
460
    } else if (ctx.HasInput("DDY") == false) {
461 462 463
      OP_INOUT_CHECK(ctx.HasInput("DDX"),
                     "Input",
                     "DDX",
464
                     "ElementwiseOpDoubleGradWithoutDXDY");
465
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
466
    } else {
467 468
      input_data_type =
          OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
469
    }
470 471

#ifdef PADDLE_WITH_MKLDNN
472
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
473 474
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
475 476 477 478 479 480
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
481 482

  framework::OpKernelType GetKernelTypeForVar(
483 484
      const std::string &var_name,
      const framework::Tensor &tensor,
485 486 487
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
488
      return framework::OpKernelType(
489 490
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
491
          tensor.layout());
492
    } else {
493 494
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
495 496
    }
  }
497 498
};

499 500 501 502 503 504 505 506 507 508 509 510 511 512
class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("D_DDX")) {
      ctx->ShareDim("DDX", "D_DDX");
      ctx->ShareLoD("DDX", "D_DDX");
    }
    if (ctx->HasOutput("D_DDY")) {
      ctx->ShareDim("DDY", "D_DDY");
      ctx->ShareLoD("DDY", "D_DDY");
    }
513 514 515 516 517 518 519 520 521 522 523 524
    if (ctx->HasOutput("D_X")) {
      ctx->ShareDim("X", "D_X");
      ctx->ShareLoD("X", "D_X");
    }
    if (ctx->HasOutput("D_Y")) {
      ctx->ShareDim("Y", "D_Y");
      ctx->ShareLoD("Y", "D_Y");
    }
    if (ctx->HasOutput("D_DOut")) {
      ctx->ShareDim("DOut", "D_DOut");
      ctx->ShareLoD("DOut", "D_DOut");
    }
525 526 527 528 529
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    framework::proto::VarType::Type input_data_type;
530
    input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut");
531 532 533

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
534 535
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
536 537 538 539 540 541 542 543
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }

  framework::OpKernelType GetKernelTypeForVar(
544 545
      const std::string &var_name,
      const framework::Tensor &tensor,
546 547 548
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
549
      return framework::OpKernelType(
550 551
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
552
          tensor.layout());
553
    } else {
554 555
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
556 557 558 559
    }
  }
};

560 561 562
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
 public:
C
chengduo 已提交
563 564
  void Compute(const framework::ExecutionContext &context) const override {
    auto *dx =
565
        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
566 567
    auto &dout =
        *context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
568
    phi::funcs::ElementwiseGradPreProcess(dout, dx);
569 570 571
  }
};

572 573
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
574 575
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
576 577
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
                           {"DDX", "DDOut"});
D
dzhwinter 已提交
578

579 580 581
DECLARE_INPLACE_OP_INFERER(ElementwiseTripleGradOpInplaceInferer,
                           {"D_DDOut", "D_DDX"});

582
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
583 584
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer,
                                    "Y",
585
                                    "DOut");
586
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseTripleGradNoBufVarsInferer,
587 588
                                    "DDX",
                                    "DDY");
S
sneaxiy 已提交
589

G
gongweibao 已提交
590 591
}  // namespace operators
}  // namespace paddle
H
hong 已提交
592 593 594 595 596 597 598 599
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name)              \
  template <typename T>                                                 \
  class kernel_type##GradMaker                                          \
      : public paddle::framework::SingleGradOpMaker<T> {                \
   public:                                                              \
    using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \
                                                                        \
   protected:                                                           \
600
    void Apply(::paddle::framework::GradOpPtr<T> op) const override {   \
H
hong 已提交
601
      op->SetType(#kernel_type "_grad");                                \
602
      op->SetInput("X", this->Input("X"));                              \
H
hong 已提交
603 604 605 606 607 608 609 610 611
      op->SetInput("Y", this->Input("Y"));                              \
      op->SetInput(::paddle::framework::GradVarName("Out"),             \
                   this->OutputGrad("Out"));                            \
      op->SetAttrMap(this->Attrs());                                    \
      op->SetOutput(::paddle::framework::GradVarName("X"),              \
                    this->InputGrad("X"));                              \
      op->SetOutput(::paddle::framework::GradVarName("Y"),              \
                    this->InputGrad("Y"));                              \
    }                                                                   \
612 613
  }

614
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name)    \
615 616
  REGISTER_OPERATOR(op_type,                                            \
                    ::paddle::operators::ElementwiseOp,                 \
617 618
                    ::paddle::operators::Elementwise##op_name##OpMaker, \
                    ::paddle::operators::ElementwiseOpInferVarType,     \
H
hong 已提交
619 620
                    op_type##GradMaker<::paddle::framework::OpDesc>,    \
                    op_type##GradMaker<::paddle::imperative::OpBase>,   \
621
                    ::paddle::operators::ElementwiseOpInplaceInferer);