fused_elemwise_activation_op.h 23.8 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
W
Wu Yi 已提交
21
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
22
#include "paddle/fluid/operators/math/compound_functors.h"
C
chengduo 已提交
23 24 25 26 27
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

C
chengduo 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/**
 * Whether the compound function is Unary(Binary(X, Y)).
 * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
 * out.
 */
bool IsUnaryCompound(const std::vector<std::string> &functor_list);

/**
 *  For the in-place unary functor, the inputs of op_desc only have Out and
 *  Out@Grad.
 */
bool HasInPlaceUnary(const std::vector<std::string> &functor_list);

/**
 * Whether the Input(X) could be absent.
 */
bool InputXCanBeAbsent(const std::vector<std::string> &functor_list);

C
chengduo 已提交
46 47
template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
48 49 50 51 52 53 54 55 56 57
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
  paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
58
  int axis = ctx.Attr<int>("axis");
C
chengduo 已提交
59
  if (ctx.Attr<bool>("save_intermediate_out")) {
60 61 62 63 64 65 66 67 68 69 70 71 72 73
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
74 75 76 77
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
78 79 80 81 82 83 84 85
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
86 87
  int axis = ctx.Attr<int>("axis");

88 89
  paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
90

C
chengduo 已提交
91
  if (ctx.Attr<bool>("save_intermediate_out")) {
92 93 94 95 96 97 98 99 100 101 102 103 104 105
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
106 107 108
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
C
chengduo 已提交
109
          typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace>
C
chengduo 已提交
110 111 112 113 114 115
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
116
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
117
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
118
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
119
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
120 121 122
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
123 124
      paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                                           UnaryFunctor>;
C
chengduo 已提交
125
  using BinaryCompoundDyFunctor =
126
      paddle::operators::math::BinaryCompoundGradDyFunctor<
C
chengduo 已提交
127 128 129 130
          T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
  using BinaryCompoundDIntermedaiteOutFunctor =
      paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
          T, BinaryGradFunctor, UnaryFunctor>;
131 132 133 134

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
135
        BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/,
136 137
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
138 139
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
140
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
141 142 143
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
144 145 146
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
C
chengduo 已提交
147
        BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/,
148 149
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
150 151
        y_grad, d_intermediate_out,
        BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
152
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
C
chengduo 已提交
153 154 155
                                unary_grad_functor),
        BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor,
                                              unary_functor));
156
  }
C
chengduo 已提交
157 158 159
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
C
chengduo 已提交
160
          typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace>
C
chengduo 已提交
161 162 163 164 165 166
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
167
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
168
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
C
chengduo 已提交
169
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
170
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
171 172 173
  int axis = ctx.Attr<int>("axis");

  using UnaryCompoundDxFunctor =
174
      paddle::operators::math::UnaryCompoundGradDxFunctor<
C
chengduo 已提交
175
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
C
chengduo 已提交
176
  using UnaryCompoundDyFunctor =
177
      paddle::operators::math::UnaryCompoundGradDyFunctor<
C
chengduo 已提交
178 179 180 181
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
  using UnaryCompoundDIntermediateFunctor =
      paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
          T, UnaryGradFunctor, BinaryFunctor, InPlace>;
182 183 184 185

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
C
chengduo 已提交
186 187
        UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
188
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
189 190 191
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
192
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
193 194
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
195
  } else {
C
chengduo 已提交
196 197 198 199
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/,
        true /*SameShapeOfIntermediateOutAndOut*/>(
200
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
C
chengduo 已提交
201 202 203
        y_grad, d_intermediate_out,
        UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor),
204
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
C
chengduo 已提交
205 206
                               binary_grad_functor),
        UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor));
207
  }
C
chengduo 已提交
208 209 210 211
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
212 213 214
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
215
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
216

C
chengduo 已提交
217
  // TODO(zcd): The following code can be refined.
218
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
219 220 221
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
222 223 224 225 226
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
227 228 229
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
230 231 232 233 234
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ScaleFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
235
  } else if (funcs_str == "elementwise_add,relu") {
236 237 238 239 240 241
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ReluFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
242
  } else if (funcs_str == "relu,elementwise_add") {
243 244 245 246 247 248 249 250 251 252 253 254 255 256
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ReluFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
  } else if (funcs_str == "tanh,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::TanhFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,tanh") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::TanhFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,sigmoid") {
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::SigmoidFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
278
  } else {
279 280
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s has not been implemented.", funcs_str));
C
chengduo 已提交
281 282 283
  }
}

