fused_elemwise_activation_op.h 20.2 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
W
Wu Yi 已提交
22
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
23
#include "paddle/fluid/operators/math/compound_functors.h"
C
chengduo 已提交
24 25 26 27 28
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

C
chengduo 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * Whether the compound function is Unary(Binary(X, Y)).
 * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
 * out.
 */
bool IsUnaryCompound(const std::vector<std::string> &functor_list);

/**
 *  For the in-place unary functor, the inputs of op_desc only have Out and
 *  Out@Grad.
 */
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);

/**
 * Whether the Input(X) could be absent.
 */
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);

C
chengduo 已提交
47 48
template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
49 50 51 52 53 54 55 56 57 58
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
  paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
59
  int axis = ctx.Attr<int>("axis");
C
chengduo 已提交
60
  if (ctx.Attr<bool>("save_intermediate_out")) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
75 76 77 78
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
79 80 81 82 83 84 85 86
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
87 88
  int axis = ctx.Attr<int>("axis");

89 90
  paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
91

C
chengduo 已提交
92
  if (ctx.Attr<bool>("save_intermediate_out")) {
93 94 95 96 97 98 99 100 101 102 103 104 105 106
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
107 108 109
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
C
chengduo 已提交
110
          typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
C
chengduo 已提交
111 112 113 114 115 116
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
117
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
118
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
119
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
120
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
121 122 123
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
124 125
      paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                                           UnaryFunctor>;
C
chengduo 已提交
126
  using BinaryCompoundDyFunctor =
127
      paddle::operators::math::BinaryCompoundGradDyFunctor<
C
chengduo 已提交
128 129 130 131
          T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
  using BinaryCompoundDIntermedaiteOutFunctor =
      paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
          T, BinaryGradFunctor, UnaryFunctor>;
132 133 134 135

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
136
        BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
137 138
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
139 140
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
141
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
142 143 144
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
145 146 147
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
148
        BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
149 150
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
151 152
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
153
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
154 155 156
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
157
  }
C
chengduo 已提交
158 159 160
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
C
chengduo 已提交
161
          typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
C
chengduo 已提交
162 163 164 165 166 167
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
168
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
169
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
170
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
171
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
172 173 174
  int axis = ctx.Attr<int>("axis");

  using UnaryCompoundDxFunctor =
175
      paddle::operators::math::UnaryCompoundGradDxFunctor<
C
chengduo 已提交
176
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
C
chengduo 已提交
177
  using UnaryCompoundDyFunctor =
178
      paddle::operators::math::UnaryCompoundGradDyFunctor<
C
chengduo 已提交
179 180 181 182
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
  using UnaryCompoundDIntermediateFunctor =
      paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
          T, UnaryGradFunctor, BinaryFunctor, InPlace>;
183 184 185 186

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
C
chengduo 已提交
187 188
        UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
189
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
190 191 192
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
193
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
194 195
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
196
  } else {
C
chengduo 已提交
197 198 199 200
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
201
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
202 203 204
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
205
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
206 207
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
208
  }
C
chengduo 已提交
209 210 211 212
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
213 214 215
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
216
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
217

C
chengduo 已提交
218
  // TODO(zcd): The following code can be refined.
219
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
220 221 222
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
223 224 225 226 227
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
228 229 230
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
231 232 233 234 235
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ScaleFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
236
  } else if (funcs_str == "elementwise_add,relu") {
237 238 239 240 241 242
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ReluFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
243
  } else if (funcs_str == "relu,elementwise_add") {
244 245 246 247 248 249 250 251 252 253 254 255 256 257
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ReluFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
258 259 260 261 262
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

C
chengduo 已提交
263 264 265 266 267 268 269
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
    const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
    const framework::Tensor *in_intermediate_out,
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
C
chengduo 已提交
270 271 272 273 274 275
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
276 277 278 279
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
280 281 282
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
283
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
284 285 286
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
287 288 289 290
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
291 292 293
        ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
294
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
295
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
C
chengduo 已提交
296 297 298 299
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ReluFunctor<T>,
        paddle::operators::math::ReluGradFunctor<T>, InPlace>(
300 301 302
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
303
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
304
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
C
chengduo 已提交
305 306 307 308
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
309 310 311
        ctx, paddle::operators::math::ReluGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
312
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
313 314 315
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
316 317 318 319
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
320 321 322
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
323
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"),
                             "Cannot get input tensor %s, variable name = %s",
                             "X", ctx.op().Input("X"));
    auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
                             "Cannot get input tensor %s, variable name = %s",
                             "Y", ctx.op().Input("Y"));
339 340 341 342 343 344
    PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

C
chengduo 已提交
345
    if (ctx.Attr<bool>("save_intermediate_out")) {
346
      PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
C
chengduo 已提交
347
                     "The save_intermediate_out is enable, so the "
348 349 350 351 352 353
                     "IntermediateOut should not be empty.");
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
354

355
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
356 357 358 359 360 361 362
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
chengduo 已提交
363 364
    auto in_y = ctx.Input<framework::Tensor>("Y");
    PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
365
    auto in_out = ctx.Input<framework::Tensor>("Out");
C
chengduo 已提交
366
    PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
367 368
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
C
chengduo 已提交
369 370 371 372
    PADDLE_ENFORCE(in_out_grad != nullptr,
                   "Input(Out@Grad) should not be nullptr.");
    framework::Tensor *in_x =
        const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
C
chengduo 已提交
373 374 375 376
    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
C
chengduo 已提交
377 378
    framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
        framework::GradVarName("IntermediateOut"));
C
chengduo 已提交
379

380 381
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

C
chengduo 已提交
382 383 384 385 386 387
    // Get intermediate_out
    framework::Tensor *in_intermediate_out = nullptr;
    if (ctx.Attr<bool>("save_intermediate_out")) {
      // if save_intermediate_out is true, for Unary(Binary(x, y)) and
      // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
      // recompute.
388 389 390
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
      PADDLE_ENFORCE(in_intermediate_out != nullptr,
C
chengduo 已提交
391
                     "The option of 'save_intermediate_out' is opened, "
392 393
                     "so the number of 'Out' should be two.");
    } else {
C
chengduo 已提交
394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
      if (!InputXCanBeAbsent(functor_list)) {
        PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
      }
    }

    // Get in_x
    if (ctx.HasInput("X")) {
      PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
    } else {
      // If functor_list contains elementwise_add, the backward doesn't use
      // in_x, in_y and in_out.
      PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
                     "Only when the compoundfunctor contains "
                     "elementwise_add_grad, the 'X' could be absent.");
      in_x = const_cast<framework::Tensor *>(in_out_grad);
409 410
    }

C
chengduo 已提交
411 412 413 414 415
    bool has_in_place = HasInPlaceUnary(functor_list);
    if (has_in_place) {
      RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
416
    } else {
C
chengduo 已提交
417 418 419
      RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
420
    }
C
chengduo 已提交
421 422 423 424
  }
};
}  // namespace operators
}  // namespace paddle