fused_elemwise_activation_op.h 19.0 KB
Newer Older
C
chengduo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
23
#include "paddle/fluid/operators/math/compound_functors.h"
C
chengduo 已提交
24 25 26 27 28 29 30
#include "paddle/fluid/operators/math/functors.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T, typename BinaryFunctor,
          typename UnaryFunctor>
31 32 33 34 35 36 37 38 39 40
static void RunBinaryCompoundFunctor(
    const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
    const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Binary(X, Unary(Y))
  // intermediate_out = Unary(Y)
  // out = Binary(X, Unary(Y))
  // In this case, the shape of intermediate_out and out are different.
  paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
      compound_func(binary_functor, unary_functor);
C
chengduo 已提交
41
  int axis = ctx.Attr<int>("axis");
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
  if (ctx.Attr<bool>("keep_intermediate_value")) {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::BinaryCompoundFunctor<
                                     T, BinaryFunctor, UnaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
57 58 59 60
}

template <typename DeviceContext, typename T, typename UnaryFunctor,
          typename BinaryFunctor>
61 62 63 64 65 66 67 68
static void RunUnaryCompoundFunctors(
    const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
    const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
    const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
  // Z = Unary(Binary(X, Y))
  // intermediate_out = Binary(X, Y)
  // out = Unary(Binary(X, Y))
  // In this case, the shape of intermediate_out and out are the same.
C
chengduo 已提交
69 70
  int axis = ctx.Attr<int>("axis");

71 72
  paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
      compound_func(unary_functor, binary_functor);
C
chengduo 已提交
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
  if (ctx.Attr<bool>("keep_intermediate_value")) {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 true /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  } else {
    FusedElemwiseAndActComputeEx<DeviceContext, T,
                                 paddle::operators::math::UnaryCompoundFunctor<
                                     T, UnaryFunctor, BinaryFunctor>,
                                 false /*KeepIntermediateValue*/,
                                 true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
  }
C
chengduo 已提交
89 90 91
}

template <typename DeviceContext, typename T, typename BinaryGradFunctor,
92
          typename UnaryFunctor, typename UnaryGradFunctor>
C
chengduo 已提交
93 94 95 96 97 98
static void RunBinaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const BinaryGradFunctor &binary_grad_functor,
    const UnaryFunctor &unary_functor,
    const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
99
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
100 101
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad) {
102
  // Z = Binary(X, Unary(Y))
C
chengduo 已提交
103 104 105
  int axis = ctx.Attr<int>("axis");

  using BinaryCompoundDxFunctor =
106 107
      paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
                                                           UnaryFunctor>;
C
chengduo 已提交
108
  using BinaryCompoundDyFunctor =
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
      paddle::operators::math::BinaryCompoundGradDyFunctor<
          T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>;

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
        true /*UseIntermediateOut*/,
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
        y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
                                unary_grad_functor));
  } else {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
        false /*UseIntermediateOut*/,
        false /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
        y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
        BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
                                unary_grad_functor));
  }
C
chengduo 已提交
131 132 133 134 135 136 137 138 139 140 141
}

template <typename DeviceContext, typename T, typename UnaryGradFunctor,
          typename BinaryFunctor, typename BinaryGradFunctor,
          bool Recomputation = true>
static void RunUnaryCompoundGradFunctors(
    const framework::ExecutionContext &ctx,
    const UnaryGradFunctor &unary_grad_functor,
    const BinaryFunctor &binary_functor,
    const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
    const framework::Tensor *in_y, const framework::Tensor *in_out,
142
    const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
143 144
    const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
    framework::Tensor *y_grad) {
145
  // Z = Unary(Binary(X, Y))
C
chengduo 已提交
146 147 148
  int axis = ctx.Attr<int>("axis");

  using UnaryCompoundDxFunctor =
149 150
      paddle::operators::math::UnaryCompoundGradDxFunctor<
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
C
chengduo 已提交
151
  using UnaryCompoundDyFunctor =
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
      paddle::operators::math::UnaryCompoundGradDyFunctor<
          T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;

  if (in_intermediate_out) {
    FusedElemwiseAndActGradComputeEx<
        DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
        true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
        y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                                       binary_grad_functor),
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor));
  } else {
    FusedElemwiseAndActGradComputeEx<DeviceContext, T, UnaryCompoundDxFunctor,
                                     UnaryCompoundDyFunctor,
                                     false /*UseIntermediateOut*/,
                                     true /*SameShapeOfIntermediateOutAndOut*/>(
        ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
        y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
                                       binary_grad_functor),
        UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
                               binary_grad_functor));
  }
