elementwise_op.h 22.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14 15

#pragma once
C
chengduo 已提交
16

17
#include <algorithm>  // for max
L
liuwei1031 已提交
18
#include <memory>
19
#include <string>
L
liuwei1031 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/data_layout.h"
24
#include "paddle/fluid/framework/op_version_registry.h"
25
#include "paddle/fluid/operators/common_infer_shape_functions.h"
26
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
C
chengduo 已提交
27

28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
G
gongweibao 已提交
31 32 33 34 35 36 37 38 39

namespace paddle {
namespace operators {

class ElementwiseOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  using Tensor = framework::Tensor;
C
chengduo 已提交
40 41

  void InferShape(framework::InferShapeContext *ctx) const override {
42 43 44 45
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ElementwiseOp");

46 47 48 49 50 51 52
    PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Y").front(),
                      framework::proto::VarType::LOD_TENSOR,
                      platform::errors::InvalidArgument(
                          "The input var's type should be LoDTensor, but the "
                          "received is %s [%s].",
                          ctx->GetInputsVarType("Y").front(),
                          ctx->Inputs("Y").front()));
C
chengduo 已提交
53 54

    if (ctx->GetInputsVarType("X").front() ==
55
        framework::proto::VarType::SELECTED_ROWS) {
56
      PADDLE_ENFORCE_EQ(
57 58
          ctx->GetInputDim("Y").size(),
          1u,
59 60 61 62 63
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the size of Y should be 1. "
              "But reveived the size of Y = %s.",
              ctx->GetInputDim("Y").size()));
64
      PADDLE_ENFORCE_EQ(
65 66
          ctx->GetInputDim("Y")[0],
          1,
67 68 69 70 71
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the first dimension of Y should be 1. "
              "But reveived the first dimension of Y = %s.",
              ctx->GetInputDim("Y")[0]));
72 73
    } else if (ctx->GetInputsVarType("X").front() !=
               framework::proto::VarType::LOD_TENSOR) {
74 75 76 77
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Input X's type[%s] is not supported by elementwise_op. Please set "
          "its type to LOD_TENSOR.",
          ctx->GetInputsVarType("X").front()));
C
chengduo 已提交
78
    }
79

80 81 82 83 84 85 86 87
    if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
      ctx->ShareDim("X", /*->*/ "Out");
      ctx->ShareLoD("X", /*->*/ "Out");
    } else {
      auto x_dims = ctx->GetInputDim("X");
      auto y_dims = ctx->GetInputDim("Y");
      int max_dim = std::max(x_dims.size(), y_dims.size());
      int axis = ctx->Attrs().Get<int>("axis");
88
      if (x_dims.size() == y_dims.size()) {
89 90
        PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
                          true,
91 92 93 94
                          platform::errors::InvalidArgument(
                              "axis should be -1 or 0 while the dimension of "
                              "tensor X (%s) is equal to the dimension of "
                              "tensor Y (%s), but received axis: %s",
95 96 97
                              x_dims.size(),
                              y_dims.size(),
                              axis));
98
      }
99 100
      PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
                        true,
101 102 103
                        platform::errors::InvalidArgument(
                            "The axis range must be [%s, %s), but axis is %s. "
                            "Please set the axis again.",
104 105 106
                            -1 * max_dim,
                            max_dim,
                            axis));
107 108
      axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
                       : axis);
109 110 111
      std::vector<int> x_dims_array(max_dim);
      std::vector<int> y_dims_array(max_dim);
      std::vector<int> out_dims_array(max_dim);
J
Jacek Czaja 已提交
112 113
#ifdef PADDLE_WITH_MKLDNN
      // (jczaja): Broadcasting of dims has to be done on Paddle shapes (NHWC)
114
      // if model is using NHWC and any of shapes in at least 3D
J
Jacek Czaja 已提交
115 116 117
      bool should_rotate =
          ctx->IsRunMKLDNNKernel() &&
          (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
118 119
           framework::DataLayout::kNHWC) &&
          (x_dims.size() >= 3 || y_dims.size() >= 3);
J
Jacek Czaja 已提交
120 121 122
      if (should_rotate) {
        // Pick bigger shape and rotate this one
        bool x_over_y = (x_dims.size() > y_dims.size());
123 124
        auto vdims = x_over_y ? phi::vectorize<int>(x_dims)
                              : phi::vectorize<int>(y_dims);
J
Jacek Czaja 已提交
125 126
        std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
        if (x_over_y) {
127
          x_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
128
        } else {
129
          y_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
130 131 132 133
        }
      }
