elementwise_op.h 19.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14 15

#pragma once
C
chengduo 已提交
16

17
#include <algorithm>  // for max
L
liuwei1031 已提交
18
#include <memory>
19
#include <string>
L
liuwei1031 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/data_layout.h"
24
#include "paddle/fluid/framework/op_version_registry.h"
25
#include "paddle/fluid/operators/common_infer_shape_functions.h"
26
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
C
chengduo 已提交
27

28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
G
gongweibao 已提交
31 32 33 34 35 36 37 38

namespace paddle {
namespace operators {

class ElementwiseOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

C
chengduo 已提交
39
  void InferShape(framework::InferShapeContext *ctx) const override {
40 41 42 43
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOp");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ElementwiseOp");

44 45 46 47 48 49 50 51
    PADDLE_ENFORCE_EQ(
        ctx->GetInputsVarType("Y").front(),
        framework::proto::VarType::LOD_TENSOR,
        platform::errors::InvalidArgument(
            "The input var's type should be phi::DenseTensor, but the "
            "received is %s [%s].",
            ctx->GetInputsVarType("Y").front(),
            ctx->Inputs("Y").front()));
C
chengduo 已提交
52 53

    if (ctx->GetInputsVarType("X").front() ==
54
        framework::proto::VarType::SELECTED_ROWS) {
55
      PADDLE_ENFORCE_EQ(
56 57
          ctx->GetInputDim("Y").size(),
          1u,
58 59 60 61 62
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the size of Y should be 1. "
              "But reveived the size of Y = %s.",
              ctx->GetInputDim("Y").size()));
63
      PADDLE_ENFORCE_EQ(
64 65
          ctx->GetInputDim("Y")[0],
          1,
66 67 68 69 70
          platform::errors::InvalidArgument(
              "For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
              "), Y must be scalar, the first dimension of Y should be 1. "
              "But reveived the first dimension of Y = %s.",
              ctx->GetInputDim("Y")[0]));
71 72
    } else if (ctx->GetInputsVarType("X").front() !=
               framework::proto::VarType::LOD_TENSOR) {
73 74 75 76
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Input X's type[%s] is not supported by elementwise_op. Please set "
          "its type to LOD_TENSOR.",
          ctx->GetInputsVarType("X").front()));
C
chengduo 已提交
77
    }
78

79 80 81 82 83 84 85 86
    if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
      ctx->ShareDim("X", /*->*/ "Out");
      ctx->ShareLoD("X", /*->*/ "Out");
    } else {
      auto x_dims = ctx->GetInputDim("X");
      auto y_dims = ctx->GetInputDim("Y");
      int max_dim = std::max(x_dims.size(), y_dims.size());
      int axis = ctx->Attrs().Get<int>("axis");
87
      if (x_dims.size() == y_dims.size()) {
88 89
        PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
                          true,
90 91 92 93
                          platform::errors::InvalidArgument(
                              "axis should be -1 or 0 while the dimension of "
                              "tensor X (%s) is equal to the dimension of "
                              "tensor Y (%s), but received axis: %s",
94 95 96
                              x_dims.size(),
                              y_dims.size(),
                              axis));
97
      }
98 99
      PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
                        true,
100 101 102
                        platform::errors::InvalidArgument(
                            "The axis range must be [%s, %s), but axis is %s. "
                            "Please set the axis again.",
103 104 105
                            -1 * max_dim,
                            max_dim,
                            axis));
106 107
      axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
                       : axis);
108 109 110
      std::vector<int> x_dims_array(max_dim);
      std::vector<int> y_dims_array(max_dim);
      std::vector<int> out_dims_array(max_dim);
J
Jacek Czaja 已提交
111
#ifdef PADDLE_WITH_MKLDNN
112
      // Broadcasting of dims has to be done on Paddle shapes (NHWC)
113
      // if model is using NHWC and any of shapes in at least 3D
J
Jacek Czaja 已提交
114 115
      bool should_rotate =
          ctx->IsRunMKLDNNKernel() &&
116
          (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
117
           phi::DataLayout::kNHWC) &&
118
          (x_dims.size() >= 3 || y_dims.size() >= 3);
J
Jacek Czaja 已提交
119 120 121
      if (should_rotate) {
        // Pick bigger shape and rotate this one
        bool x_over_y = (x_dims.size() > y_dims.size());
122 123
        auto vdims = x_over_y ? phi::vectorize<int>(x_dims)
                              : phi::vectorize<int>(y_dims);
J
Jacek Czaja 已提交
124 125
        std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());
        if (x_over_y) {
126
          x_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
127
        } else {
128
          y_dims = phi::make_ddim(vdims);
J
Jacek Czaja 已提交
129 130 131 132
        }
      }
#endif