C
chengduo 已提交
175 176 177 178
}

template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
179 180 181
                        const framework::Tensor &in_x,
                        const framework::Tensor &in_y,
                        std::vector<framework::Tensor *> *outputs) {
C
chengduo 已提交
182
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
183

C
chengduo 已提交
184
  // TODO(zcd): The following code can be refined.
185
  auto funcs_str = functors[0] + "," + functors[1];
C
chengduo 已提交
186 187 188
  if (funcs_str == "elementwise_add,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
189 190 191 192 193
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
194 195 196
  } else if (funcs_str == "scale,elementwise_add") {
    // Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
197 198 199 200 201
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ScaleFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
202
  } else if (funcs_str == "elementwise_add,relu") {
203 204 205 206 207 208
    // Z = Binary(X, Unary(Y))
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::AddFunctor<T>,
                             paddle::operators::math::ReluFunctor<T>>(
        ctx, paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
C
chengduo 已提交
209
  } else if (funcs_str == "relu,elementwise_add") {
210 211 212 213 214 215 216 217 218 219 220 221 222 223
    // Z = Unary(Binary(X, Y))
    RunUnaryCompoundFunctors<DeviceContext, T,
                             paddle::operators::math::ReluFunctor<T>,
                             paddle::operators::math::AddFunctor<T>>(
        ctx, paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
  } else if (funcs_str == "elementwise_mul,scale") {
    // Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundFunctor<DeviceContext, T,
                             paddle::operators::math::MulFunctor<T>,
                             paddle::operators::math::ScaleFunctor<T>>(
        ctx, paddle::operators::math::MulFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
C
chengduo 已提交
224 225 226 227 228
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

229
template <typename DeviceContext, typename T, bool ReComputation>
C
chengduo 已提交
230 231 232 233
static void RunGradFunctors(const framework::ExecutionContext &ctx,
                            const framework::Tensor *in_x,
                            const framework::Tensor *in_y,
                            const framework::Tensor *in_out,
234
                            const framework::Tensor *in_intermediate_out,
C
chengduo 已提交
235 236 237 238 239 240
                            const framework::Tensor *in_out_grad,
                            framework::Tensor *x_grad,
                            framework::Tensor *y_grad) {
  auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
  auto funcs_str = functors[0] + "," + functors[1];

241
  // TODO(zcd): The following code can be refined. for example, use registrition
C
chengduo 已提交
242 243 244
  if (funcs_str == "elementwise_add_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
245 246 247 248 249 250 251 252
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  paddle::operators::math::AddGradFunctor<T>,
                                  paddle::operators::math::ScaleFunctor<T>,
                                  paddle::operators::math::ScaleGradFunctor<T>>(
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad);
C
chengduo 已提交
253 254 255
  } else if (funcs_str == "scale_grad,elementwise_add_grad") {
    // The backward of Z = Unary(Binary(X, Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
256 257 258 259 260 261 262 263 264
    RunUnaryCompoundGradFunctors<DeviceContext, T,
                                 paddle::operators::math::ScaleGradFunctor<T>,
                                 paddle::operators::math::AddFunctor<T>,
                                 paddle::operators::math::AddGradFunctor<T>,
                                 ReComputation /*Recomputation*/>(
        ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad);
C
chengduo 已提交
265
  } else if (funcs_str == "elementwise_add_grad,relu_grad") {
266 267 268 269 270 271 272 273
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  paddle::operators::math::AddGradFunctor<T>,
                                  paddle::operators::math::ReluFunctor<T>,
                                  paddle::operators::math::ReluGradFunctor<T>>(
        ctx, paddle::operators::math::AddGradFunctor<T>(),
        paddle::operators::math::ReluFunctor<T>(),
        paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad);
C
chengduo 已提交
274
  } else if (funcs_str == "relu_grad,elementwise_add_grad") {
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    RunUnaryCompoundGradFunctors<DeviceContext, T,
                                 paddle::operators::math::ReluGradFunctor<T>,
                                 paddle::operators::math::AddFunctor<T>,
                                 paddle::operators::math::AddGradFunctor<T>,
                                 ReComputation /*Recomputation*/>(
        ctx, paddle::operators::math::ReluGradFunctor<T>(),
        paddle::operators::math::AddFunctor<T>(),
        paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad);
  } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
    // The backward of Z = Binary(X, Unary(Y))
    T scale = static_cast<T>(ctx.Attr<float>("scale"));
    RunBinaryCompoundGradFunctors<DeviceContext, T,
                                  paddle::operators::math::MulGradFunctor<T>,
                                  paddle::operators::math::ScaleFunctor<T>,
                                  paddle::operators::math::ScaleGradFunctor<T>>(
        ctx, paddle::operators::math::MulGradFunctor<T>(),
        paddle::operators::math::ScaleFunctor<T>(scale),
        paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
        in_intermediate_out, in_out_grad, x_grad, y_grad);
C
chengduo 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
  } else {
    PADDLE_THROW("%s has not been implemented.", funcs_str);
  }
}

template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"),
                             "Cannot get input tensor %s, variable name = %s",
                             "X", ctx.op().Input("X"));
    auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
                             "Cannot get input tensor %s, variable name = %s",
                             "Y", ctx.op().Input("Y"));
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
    PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
    auto output = ctx.Output<framework::Tensor>("Out");

    std::vector<framework::Tensor *> outputs;
    outputs.emplace_back(output);

    if (ctx.Attr<bool>("keep_intermediate_value")) {
      PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
                     "The keep_intermediate_value is enable, so the "
                     "IntermediateOut should not be empty.");
      auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
      outputs.emplace_back(intermediate_out);
    } else {
      outputs.emplace_back(nullptr);
    }
C
chengduo 已提交
325

326
    RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
C
chengduo 已提交
327 328 329 330 331 332 333
  }
};

