fused_elemwise_activation_op.h 23.0 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
W
Wu Yi 已提交
21
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
22
#include "paddle/fluid/operators/math/compound_functors.h"
C
chengduo 已提交
23 24 25 26 27
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

C
chengduo 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/**
 * Whether the compound function is Unary(Binary(X, Y)).
 * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
 * out.
 */
bool IsUnaryCompound(const std::vector<std::string> &functor_list);

/**
 *  For the in-place unary functor, the inputs of op_desc only have Out and
 *  Out@Grad.
 */
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);

/**
 * Whether the Input(X) could be absent.
 */
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);

C
chengduo 已提交
46 47
template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
48 49 50 51 52 53 54 55 56 57
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
  paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
58
  int axis = ctx.Attr<int>("axis");
C
chengduo 已提交
59
  if (ctx.Attr<bool>("save_intermediate_out")) {
60 61 62 63 64 65 66 67 68 69 70 71 72 73
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
74 75 76 77
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
78 79 80 81 82 83 84 85
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
86 87
  int axis = ctx.Attr<int>("axis");

88 89
  paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
90

C
chengduo 已提交
91
  if (ctx.Attr<bool>("save_intermediate_out")) {
92 93 94 95 96 97 98 99 100 101 102 103 104 105
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
106 107 108
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
C
chengduo 已提交
109
          typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
C
chengduo 已提交
110 111 112 113 114 115
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
116
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
117
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
118
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
119
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
120 121 122
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
123 124
      paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                                           UnaryFunctor>;
C
chengduo 已提交
125
  using BinaryCompoundDyFunctor =
126
      paddle::operators::math::BinaryCompoundGradDyFunctor<
C
chengduo 已提交
127 128 129 130
          T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
  using BinaryCompoundDIntermedaiteOutFunctor =
      paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
          T, BinaryGradFunctor, UnaryFunctor>;
131 132 133 134

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
135
        BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
136 137
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
138 139
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
140
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
141 142 143
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
144 145 146
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
147
        BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
148 149
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
150 151
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
152
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
153 154 155
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
156
  }
C
chengduo 已提交
157 158 159
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
C
chengduo 已提交
160
          typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
C
chengduo 已提交
161 162 163 164 165 166
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
167
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
168
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
169
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
170
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
171 172 173
  int axis = ctx.Attr<int>("axis");

  using UnaryCompoundDxFunctor =
174
      paddle::operators::math::UnaryCompoundGradDxFunctor<
C
chengduo 已提交
175
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
C
chengduo 已提交
176
  using UnaryCompoundDyFunctor =
177
      paddle::operators::math::UnaryCompoundGradDyFunctor<
C
chengduo 已提交
178 179 180 181
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
  using UnaryCompoundDIntermediateFunctor =
      paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
          T, UnaryGradFunctor, BinaryFunctor, InPlace>;
182 183 184 185

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
C
chengduo 已提交
186 187
        UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
188
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
189 190 191
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
192
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
193 194
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
195
  } else {
C
chengduo 已提交
196 197 198 199
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
200
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
201 202 203
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
204
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
205 206
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
207
  }
C
chengduo 已提交
208 209 210 211
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
212 213 214
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
215
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
216

C
chengduo 已提交
217
  // TODO(zcd): The following code can be refined.
218
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
219 220 221
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
222 223 224 225 226
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
227 228 229
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
230 231 232 233 234
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ScaleFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
235
  } else if (funcs_str == "elementwise_add,relu") {
236 237 238 239 240 241
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ReluFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
242
  } else if (funcs_str == "relu,elementwise_add") {
243 244 245 246 247 248 249 250 251 252 253 254 255 256
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ReluFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
  } else if (funcs_str == "tanh,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::TanhFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,tanh") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::TanhFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,sigmoid") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::SigmoidFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
278 279 280 281 282
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

C
chengduo 已提交
283 284 285 286 287 288 289
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
    const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
    const framework::Tensor *in_intermediate_out,
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
C
chengduo 已提交
290 291 292 293 294 295
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
296 297 298 299
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
300 301 302
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
303
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
304 305 306
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
307 308 309 310
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
311 312 313
        ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
314
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
315
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
316
    // The backward of Z = Binary(X, Unary(Y))
C
chengduo 已提交
317 318 319 320
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ReluFunctor<T>,
        paddle::operators::math::ReluGradFunctor<T>, InPlace>(
321 322 323
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
324
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
325
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
326
    // The backward of Z = Unary(Binary(X, Y))
C
chengduo 已提交
327 328 329 330
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
331 332 333
        ctx, paddle::operators::math::ReluGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
334
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
335 336 337
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
338 339 340 341
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
342 343 344
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
345
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
  } else if (funcs_str == "tanh_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::TanhGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::TanhFunctor<T>,
        paddle::operators::math::TanhGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::SigmoidFunctor<T>,
        paddle::operators::math::SigmoidGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(),
        paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
376 377 378 379 380 381 382 383 384
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
385 386 387 388
    auto &in_x = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("X"), "Input",
                                 "X", "FusedElemwiseActivation");
    auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
                                 "Y", "FusedElemwiseActivation");
389 390 391 392 393 394
    PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

C
chengduo 已提交
395
    if (ctx.Attr<bool>("save_intermediate_out")) {
396
      PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
C
chengduo 已提交
397
                     "The save_intermediate_out is enable, so the "
398 399 400 401 402 403
                     "IntermediateOut should not be empty.");
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
404

405
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
406 407 408 409 410 411 412
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
chengduo 已提交
413 414
    auto in_y = ctx.Input<framework::Tensor>("Y");
    PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
415
    auto in_out = ctx.Input<framework::Tensor>("Out");
C
chengduo 已提交
416
    PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
417 418
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
C
chengduo 已提交
419 420 421 422
    PADDLE_ENFORCE(in_out_grad != nullptr,
                   "Input(Out@Grad) should not be nullptr.");
    framework::Tensor *in_x =
        const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
C
chengduo 已提交
423 424 425 426
    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
C
chengduo 已提交
427 428
    framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
        framework::GradVarName("IntermediateOut"));
C
chengduo 已提交
429

430 431
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

C
chengduo 已提交
432 433 434 435 436 437
    // Get intermediate_out
    framework::Tensor *in_intermediate_out = nullptr;
    if (ctx.Attr<bool>("save_intermediate_out")) {
      // if save_intermediate_out is true, for Unary(Binary(x, y)) and
      // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
      // recompute.
438 439 440
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
      PADDLE_ENFORCE(in_intermediate_out != nullptr,
C
chengduo 已提交
441
                     "The option of 'save_intermediate_out' is opened, "
442 443
                     "so the number of 'Out' should be two.");
    } else {
C
chengduo 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
      if (!InputXCanBeAbsent(functor_list)) {
        PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
      }
    }

    // Get in_x
    if (ctx.HasInput("X")) {
      PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
    } else {
      // If functor_list contains elementwise_add, the backward doesn't use
      // in_x, in_y and in_out.
      PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
                     "Only when the compoundfunctor contains "
                     "elementwise_add_grad, the 'X' could be absent.");
      in_x = const_cast<framework::Tensor *>(in_out_grad);
459 460
    }

C
chengduo 已提交
461 462 463 464 465
    bool has_in_place = HasInPlaceUnary(functor_list);
    if (has_in_place) {
      RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
466
    } else {
C
chengduo 已提交
467 468 469
      RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
470
    }
C
chengduo 已提交
471 472 473 474
  }
};
}  // namespace operators
}  // namespace paddle