133 134 135 136 137 138 139
      GetBroadcastDimsArrays(x_dims,
                             y_dims,
                             x_dims_array.data(),
                             y_dims_array.data(),
                             out_dims_array.data(),
                             max_dim,
                             axis);
J
Jacek Czaja 已提交
140 141 142
#ifdef PADDLE_WITH_MKLDNN
      // Now rotate shape back if needed (NHWC -> NCHW)
      if (should_rotate) {
143 144
        std::rotate(out_dims_array.begin() + 1,
                    out_dims_array.end() - 1,
J
Jacek Czaja 已提交
145 146 147
                    out_dims_array.end());
      }
#endif
148
      ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
149 150 151
      // to do
      ctx->ShareLoD("X", /*->*/ "Out");
    }
G
gongweibao 已提交
152
  }
153

154
  phi::KernelKey GetExpectedKernelType(
C
chengduo 已提交
155
      const framework::ExecutionContext &ctx) const override {
156 157
    auto input_data_type =
        OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
158
    return phi::KernelKey(input_data_type, ctx.GetPlace());
159
  }
160

161
  phi::KernelKey GetKernelTypeForVar(
162
      const std::string &var_name UNUSED,
163
      const phi::DenseTensor &tensor,
164 165
      const phi::KernelKey &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.dtype())) {
166
      // only promote inputs’s types when contains complex input
167
      return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
168
    } else {
J
Jacek Czaja 已提交
169 170 171 172
#ifdef PADDLE_WITH_MKLDNN
      // When elementwise is first oneDNN op (there was some non oneDNN op
      // previously)
      // then we also need to rotate shape NHWC -> NCWH
173
      if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
174
          (tensor.layout() != phi::DataLayout::ONEDNN) &&
175 176
          phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
              phi::DataLayout::kNHWC) {
177 178 179
        return phi::KernelKey(tensor.place(),
                              phi::DataLayout::kNHWC,
                              expected_kernel_type.dtype());
J
Jacek Czaja 已提交
180 181
      }
#endif
182 183
      return phi::KernelKey(
          tensor.place(), tensor.layout(), expected_kernel_type.dtype());
184 185
    }
  }
G
gongweibao 已提交
186 187
};

C
chengduo 已提交
188 189 190
class ElementwiseOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
191
  std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
C
chengduo 已提交
192
      const override {
193 194
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
195 196 197
  }
};

G
gongweibao 已提交
198 199
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
200
  void Make() final {
201 202 203
    AddInputX();
    AddInputY();
    AddOpOutput();
G
gongweibao 已提交
204
    AddAttr<int>("axis",
205 206 207 208
                 "(int, default -1). If X.dimension != Y.dimension,"
                 "Y.dimension must be a subsequence of x.dimension. And axis "
                 "is the start dimension index "
                 "for broadcasting Y onto X. ")
209
        .SetDefault(-1);
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    AddOpComment();
  }

 protected:
  virtual void AddInputX() {
    AddInput("X", "(Tensor), The first input tensor of elementwise op.");
  }
  virtual void AddInputY() {
    AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
  }
  virtual void AddOpOutput() {
    AddOutput("Out",
              "N-dimension tensor. A location into which the result is stored. "
              "It's dimension "
              "equals with x");
  }
  virtual void AddOpComment() { AddComment(GetCommentExamples()); }

C
co63oc 已提交
228
  virtual std::string GetOpFunctionality() const { return ""; }
229 230 231 232 233 234 235 236 237

  virtual std::string GetName() const = 0;
  virtual std::string GetEquation() const = 0;

  std::string GetCommentExamples() const {
    return string::Sprintf(R"DOC(
Elementwise %s Operator.

%s
K
kexinzhao 已提交
238 239 240

The equation is:

Y
Yu Yang 已提交
241
$$%s$$
K
kexinzhao 已提交
242

243
- $X$: a tensor of any dimension.
L
Luo Tao 已提交
244
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
K
kexinzhao 已提交
245 246

There are two cases for this operator:
247

L
Luo Tao 已提交
248 249
1. The shape of $Y$ is the same with $X$.
2. The shape of $Y$ is a continuous subsequence of $X$.
K
kexinzhao 已提交
250 251

For case 2:
252

253 254
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
   for broadcasting $Y$ onto $X$.
L
Luo Tao 已提交
255
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
256
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
L
Luo Tao 已提交
257
   subsequence, such as shape(Y) = (2, 1) => (2).
K
kexinzhao 已提交
258

L
Luo Tao 已提交
259
For example:
260

G
gongweibao 已提交
261
  .. code-block:: text
G
gongweibao 已提交
262

263 264
    shape(X) = (2, 3, 4, 5), shape(Y) = (,)
    shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
L
Luo Tao 已提交
265
    shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
266 267
    shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
    shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
268
    shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
269

Y
Yu Yang 已提交
270
)DOC",
271
                           GetName(),
C
co63oc 已提交
272
                           GetOpFunctionality(),
273
                           GetEquation());
