未验证 提交 c775bc69 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomPlace] fix amp (#48090)

* [CustomPlace] fix amp

* [CustomPlace] fix amp

* fix ut because of too long time matmul fp16
上级 04709310
...@@ -29,7 +29,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, ...@@ -29,7 +29,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
paddle::platform::is_xpu_place(place) || paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) || paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) || paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) { paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if ((data_type == paddle::experimental::DataType::FLOAT32 || if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 || data_type == paddle::experimental::DataType::FLOAT16 ||
......
...@@ -27,7 +27,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor, ...@@ -27,7 +27,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
paddle::platform::is_xpu_place(place) || paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) || paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) || paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) { paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if ((data_type == paddle::experimental::DataType::FLOAT32 || if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 || data_type == paddle::experimental::DataType::FLOAT16 ||
......
...@@ -54,7 +54,11 @@ def train_func_ampo1(epoch_id, train_loader, model, cost, optimizer, scaler): ...@@ -54,7 +54,11 @@ def train_func_ampo1(epoch_id, train_loader, model, cost, optimizer, scaler):
for batch_id, (images, labels) in enumerate(train_loader()): for batch_id, (images, labels) in enumerate(train_loader()):
# forward # forward
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={"flatten_contiguous_range", "greater_than"}, custom_black_list={
"flatten_contiguous_range",
"greater_than",
"matmul_v2",
},
level='O1', level='O1',
): ):
outputs = model(images) outputs = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册