From c6272b6abc367f397a6a09da5518b1b641652f9b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 14 Jul 2022 21:04:47 +0800 Subject: [PATCH] Some Ops support fp16 (#44295) * sparse support amp * EagerAmpAutoCasts support sparse --- paddle/fluid/eager/eager_amp_auto_cast.h | 35 ++++++++++++------- paddle/phi/kernels/gpu/pad3d_grad_kernel.cu | 9 +++-- paddle/phi/kernels/sparse/empty_kernel.cc | 2 ++ .../kernels/sparse/gpu/unary_grad_kernel.cu | 4 +++ paddle/phi/kernels/sparse/gpu/unary_kernel.cu | 4 +++ 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/eager/eager_amp_auto_cast.h b/paddle/fluid/eager/eager_amp_auto_cast.h index 26af2b98ca0..f98f25635f7 100644 --- a/paddle/fluid/eager/eager_amp_auto_cast.h +++ b/paddle/fluid/eager/eager_amp_auto_cast.h @@ -39,6 +39,27 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, return false; } +inline paddle::experimental::Tensor Cast( + const paddle::experimental::Tensor& input, + const paddle::experimental::DataType& dst_dtype, + const bool trace_backward = true) { + if (input.is_sparse_coo_tensor() || input.is_sparse_csr_tensor()) { + if (trace_backward) { + return sparse::cast_final_state_dygraph_function( + input, paddle::experimental::DataType::UNDEFINED, dst_dtype); + } else { + return paddle::experimental::sparse::cast( + input, paddle::experimental::DataType::UNDEFINED, dst_dtype); + } + } else { + if (trace_backward) { + return cast_final_state_dygraph_function(input, dst_dtype); + } else { + return paddle::experimental::cast(input, dst_dtype); + } + } +} + inline std::vector EagerAmpAutoCasts( const std::string& inputs_name, const std::vector& inputs, @@ -51,13 +72,7 @@ inline std::vector EagerAmpAutoCasts( std::vector inputs_casted; for (auto& input : inputs) { if (NeedCast(input, dst_dtype)) { - if (trace_backward) { - inputs_casted.emplace_back( - std::move(cast_final_state_dygraph_function(input, dst_dtype))); - } else { - inputs_casted.emplace_back( - std::move(paddle::experimental::cast(input, dst_dtype))); - } + inputs_casted.emplace_back(std::move(Cast(input, dst_dtype))); } else { inputs_casted.emplace_back(input); } @@ -92,11 +107,7 @@ inline paddle::experimental::Tensor EagerAmpAutoCast( } } if (NeedCast(input, dst_dtype)) { - if (trace_backward) { - return cast_final_state_dygraph_function(input, dst_dtype); - } else { - return paddle::experimental::cast(input, dst_dtype); - } + return Cast(input, dst_dtype, trace_backward); } return input; } diff --git a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu index 8f4af0a4508..e9f820a3184 100644 --- a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu @@ -503,5 +503,10 @@ void Pad3dGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - pad3d_grad, GPU, ALL_LAYOUT, phi::Pad3dGradKernel, float, double) {} +PD_REGISTER_KERNEL(pad3d_grad, + GPU, + ALL_LAYOUT, + phi::Pad3dGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index fe7fb72b4ca..c1706b9919d 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -97,6 +97,7 @@ PD_REGISTER_KERNEL(empty_like_coo, GPU, ALL_LAYOUT, phi::sparse::EmptyLikeCooKernel, + phi::dtype::float16, float, double, int8_t, @@ -112,6 +113,7 @@ PD_REGISTER_KERNEL(empty_like_csr, GPU, ALL_LAYOUT, phi::sparse::EmptyLikeCsrKernel, + phi::dtype::float16, float, double, int8_t, diff --git a/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu index c1f2b2a1f0d..be0f13fb0e5 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu @@ -23,6 +23,7 @@ GPU, \ ALL_LAYOUT, \ phi::sparse::prefix##CooGradKernel, \ + phi::dtype::float16, \ float, \ double) { \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \ @@ -32,6 +33,7 @@ GPU, \ ALL_LAYOUT, \ phi::sparse::prefix##CsrGradKernel, \ + phi::dtype::float16, \ float, \ double) { \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \ @@ -56,6 +58,7 @@ PD_REGISTER_KERNEL(cast_coo_grad, GPU, ALL_LAYOUT, phi::sparse::CastCooGradKernel, + phi::dtype::float16, float, double, int8_t, @@ -69,6 +72,7 @@ PD_REGISTER_KERNEL(cast_csr_grad, GPU, ALL_LAYOUT, phi::sparse::CastCsrGradKernel, + phi::dtype::float16, float, double, int8_t, diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index fdf0b5106d3..6358b7b9835 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -67,6 +67,7 @@ void DivCsrScalarKernel(const Context& dev_ctx, GPU, \ ALL_LAYOUT, \ phi::sparse::prefix##CooKernel, \ + phi::dtype::float16, \ float, \ double) { \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \ @@ -76,6 +77,7 @@ void DivCsrScalarKernel(const Context& dev_ctx, GPU, \ ALL_LAYOUT, \ phi::sparse::prefix##CsrKernel, \ + phi::dtype::float16, \ float, \ double) { \ kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \ @@ -119,6 +121,7 @@ PD_REGISTER_KERNEL(cast_coo, GPU, ALL_LAYOUT, phi::sparse::CastCooKernel, + phi::dtype::float16, float, double, int8_t, @@ -132,6 +135,7 @@ PD_REGISTER_KERNEL(cast_csr, GPU, ALL_LAYOUT, phi::sparse::CastCsrKernel, + phi::dtype::float16, float, double, int8_t, -- GitLab