fused_elemwise_activation_op.h 23.2 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
W
Wu Yi 已提交
22
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
23
#include "paddle/fluid/operators/math/compound_functors.h"
C
chengduo 已提交
24 25 26 27 28
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

C
chengduo 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * Whether the compound function is Unary(Binary(X, Y)).
 * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
 * out.
 */
bool IsUnaryCompound(const std::vector<std::string> &functor_list);

/**
 *  For the in-place unary functor, the inputs of op_desc only have Out and
 *  Out@Grad.
 */
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);

/**
 * Whether the Input(X) could be absent.
 */
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);

C
chengduo 已提交
47 48
template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
49 50 51 52 53 54 55 56 57 58
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
  paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
59
  int axis = ctx.Attr<int>("axis");
C
chengduo 已提交
60
  if (ctx.Attr<bool>("save_intermediate_out")) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
75 76 77 78
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
79 80 81 82 83 84 85 86
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
87 88
  int axis = ctx.Attr<int>("axis");

89 90
  paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
91

C
chengduo 已提交
92
  if (ctx.Attr<bool>("save_intermediate_out")) {
93 94 95 96 97 98 99 100 101 102 103 104 105 106
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
107 108 109
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
C
chengduo 已提交
110
          typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
C
chengduo 已提交
111 112 113 114 115 116
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
117
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
118
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
119
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
120
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
121 122 123
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
124 125
      paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                                           UnaryFunctor>;
C
chengduo 已提交
126
  using BinaryCompoundDyFunctor =
127
      paddle::operators::math::BinaryCompoundGradDyFunctor<
C
chengduo 已提交
128 129 130 131
          T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
  using BinaryCompoundDIntermedaiteOutFunctor =
      paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
          T, BinaryGradFunctor, UnaryFunctor>;
132 133 134 135

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
136
        BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
137 138
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
139 140
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
141
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
142 143 144
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
145 146 147
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
148
        BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
149 150
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
151 152
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
153
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
154 155 156
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
157
  }
C
chengduo 已提交
158 159 160
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
C
chengduo 已提交
161
          typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
C
chengduo 已提交
162 163 164 165 166 167
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
168
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
169
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
170
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
171
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
172 173 174
  int axis = ctx.Attr<int>("axis");

  using UnaryCompoundDxFunctor =
175
      paddle::operators::math::UnaryCompoundGradDxFunctor<
C
chengduo 已提交
176
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
C
chengduo 已提交
177
  using UnaryCompoundDyFunctor =
178
      paddle::operators::math::UnaryCompoundGradDyFunctor<
C
chengduo 已提交
179 180 181 182
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
  using UnaryCompoundDIntermediateFunctor =
      paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
          T, UnaryGradFunctor, BinaryFunctor, InPlace>;
183 184 185 186

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
C
chengduo 已提交
187 188
        UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
189
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
190 191 192
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
193
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
194 195
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
196
  } else {
C
chengduo 已提交
197 198 199 200
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
201
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
202 203 204
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
205
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
206 207
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
208
  }
C
chengduo 已提交
209 210 211 212
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
213 214 215
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
216
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
217

C
chengduo 已提交
218
  // TODO(zcd): The following code can be refined.
219
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
220 221 222
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
223 224 225 226 227
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
228 229 230
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
231 232 233 234 235
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ScaleFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
236
  } else if (funcs_str == "elementwise_add,relu") {
237 238 239 240 241 242
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ReluFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
243
  } else if (funcs_str == "relu,elementwise_add") {
244 245 246 247 248 249 250 251 252 253 254 255 256 257
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ReluFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
  } else if (funcs_str == "tanh,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::TanhFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,tanh") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::TanhFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,sigmoid") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::SigmoidFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
279 280 281 282 283
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

C
chengduo 已提交
284 285 286 287 288 289 290
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
    const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
    const framework::Tensor *in_intermediate_out,
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
C
chengduo 已提交
291 292 293 294 295 296
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
297 298 299 300
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
301 302 303
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
304
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
305 306 307
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
308 309 310 311
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
312 313 314
        ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
315
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
316
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
317
    // The backward of Z = Binary(X, Unary(Y))
C
chengduo 已提交
318 319 320 321
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ReluFunctor<T>,
        paddle::operators::math::ReluGradFunctor<T>, InPlace>(
322 323 324
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
325
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
326
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
327
    // The backward of Z = Unary(Binary(X, Y))
C
chengduo 已提交
328 329 330 331
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
332 333 334
        ctx, paddle::operators::math::ReluGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
335
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
336 337 338
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
339 340 341 342
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
343 344 345
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
346
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
  } else if (funcs_str == "tanh_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::TanhGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::TanhFunctor<T>,
        paddle::operators::math::TanhGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::SigmoidFunctor<T>,
        paddle::operators::math::SigmoidGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(),
        paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"),
                             "Cannot get input tensor %s, variable name = %s",
                             "X", ctx.op().Input("X"));
    auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
                             "Cannot get input tensor %s, variable name = %s",
                             "Y", ctx.op().Input("Y"));
392 393 394 395 396 397
    PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

C
chengduo 已提交
398
    if (ctx.Attr<bool>("save_intermediate_out")) {
399
      PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
C
chengduo 已提交
400
                     "The save_intermediate_out is enable, so the "
401 402 403 404 405 406
                     "IntermediateOut should not be empty.");
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
407

408
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
409 410 411 412 413 414 415
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
chengduo 已提交
416 417
    auto in_y = ctx.Input<framework::Tensor>("Y");
    PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr.");
418
    auto in_out = ctx.Input<framework::Tensor>("Out");
C
chengduo 已提交
419
    PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr.");
420 421
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
C
chengduo 已提交
422 423 424 425
    PADDLE_ENFORCE(in_out_grad != nullptr,
                   "Input(Out@Grad) should not be nullptr.");
    framework::Tensor *in_x =
        const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
C
chengduo 已提交
426 427 428 429
    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
C
chengduo 已提交
430 431
    framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
        framework::GradVarName("IntermediateOut"));
C
chengduo 已提交
432

433 434
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

C
chengduo 已提交
435 436 437 438 439 440
    // Get intermediate_out
    framework::Tensor *in_intermediate_out = nullptr;
    if (ctx.Attr<bool>("save_intermediate_out")) {
      // if save_intermediate_out is true, for Unary(Binary(x, y)) and
      // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
      // recompute.
441 442 443
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
      PADDLE_ENFORCE(in_intermediate_out != nullptr,
C
chengduo 已提交
444
                     "The option of 'save_intermediate_out' is opened, "
445 446
                     "so the number of 'Out' should be two.");
    } else {
C
chengduo 已提交
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
      if (!InputXCanBeAbsent(functor_list)) {
        PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null.");
      }
    }

    // Get in_x
    if (ctx.HasInput("X")) {
      PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr.");
    } else {
      // If functor_list contains elementwise_add, the backward doesn't use
      // in_x, in_y and in_out.
      PADDLE_ENFORCE(InputXCanBeAbsent(functor_list),
                     "Only when the compoundfunctor contains "
                     "elementwise_add_grad, the 'X' could be absent.");
      in_x = const_cast<framework::Tensor *>(in_out_grad);
462 463
    }

C
chengduo 已提交
464 465 466 467 468
    bool has_in_place = HasInPlaceUnary(functor_list);
    if (has_in_place) {
      RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
469
    } else {
C
chengduo 已提交
470 471 472
      RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
473
    }
C
chengduo 已提交
474 475 476 477
  }
};
}  // namespace operators
}  // namespace paddle