#endif

134 135 136 137 138 139 140
      GetBroadcastDimsArrays(x_dims,
                             y_dims,
                             x_dims_array.data(),
                             y_dims_array.data(),
                             out_dims_array.data(),
                             max_dim,
                             axis);
J
Jacek Czaja 已提交
141 142 143
#ifdef PADDLE_WITH_MKLDNN
      // Now rotate shape back if needed (NHWC -> NCHW)
      if (should_rotate) {
144 145
        std::rotate(out_dims_array.begin() + 1,
                    out_dims_array.end() - 1,
J
Jacek Czaja 已提交
146 147 148
                    out_dims_array.end());
      }
#endif
149
      ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
150 151 152
      // to do
      ctx->ShareLoD("X", /*->*/ "Out");
    }
G
gongweibao 已提交
153
  }
154 155

  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
156
      const framework::ExecutionContext &ctx) const override {
157 158
    auto input_data_type =
        OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
159 160

#ifdef PADDLE_WITH_MKLDNN
161
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
162 163
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
164 165 166 167 168 169
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
170 171

  framework::OpKernelType GetKernelTypeForVar(
172 173
      const std::string &var_name,
      const framework::Tensor &tensor,
174
      const framework::OpKernelType &expected_kernel_type) const override {
175 176
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
177
      return framework::OpKernelType(
178 179
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
180
          tensor.layout());
181
    } else {
J
Jacek Czaja 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
#ifdef PADDLE_WITH_MKLDNN
      // When elementwise is first oneDNN op (there was some non oneDNN op
      // previously)
      // then we also need to rotate shape NHWC -> NCWH
      if ((expected_kernel_type.data_layout_ ==
           framework::DataLayout::kMKLDNN) &&
          (tensor.layout() != framework::DataLayout::kMKLDNN) &&
          paddle::platform::MKLDNNDeviceContext::tls()
                  .get_cur_paddle_data_layout() ==
              framework::DataLayout::kNHWC) {
        return framework::OpKernelType(expected_kernel_type.data_type_,
                                       tensor.place(),
                                       framework::DataLayout::kNHWC);
      }
#endif
197 198
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
199 200
    }
  }
G
gongweibao 已提交
201 202
};

C
chengduo 已提交
203 204 205
class ElementwiseOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
206
  std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
C
chengduo 已提交
207
      const override {
208 209
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
210 211 212
  }
};

G
gongweibao 已提交
213 214
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
215
  void Make() final {
216 217 218
    AddInputX();
    AddInputY();
    AddOpOutput();
G
gongweibao 已提交
219
    AddAttr<int>("axis",
220 221 222 223
                 "(int, default -1). If X.dimension != Y.dimension,"
                 "Y.dimension must be a subsequence of x.dimension. And axis "
                 "is the start dimension index "
                 "for broadcasting Y onto X. ")
224
        .SetDefault(-1);
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    AddOpComment();
  }

 protected:
  virtual void AddInputX() {
    AddInput("X", "(Tensor), The first input tensor of elementwise op.");
  }
  virtual void AddInputY() {
    AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
  }
  virtual void AddOpOutput() {
    AddOutput("Out",
              "N-dimension tensor. A location into which the result is stored. "
              "It's dimension "
              "equals with x");
  }
  virtual void AddOpComment() { AddComment(GetCommentExamples()); }

  virtual std::string GetOpFuntionality() const { return ""; }

  virtual std::string GetName() const = 0;
  virtual std::string GetEquation() const = 0;

  std::string GetCommentExamples() const {
    return string::Sprintf(R"DOC(
Elementwise %s Operator.

%s
K
kexinzhao 已提交
253 254 255

The equation is:

Y
Yu Yang 已提交
256
$$%s$$
K
kexinzhao 已提交
257

258
- $X$: a tensor of any dimension.
L
Luo Tao 已提交
259
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
K
kexinzhao 已提交
260 261

There are two cases for this operator:
262

L
Luo Tao 已提交
263 264
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
K
kexinzhao 已提交
265 266

For case 2:
267

268 269
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
   for broadcasting $Y$ onto $X$.
L
Luo Tao 已提交
270
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
271
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
L
Luo Tao 已提交
272
   subsequence, such as shape(Y) = (2, 1) => (2).
K
kexinzhao 已提交
273

L
Luo Tao 已提交
274
For example:
275

G
gongweibao 已提交
276
  .. code-block:: text
G
gongweibao 已提交
277

278 279
    shape(X) = (2, 3, 4, 5), shape(Y) = (,)
    shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
L
Luo Tao 已提交
280
    shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
281 282
    shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
    shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
283
    shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
284

Y
Yu Yang 已提交
285
)DOC",
286 287 288
                           GetName(),
                           GetOpFuntionality(),
                           GetEquation());
