未验证 提交 e8e47581 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix bug in pfp16 in DataParallel (#38378)

* fix bug in pfp16

* fix hip

* fix hip
上级 9cfdae91
...@@ -89,6 +89,22 @@ struct DataTypeTrait<void> { ...@@ -89,6 +89,22 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); _ForEachDataTypeHelper_(callback, int64_t, INT64);
// It's only for DataParallel in HIP, bf16 not support in HIP.
#define _ForEachDataTypeForHIP_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);
#define DefineDataTypeTrait(cpp_type, proto_type) \ #define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \ template <> \
struct DataTypeTrait<cpp_type> { \ struct DataTypeTrait<cpp_type> { \
...@@ -147,6 +163,20 @@ inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) { ...@@ -147,6 +163,20 @@ inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackTiny #undef VisitDataTypeCallbackTiny
} }
template <typename Visitor>
inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackHIP(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataTypeForHIP_(VisitDataTypeCallbackHIP);
#undef VisitDataTypeCallbackHIP
}
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
......
...@@ -48,9 +48,19 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { ...@@ -48,9 +48,19 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
} else if (platform::is_cpu_place(tensor->place())) { } else if (platform::is_cpu_place(tensor->place())) {
VLOG(4) << "before div 2" << *tensor; VLOG(4) << "before div 2" << *tensor;
VLOG(4) << "NDiv for cpu devices : rank = " << nranks; VLOG(4) << "NDiv for cpu devices : rank = " << nranks;
framework::VisitDataTypeSmall( #ifdef PADDLE_WITH_HIP
if (dtype_ == paddle::framework::proto::VarType_Type_BF16) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unsupport BF16 in DataParallel for now"));
}
framework::VisitDataTypeForHIP(
dtype_, DivNRanksForAllReduce<platform::CPUDeviceContext>( dtype_, DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context)); tensor, nranks, context));
#else
framework::VisitDataType(dtype_,
DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context));
#endif
VLOG(4) << "after div 2" << *tensor; VLOG(4) << "after div 2" << *tensor;
} else if (platform::is_xpu_place(tensor->place())) { } else if (platform::is_xpu_place(tensor->place())) {
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
......
...@@ -20,9 +20,19 @@ namespace imperative { ...@@ -20,9 +20,19 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks, void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks,
const platform::DeviceContext &context) { const platform::DeviceContext &context) {
framework::VisitDataTypeSmall( #ifdef PADDLE_WITH_HIP
if (dtype_ == paddle::framework::proto::VarType_Type_BF16) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unsupport BF16 in DataParallel for now"));
}
framework::VisitDataTypeForHIP(
dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks, dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks,
context)); context));
#else
framework::VisitDataType(
dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks,
context));
#endif
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册