/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/functors.h" namespace paddle { namespace operators { template static void RunBinaryCompoundFunctor( const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor, const UnaryFunctor &unary_functor, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { // Z = Binary(X, Unary(Y)) // intermediate_out = Unary(Y) // out = Binary(X, Unary(Y)) // In this case, the shape of intermediate_out and out are different. paddle::operators::math::BinaryCompoundFunctor compound_func(binary_functor, unary_functor); int axis = ctx.Attr("axis"); if (ctx.Attr("keep_intermediate_value")) { FusedElemwiseAndActComputeEx, true /*KeepIntermediateValue*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } else { FusedElemwiseAndActComputeEx, false /*KeepIntermediateValue*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } } template static void RunUnaryCompoundFunctors( const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor, const BinaryFunctor &binary_functor, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { // Z = Unary(Binary(X, Y)) // intermediate_out = Binary(X, Y) // out = Unary(Binary(X, Y)) // In this case, the shape of intermediate_out and out are the same. int axis = ctx.Attr("axis"); paddle::operators::math::UnaryCompoundFunctor compound_func(unary_functor, binary_functor); if (ctx.Attr("keep_intermediate_value")) { FusedElemwiseAndActComputeEx, true /*KeepIntermediateValue*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } else { FusedElemwiseAndActComputeEx, false /*KeepIntermediateValue*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } } template static void RunBinaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const BinaryGradFunctor &binary_grad_functor, const UnaryFunctor &unary_functor, const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad) { // Z = Binary(X, Unary(Y)) int axis = ctx.Attr("axis"); using BinaryCompoundDxFunctor = paddle::operators::math::BinaryCompoundGradDxFunctor; using BinaryCompoundDyFunctor = paddle::operators::math::BinaryCompoundGradDyFunctor< T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, true /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, unary_grad_functor)); } else { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, false /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, unary_grad_functor)); } } template static void RunUnaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const UnaryGradFunctor &unary_grad_functor, const BinaryFunctor &binary_functor, const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad) { // Z = Unary(Binary(X, Y)) int axis = ctx.Attr("axis"); using UnaryCompoundDxFunctor = paddle::operators::math::UnaryCompoundGradDxFunctor< T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; using UnaryCompoundDyFunctor = paddle::operators::math::UnaryCompoundGradDyFunctor< T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, binary_grad_functor)); } else { FusedElemwiseAndActGradComputeEx( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, binary_grad_functor)); } } template static void RunFunctors(const framework::ExecutionContext &ctx, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { auto &functors = ctx.Attr>("functor_list"); // TODO(zcd): The following code can be refined. auto funcs_str = functors[0] + "," + functors[1]; if (funcs_str == "elementwise_add,scale") { // Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundFunctor, paddle::operators::math::ScaleFunctor>( ctx, paddle::operators::math::AddFunctor(), paddle::operators::math::ScaleFunctor(scale), in_x, in_y, outputs); } else if (funcs_str == "scale,elementwise_add") { // Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); RunUnaryCompoundFunctors, paddle::operators::math::AddFunctor>( ctx, paddle::operators::math::ScaleFunctor(scale), paddle::operators::math::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_add,relu") { // Z = Binary(X, Unary(Y)) RunBinaryCompoundFunctor, paddle::operators::math::ReluFunctor>( ctx, paddle::operators::math::AddFunctor(), paddle::operators::math::ReluFunctor(), in_x, in_y, outputs); } else if (funcs_str == "relu,elementwise_add") { // Z = Unary(Binary(X, Y)) RunUnaryCompoundFunctors, paddle::operators::math::AddFunctor>( ctx, paddle::operators::math::ReluFunctor(), paddle::operators::math::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_mul,scale") { // Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundFunctor, paddle::operators::math::ScaleFunctor>( ctx, paddle::operators::math::MulFunctor(), paddle::operators::math::ScaleFunctor(scale), in_x, in_y, outputs); } else { PADDLE_THROW("%s has not been implemented.", funcs_str); } } template static void RunGradFunctors(const framework::ExecutionContext &ctx, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad) { auto &functors = ctx.Attr>("functor_list"); auto funcs_str = functors[0] + "," + functors[1]; // TODO(zcd): The following code can be refined. for example, use registrition if (funcs_str == "elementwise_add_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundGradFunctors, paddle::operators::math::ScaleFunctor, paddle::operators::math::ScaleGradFunctor>( ctx, paddle::operators::math::AddGradFunctor(), paddle::operators::math::ScaleFunctor(scale), paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "scale_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); RunUnaryCompoundGradFunctors, paddle::operators::math::AddFunctor, paddle::operators::math::AddGradFunctor, ReComputation /*Recomputation*/>( ctx, paddle::operators::math::ScaleGradFunctor(scale), paddle::operators::math::AddFunctor(), paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "elementwise_add_grad,relu_grad") { RunBinaryCompoundGradFunctors, paddle::operators::math::ReluFunctor, paddle::operators::math::ReluGradFunctor>( ctx, paddle::operators::math::AddGradFunctor(), paddle::operators::math::ReluFunctor(), paddle::operators::math::ReluGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "relu_grad,elementwise_add_grad") { RunUnaryCompoundGradFunctors, paddle::operators::math::AddFunctor, paddle::operators::math::AddGradFunctor, ReComputation /*Recomputation*/>( ctx, paddle::operators::math::ReluGradFunctor(), paddle::operators::math::AddFunctor(), paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "elementwise_mul_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundGradFunctors, paddle::operators::math::ScaleFunctor, paddle::operators::math::ScaleGradFunctor>( ctx, paddle::operators::math::MulGradFunctor(), paddle::operators::math::ScaleFunctor(scale), paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else { PADDLE_THROW("%s has not been implemented.", funcs_str); } } template class FusedElemwiseActivationKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto &in_x = detail::Ref(ctx.Input("X"), "Cannot get input tensor %s, variable name = %s", "X", ctx.op().Input("X")); auto &in_y = detail::Ref(ctx.Input("Y"), "Cannot get input tensor %s, variable name = %s", "Y", ctx.op().Input("Y")); PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); auto output = ctx.Output("Out"); std::vector outputs; outputs.emplace_back(output); if (ctx.Attr("keep_intermediate_value")) { PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"), "The keep_intermediate_value is enable, so the " "IntermediateOut should not be empty."); auto intermediate_out = ctx.Output("IntermediateOut"); outputs.emplace_back(intermediate_out); } else { outputs.emplace_back(nullptr); } RunFunctors(ctx, in_x, in_y, &outputs); } }; template class FusedElemwiseActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto x = ctx.Input("X"); auto y = ctx.Input("Y"); auto in_out = ctx.Input("Out"); auto in_out_grad = ctx.Input(framework::GradVarName("Out")); framework::Tensor *x_grad = ctx.Output(framework::GradVarName("X")); framework::Tensor *y_grad = ctx.Output(framework::GradVarName("Y")); PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr."); if (ctx.Attr("recomputation")) { PADDLE_ENFORCE( x != nullptr, "The recomputation is opened, so Input(X) should not be absent."); } else { PADDLE_ENFORCE(in_out != nullptr, "The recomputation is disabled, so the Input('Out') " "should not be empty."); } framework::Tensor *in_x; auto functor_list = ctx.Attr>("functor_list"); // If functor_list contains elementwise_add, the backward doesn't use // in_x, and in_outs. if (x == nullptr) { PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" || functor_list[1] == "elementwise_add_grad", "Only when the compoundfunctor contains " "elementwise_add_grad, the 'X' could be absent."); in_x = const_cast(in_out_grad); in_out = const_cast(in_out_grad); } else { in_x = const_cast(x); } framework::Tensor *in_intermediate_out; if (ctx.Attr("keep_intermediate_value")) { in_intermediate_out = const_cast( ctx.Input("IntermediateOut")); PADDLE_ENFORCE(in_intermediate_out != nullptr, "The option of 'keep_intermediate_value' is opened, " "so the number of 'Out' should be two."); } else { in_intermediate_out = nullptr; } if (ctx.Attr("recomputation")) { RunGradFunctors( ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } else { RunGradFunctors( ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad); } } }; } // namespace operators } // namespace paddle