未验证 提交 c775bc69 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomPlace] fix amp (#48090)

* [CustomPlace] fix amp

* [CustomPlace] fix amp

* fix ut because of too long time matmul fp16
上级 04709310
......@@ -29,7 +29,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 ||
......
......@@ -27,7 +27,8 @@ static inline bool NeedCast(const paddle::experimental::Tensor& tensor,
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
paddle::platform::is_npu_pinned_place(place) ||
paddle::platform::is_custom_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if ((data_type == paddle::experimental::DataType::FLOAT32 ||
data_type == paddle::experimental::DataType::FLOAT16 ||
......
......@@ -54,7 +54,11 @@ def train_func_ampo1(epoch_id, train_loader, model, cost, optimizer, scaler):
for batch_id, (images, labels) in enumerate(train_loader()):
# forward
with paddle.amp.auto_cast(
custom_black_list={"flatten_contiguous_range", "greater_than"},
custom_black_list={
"flatten_contiguous_range",
"greater_than",
"matmul_v2",
},
level='O1',
):
outputs = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册