diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index b559d66fe74561e9f750dfd3da2a640ca1f74dfc..c17f8326a399448791b1875336f7abd8ab256801 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -183,7 +183,7 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( std::string d_ele_y_n = d_ele_y->Name(); OpDesc desc; - desc.SetType("fused_elemwise_activation_grad"); + desc.SetType("fused_elemwise_add_activation_grad"); desc.SetInput("IntermediateOut", {}); desc.SetInput("X", {}); desc.SetInput("Y", std::vector({ele_y_n})); @@ -231,7 +231,7 @@ Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode( desc.SetInput("Y", std::vector({ele_y_n})); desc.SetOutput("Out", std::vector({act_out_n})); desc.SetOutput("IntermediateOut", std::vector({ele_out_n})); - desc.SetType("fused_elemwise_activation"); + desc.SetType("fused_elemwise_add_activation"); desc.SetAttr("save_intermediate_out", true); desc.SetAttr("functor_list", std::vector( {op_1->Op()->Type(), op_2->Op()->Type()})); @@ -251,7 +251,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { std::unordered_set need_removed_nodes; for (auto &cur_node : graph->Nodes()) { if (cur_node->IsVar()) continue; - if (cur_node->Name() == "fused_elemwise_activation") { + if (cur_node->Name() == "fused_elemwise_add_activation") { bool save_intermediate_out = BOOST_GET_CONST( bool, cur_node->Op()->GetAttr("save_intermediate_out")); auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); @@ -272,7 +272,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - } else if (cur_node->Name() == "fused_elemwise_activation_grad") { + } else if (cur_node->Name() == "fused_elemwise_add_activation_grad") { auto intermediate_out_grad_args = cur_node->Op()->Output(GradVarName("IntermediateOut")); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 206eeea87fb03dc32cb9a2e86e7f34b7a78b7101..bce22ca9a7c20ed0fdeeeae4a45b98a20cca03d4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -2273,8 +2273,9 @@ template struct FusedElemwiseAndActGradNoBroadcast { HOSTDEVICE void operator()(size_t i) { - T x_val = x_[i]; - T y_val = y_[i]; + T zero = static_cast(0); + T x_val = (x_ == nullptr) ? zero : x_[i]; + T y_val = (y_ == nullptr) ? zero : y_[i]; T out_val = out_[i]; T dout_val = dout_[i]; T intermediate_out_val = UseIntermediateOut @@ -2320,16 +2321,19 @@ void FusedElemwiseAndActGradComputeNoBroadcast( size_t N = static_cast(framework::product(x_dim)); platform::ForRange for_range( ctx.template device_context(), N); - for_range( - FusedElemwiseAndActGradNoBroadcast{ - x->data(), y->data(), - intermediate_out ? intermediate_out->data() : nullptr, - out->data(), dout->data(), dx_op, dy_op, dintermediate_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), - dintermediate == nullptr ? nullptr : dintermediate->mutable_data( - ctx.GetPlace())}); + const T *x_data = nullptr; + const T *y_data = nullptr; + if (x->IsInitialized()) x_data = x->data(); + if (y->IsInitialized()) y_data = y->data(); + + for_range(FusedElemwiseAndActGradNoBroadcast< + T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>{ + x_data, y_data, intermediate_out ? intermediate_out->data() : nullptr, + out->data(), dout->data(), dx_op, dy_op, dintermediate_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())}); } template (0); for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { int offset = i * w + j; @@ -2347,6 +2352,8 @@ static void FusedElemwiseAndActGradBroadcast1CPU( tmp_out_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset; x_idx = BcastY ? offset : j; + T x_val = (x == nullptr) ? zero : x[x_idx]; + T y_val = (y == nullptr) ? zero : y[y_idx]; if (SameShapeOfIntermediateOutAndOut) { tmp_out_idx = offset; @@ -2354,11 +2361,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( if (dx != nullptr) { T tmp = UseIntermediateOut - ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], + ? dx_op.UseIntermediateOut(x_val, y_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) - : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], - dout[offset]); + : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -2372,11 +2378,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( } if (dy != nullptr) { T tmp = UseIntermediateOut - ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], + ? dy_op.UseIntermediateOut(x_val, y_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) - : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], - dout[offset]); + : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { if (i == 0) { dy[y_idx] = tmp; @@ -2390,10 +2395,10 @@ static void FusedElemwiseAndActGradBroadcast1CPU( if (d_intermediate != nullptr) { T tmp = UseIntermediateOut ? dintermediate_op.UseIntermediateOut( - x[x_idx], intermediate_out[tmp_out_idx], out[offset], + x_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) - : dintermediate_op.Recompute(x[x_idx], y[y_idx], - out[offset], dout[i]); + : dintermediate_op.Recompute(x_val, y_val, out[offset], + dout[i]); if (SameShapeOfIntermediateOutAndOut) { d_intermediate[tmp_out_idx] = tmp; } else { @@ -2416,6 +2421,7 @@ static void FusedElemwiseAndActGradBroadcast2CPU( const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int64_t tmp_out_idx, x_idx, y_idx; + T zero = static_cast(0); for (int i = 0; i < pre; ++i) { for (int j = 0; j < n; ++j) { for (int k = 0; k < post; ++k) { @@ -2425,17 +2431,20 @@ static void FusedElemwiseAndActGradBroadcast2CPU( y_idx = BcastY ? j : offset; x_idx = BcastY ? offset : j; + T x_val = (x == nullptr) ? zero : x[x_idx]; + T y_val = (y == nullptr) ? zero : y[y_idx]; + if (SameShapeOfIntermediateOutAndOut) { tmp_out_idx = offset; } if (dx != nullptr) { - T tmp = UseIntermediateOut - ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], - dout[offset]); + T tmp = + UseIntermediateOut + ? dx_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -2448,12 +2457,12 @@ static void FusedElemwiseAndActGradBroadcast2CPU( } } if (dy != nullptr) { - T tmp = UseIntermediateOut - ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], - dout[offset]); + T tmp = + UseIntermediateOut + ? dy_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { if (i == 0 && k == 0) { dy[y_idx] = tmp; @@ -2467,10 +2476,10 @@ static void FusedElemwiseAndActGradBroadcast2CPU( if (d_intermediate != nullptr) { T tmp = UseIntermediateOut ? dintermediate_op.UseIntermediateOut( - x[x_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dintermediate_op.Recompute(x[x_idx], y[y_idx], - out[offset], dout[i]); + x_val, intermediate_out[tmp_out_idx], out[offset], + dout[offset]) + : dintermediate_op.Recompute(x_val, y_val, out[offset], + dout[i]); if (SameShapeOfIntermediateOutAndOut) { d_intermediate[tmp_out_idx] = tmp; } else { @@ -2499,6 +2508,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( int tid = threadIdx.x; T val(0), inter_val(0); int64_t tmp_out_idx, x_idx, y_idx; + T zero = static_cast(0); do { int offset = i * w + j; @@ -2506,18 +2516,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( tmp_out_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset; x_idx = BcastY ? offset : j; + T x_val = (x == nullptr) ? zero : x[x_idx]; + T y_val = (y == nullptr) ? zero : y[y_idx]; if (SameShapeOfIntermediateOutAndOut) { tmp_out_idx = offset; } if (dx != nullptr) { - T tmp = - UseIntermediateOut - ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = UseIntermediateOut + ? dx_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -2526,12 +2537,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( } } if (dy != nullptr) { - T tmp = - UseIntermediateOut - ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = UseIntermediateOut + ? dy_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { val += tmp; } else { @@ -2543,7 +2553,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( ? dintermediate_op.UseIntermediateOut( y[y_idx], intermediate_out[tmp_out_idx], out[offset], dout[offset]) - : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], + : dintermediate_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (SameShapeOfIntermediateOutAndOut) { d_intermediate[tmp_out_idx] = tmp; @@ -2610,6 +2620,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( T val(0), inter_val(0); int ttid = tid; int64_t tmp_out_idx, x_idx, y_idx; + T zero = static_cast(0); while (true) { int i = ttid / post; int k = ttid % post; @@ -2620,18 +2631,19 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( tmp_out_idx = BcastY ? j : offset; y_idx = BcastY ? j : offset; x_idx = BcastY ? offset : j; + T x_val = (x == nullptr) ? zero : x[x_idx]; + T y_val = (y == nullptr) ? zero : y[y_idx]; if (SameShapeOfIntermediateOutAndOut) { tmp_out_idx = offset; } if (dx != nullptr) { - T tmp = - UseIntermediateOut - ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = UseIntermediateOut + ? dx_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -2640,12 +2652,11 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( } } if (dy != nullptr) { - T tmp = - UseIntermediateOut - ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], - intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = UseIntermediateOut + ? dy_op.UseIntermediateOut(x_val, y_val, + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (BcastY) { val += tmp; } else { @@ -2655,9 +2666,9 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( if (d_intermediate != nullptr) { T tmp = UseIntermediateOut ? dintermediate_op.UseIntermediateOut( - y[y_idx], intermediate_out[tmp_out_idx], out[offset], + y_val, intermediate_out[tmp_out_idx], out[offset], dout[offset]) - : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], + : dintermediate_op.Recompute(x_val, y_val, out[offset], dout[offset]); if (SameShapeOfIntermediateOutAndOut) { d_intermediate[tmp_out_idx] = tmp; @@ -2730,16 +2741,20 @@ void FusedElemwiseAndActGradComputeWithBroadcast( int pre, n, post, is_run_common_broadcast; get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast); + const T *x_data = nullptr; + const T *y_data = nullptr; + if (x->IsInitialized()) x_data = x->data(); + if (y->IsInitialized()) y_data = y->data(); if (post == 1) { int h = pre; int w = n; + if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef __NVCC__ FusedElemwiseAndActGradBroadcast1CUDA( - ctx.template device_context().stream(), x->data(), - y->data(), + ctx.template device_context().stream(), x_data, y_data, intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), h, w, dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), @@ -2751,7 +2766,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( FusedElemwiseAndActGradBroadcast1CPU( - x->data(), y->data(), + x_data, y_data, intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), h, w, dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), @@ -2765,8 +2780,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( FusedElemwiseAndActGradBroadcast2CUDA( - ctx.template device_context().stream(), x->data(), - y->data(), + ctx.template device_context().stream(), x_data, y_data, intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), pre, n, post, dx_op, dy_op, dintermediate_op, @@ -2779,7 +2793,7 @@ void FusedElemwiseAndActGradComputeWithBroadcast( FusedElemwiseAndActGradBroadcast2CPU( - x->data(), y->data(), + x_data, y_data, intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), pre, n, post, dx_op, dy_op, dintermediate_op, diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 2de24cddd7d678c33615b70c935f28e9a307266e..4ff66d0d2b856d505fade0510c22b565e0d94678 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker } }; +class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {}; + +template +class FusedElemwiseAddActivationGradMaker + : public FusedElemwiseActivationGradMaker { + public: + using FusedElemwiseActivationGradMaker::FusedElemwiseActivationGradMaker; +}; + class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -361,10 +370,61 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp { + public: + using FusedElemwiseActivationOp::FusedElemwiseActivationOp; + void InferShape(framework::InferShapeContext *ctx) const override { + FusedElemwiseActivationOp::InferShape(ctx); + std::vector functor_names = + ctx->Attrs().Get>("functor_list"); + bool elemntwise_add_detected = false; + for (auto names : functor_names) { + if (names == "elementwise_add") { + elemntwise_add_detected = true; + break; + } + } + PADDLE_ENFORCE_EQ( + elemntwise_add_detected, true, + platform::errors::InvalidArgument( + "When the FusedElemwiseAddActivationOp Is used in fused pass, the " + "elementwise_add Op must be" + "detected and used, Please check the fuse pass pattern")); + } +}; + +class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad { + public: + using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad; + + void InferShape(framework::InferShapeContext *ctx) const override { + FusedElemwiseActivationOpGrad::InferShape(ctx); + std::vector functor_names = + ctx->Attrs().Get>("functor_list"); + bool elemntwise_add_grad_detected = false; + for (auto names : functor_names) { + if (names == "elementwise_add_grad") { + elemntwise_add_grad_detected = true; + break; + } + } + PADDLE_ENFORCE_EQ( + elemntwise_add_grad_detected, true, + platform::errors::InvalidArgument( + "When the FusedElemwiseAddActivationOpGrad Is used in fused pass, " + "the elementwise_add_grad Op must be" + "detected and used, Please check the fuse pass pattern")); } }; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER( + FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y"); } // namespace operators } // namespace paddle @@ -390,3 +450,27 @@ REGISTER_OP_CPU_KERNEL( float>, ops::FusedElemwiseActivationGradKernel); + +// for memory optimization, we register the fused_elemwise_add_activation OP +REGISTER_OPERATOR( + fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp, + ops::FusedElemwiseAddActivationMaker, + ops::FusedElemwiseAddActivationGradMaker, + ops::FusedElemwiseAddActivationGradMaker); +REGISTER_OPERATOR(fused_elemwise_add_activation_grad, + ops::FusedElemwiseAddActivationNoNeddBufVarInferer, + ops::FusedElemwiseAddActivationOpGrad); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_add_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_add_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu index dba4097c7f31d5b5df992feb01623008ec9aedec..7b44aa82e4a22ba195fbe8b86ef78ad7f37397f8 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cu @@ -32,3 +32,21 @@ REGISTER_OP_CUDA_KERNEL( double>, ops::FusedElemwiseActivationGradKernel); + +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_add_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_add_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py index a1c20be9a92f83d67e934eeaf84b95c2fac0b579..6c3fa9e61d2406bc8e84d2c691759a5241b6d67b 100644 --- a/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py @@ -77,4 +77,6 @@ class TestMNIST(TestParallelExecutorBase): if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py index e00dc8c7688d6f8ef88d1bbe2a1c82d498f6c55b..80bb14adf7b9fe4acd87e5ad1eaaafb04e5e6757 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -390,4 +390,6 @@ for mode in {0, 1}: grad_chek=False) if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main()