fused_elemwise_activation_op.h 23.1 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
W
Wu Yi 已提交
21
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
22 23 24
#include "paddle/pten/kernels/funcs/compound_functors.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"
C
chengduo 已提交
25 26 27 28

namespace paddle {
namespace operators {

C
chengduo 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * Whether the compound function is Unary(Binary(X, Y)).
 * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
 * out.
 */
bool IsUnaryCompound(const std::vector<std::string> &functor_list);

/**
 *  For the in-place unary functor, the inputs of op_desc only have Out and
 *  Out@Grad.
 */
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);

/**
 * Whether the Input(X) could be absent.
 */
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);

C
chengduo 已提交
47 48
template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
49 50 51 52 53 54 55 56
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
57
  pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
58
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
59
  int axis = ctx.Attr<int>("axis");
C
chengduo 已提交
60
  if (ctx.Attr<bool>("save_intermediate_out")) {
61 62 63 64 65
    FusedElemwiseAndActComputeEx<
        DeviceContext, T,
        pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
        true /*KeepIntermediateValue*/,
        false /*SameShapeOfIntermediateOutAndOut*/>(
66 67
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
68 69 70 71 72
    FusedElemwiseAndActComputeEx<
        DeviceContext, T,
        pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
        false /*KeepIntermediateValue*/,
        false /*SameShapeOfIntermediateOutAndOut*/>(
73 74
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
75 76 77 78
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
79 80 81 82 83 84 85 86
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
87 88
  int axis = ctx.Attr<int>("axis");

89
  pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
90
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
91

C
chengduo 已提交
92
  if (ctx.Attr<bool>("save_intermediate_out")) {
93 94 95 96 97
    FusedElemwiseAndActComputeEx<
        DeviceContext, T,
        pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
        true /*KeepIntermediateValue*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
98 99
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
100 101 102 103 104
    FusedElemwiseAndActComputeEx<
        DeviceContext, T,
        pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
        false /*KeepIntermediateValue*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
105 106
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
107 108 109
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
C
chengduo 已提交
110
          typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
C
chengduo 已提交
111 112 113 114 115 116
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
117
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
118
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
119
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
120
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
121 122 123
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
124 125 126 127
      pten::funcs::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                               UnaryFunctor>;
  using BinaryCompoundDyFunctor = pten::funcs::BinaryCompoundGradDyFunctor<
      T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
C
chengduo 已提交
128
  using BinaryCompoundDIntermedaiteOutFunctor =
129
      pten::funcs::BinaryCompoundGradDIntermedaiteOutFunctor<
C
chengduo 已提交
130
          T, BinaryGradFunctor, UnaryFunctor>;
131 132 133 134

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
135
        BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
136 137
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
138 139
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
140
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
141 142 143
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
144 145 146
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
147
        BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
148 149
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
150 151
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
152
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
153 154 155
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
156
  }
C
chengduo 已提交
157 158 159
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
C
chengduo 已提交
160
          typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
C
chengduo 已提交
161 162 163 164 165 166
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
167
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
168
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
169
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
170
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
171 172
  int axis = ctx.Attr<int>("axis");

173 174 175 176
  using UnaryCompoundDxFunctor = pten::funcs::UnaryCompoundGradDxFunctor<
      T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
  using UnaryCompoundDyFunctor = pten::funcs::UnaryCompoundGradDyFunctor<
      T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
C
chengduo 已提交
177
  using UnaryCompoundDIntermediateFunctor =
178
      pten::funcs::UnaryCompoundGradDIntermediateFunctor<
C
chengduo 已提交
179
          T, UnaryGradFunctor, BinaryFunctor, InPlace>;
180 181 182 183

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
C
chengduo 已提交
184 185
        UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
186
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
187 188 189
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
190
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
191 192
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
193
  } else {
C
chengduo 已提交
194 195 196 197
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
198
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
199 200 201
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
202
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
203 204
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
205
  }
C
chengduo 已提交
206 207 208 209
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
210 211 212
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
213
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
214

C
chengduo 已提交
215
  // TODO(zcd): The following code can be refined.
216
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
217 218 219
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
220 221 222 223
    RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
                             pten::funcs::ScaleFunctor<T>>(
        ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ScaleFunctor<T>(scale),
        in_x, in_y, outputs);
C
chengduo 已提交
224 225 226
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
227 228 229 230
    RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ScaleFunctor<T>,
                             pten::funcs::AddFunctor<T>>(
        ctx, pten::funcs::ScaleFunctor<T>(scale), pten::funcs::AddFunctor<T>(),
        in_x, in_y, outputs);
C
chengduo 已提交
231
  } else if (funcs_str == "elementwise_add,relu") {
232
    // Z = Binary(X, Unary(Y))
233 234 235 236
    RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
                             pten::funcs::ReluFunctor<T>>(
        ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ReluFunctor<T>(), in_x,
        in_y, outputs);
C
chengduo 已提交
237
  } else if (funcs_str == "relu,elementwise_add") {
238
    // Z = Unary(Binary(X, Y))
239 240 241 242
    RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ReluFunctor<T>,
                             pten::funcs::AddFunctor<T>>(
        ctx, pten::funcs::ReluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
        in_y, outputs);
243 244 245
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
246 247 248 249
    RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
                             pten::funcs::ScaleFunctor<T>>(
        ctx, pten::funcs::MultiplyFunctor<T>(),
        pten::funcs::ScaleFunctor<T>(scale), in_x, in_y, outputs);
250 251
  } else if (funcs_str == "tanh,elementwise_add") {
    // Z = Unary(Binary(X, Y))
252 253 254 255
    RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::TanhFunctor<T>,
                             pten::funcs::AddFunctor<T>>(
        ctx, pten::funcs::TanhFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
        in_y, outputs);
256 257
  } else if (funcs_str == "elementwise_mul,tanh") {
    // Z = Binary(X, Unary(Y))
258 259 260 261
    RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
                             pten::funcs::TanhFunctor<T>>(
        ctx, pten::funcs::MultiplyFunctor<T>(), pten::funcs::TanhFunctor<T>(),
        in_x, in_y, outputs);
262 263
  } else if (funcs_str == "elementwise_mul,sigmoid") {
    // Z = Binary(X, Unary(Y))
264 265 266 267
    RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
                             pten::funcs::SigmoidFunctor<T>>(
        ctx, pten::funcs::MultiplyFunctor<T>(),
        pten::funcs::SigmoidFunctor<T>(), in_x, in_y, outputs);
268 269
  } else if (funcs_str == "gelu,elementwise_add") {
    // Z = Unary(Binary(X, Y))
270 271 272 273
    RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::GeluFunctor<T>,
                             pten::funcs::AddFunctor<T>>(
        ctx, pten::funcs::GeluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
        in_y, outputs);
C
chengduo 已提交
274
  } else {
275 276
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s has not been implemented.", funcs_str));
C
chengduo 已提交
277 278 279
  }
}

C
chengduo 已提交
280 281 282 283 284 285 286
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
    const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
    const framework::Tensor *in_intermediate_out,
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
C
chengduo 已提交
287 288 289 290 291 292
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
293 294 295 296 297 298 299
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  pten::funcs::AddGradFunctor<T>,
                                  pten::funcs::ScaleFunctor<T>,
                                  pten::funcs::ScaleGradFunctor<T>, InPlace>(
        ctx, pten::funcs::AddGradFunctor<T>(),
        pten::funcs::ScaleFunctor<T>(scale),
        pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
300
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
301 302 303
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
304
    RunUnaryCompoundGradFunctors<
305 306 307 308 309 310
        DeviceContext, T, pten::funcs::ScaleGradFunctor<T>,
        pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
        ctx, pten::funcs::ScaleGradFunctor<T>(scale),
        pten::funcs::AddFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x,
        in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad,
        d_intermediate_out);
C
chengduo 已提交
311
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
312
    // The backward of Z = Binary(X, Unary(Y))
C
chengduo 已提交
313
    RunBinaryCompoundGradFunctors<
314 315 316 317
        DeviceContext, T, pten::funcs::AddGradFunctor<T>,
        pten::funcs::ReluFunctor<T>, pten::funcs::ReluGradFunctor<T>, InPlace>(
        ctx, pten::funcs::AddGradFunctor<T>(), pten::funcs::ReluFunctor<T>(),
        pten::funcs::ReluGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
318
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
319
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
320
    // The backward of Z = Unary(Binary(X, Y))
C
chengduo 已提交
321
    RunUnaryCompoundGradFunctors<
322 323 324 325
        DeviceContext, T, pten::funcs::ReluGradFunctor<T>,
        pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
        ctx, pten::funcs::ReluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
        pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
326
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
327 328 329
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
330 331 332 333 334 335 336
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  pten::funcs::MulGradFunctor<T>,
                                  pten::funcs::ScaleFunctor<T>,
                                  pten::funcs::ScaleGradFunctor<T>, InPlace>(
        ctx, pten::funcs::MulGradFunctor<T>(),
        pten::funcs::ScaleFunctor<T>(scale),
        pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
337
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
338 339 340
  } else if (funcs_str == "tanh_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    RunUnaryCompoundGradFunctors<
341 342 343 344
        DeviceContext, T, pten::funcs::TanhGradFunctor<T>,
        pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
        ctx, pten::funcs::TanhGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
        pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
345 346 347 348
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
349 350 351 352
        DeviceContext, T, pten::funcs::MulGradFunctor<T>,
        pten::funcs::TanhFunctor<T>, pten::funcs::TanhGradFunctor<T>, InPlace>(
        ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::TanhFunctor<T>(),
        pten::funcs::TanhGradFunctor<T>(), in_x, in_y, in_out,
353 354 355
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
    // The backward of Z = Binary(X, Unary(Y))
356 357 358 359 360 361
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  pten::funcs::MulGradFunctor<T>,
                                  pten::funcs::SigmoidFunctor<T>,
                                  pten::funcs::SigmoidGradFunctor<T>, InPlace>(
        ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::SigmoidFunctor<T>(),
        pten::funcs::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
362
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
363 364 365
  } else if (funcs_str == "gelu_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    RunUnaryCompoundGradFunctors<
366 367 368 369
        DeviceContext, T, pten::funcs::GeluGradFunctor<T>,
        pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
        ctx, pten::funcs::GeluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
        pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
370
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
371
  } else {
372 373
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s has not been implemented.", funcs_str));
C
chengduo 已提交
374 375 376 377 378 379 380
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
381 382 383 384
    auto &in_x = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("X"), "Input",
                                 "X", "FusedElemwiseActivation");
    auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
                                 "Y", "FusedElemwiseActivation");
385 386 387 388

    PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
                      platform::errors::InvalidArgument(
                          "The output(Out) should not be empty"));
389 390 391 392 393
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

C
chengduo 已提交
394
    if (ctx.Attr<bool>("save_intermediate_out")) {
395 396 397 398 399
      PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
                        platform::errors::InvalidArgument(
                            "The save_intermediate_out is enable, so the "
                            "IntermediateOut should not be empty."));

400 401 402 403 404
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
405

406
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
407 408 409 410 411 412 413
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
chengduo 已提交
414
    auto in_y = ctx.Input<framework::Tensor>("Y");
415 416
    PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
                                         "Input(Y) should not be nullptr."));
417
    auto in_out = ctx.Input<framework::Tensor>("Out");
418 419 420
    PADDLE_ENFORCE_NE(
        in_out, nullptr,
        platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
421 422
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
423 424 425 426
    PADDLE_ENFORCE_NE(in_out_grad, nullptr,
                      platform::errors::InvalidArgument(
                          "Input(Out@Grad) should not be nullptr."));

C
chengduo 已提交
427 428
    framework::Tensor *in_x =
        const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
C
chengduo 已提交
429 430 431 432
    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
C
chengduo 已提交
433 434
    framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
        framework::GradVarName("IntermediateOut"));
C
chengduo 已提交
435

436 437
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

C
chengduo 已提交
438 439 440 441 442 443
    // Get intermediate_out
    framework::Tensor *in_intermediate_out = nullptr;
    if (ctx.Attr<bool>("save_intermediate_out")) {
      // if save_intermediate_out is true, for Unary(Binary(x, y)) and
      // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
      // recompute.
444 445
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
446 447 448 449
      PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
                        platform::errors::InvalidArgument(
                            "The option of 'save_intermediate_out' is opened,"
                            " so the number of 'Out' should be two."));
450
    } else {
C
chengduo 已提交
451
      if (!InputXCanBeAbsent(functor_list)) {
452 453
        PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
                                             "Input(X) should not be null."));
C
chengduo 已提交
454 455 456 457 458
      }
    }

    // Get in_x
    if (ctx.HasInput("X")) {
459 460
      PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
                                           "Input(X) should not be null."));
C
chengduo 已提交
461 462 463
    } else {
      // If functor_list contains elementwise_add, the backward doesn't use
      // in_x, in_y and in_out.
464 465 466 467
      PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
                        platform::errors::InvalidArgument(
                            "Only when the compoundfunctor contains "
                            "elementwise_add_grad, the 'X' could be absent."));
C
chengduo 已提交
468
      in_x = const_cast<framework::Tensor *>(in_out_grad);
469 470
    }

C
chengduo 已提交
471 472 473 474 475
    bool has_in_place = HasInPlaceUnary(functor_list);
    if (has_in_place) {
      RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
476
    } else {
C
chengduo 已提交
477 478 479
      RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
480
    }
C
chengduo 已提交
481 482 483 484
  }
};
}  // namespace operators
}  // namespace paddle