/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/pten/kernels/funcs/compound_functors.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" #include "paddle/pten/kernels/funcs/functors.h" namespace paddle { namespace operators { /** * Whether the compound function is Unary(Binary(X, Y)). * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final * out. */ bool IsUnaryCompound(const std::vector &functor_list); /** * For the in-place unary functor, the inputs of op_desc only have Out and * Out@Grad. */ bool HasInPlaceUnary(const std::vector &functor_list); /** * Whether the Input(X) could be absent. */ bool InputXCanBeAbsent(const std::vector &functor_list); template static void RunBinaryCompoundFunctor( const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor, const UnaryFunctor &unary_functor, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { // Z = Binary(X, Unary(Y)) // intermediate_out = Unary(Y) // out = Binary(X, Unary(Y)) // In this case, the shape of intermediate_out and out are different. pten::funcs::BinaryCompoundFunctor compound_func(binary_functor, unary_functor); int axis = ctx.Attr("axis"); if (ctx.Attr("save_intermediate_out")) { FusedElemwiseAndActComputeEx< DeviceContext, T, pten::funcs::BinaryCompoundFunctor, true /*KeepIntermediateValue*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } else { FusedElemwiseAndActComputeEx< DeviceContext, T, pten::funcs::BinaryCompoundFunctor, false /*KeepIntermediateValue*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } } template static void RunUnaryCompoundFunctors( const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor, const BinaryFunctor &binary_functor, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { // Z = Unary(Binary(X, Y)) // intermediate_out = Binary(X, Y) // out = Unary(Binary(X, Y)) // In this case, the shape of intermediate_out and out are the same. int axis = ctx.Attr("axis"); pten::funcs::UnaryCompoundFunctor compound_func(unary_functor, binary_functor); if (ctx.Attr("save_intermediate_out")) { FusedElemwiseAndActComputeEx< DeviceContext, T, pten::funcs::UnaryCompoundFunctor, true /*KeepIntermediateValue*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } else { FusedElemwiseAndActComputeEx< DeviceContext, T, pten::funcs::UnaryCompoundFunctor, false /*KeepIntermediateValue*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); } } template static void RunBinaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const BinaryGradFunctor &binary_grad_functor, const UnaryFunctor &unary_functor, const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { // Z = Binary(X, Unary(Y)) int axis = ctx.Attr("axis"); using BinaryCompoundDxFunctor = pten::funcs::BinaryCompoundGradDxFunctor; using BinaryCompoundDyFunctor = pten::funcs::BinaryCompoundGradDyFunctor< T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>; using BinaryCompoundDIntermedaiteOutFunctor = pten::funcs::BinaryCompoundGradDIntermedaiteOutFunctor< T, BinaryGradFunctor, UnaryFunctor>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, d_intermediate_out, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, unary_grad_functor), BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor, unary_functor)); } else { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, d_intermediate_out, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, unary_grad_functor), BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor, unary_functor)); } } template static void RunUnaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const UnaryGradFunctor &unary_grad_functor, const BinaryFunctor &binary_functor, const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { // Z = Unary(Binary(X, Y)) int axis = ctx.Attr("axis"); using UnaryCompoundDxFunctor = pten::funcs::UnaryCompoundGradDxFunctor< T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>; using UnaryCompoundDyFunctor = pten::funcs::UnaryCompoundGradDyFunctor< T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>; using UnaryCompoundDIntermediateFunctor = pten::funcs::UnaryCompoundGradDIntermediateFunctor< T, UnaryGradFunctor, BinaryFunctor, InPlace>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, d_intermediate_out, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor)); } else { FusedElemwiseAndActGradComputeEx< DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, y_grad, d_intermediate_out, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, binary_grad_functor), UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor)); } } template static void RunFunctors(const framework::ExecutionContext &ctx, const framework::Tensor &in_x, const framework::Tensor &in_y, std::vector *outputs) { auto &functors = ctx.Attr>("functor_list"); // TODO(zcd): The following code can be refined. auto funcs_str = functors[0] + "," + functors[1]; if (funcs_str == "elementwise_add,scale") { // Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundFunctor, pten::funcs::ScaleFunctor>( ctx, pten::funcs::AddFunctor(), pten::funcs::ScaleFunctor(scale), in_x, in_y, outputs); } else if (funcs_str == "scale,elementwise_add") { // Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); RunUnaryCompoundFunctors, pten::funcs::AddFunctor>( ctx, pten::funcs::ScaleFunctor(scale), pten::funcs::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_add,relu") { // Z = Binary(X, Unary(Y)) RunBinaryCompoundFunctor, pten::funcs::ReluFunctor>( ctx, pten::funcs::AddFunctor(), pten::funcs::ReluFunctor(), in_x, in_y, outputs); } else if (funcs_str == "relu,elementwise_add") { // Z = Unary(Binary(X, Y)) RunUnaryCompoundFunctors, pten::funcs::AddFunctor>( ctx, pten::funcs::ReluFunctor(), pten::funcs::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_mul,scale") { // Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundFunctor, pten::funcs::ScaleFunctor>( ctx, pten::funcs::MultiplyFunctor(), pten::funcs::ScaleFunctor(scale), in_x, in_y, outputs); } else if (funcs_str == "tanh,elementwise_add") { // Z = Unary(Binary(X, Y)) RunUnaryCompoundFunctors, pten::funcs::AddFunctor>( ctx, pten::funcs::TanhFunctor(), pten::funcs::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_mul,tanh") { // Z = Binary(X, Unary(Y)) RunBinaryCompoundFunctor, pten::funcs::TanhFunctor>( ctx, pten::funcs::MultiplyFunctor(), pten::funcs::TanhFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_mul,sigmoid") { // Z = Binary(X, Unary(Y)) RunBinaryCompoundFunctor, pten::funcs::SigmoidFunctor>( ctx, pten::funcs::MultiplyFunctor(), pten::funcs::SigmoidFunctor(), in_x, in_y, outputs); } else if (funcs_str == "gelu,elementwise_add") { // Z = Unary(Binary(X, Y)) RunUnaryCompoundFunctors, pten::funcs::AddFunctor>( ctx, pten::funcs::GeluFunctor(), pten::funcs::AddFunctor(), in_x, in_y, outputs); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s has not been implemented.", funcs_str)); } } template static void RunGradFunctors( const framework::ExecutionContext &ctx, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { auto &functors = ctx.Attr>("functor_list"); auto funcs_str = functors[0] + "," + functors[1]; if (funcs_str == "elementwise_add_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundGradFunctors, pten::funcs::ScaleFunctor, pten::funcs::ScaleGradFunctor, InPlace>( ctx, pten::funcs::AddGradFunctor(), pten::funcs::ScaleFunctor(scale), pten::funcs::ScaleGradFunctor(scale), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "scale_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); RunUnaryCompoundGradFunctors< DeviceContext, T, pten::funcs::ScaleGradFunctor, pten::funcs::AddFunctor, pten::funcs::AddGradFunctor, InPlace>( ctx, pten::funcs::ScaleGradFunctor(scale), pten::funcs::AddFunctor(), pten::funcs::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_add_grad,relu_grad") { // The backward of Z = Binary(X, Unary(Y)) RunBinaryCompoundGradFunctors< DeviceContext, T, pten::funcs::AddGradFunctor, pten::funcs::ReluFunctor, pten::funcs::ReluGradFunctor, InPlace>( ctx, pten::funcs::AddGradFunctor(), pten::funcs::ReluFunctor(), pten::funcs::ReluGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "relu_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) RunUnaryCompoundGradFunctors< DeviceContext, T, pten::funcs::ReluGradFunctor, pten::funcs::AddFunctor, pten::funcs::AddGradFunctor, InPlace>( ctx, pten::funcs::ReluGradFunctor(), pten::funcs::AddFunctor(), pten::funcs::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_mul_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); RunBinaryCompoundGradFunctors, pten::funcs::ScaleFunctor, pten::funcs::ScaleGradFunctor, InPlace>( ctx, pten::funcs::MulGradFunctor(), pten::funcs::ScaleFunctor(scale), pten::funcs::ScaleGradFunctor(scale), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "tanh_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) RunUnaryCompoundGradFunctors< DeviceContext, T, pten::funcs::TanhGradFunctor, pten::funcs::AddFunctor, pten::funcs::AddGradFunctor, InPlace>( ctx, pten::funcs::TanhGradFunctor(), pten::funcs::AddFunctor(), pten::funcs::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_mul_grad,tanh_grad") { // The backward of Z = Binary(X, Unary(Y)) RunBinaryCompoundGradFunctors< DeviceContext, T, pten::funcs::MulGradFunctor, pten::funcs::TanhFunctor, pten::funcs::TanhGradFunctor, InPlace>( ctx, pten::funcs::MulGradFunctor(), pten::funcs::TanhFunctor(), pten::funcs::TanhGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") { // The backward of Z = Binary(X, Unary(Y)) RunBinaryCompoundGradFunctors, pten::funcs::SigmoidFunctor, pten::funcs::SigmoidGradFunctor, InPlace>( ctx, pten::funcs::MulGradFunctor(), pten::funcs::SigmoidFunctor(), pten::funcs::SigmoidGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "gelu_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) RunUnaryCompoundGradFunctors< DeviceContext, T, pten::funcs::GeluGradFunctor, pten::funcs::AddFunctor, pten::funcs::AddGradFunctor, InPlace>( ctx, pten::funcs::GeluGradFunctor(), pten::funcs::AddFunctor(), pten::funcs::AddGradFunctor(), in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s has not been implemented.", funcs_str)); } } template class FusedElemwiseActivationKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto &in_x = GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", "FusedElemwiseActivation"); auto &in_y = GET_DATA_SAFELY(ctx.Input("Y"), "Input", "Y", "FusedElemwiseActivation"); PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, platform::errors::InvalidArgument( "The output(Out) should not be empty")); auto output = ctx.Output("Out"); std::vector outputs; outputs.emplace_back(output); if (ctx.Attr("save_intermediate_out")) { PADDLE_ENFORCE_EQ(ctx.HasOutput("IntermediateOut"), true, platform::errors::InvalidArgument( "The save_intermediate_out is enable, so the " "IntermediateOut should not be empty.")); auto intermediate_out = ctx.Output("IntermediateOut"); outputs.emplace_back(intermediate_out); } else { outputs.emplace_back(nullptr); } RunFunctors(ctx, in_x, in_y, &outputs); } }; template class FusedElemwiseActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto in_y = ctx.Input("Y"); PADDLE_ENFORCE_NE(in_y, nullptr, platform::errors::InvalidArgument( "Input(Y) should not be nullptr.")); auto in_out = ctx.Input("Out"); PADDLE_ENFORCE_NE( in_out, nullptr, platform::errors::InvalidArgument("Input(Out) should not be nullptr.")); auto in_out_grad = ctx.Input(framework::GradVarName("Out")); PADDLE_ENFORCE_NE(in_out_grad, nullptr, platform::errors::InvalidArgument( "Input(Out@Grad) should not be nullptr.")); framework::Tensor *in_x = const_cast(ctx.Input("X")); framework::Tensor *x_grad = ctx.Output(framework::GradVarName("X")); framework::Tensor *y_grad = ctx.Output(framework::GradVarName("Y")); framework::Tensor *d_intermediate_out = ctx.Output( framework::GradVarName("IntermediateOut")); auto functor_list = ctx.Attr>("functor_list"); // Get intermediate_out framework::Tensor *in_intermediate_out = nullptr; if (ctx.Attr("save_intermediate_out")) { // if save_intermediate_out is true, for Unary(Binary(x, y)) and // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to // recompute. in_intermediate_out = const_cast( ctx.Input("IntermediateOut")); PADDLE_ENFORCE_NE(in_intermediate_out, nullptr, platform::errors::InvalidArgument( "The option of 'save_intermediate_out' is opened," " so the number of 'Out' should be two.")); } else { if (!InputXCanBeAbsent(functor_list)) { PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument( "Input(X) should not be null.")); } } // Get in_x if (ctx.HasInput("X")) { PADDLE_ENFORCE_NE(in_x, nullptr, platform::errors::InvalidArgument( "Input(X) should not be null.")); } else { // If functor_list contains elementwise_add, the backward doesn't use // in_x, in_y and in_out. PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list), true, platform::errors::InvalidArgument( "Only when the compoundfunctor contains " "elementwise_add_grad, the 'X' could be absent.")); in_x = const_cast(in_out_grad); } bool has_in_place = HasInPlaceUnary(functor_list); if (has_in_place) { RunGradFunctors( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else { RunGradFunctors( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } } }; } // namespace operators } // namespace paddle