C
chengduo 已提交
284 285 286 287 288 289 290
template <typename DeviceContext, typename T, bool InPlace>
static void RunGradFunctors(
    const framework::ExecutionContext &ctx, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
    const framework::Tensor *in_intermediate_out,
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) {
C
chengduo 已提交
291 292 293 294 295 296
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
297 298 299 300
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
301 302 303
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
304
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
305 306 307
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
308 309 310 311
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
312 313 314
        ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
315
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
316
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
317
    // The backward of Z = Binary(X, Unary(Y))
C
chengduo 已提交
318 319 320 321
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
        paddle::operators::math::ReluFunctor<T>,
        paddle::operators::math::ReluGradFunctor<T>, InPlace>(
322 323 324
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
325
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
326
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
327
    // The backward of Z = Unary(Binary(X, Y))
C
chengduo 已提交
328 329 330 331
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
332 333 334
        ctx, paddle::operators::math::ReluGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
C
chengduo 已提交
335
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
336 337 338
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
C
chengduo 已提交
339 340 341 342
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::ScaleFunctor<T>,
        paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
343 344 345
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
C
chengduo 已提交
346
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
  } else if (funcs_str == "tanh_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    RunUnaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>,
        paddle::operators::math::AddFunctor<T>,
        paddle::operators::math::AddGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::TanhGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::TanhFunctor<T>,
        paddle::operators::math::TanhGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::TanhFunctor<T>(),
        paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
  } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    RunBinaryCompoundGradFunctors<
        DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
        paddle::operators::math::SigmoidFunctor<T>,
        paddle::operators::math::SigmoidGradFunctor<T>, InPlace>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::SigmoidFunctor<T>(),
        paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
C
chengduo 已提交
377
  } else {
378 379
    PADDLE_THROW(platform::errors::InvalidArgument(
        "%s has not been implemented.", funcs_str));
C
chengduo 已提交
380 381 382 383 384 385 386
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
387 388 389 390
    auto &in_x = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("X"), "Input",
                                 "X", "FusedElemwiseActivation");
    auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
                                 "Y", "FusedElemwiseActivation");
391 392 393 394

    PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
                      platform::errors::InvalidArgument(
                          "The output(Out) should not be empty"));
395 396 397 398 399
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

C
chengduo 已提交
400
    if (ctx.Attr<bool>("save_intermediate_out")) {
401 402 403 404 405
      PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true,
                        platform::errors::InvalidArgument(
                            "The save_intermediate_out is enable, so the "
                            "IntermediateOut should not be empty."));

406 407 408 409 410
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
411

412
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
413 414 415 416 417 418 419
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
C
chengduo 已提交
420
    auto in_y = ctx.Input<framework::Tensor>("Y");
421 422
    PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument(
                                         "Input(Y) should not be nullptr."));
423
    auto in_out = ctx.Input<framework::Tensor>("Out");
424 425 426
    PADDLE_ENFORCE_NE(
        in_out, nullptr,
        platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
427 428
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
429 430 431 432
    PADDLE_ENFORCE_NE(in_out_grad, nullptr,
                      platform::errors::InvalidArgument(
                          "Input(Out@Grad) should not be nullptr."));

C
chengduo 已提交
433 434
    framework::Tensor *in_x =
        const_cast<framework::Tensor *>(ctx.Input<framework::Tensor>("X"));
C
chengduo 已提交
435 436 437 438
    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
C
chengduo 已提交
439 440
    framework::Tensor *d_intermediate_out = ctx.Output<framework::Tensor>(
        framework::GradVarName("IntermediateOut"));
C
chengduo 已提交
441

442 443
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

C
chengduo 已提交
444 445 446 447 448 449
    // Get intermediate_out
    framework::Tensor *in_intermediate_out = nullptr;
    if (ctx.Attr<bool>("save_intermediate_out")) {
      // if save_intermediate_out is true, for Unary(Binary(x, y)) and
      // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to
      // recompute.
450 451
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
452 453 454 455
      PADDLE_ENFORCE_NE(in_intermediate_out, nullptr,
                        platform::errors::InvalidArgument(
                            "The option of 'save_intermediate_out' is opened,"
                            " so the number of 'Out' should be two."));
456
    } else {
C
chengduo 已提交
457
      if (!InputXCanBeAbsent(functor_list)) {
458 459
        PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
                                             "Input(X) should not be null."));
C
chengduo 已提交
460 461 462 463 464
      }
    }

    // Get in_x
    if (ctx.HasInput("X")) {
465 466
      PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument(
                                           "Input(X) should not be null."));
C
chengduo 已提交
467 468 469
    } else {
      // If functor_list contains elementwise_add, the backward doesn't use
      // in_x, in_y and in_out.
470 471 472 473
      PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true,
                        platform::errors::InvalidArgument(
                            "Only when the compoundfunctor contains "
                            "elementwise_add_grad, the 'X' could be absent."));
C
chengduo 已提交
474
      in_x = const_cast<framework::Tensor *>(in_out_grad);
475 476
    }

C
chengduo 已提交
477 478 479 480 481
    bool has_in_place = HasInPlaceUnary(functor_list);
    if (has_in_place) {
      RunGradFunctors<DeviceContext, T, true /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
482
    } else {
C
chengduo 已提交
483 484 485
      RunGradFunctors<DeviceContext, T, false /*InPlace*/>(
          ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad, d_intermediate_out);
486
    }
C
chengduo 已提交
487 488 489 490
  }
};
}  // namespace operators
}  // namespace paddle