G
gongweibao 已提交
274 275 276 277 278 279 280
  }
};

class ElementwiseOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

C
chengduo 已提交
281
  void InferShape(framework::InferShapeContext *ctx) const override {
282
    auto out_grad_name = framework::GradVarName("Out");
283
    OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ElementwiseOpGrad");
284 285 286
    OP_INOUT_CHECK(ctx->HasInput(out_grad_name),
                   "Input",
                   out_grad_name,
287
                   "ElementwiseOpGrad");
Q
Qiao Longfei 已提交
288 289 290
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
291 292
      ctx->ShareDim("X", /*->*/ x_grad_name);
      ctx->ShareLoD("X", /*->*/ x_grad_name);
G
gongweibao 已提交
293
    }
Q
Qiao Longfei 已提交
294
    if (ctx->HasOutput(y_grad_name)) {
295 296
      ctx->ShareDim("Y", /*->*/ y_grad_name);
      ctx->ShareLoD("Y", /*->*/ y_grad_name);
G
gongweibao 已提交
297 298
    }
  }
299

300
  phi::KernelKey GetExpectedKernelType(
C
chengduo 已提交
301
      const framework::ExecutionContext &ctx) const override {
302 303
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
304
    return phi::KernelKey(input_data_type, ctx.GetPlace());
305
  }
C
chentianyu03 已提交
306

307
  phi::KernelKey GetKernelTypeForVar(
308
      const std::string &var_name UNUSED,
309
      const phi::DenseTensor &tensor,
310 311
      const phi::KernelKey &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.dtype())) {
C
chentianyu03 已提交
312
      // only promote inputs’s types when contains complex input
313
      return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
C
chentianyu03 已提交
314
    } else {
315 316
      return phi::KernelKey(
          tensor.place(), tensor.layout(), expected_kernel_type.dtype());
C
chentianyu03 已提交
317 318
    }
  }
G
gongweibao 已提交
319
};
320

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    auto x_grad_name = framework::GradVarName("X");
    auto y_grad_name = framework::GradVarName("Y");
    if (ctx->HasOutput(x_grad_name)) {
      ctx->ShareDim("X", x_grad_name);
      ctx->ShareLoD("X", x_grad_name);
    }
    if (ctx->HasOutput(y_grad_name)) {
      ctx->ShareDim("Y", y_grad_name);
      ctx->ShareLoD("Y", y_grad_name);
    }
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

342
  phi::KernelKey GetExpectedKernelType(
343
      const framework::ExecutionContext &ctx) const override {
344
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
345
    return phi::KernelKey(input_data_type, ctx.GetPlace());
346
  }
C
chentianyu03 已提交
347

348
  phi::KernelKey GetKernelTypeForVar(
349
      const std::string &var_name UNUSED,
350
      const phi::DenseTensor &tensor,
351 352
      const phi::KernelKey &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.dtype())) {
C
chentianyu03 已提交
353
      // only promote inputs’s types when contains complex input
354
      return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
C
chentianyu03 已提交
355
    } else {
356 357
      return phi::KernelKey(
          tensor.place(), tensor.layout(), expected_kernel_type.dtype());
C
chentianyu03 已提交
358 359
    }
  }
360 361 362 363 364 365 366 367 368 369 370 371 372 373
};

class ElementwiseOpDoubleGradWithoutDXDY
    : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("DDOut")) {
      ctx->ShareDim("DOut", "DDOut");
      ctx->ShareLoD("DOut", "DDOut");
    }
  }

374
  phi::KernelKey GetExpectedKernelType(
375
      const framework::ExecutionContext &ctx) const override {
376 377
    framework::proto::VarType::Type input_data_type;
    if (ctx.HasInput("DDX") == false) {
378 379 380
      OP_INOUT_CHECK(ctx.HasInput("DDY"),
                     "Input",
                     "DDY",
381
                     "ElementwiseOpDoubleGradWithoutDXDY");
382
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
383
    } else if (ctx.HasInput("DDY") == false) {
384 385 386
      OP_INOUT_CHECK(ctx.HasInput("DDX"),
                     "Input",
                     "DDX",
387
                     "ElementwiseOpDoubleGradWithoutDXDY");
388
      input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
389
    } else {
390 391
      input_data_type =
          OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
392
    }
393
    return phi::KernelKey(input_data_type, ctx.GetPlace());
394
  }
395