G
gongweibao 已提交
289 290 291 292 293 294 295 296
  }
};

class ElementwiseOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

C
chengduo 已提交
297
  void InferShape(framework::InferShapeContext *ctx) const override {
298
    auto out_grad_name = framework::GradVarName("Out");
299
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOpGrad");
300 301 302
    OP_INOUT_CHECK(ctx->HasInput(out_grad_name),
                   "Input",
                   out_grad_name,
303
                   "ElementwiseOpGrad");
Q
Qiao Longfei 已提交
304 305 306
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
307 308
      ctx->ShareDim("X", /*->*/ x_grad_name);
      ctx->ShareLoD("X", /*->*/ x_grad_name);
G
gongweibao 已提交
309
    }
Q
Qiao Longfei 已提交
310
    if (ctx->HasOutput(y_grad_name)) {
311 312
      ctx->ShareDim("Y", /*->*/ y_grad_name);
      ctx->ShareLoD("Y", /*->*/ y_grad_name);
G
gongweibao 已提交
313 314
    }
  }
315 316

  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
317
      const framework::ExecutionContext &ctx) const override {
318 319
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
320 321

#ifdef PADDLE_WITH_MKLDNN
322
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
323 324
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
325 326 327 328 329 330
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
C
chentianyu03 已提交
331 332

  framework::OpKernelType GetKernelTypeForVar(
333 334
      const std::string &var_name,
      const framework::Tensor &tensor,
C
chentianyu03 已提交
335 336 337
      const framework::OpKernelType &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
338
      return framework::OpKernelType(
339 340
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
341
          tensor.layout());
C
chentianyu03 已提交
342
    } else {
343 344
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
C
chentianyu03 已提交
345 346
    }
  }
G
gongweibao 已提交
347
};
348

349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->ShareDim("X", x_grad_name);
      ctx->ShareLoD("X", x_grad_name);
    }
    if (ctx->HasOutput(y_grad_name)) {
      ctx->ShareDim("Y", y_grad_name);
      ctx->ShareLoD("Y", y_grad_name);
    }
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
373
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
374 375

#ifdef PADDLE_WITH_MKLDNN
376
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
377 378
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
379 380 381 382 383 384
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
C
chentianyu03 已提交
385 386

  framework::OpKernelType GetKernelTypeForVar(
387 388
      const std::string &var_name,
      const framework::Tensor &tensor,
C
chentianyu03 已提交
389 390 391
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
392
      return framework::OpKernelType(
393 394
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
395
          tensor.layout());
C
chentianyu03 已提交
396
    } else {
397 398
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
C
chentianyu03 已提交
399 400
    }
  }
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
};

class ElementwiseOpDoubleGradWithoutDXDY
    : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
418 419
    framework::proto::VarType::Type input_data_type;
    if (ctx.HasInput("DDX") == false) {
420 421 422
      OP_INOUT_CHECK(ctx.HasInput("DDY"),
                     "Input",
                     "DDY",
423
                     "ElementwiseOpDoubleGradWithoutDXDY");
424
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
425
    } else if (ctx.HasInput("DDY") == false) {
426 427 428
      OP_INOUT_CHECK(ctx.HasInput("DDX"),
                     "Input",
                     "DDX",
429
                     "ElementwiseOpDoubleGradWithoutDXDY");
430
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
431
    } else {
432 433
      input_data_type =
          OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
434
    }
435 436

#ifdef PADDLE_WITH_MKLDNN
437
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
438 439
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
440 441 442 443 444 445
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
446 447

  framework::OpKernelType GetKernelTypeForVar(
448 449
      const std::string &var_name,
      const framework::Tensor &tensor,
450 451 452
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
453
      return framework::OpKernelType(
454 455
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
456
          tensor.layout());
457
    } else {
458 459
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
460 461
    }
  }
462 463
};