template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
334 335 336 337 338 339
    auto x = ctx.Input<framework::Tensor>("X");
    auto y = ctx.Input<framework::Tensor>("Y");

    auto in_out = ctx.Input<framework::Tensor>("Out");
    auto in_out_grad =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
C
chengduo 已提交
340 341 342 343 344 345

    framework::Tensor *x_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    framework::Tensor *y_grad =
        ctx.Output<framework::Tensor>(framework::GradVarName("Y"));

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
    PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr.");

    if (ctx.Attr<bool>("recomputation")) {
      PADDLE_ENFORCE(
          x != nullptr,
          "The recomputation is opened, so Input(X) should not be absent.");
    } else {
      PADDLE_ENFORCE(in_out != nullptr,
                     "The recomputation is disabled, so the Input('Out') "
                     "should not be empty.");
    }

    framework::Tensor *in_x;
    auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");

    // If functor_list contains elementwise_add, the backward doesn't use
    // in_x, and in_outs.
    if (x == nullptr) {
      PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" ||
                         functor_list[1] == "elementwise_add_grad",
                     "Only when the compoundfunctor contains "
                     "elementwise_add_grad, the 'X' could be absent.");
      in_x = const_cast<framework::Tensor *>(in_out_grad);
      in_out = const_cast<framework::Tensor *>(in_out_grad);
    } else {
      in_x = const_cast<framework::Tensor *>(x);
    }

    framework::Tensor *in_intermediate_out;
    if (ctx.Attr<bool>("keep_intermediate_value")) {
      in_intermediate_out = const_cast<framework::Tensor *>(
          ctx.Input<framework::Tensor>("IntermediateOut"));
      PADDLE_ENFORCE(in_intermediate_out != nullptr,
                     "The option of 'keep_intermediate_value' is opened, "
                     "so the number of 'Out' should be two.");
    } else {
      in_intermediate_out = nullptr;
    }

    if (ctx.Attr<bool>("recomputation")) {
      RunGradFunctors<DeviceContext, T, true /*Recomputation*/>(
          ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad);
    } else {
      RunGradFunctors<DeviceContext, T, false /*Recomputation*/>(
          ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
          y_grad);
    }
C
chengduo 已提交
394 395 396 397
  }
};
}  // namespace operators
}  // namespace paddle