396
  phi::KernelKey GetKernelTypeForVar(
397
      const std::string &var_name UNUSED,
398
      const phi::DenseTensor &tensor,
399 400
      const phi::KernelKey &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.dtype())) {
401
      // only promote inputs’s types when contains complex input
402
      return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
403
    } else {
404 405
      return phi::KernelKey(
          tensor.place(), tensor.layout(), expected_kernel_type.dtype());
406 407
    }
  }
408 409
};

410 411 412 413 414 415 416 417 418 419 420 421 422
class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    if (ctx->HasOutput("D_DDX")) {
      ctx->ShareDim("DDX", "D_DDX");
      ctx->ShareLoD("DDX", "D_DDX");
    }
    if (ctx->HasOutput("D_DDY")) {
      ctx->ShareDim("DDY", "D_DDY");
      ctx->ShareLoD("DDY", "D_DDY");
    }
423 424 425 426 427 428 429 430 431 432 433 434
    if (ctx->HasOutput("D_X")) {
      ctx->ShareDim("X", "D_X");
      ctx->ShareLoD("X", "D_X");
    }
    if (ctx->HasOutput("D_Y")) {
      ctx->ShareDim("Y", "D_Y");
      ctx->ShareLoD("Y", "D_Y");
    }
    if (ctx->HasOutput("D_DOut")) {
      ctx->ShareDim("DOut", "D_DOut");
      ctx->ShareLoD("DOut", "D_DOut");
    }
435 436
  }

437
  phi::KernelKey GetExpectedKernelType(
438 439
      const framework::ExecutionContext &ctx) const override {
    framework::proto::VarType::Type input_data_type;
440
    input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut");
441
    return phi::KernelKey(input_data_type, ctx.GetPlace());
442 443
  }

444
  phi::KernelKey GetKernelTypeForVar(
445
      const std::string &var_name UNUSED,
446
      const phi::DenseTensor &tensor,
447 448
      const phi::KernelKey &expected_kernel_type) const override {
    if (framework::IsComplexType(expected_kernel_type.dtype())) {
449
      // only promote inputs’s types when contains complex input
450
      return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
451
    } else {
452 453
      return phi::KernelKey(
          tensor.place(), tensor.layout(), expected_kernel_type.dtype());
454 455 456 457
    }
  }
};

458 459 460
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
 public:
C
chengduo 已提交
461
  void Compute(const framework::ExecutionContext &context) const override {
462
    auto *dx = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
463
    auto &dout =
464
        *context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
465
    phi::funcs::ElementwiseGradPreProcess(dout, dx);
466 467 468
  }
};

469 470
DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
471 472
                           {framework::GradVarName("Out"),
                            framework::GradVarName("X")});
473 474
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
                           {"DDX", "DDOut"});
D
dzhwinter 已提交
475

476 477 478
DECLARE_INPLACE_OP_INFERER(ElementwiseTripleGradOpInplaceInferer,
                           {"D_DDOut", "D_DDX"});

479
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
480 481
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer,
                                    "Y",
482
                                    "DOut");
483
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseTripleGradNoBufVarsInferer,
484 485
                                    "DDX",
                                    "DDY");
S
sneaxiy 已提交
486

G
gongweibao 已提交
487 488
}  // namespace operators
}  // namespace paddle
H
hong 已提交
489 490 491 492 493 494 495 496
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name)              \
  template <typename T>                                                 \
  class kernel_type##GradMaker                                          \
      : public paddle::framework::SingleGradOpMaker<T> {                \
   public:                                                              \
    using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \
                                                                        \
   protected:                                                           \
497
    void Apply(::paddle::framework::GradOpPtr<T> op) const override {   \
H
hong 已提交
498
      op->SetType(#kernel_type "_grad");                                \
499
      op->SetInput("X", this->Input("X"));                              \
H
hong 已提交
500 501 502 503 504 505 506 507 508
      op->SetInput("Y", this->Input("Y"));                              \
      op->SetInput(::paddle::framework::GradVarName("Out"),             \
                   this->OutputGrad("Out"));                            \
      op->SetAttrMap(this->Attrs());                                    \
      op->SetOutput(::paddle::framework::GradVarName("X"),              \
                    this->InputGrad("X"));                              \
      op->SetOutput(::paddle::framework::GradVarName("Y"),              \
                    this->InputGrad("Y"));                              \
    }                                                                   \
509 510
  }

511
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name)    \
512 513
  REGISTER_OPERATOR(op_type,                                            \
                    ::paddle::operators::ElementwiseOp,                 \
514 515
                    ::paddle::operators::Elementwise##op_name##OpMaker, \
                    ::paddle::operators::ElementwiseOpInferVarType,     \
H
hong 已提交
516 517
                    op_type##GradMaker<::paddle::framework::OpDesc>,    \
                    op_type##GradMaker<::paddle::imperative::OpBase>,   \
518
                    ::paddle::operators::ElementwiseOpInplaceInferer);