464 465 466 467 468 469 470 471 472 473 474 475 476 477
class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  using Tensor = framework::Tensor;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("D_DDX")) {
      ctx->ShareDim("DDX", "D_DDX");
      ctx->ShareLoD("DDX", "D_DDX");
    }
    if (ctx->HasOutput("D_DDY")) {
      ctx->ShareDim("DDY", "D_DDY");
      ctx->ShareLoD("DDY", "D_DDY");
    }
478 479 480 481 482 483 484 485 486 487 488 489
    if (ctx->HasOutput("D_X")) {
      ctx->ShareDim("X", "D_X");
      ctx->ShareLoD("X", "D_X");
    }
    if (ctx->HasOutput("D_Y")) {
      ctx->ShareDim("Y", "D_Y");
      ctx->ShareLoD("Y", "D_Y");
    }
    if (ctx->HasOutput("D_DOut")) {
      ctx->ShareDim("DOut", "D_DOut");
      ctx->ShareLoD("DOut", "D_DOut");
    }
490 491 492 493 494
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    framework::proto::VarType::Type input_data_type;
495
    input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut");
496 497 498

#ifdef PADDLE_WITH_MKLDNN
    if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
499 500
      return framework::OpKernelType(input_data_type,
                                     ctx.GetPlace(),
501 502 503 504 505 506 507 508
                                     framework::DataLayout::kMKLDNN,
                                     framework::LibraryType::kMKLDNN);
    }
#endif
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }

  framework::OpKernelType GetKernelTypeForVar(
509 510
      const std::string &var_name,
      const framework::Tensor &tensor,
511 512 513
      const framework::OpKernelType &expected_kernel_type) const {
    if (framework::IsComplexType(expected_kernel_type.data_type_)) {
      // only promote inputs’s types when contains complex input
514
      return framework::OpKernelType(
515 516
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
517
          tensor.layout());
518
    } else {
519 520
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
521 522 523 524
    }
  }
};

525 526 527
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
 public:
C
chengduo 已提交
528 529
  void Compute(const framework::ExecutionContext &context) const override {
    auto *dx =
530
        context.Output<framework::LoDTensor>(framework::GradVarName("X"));
531 532
    auto &dout =
        *context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
533
    phi::funcs::ElementwiseGradPreProcess(dout, dx);
534 535 536
  }
};

537 538
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
539 540
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
541 542
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
                           {"DDX", "DDOut"});
D
dzhwinter 已提交
543

544 545 546
DECLARE_INPLACE_OP_INFERER(ElementwiseTripleGradOpInplaceInferer,
                           {"D_DDOut", "D_DDX"});

547
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
548 549
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer,
                                    "Y",
550
                                    "DOut");
551
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseTripleGradNoBufVarsInferer,
552 553
                                    "DDX",
                                    "DDY");
S
sneaxiy 已提交
554

G
gongweibao 已提交
555 556
}  // namespace operators
}  // namespace paddle
H
hong 已提交
557 558 559 560 561 562 563 564
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name)              \
  template <typename T>                                                 \
  class kernel_type##GradMaker                                          \
      : public paddle::framework::SingleGradOpMaker<T> {                \
   public:                                                              \
    using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \
                                                                        \
   protected:                                                           \
565
    void Apply(::paddle::framework::GradOpPtr<T> op) const override {   \
H
hong 已提交
566
      op->SetType(#kernel_type "_grad");                                \
567
      op->SetInput("X", this->Input("X"));                              \
H
hong 已提交
568 569 570 571 572 573 574 575 576
      op->SetInput("Y", this->Input("Y"));                              \
      op->SetInput(::paddle::framework::GradVarName("Out"),             \
                   this->OutputGrad("Out"));                            \
      op->SetAttrMap(this->Attrs());                                    \
      op->SetOutput(::paddle::framework::GradVarName("X"),              \
                    this->InputGrad("X"));                              \
      op->SetOutput(::paddle::framework::GradVarName("Y"),              \
                    this->InputGrad("Y"));                              \
    }                                                                   \
577 578
  }

579
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name)    \
580 581
  REGISTER_OPERATOR(op_type,                                            \
                    ::paddle::operators::ElementwiseOp,                 \
582 583
                    ::paddle::operators::Elementwise##op_name##OpMaker, \
                    ::paddle::operators::ElementwiseOpInferVarType,     \
H
hong 已提交
584 585
                    op_type##GradMaker<::paddle::framework::OpDesc>,    \
                    op_type##GradMaker<::paddle::imperative::OpBase>,   \
586
                    ::paddle::operators::ElementwiseOpInplaceInferer);