提交 f37cb543 编写于 作者: G gaotingquan 提交者: cuicheng01

rm op black list in amp

the op flatten_contiguous_range and greater_than has supported amp mode since paddle 2.4
上级 a7ba6eab
...@@ -504,11 +504,7 @@ class Engine(object): ...@@ -504,11 +504,7 @@ class Engine(object):
batch_tensor = paddle.to_tensor(batch_data) batch_tensor = paddle.to_tensor(batch_data)
if self.amp and self.amp_eval: if self.amp and self.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=self.amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=self.amp_level):
out = self.model(batch_tensor) out = self.model(batch_tensor)
else: else:
out = self.model(batch_tensor) out = self.model(batch_tensor)
......
...@@ -56,11 +56,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -56,11 +56,7 @@ def classification_eval(engine, epoch_id=0):
# image input # image input
if engine.amp and engine.amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=engine.amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0]) out = engine.model(batch[0])
else: else:
out = engine.model(batch[0]) out = engine.model(batch[0])
...@@ -114,11 +110,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -114,11 +110,7 @@ def classification_eval(engine, epoch_id=0):
# calc loss # calc loss
if engine.eval_loss_func is not None: if engine.eval_loss_func is not None:
if engine.amp and engine.amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=engine.amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels) loss_dict = engine.eval_loss_func(preds, labels)
else: else:
loss_dict = engine.eval_loss_func(preds, labels) loss_dict = engine.eval_loss_func(preds, labels)
......
...@@ -137,11 +137,7 @@ def compute_feature(engine, name="gallery"): ...@@ -137,11 +137,7 @@ def compute_feature(engine, name="gallery"):
has_camera = True has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=engine.amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0]) out = engine.model(batch[0])
else: else:
out = engine.model(batch[0]) out = engine.model(batch[0])
......
...@@ -50,11 +50,7 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -50,11 +50,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
# image input # image input
if engine.amp: if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper() amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
else: else:
......
...@@ -64,11 +64,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step): ...@@ -64,11 +64,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step):
# image input # image input
if engine.amp: if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper() amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
loss_dict, logits_label = get_loss( loss_dict, logits_label = get_loss(
engine, inputs, batch_size_label, temperture, threshold, engine, inputs, batch_size_label, temperture, threshold,
targets_x) targets_x)
......
...@@ -191,11 +191,7 @@ def forward(engine, batch, loss_func): ...@@ -191,11 +191,7 @@ def forward(engine, batch, loss_func):
batch_info = {"label": batch[1], "domain": batch[2]} batch_info = {"label": batch[1], "domain": batch[2]}
if engine.amp: if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper() amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(level=amp_level):
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
out = engine.model(batch[0], batch[1]) out = engine.model(batch[0], batch[1])
loss_dict = loss_func(out, batch_info) loss_dict = loss_func(out, batch_info)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册