提交 40a3a714 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5148 change group conv dtype in gpu resnext50

Merge pull request !5148 from zhaoting/master
......@@ -44,9 +44,6 @@ def auto_mixed_precision(network):
elif name == 'fc':
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
change = True
elif name == 'conv2':
subcell.to_float(mstype.float32)
change = True
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
change = True
......
......@@ -36,7 +36,6 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.image_classification import get_network
from src.utils.auto_mixed_precision import auto_mixed_precision
from src.config import config
......@@ -273,8 +272,8 @@ def train(cloud_args=None):
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
else:
auto_mixed_precision(network)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'})
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O2")
# checkpoint save
progress_cb = ProgressMonitor(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册