From e8e47581127d157fcbe055d06ebbdbfbdddd907c Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Mon, 27 Dec 2021 17:53:18 +0800 Subject: [PATCH] [BugFix]Fix bug in pfp16 in DataParallel (#38378) * fix bug in pfp16 * fix hip * fix hip --- paddle/fluid/framework/data_type.h | 30 ++++++++++++++++++++++++++++++ paddle/fluid/imperative/reducer.cc | 12 +++++++++++- paddle/fluid/imperative/reducer.cu | 12 +++++++++++- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 08749b6b751..ec8284b8255 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -89,6 +89,22 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); +// It's only for DataParallel in HIP, bf16 not support in HIP. +#define _ForEachDataTypeForHIP_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, bool, BOOL); \ + _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ + _ForEachDataTypeHelper_(callback, int16_t, INT16); \ + _ForEachDataTypeHelper_(callback, int8_t, INT8); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX64); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::complex, \ + COMPLEX128); + #define DefineDataTypeTrait(cpp_type, proto_type) \ template <> \ struct DataTypeTrait { \ @@ -147,6 +163,20 @@ inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) { #undef VisitDataTypeCallbackTiny } +template +inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) { +#define VisitDataTypeCallbackHIP(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachDataTypeForHIP_(VisitDataTypeCallbackHIP); +#undef VisitDataTypeCallbackHIP +} + extern std::string DataTypeToString(const proto::VarType::Type type); extern size_t SizeOfType(proto::VarType::Type type); inline std::ostream& operator<<(std::ostream& out, diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 9014871229b..beddbd5d120 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -48,9 +48,19 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { } else if (platform::is_cpu_place(tensor->place())) { VLOG(4) << "before div 2" << *tensor; VLOG(4) << "NDiv for cpu devices : rank = " << nranks; - framework::VisitDataTypeSmall( +#ifdef PADDLE_WITH_HIP + if (dtype_ == paddle::framework::proto::VarType_Type_BF16) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupport BF16 in DataParallel for now")); + } + framework::VisitDataTypeForHIP( dtype_, DivNRanksForAllReduce( tensor, nranks, context)); +#else + framework::VisitDataType(dtype_, + DivNRanksForAllReduce( + tensor, nranks, context)); +#endif VLOG(4) << "after div 2" << *tensor; } else if (platform::is_xpu_place(tensor->place())) { #ifdef PADDLE_WITH_XPU_BKCL diff --git a/paddle/fluid/imperative/reducer.cu b/paddle/fluid/imperative/reducer.cu index ca233292b34..05453a61b7e 100644 --- a/paddle/fluid/imperative/reducer.cu +++ b/paddle/fluid/imperative/reducer.cu @@ -20,9 +20,19 @@ namespace imperative { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks, const platform::DeviceContext &context) { - framework::VisitDataTypeSmall( +#ifdef PADDLE_WITH_HIP + if (dtype_ == paddle::framework::proto::VarType_Type_BF16) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupport BF16 in DataParallel for now")); + } + framework::VisitDataTypeForHIP( dtype_, DivNRanksForAllReduce(tensor, nranks, context)); +#else + framework::VisitDataType( + dtype_, DivNRanksForAllReduce(tensor, nranks, + context)); +#endif } #endif -- GitLab