未验证 提交 e8e47581 编写于 作者: S ShenLiang 提交者: GitHub

[BugFix]Fix bug in pfp16 in DataParallel (#38378)

* fix bug in pfp16

* fix hip

* fix hip
上级 9cfdae91
......@@ -89,6 +89,22 @@ struct DataTypeTrait<void> {
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64);
// It's only for DataParallel in HIP, bf16 not support in HIP.
#define _ForEachDataTypeForHIP_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<float>, \
COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex<double>, \
COMPLEX128);
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
......@@ -147,6 +163,20 @@ inline void VisitDataTypeTiny(proto::VarType::Type type, Visitor visitor) {
#undef VisitDataTypeCallbackTiny
}
template <typename Visitor>
inline void VisitDataTypeForHIP(proto::VarType::Type type, Visitor visitor) {
#define VisitDataTypeCallbackHIP(cpp_type, proto_type) \
do { \
if (type == proto_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)
_ForEachDataTypeForHIP_(VisitDataTypeCallbackHIP);
#undef VisitDataTypeCallbackHIP
}
extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out,
......
......@@ -48,9 +48,19 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
} else if (platform::is_cpu_place(tensor->place())) {
VLOG(4) << "before div 2" << *tensor;
VLOG(4) << "NDiv for cpu devices : rank = " << nranks;
framework::VisitDataTypeSmall(
#ifdef PADDLE_WITH_HIP
if (dtype_ == paddle::framework::proto::VarType_Type_BF16) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unsupport BF16 in DataParallel for now"));
}
framework::VisitDataTypeForHIP(
dtype_, DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context));
#else
framework::VisitDataType(dtype_,
DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context));
#endif
VLOG(4) << "after div 2" << *tensor;
} else if (platform::is_xpu_place(tensor->place())) {
#ifdef PADDLE_WITH_XPU_BKCL
......
......@@ -20,9 +20,19 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks,
const platform::DeviceContext &context) {
framework::VisitDataTypeSmall(
#ifdef PADDLE_WITH_HIP
if (dtype_ == paddle::framework::proto::VarType_Type_BF16) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unsupport BF16 in DataParallel for now"));
}
framework::VisitDataTypeForHIP(
dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks,
context));
#else
framework::VisitDataType(
dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks,
context));
#endif
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册