diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index c22f88e0dbf291778a0d7a2bd4df9ed5f42a3dde..ac61f6d874c5959d972d5ec651130a6c03e12ecd 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -65,3 +65,15 @@ if(WITH_CUSTOM_DEVICE) comm_static_check dense_tensor) endif() + +set(COMM_UTILS_DEPS process_group) +if(WITH_NCCL OR WITH_RCCL) + set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_nccl) +endif() +if(WITH_CUSTOM_DEVICE) + set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_custom) +endif() +cc_library( + processgroup_comm_utils + SRCS processgroup_comm_utils.cc + DEPS ${COMM_UTILS_DEPS}) diff --git a/paddle/phi/backends/processgroup_comm_utils.cc b/paddle/fluid/distributed/collective/processgroup_comm_utils.cc similarity index 100% rename from paddle/phi/backends/processgroup_comm_utils.cc rename to paddle/fluid/distributed/collective/processgroup_comm_utils.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 2292baa996d4c31b05d0b8060cbe230b719c5932..073eec71eb37e9ca70023b43cb0a0a67918a249a 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -20,7 +20,7 @@ cc_library( graph_helper SRCS graph_helper.cc DEPS graph program_utils scale_loss_grad_op_handle - grad_merge_all_reduce_op_handle) + grad_merge_all_reduce_op_handle collective_helper) cc_library( pass SRCS pass.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d39aeedd45908bb26193a85d9c762741d07b34c8..af43f606ffd4651414f705d29707d6b3d24758a4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -162,6 +162,10 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} processgroup_comm_utils) +if(WITH_NCCL OR WITH_RCCL) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} process_group_nccl) +endif() if (WITH_GPU OR WITH_ROCM) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() diff --git a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu b/paddle/fluid/operators/class_center_sample_op.cu similarity index 100% rename from paddle/phi/kernels/gpu/class_center_sample_kernel.cu rename to paddle/fluid/operators/class_center_sample_op.cu diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu index bec88e5dfd2a714e555a19f6baf1a8de47cbe9a9..a7d5a514c585531c0900d9f98e7aefbb45745001 100644 --- a/paddle/fluid/operators/inplace_abn_op.cu +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/fluid/operators/inplace_abn_op.h" #include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/operators/sync_batch_norm_utils.h" #include "paddle/phi/kernels/batch_norm_grad_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h" -#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" #include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h" #include "paddle/phi/kernels/sync_batch_norm_kernel.h" diff --git a/paddle/fluid/operators/margin_cross_entropy_op.cc b/paddle/fluid/operators/margin_cross_entropy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..856ea09441b4b3bb9860c2bb92aa1e4c259da7ee --- /dev/null +++ b/paddle/fluid/operators/margin_cross_entropy_op.cc @@ -0,0 +1,16 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This File is used for compile margin_cross_entropy_op.cu. +// And this file will be deleted after margin_cross_entropy_op is moved to phi diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu b/paddle/fluid/operators/margin_cross_entropy_op.cu similarity index 76% rename from paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu rename to paddle/fluid/operators/margin_cross_entropy_op.cu index 5cbb21c45b76d8f42372bbf125de6d725fce9a95..c5d007952b80c7f84eed40664124518d1925beaf 100644 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_kernel.cu +++ b/paddle/fluid/operators/margin_cross_entropy_op.cu @@ -22,19 +22,25 @@ namespace cub = hipcub; #include #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" +#include "paddle/phi/kernels/margin_cross_entropy_grad_kernel.h" + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif -// trace op include #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -467,6 +473,116 @@ void MarginCrossEntropyKernel(const Context& dev_ctx, #endif } +template +__global__ void CalculateGrad(T* logits_grad, + const T* loss_grad, + const T* logits, + const IndexT* label, + const float margin1, + const float margin2, + const float scale, + const int rank, + const int64_t N, + const int64_t D, + const int* class_interval_ptr) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + int start_index = class_interval_ptr[rank]; + CUDA_KERNEL_LOOP(i, N * D) { + auto row = i / D; + auto col = i % D; + if ((col + start_index) == label[row]) { + logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; + if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { + MPType dout = static_cast(logits_grad[i]); + MPType one = static_cast(1.0f); + MPType x = static_cast(logits[i]); + MPType m1 = static_cast(margin1); + MPType m2 = static_cast(margin2); + + MPType d = m1 * sin(m1 * acos(x) + m2) / sqrt(one - x * x); + logits_grad[i] = static_cast(dout * d); + } + } else { + logits_grad[i] *= loss_grad[row]; + } + if (fabs(scale - 1.0) > 1e-8) { + logits_grad[i] *= static_cast(scale); + } + } +} + +template +void MarginCrossEntropyGradKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + bool return_softmax, + int ring_id, + int rank, + int nranks, + float margin1, + float margin2, + float margin3, + float scale, + DenseTensor* logits_grad) { + const auto softmax_dims = softmax.dims(); + const int axis = softmax_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, softmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, softmax_dims); + + if (return_softmax) { + phi::Copy( + dev_ctx, softmax, dev_ctx.GetPlace(), false, logits_grad); + } else { + logits_grad->ShareDataWith(softmax); + } + + int blocks = NumBlocks(N * D); + int threads = kNumCUDAThreads; + const auto& label_type = label.dtype(); + + DenseTensor class_interval; + GetClassInterval(dev_ctx.stream(), + dev_ctx.GetPlace(), + dev_ctx, + ring_id, + rank, + nranks, + D, + &class_interval); + + if (label_type == phi::DataType::INT32) { + typedef int32_t LabelT; + CalculateGrad + <<>>(logits_grad->data(), + loss_grad.data(), + logits.data(), + label.data(), + margin1, + margin2, + scale, + rank, + N, + D, + class_interval.data()); + } else if (label_type == phi::DataType::INT64) { + typedef int64_t LabelT; + CalculateGrad + <<>>(logits_grad->data(), + loss_grad.data(), + logits.data(), + label.data(), + margin1, + margin2, + scale, + rank, + N, + D, + class_interval.data()); + } +} + } // namespace phi PD_REGISTER_KERNEL(margin_cross_entropy, @@ -476,3 +592,11 @@ PD_REGISTER_KERNEL(margin_cross_entropy, float, double, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(margin_cross_entropy_grad, + GPU, + ALL_LAYOUT, + phi::MarginCrossEntropyGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b5cd05db3321af607f9b0887aa355b620feb876 --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -0,0 +1,393 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sync_batch_norm_utils.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/sync_batch_norm_kernel.h" + +// sparse header +#include "paddle/phi/kernels/sparse/empty_kernel.h" + +namespace phi { + +template +void SyncBatchNormKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& mean, + const DenseTensor& variance, + const DenseTensor& scale, + const DenseTensor& bias, + bool is_test, + float momentum, + float epsilon_f, + const std::string& data_layout_str, + bool use_global_stats, + bool trainable_statistics, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space) { + PADDLE_ENFORCE_EQ(use_global_stats, + false, + phi::errors::InvalidArgument( + "sync_batch_norm doesn't support " + "to set use_global_stats True. Please use batch_norm " + "in this case.")); + + double epsilon = epsilon_f; + const bool trainable_stats = trainable_statistics; + const DataLayout layout = phi::StringToDataLayout(data_layout_str); + bool test_mode = is_test && (!trainable_statistics); + const auto& x_dims = x.dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + phi::errors::InvalidArgument( + "The Input dim size should be less than 6.")); + int N, C, H, W, D; + funcs::ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + int x_numel = x.numel(); + + const T* x_d = x.template data(); + const auto* s_d = scale.template data>(); + const auto* b_d = bias.template data>(); + + T* y_d = ctx.template Alloc(y); + + const BatchNormParamType* mean_data = nullptr; + const BatchNormParamType* var_data = nullptr; + + auto stream = ctx.stream(); + const int block = 512; + int max_threads = ctx.GetMaxPhysicalThreadCount(); + + phi::Allocator::AllocationPtr alloc_ptr{nullptr}; + + if (test_mode) { + mean_data = mean.template data>(); + var_data = variance.template data>(); + } else { + // x, x^2, 1, here 1 is used to calc device num + // device num also can be got from phi::DeviceContextPool + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); + alloc_ptr = phi::memory_utils::Alloc( + ctx.GetPlace(), + bytes, + phi::Stream(reinterpret_cast(ctx.stream()))); + + auto* stats = reinterpret_cast*>(alloc_ptr->ptr()); + const int threads = 256; + int grid = std::min(C, (max_threads + threads - 1) / threads); + if (layout == phi::DataLayout::kNCHW) { + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); + } else { + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + ncclComm_t comm = static_cast(detail::GetCCLComm(x.place(), 0)); + if (comm == nullptr) { + comm = ctx.nccl_comm(); + } + + if (comm) { + int dtype = phi::ToNCCLDataType(mean_out->dtype()); + // In-place operation + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(stats, + stats, + 2 * C + 1, + static_cast(dtype), + ncclSum, + comm, + stream)); + VLOG(3) << "Sync result using all reduce"; + } +#endif + + auto* est_mean_data = ctx.template Alloc>(mean_out); + auto* est_var_data = + ctx.template Alloc>(variance_out); + + auto* sv_mean_data = ctx.template Alloc>(saved_mean); + auto* sv_inv_var_data = + ctx.template Alloc>(saved_variance); + + // Note, Input('Mean')/Input('Variance') share variable with + // Output('MeanOut')/Output('VarianceOut') + KeSyncAndMovingStats + <<<(C + block - 1) / block, block, 0, stream>>>(stats, + stats + C, + stats + 2 * C, + C, + momentum, + epsilon, + sv_mean_data, + sv_inv_var_data, + est_mean_data, + est_var_data); + + mean_data = sv_mean_data; + var_data = stats + C; + } + + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + if (layout == phi::DataLayout::kNCHW) { + KeNormAffine + <<>>(x_d, + s_d, + b_d, + mean_data, + var_data, + epsilon, + C, + H * W * D, + x_numel, + y_d); + } else { + KeNormAffine + <<>>(x_d, + s_d, + b_d, + mean_data, + var_data, + epsilon, + C, + H * W * D, + x_numel, + y_d); + } +} + +template +void SyncBatchNormGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const DenseTensor& y_grad, + float momentum, + float epsilon_f, + const std::string& data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + SyncBatchNormGradFunctor(ctx, + &x, + nullptr, + scale, + bias, + saved_mean, + saved_variance, + y_grad, + epsilon_f, + data_layout_str, + x_grad, + scale_grad, + bias_grad); +} + +} // namespace phi + +namespace phi { +namespace sparse { + +template +void SyncBatchNormCooKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mean, + const DenseTensor& variance, + const DenseTensor& scale, + const DenseTensor& bias, + bool is_test, + float momentum, + float epsilon, + const std::string& data_layout, + bool use_global_stats, + bool trainable_statistics, + SparseCooTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space) { + EmptyLikeCooKernel(dev_ctx, x, y); + phi::SyncBatchNormKernel(dev_ctx, + x.values(), + mean, + variance, + scale, + bias, + is_test, + momentum, + epsilon, + data_layout, + use_global_stats, + trainable_statistics, + y->mutable_values(), + mean_out, + variance_out, + saved_mean, + saved_variance, + reserve_space); + y->SetIndicesDict(x.GetIndicesDict()); +} + +template +void SyncBatchNormCooGradKernel( + const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const SparseCooTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + SparseCooTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + EmptyLikeCooKernel(dev_ctx, x, x_grad); + *scale_grad = phi::EmptyLike(dev_ctx, scale); + *bias_grad = phi::EmptyLike(dev_ctx, bias); + phi::SyncBatchNormGradKernel(dev_ctx, + x.values(), + scale, + bias, + saved_mean, + saved_variance, + reserve_space, + y_grad.values(), + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + x_grad->mutable_values(), + scale_grad, + bias_grad); +} + +} // namespace sparse +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormKernel, + float, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} +#else +PD_REGISTER_KERNEL(sync_batch_norm, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} +#endif + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm_grad, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormGradKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(sync_batch_norm_grad, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormGradKernel, + float, + double, + phi::dtype::float16) {} +#endif + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm_coo, + GPU, + ALL_LAYOUT, + phi::sparse::SyncBatchNormCooKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(sync_batch_norm_coo, + GPU, + ALL_LAYOUT, + phi::sparse::SyncBatchNormCooKernel, + float, + double, + phi::dtype::float16) {} +#endif + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SyncBatchNormCooGradKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(sync_batch_norm_coo_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SyncBatchNormCooGradKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_utils.h b/paddle/fluid/operators/sync_batch_norm_utils.h similarity index 100% rename from paddle/phi/kernels/gpu/sync_batch_norm_utils.h rename to paddle/fluid/operators/sync_batch_norm_utils.h diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index c90cde6ab58337306dff6d65163b302ca14f4f6b..7f88a8f7de65cbfd7a69d0db6244b218f2fcee29 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -227,10 +227,18 @@ if(WITH_GPU) SRCS device_event_test.cc DEPS device_event_gpu) endif() - nv_library( - cuda_graph_with_memory_pool - SRCS cuda_graph_with_memory_pool.cc - DEPS ${DEVICE_EVENT_LIBS} device_context allocator phi_backends) + if(WITH_CUSTOM_DEVICE) + nv_library( + cuda_graph_with_memory_pool + SRCS cuda_graph_with_memory_pool.cc + DEPS ${DEVICE_EVENT_LIBS} device_event_custom_device device_context + allocator phi_backends) + else() + nv_library( + cuda_graph_with_memory_pool + SRCS cuda_graph_with_memory_pool.cc + DEPS ${DEVICE_EVENT_LIBS} device_context allocator phi_backends) + endif() nv_test( device_context_test SRCS device_context_test.cu diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index f82145814a6ac64ed93f25adae4175768a9e133a..54888cf8932af908419ab148c36d7d4c1e6019fd 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -74,15 +74,3 @@ if(WITH_CUSTOM_DEVICE) SRCS custom/capi_test.cc DEPS phi_capi) endif() - -set(COMM_UTILS_DEPS process_group) -if(WITH_NCCL OR WITH_RCCL) - set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_nccl) -endif() -if(WITH_CUSTOM_DEVICE) - set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} process_group_custom) -endif() -cc_library( - processgroup_comm_utils - SRCS processgroup_comm_utils.cc - DEPS ${COMM_UTILS_DEPS}) diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 3c4d9d850084e9ed203102eb22462c30d317092d..73a9b1f9c4bab5240c187d7d725e9805f27f205a 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -16,7 +16,7 @@ if(WITH_GLOO) cc_library( gloo_utils SRCS gloo_utils.cc - DEPS gloo dense_tensor enforce) + DEPS gloo dense_tensor enforce tcp_store) cc_library( gloo_comm_context diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 409c746938f0097058be49e876d5180d58136a88..f233c2d13a98cbaaca3781ccb81da92125ccd89f 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -88,13 +88,11 @@ if(WITH_FLASHATTN) endif() if(WITH_NCCL OR WITH_RCCL) - set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl - nccl_comm_context) + set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} nccl_comm_context) endif() if(WITH_GLOO) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} gloo_comm_context) endif() -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils) if(WITH_CUDNN_FRONTEND) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} cudnn-frontend) endif() @@ -180,32 +178,24 @@ endif() file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc" "fusion/xpu/*.cc") -add_library(phi_cpu ${kernel_cc}) -kernel_declare("${kernel_cc}") if(WITH_MKLDNN) - target_link_libraries(phi_cpu ${COMMON_KERNEL_DEPS} - get_kerneltype_forvar_utils) -else() - target_link_libraries(phi_cpu ${COMMON_KERNEL_DEPS}) + set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} get_kerneltype_forvar_utils) endif() -set(ADD_PHI_KERNELS phi_cpu) - if(WITH_GPU OR WITH_ROCM) if(WITH_GPU) - add_library(phi_gpu ${kernel_cu}) + add_library(phi_gpu ${kernel_cu} ${kernel_cc}) if(WITH_CUTLASS) add_dependencies(phi_gpu cutlass_codegen) endif() elseif(WITH_ROCM) - hip_add_library(phi_gpu STATIC ${kernel_cu}) + hip_add_library(phi_gpu STATIC ${kernel_cu} ${kernel_cc}) endif() kernel_declare("${kernel_cu}") + kernel_declare("${kernel_cc}") target_link_libraries(phi_gpu ${COMMON_KERNEL_DEPS}) set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_gpu) -endif() - -if(WITH_XPU) +elseif(WITH_XPU) if(WITH_XPU_KP) file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/kps/ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/kps/) @@ -215,15 +205,27 @@ if(WITH_XPU) file(RENAME ${kernel} "${CMAKE_CURRENT_BINARY_DIR}/kps/${name}.kps") endforeach() file(GLOB kernel_xpu_kps "${CMAKE_CURRENT_BINARY_DIR}/kps/*.kps") - xpu_add_library(phi_xpu STATIC ${kernel_xpu} ${kernel_xpu_kps} DEPENDS - ${COMMON_KERNEL_DEPS}) + xpu_add_library( + phi_xpu + STATIC + ${kernel_xpu} + ${kernel_xpu_kps} + ${kernel_cc} + DEPENDS + ${COMMON_KERNEL_DEPS}) else() - add_library(phi_xpu ${kernel_xpu}) + add_library(phi_xpu ${kernel_xpu} ${kernel_cc}) endif() kernel_declare("${kernel_xpu}") kernel_declare("${kernel_xpu_kps}") + kernel_declare("${kernel_cc}") target_link_libraries(phi_xpu ${COMMON_KERNEL_DEPS}) set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_xpu) +else() + add_library(phi_cpu ${kernel_cc}) + target_link_libraries(phi_cpu ${COMMON_KERNEL_DEPS}) + kernel_declare("${kernel_cc}") + set(ADD_PHI_KERNELS phi_cpu) endif() set_property(GLOBAL PROPERTY PHI_KERNELS ${ADD_PHI_KERNELS}) diff --git a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu deleted file mode 100644 index 51f5e28b032f8e317bdda4e1575b934e9db7eaf3..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/gpu/margin_cross_entropy_grad_kernel.cu +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// old op include, fluid should be removed -#ifdef PADDLE_WITH_HIP -#include -namespace cub = hipcub; -#else -#include -#endif - -#include -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/axis_utils.h" -#include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/impl/softmax_kernel_impl.h" -#include "paddle/phi/kernels/margin_cross_entropy_grad_kernel.h" - -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/tensor_utils.h" -#include "paddle/phi/core/visit_type.h" -#include "paddle/phi/kernels/funcs/eigen/common.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/process_group.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif -#include "paddle/phi/backends/gpu/gpu_context.h" - -namespace phi { - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -void GetClassInterval(const gpuStream_t& stream, - const phi::Place& place, - const Context& dev_ctx, - const int rid, - const int rank, - const int nranks, - const int D, - DenseTensor* class_interval) { - std::vector shard_dim_vec(nranks + 1, 0); - shard_dim_vec[rank + 1] = D; - if (nranks <= 1) { - phi::TensorFromVector(shard_dim_vec, dev_ctx, class_interval); - return; - } -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - DenseTensor num_classes_per_device; - phi::TensorFromVector(shard_dim_vec, dev_ctx, &num_classes_per_device); - int* num_classes_per_device_ptr = num_classes_per_device.data(); - - auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); - if (map->has(rid)) { - // Use ProcessGroup - paddle::distributed::ProcessGroup* pg = map->get(rid); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(num_classes_per_device); - out_tensor.push_back(num_classes_per_device); - - paddle::distributed::AllreduceOptions opts; - opts.reduce_op = paddle::distributed::ReduceOp::SUM; - auto task = pg->AllReduce(in_tensor, out_tensor, opts); - task->Wait(); - } else { - const auto& comm = - paddle::platform::NCCLCommContext::Instance().Get(rid, place); - // use global calculate stream - const auto calcu_stream = - static_cast(phi::DeviceContextPool::Instance().Get(place)) - ->stream(); - - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - num_classes_per_device_ptr, - num_classes_per_device_ptr, - num_classes_per_device.numel(), - phi::ToNCCLDataType(num_classes_per_device.dtype()), - ncclSum, - comm->comm(), - calcu_stream)); - } - - class_interval->Resize({nranks + 1}); - auto class_interval_ptr = dev_ctx.template Alloc(class_interval); - - size_t cub_temp_storage_bytes = 0; - cub::DeviceScan::InclusiveSum( - nullptr, cub_temp_storage_bytes, nullptr, nullptr, nranks + 1, stream); - auto cub_temp_storage = - phi::memory_utils::Alloc(place, cub_temp_storage_bytes); - cub::DeviceScan::InclusiveSum(cub_temp_storage->ptr(), - cub_temp_storage_bytes, - num_classes_per_device_ptr, - class_interval_ptr, - nranks + 1, - stream); - return; -#endif -} - -template -__global__ void CalculateGrad(T* logits_grad, - const T* loss_grad, - const T* logits, - const IndexT* label, - const float margin1, - const float margin2, - const float scale, - const int rank, - const int64_t N, - const int64_t D, - const int* class_interval_ptr) { - using MPType = typename phi::dtype::MPTypeTrait::Type; - int start_index = class_interval_ptr[rank]; - CUDA_KERNEL_LOOP(i, N * D) { - auto row = i / D; - auto col = i % D; - if ((col + start_index) == label[row]) { - logits_grad[i] = (logits_grad[i] - static_cast(1.0)) * loss_grad[row]; - if (fabs(margin1 - 1.0) > 1e-8 || fabs(margin2) > 1e-8) { - MPType dout = static_cast(logits_grad[i]); - MPType one = static_cast(1.0f); - MPType x = static_cast(logits[i]); - MPType m1 = static_cast(margin1); - MPType m2 = static_cast(margin2); - - MPType d = m1 * sin(m1 * acos(x) + m2) / sqrt(one - x * x); - logits_grad[i] = static_cast(dout * d); - } - } else { - logits_grad[i] *= loss_grad[row]; - } - if (fabs(scale - 1.0) > 1e-8) { - logits_grad[i] *= static_cast(scale); - } - } -} - -template -void MarginCrossEntropyGradKernel(const Context& dev_ctx, - const DenseTensor& logits, - const DenseTensor& label, - const DenseTensor& softmax, - const DenseTensor& loss_grad, - bool return_softmax, - int ring_id, - int rank, - int nranks, - float margin1, - float margin2, - float margin3, - float scale, - DenseTensor* logits_grad) { - const auto softmax_dims = softmax.dims(); - const int axis = softmax_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, softmax_dims); - const int D = phi::funcs::SizeFromAxis(axis, softmax_dims); - - if (return_softmax) { - phi::Copy( - dev_ctx, softmax, dev_ctx.GetPlace(), false, logits_grad); - } else { - logits_grad->ShareDataWith(softmax); - } - - int blocks = NumBlocks(N * D); - int threads = kNumCUDAThreads; - const auto& label_type = label.dtype(); - - DenseTensor class_interval; - GetClassInterval(dev_ctx.stream(), - dev_ctx.GetPlace(), - dev_ctx, - ring_id, - rank, - nranks, - D, - &class_interval); - - if (label_type == phi::DataType::INT32) { - typedef int32_t LabelT; - CalculateGrad - <<>>(logits_grad->data(), - loss_grad.data(), - logits.data(), - label.data(), - margin1, - margin2, - scale, - rank, - N, - D, - class_interval.data()); - } else if (label_type == phi::DataType::INT64) { - typedef int64_t LabelT; - CalculateGrad - <<>>(logits_grad->data(), - loss_grad.data(), - logits.data(), - label.data(), - margin1, - margin2, - scale, - rank, - N, - D, - class_interval.data()); - } -} - -} // namespace phi - -PD_REGISTER_KERNEL(margin_cross_entropy_grad, - GPU, - ALL_LAYOUT, - phi::MarginCrossEntropyGradKernel, - float, - double, - phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu deleted file mode 100644 index 84d3f3c972ad9076e989327b642f577eb5670b21..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" - -namespace phi { - -template -void SyncBatchNormGradKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, - const DenseTensor& saved_mean, - const DenseTensor& saved_variance, - const paddle::optional& reserve_space, - const DenseTensor& y_grad, - float momentum, - float epsilon_f, - const std::string& data_layout_str, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - DenseTensor* x_grad, - DenseTensor* scale_grad, - DenseTensor* bias_grad) { - SyncBatchNormGradFunctor(ctx, - &x, - nullptr, - scale, - bias, - saved_mean, - saved_variance, - y_grad, - epsilon_f, - data_layout_str, - x_grad, - scale_grad, - bias_grad); -} - -} // namespace phi - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(sync_batch_norm_grad, - GPU, - ALL_LAYOUT, - phi::SyncBatchNormGradKernel, - float, - phi::dtype::float16) {} -#else -PD_REGISTER_KERNEL(sync_batch_norm_grad, - GPU, - ALL_LAYOUT, - phi::SyncBatchNormGradKernel, - float, - double, - phi::dtype::float16) {} -#endif diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu deleted file mode 100644 index 19b9f5845bf76aeb6805cd5d7cf38c74aad8c4c0..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/sync_batch_norm_kernel.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" - -namespace phi { - -template -void SyncBatchNormKernel(const Context &ctx, - const DenseTensor &x, - const DenseTensor &mean, - const DenseTensor &variance, - const DenseTensor &scale, - const DenseTensor &bias, - bool is_test, - float momentum, - float epsilon_f, - const std::string &data_layout_str, - bool use_global_stats, - bool trainable_statistics, - DenseTensor *y, - DenseTensor *mean_out, - DenseTensor *variance_out, - DenseTensor *saved_mean, - DenseTensor *saved_variance, - DenseTensor *reserve_space) { - PADDLE_ENFORCE_EQ(use_global_stats, - false, - phi::errors::InvalidArgument( - "sync_batch_norm doesn't support " - "to set use_global_stats True. Please use batch_norm " - "in this case.")); - - double epsilon = epsilon_f; - const bool trainable_stats = trainable_statistics; - const DataLayout layout = phi::StringToDataLayout(data_layout_str); - bool test_mode = is_test && (!trainable_statistics); - const auto &x_dims = x.dims(); - PADDLE_ENFORCE_GE(x_dims.size(), - 2, - phi::errors::InvalidArgument( - "The Input dim size should be larger than 1.")); - PADDLE_ENFORCE_LE(x_dims.size(), - 5, - phi::errors::InvalidArgument( - "The Input dim size should be less than 6.")); - int N, C, H, W, D; - funcs::ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); - int x_numel = x.numel(); - - const T *x_d = x.template data(); - const auto *s_d = scale.template data>(); - const auto *b_d = bias.template data>(); - - T *y_d = ctx.template Alloc(y); - - const BatchNormParamType *mean_data = nullptr; - const BatchNormParamType *var_data = nullptr; - - auto stream = ctx.stream(); - const int block = 512; - int max_threads = ctx.GetMaxPhysicalThreadCount(); - - phi::Allocator::AllocationPtr alloc_ptr{nullptr}; - - if (test_mode) { - mean_data = mean.template data>(); - var_data = variance.template data>(); - } else { - // x, x^2, 1, here 1 is used to calc device num - // device num also can be got from phi::DeviceContextPool - const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); - alloc_ptr = phi::memory_utils::Alloc( - ctx.GetPlace(), - bytes, - phi::Stream(reinterpret_cast(ctx.stream()))); - - auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); - const int threads = 256; - int grid = std::min(C, (max_threads + threads - 1) / threads); - if (layout == phi::DataLayout::kNCHW) { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } else { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - ncclComm_t comm = static_cast(detail::GetCCLComm(x.place(), 0)); - if (comm == nullptr) { - comm = ctx.nccl_comm(); - } - - if (comm) { - int dtype = phi::ToNCCLDataType(mean_out->dtype()); - // In-place operation - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllReduce(stats, - stats, - 2 * C + 1, - static_cast(dtype), - ncclSum, - comm, - stream)); - VLOG(3) << "Sync result using all reduce"; - } -#endif - - auto *est_mean_data = ctx.template Alloc>(mean_out); - auto *est_var_data = - ctx.template Alloc>(variance_out); - - auto *sv_mean_data = ctx.template Alloc>(saved_mean); - auto *sv_inv_var_data = - ctx.template Alloc>(saved_variance); - - // Note, Input('Mean')/Input('Variance') share variable with - // Output('MeanOut')/Output('VarianceOut') - KeSyncAndMovingStats - <<<(C + block - 1) / block, block, 0, stream>>>(stats, - stats + C, - stats + 2 * C, - C, - momentum, - epsilon, - sv_mean_data, - sv_inv_var_data, - est_mean_data, - est_var_data); - - mean_data = sv_mean_data; - var_data = stats + C; - } - - int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; - if (layout == phi::DataLayout::kNCHW) { - KeNormAffine - <<>>(x_d, - s_d, - b_d, - mean_data, - var_data, - epsilon, - C, - H * W * D, - x_numel, - y_d); - } else { - KeNormAffine - <<>>(x_d, - s_d, - b_d, - mean_data, - var_data, - epsilon, - C, - H * W * D, - x_numel, - y_d); - } -} - -} // namespace phi - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(sync_batch_norm, - GPU, - ALL_LAYOUT, - phi::SyncBatchNormKernel, - float, - phi::dtype::float16) { - if (kernel_key.dtype() == phi::DataType::FLOAT16) { - kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); - } -} -#else -PD_REGISTER_KERNEL(sync_batch_norm, - GPU, - ALL_LAYOUT, - phi::SyncBatchNormKernel, - float, - double, - phi::dtype::float16) { - if (kernel_key.dtype() == phi::DataType::FLOAT16) { - kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); - } -} -#endif diff --git a/paddle/phi/kernels/sparse/gpu/sync_batch_norm_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sync_batch_norm_grad_kernel.cu deleted file mode 100644 index 664b3a1ee2699286620b2a0f82994cc1e124ee78..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/gpu/sync_batch_norm_grad_kernel.cu +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/sparse/sync_batch_norm_grad_kernel.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/sparse/empty_kernel.h" -#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h" - -namespace phi { -namespace sparse { - -template -void SyncBatchNormCooGradKernel( - const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& scale, - const DenseTensor& bias, - const DenseTensor& saved_mean, - const DenseTensor& saved_variance, - const paddle::optional& reserve_space, - const SparseCooTensor& y_grad, - float momentum, - float epsilon, - const std::string& data_layout, - bool is_test, - bool use_global_stats, - bool trainable_statistics, - SparseCooTensor* x_grad, - DenseTensor* scale_grad, - DenseTensor* bias_grad) { - EmptyLikeCooKernel(dev_ctx, x, x_grad); - *scale_grad = phi::EmptyLike(dev_ctx, scale); - *bias_grad = phi::EmptyLike(dev_ctx, bias); - phi::SyncBatchNormGradKernel(dev_ctx, - x.values(), - scale, - bias, - saved_mean, - saved_variance, - reserve_space, - y_grad.values(), - momentum, - epsilon, - data_layout, - is_test, - use_global_stats, - trainable_statistics, - x_grad->mutable_values(), - scale_grad, - bias_grad); -} - -} // namespace sparse -} // namespace phi - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(sync_batch_norm_coo_grad, - GPU, - ALL_LAYOUT, - phi::sparse::SyncBatchNormCooGradKernel, - float, - phi::dtype::float16) {} -#else -PD_REGISTER_KERNEL(sync_batch_norm_coo_grad, - GPU, - ALL_LAYOUT, - phi::sparse::SyncBatchNormCooGradKernel, - float, - double, - phi::dtype::float16) {} -#endif diff --git a/paddle/phi/kernels/sparse/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/sparse/gpu/sync_batch_norm_kernel.cu deleted file mode 100644 index 162f1f4b937655273bba5d680973099333178224..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/gpu/sync_batch_norm_kernel.cu +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/kernels/sparse/sync_batch_norm_kernel.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/sparse/empty_kernel.h" -#include "paddle/phi/kernels/sync_batch_norm_kernel.h" - -namespace phi { -namespace sparse { - -template -void SyncBatchNormCooKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& mean, - const DenseTensor& variance, - const DenseTensor& scale, - const DenseTensor& bias, - bool is_test, - float momentum, - float epsilon, - const std::string& data_layout, - bool use_global_stats, - bool trainable_statistics, - SparseCooTensor* y, - DenseTensor* mean_out, - DenseTensor* variance_out, - DenseTensor* saved_mean, - DenseTensor* saved_variance, - DenseTensor* reserve_space) { - EmptyLikeCooKernel(dev_ctx, x, y); - phi::SyncBatchNormKernel(dev_ctx, - x.values(), - mean, - variance, - scale, - bias, - is_test, - momentum, - epsilon, - data_layout, - use_global_stats, - trainable_statistics, - y->mutable_values(), - mean_out, - variance_out, - saved_mean, - saved_variance, - reserve_space); - y->SetIndicesDict(x.GetIndicesDict()); -} - -} // namespace sparse -} // namespace phi - -#ifdef PADDLE_WITH_HIP -PD_REGISTER_KERNEL(sync_batch_norm_coo, - GPU, - ALL_LAYOUT, - phi::sparse::SyncBatchNormCooKernel, - float, - phi::dtype::float16) {} -#else -PD_REGISTER_KERNEL(sync_batch_norm_coo, - GPU, - ALL_LAYOUT, - phi::sparse::SyncBatchNormCooKernel, - float, - double, - phi::dtype::float16) {} -#endif