提交 2ff29f01 编写于 作者: C chenzomi

fix hswishquant and hsigmoidquant validation false bug

上级 34469401
...@@ -920,7 +920,7 @@ class HSwishQuant(_QuantActivation): ...@@ -920,7 +920,7 @@ class HSwishQuant(_QuantActivation):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range, narrow_range=narrow_range,
quant_delay=quant_delay) quant_delay=quant_delay)
if isinstance(activation, nn.HSwish): if issubclass(activation, nn.HSwish):
self.act = activation() self.act = activation()
else: else:
raise ValueError("Activation should be `nn.HSwish`") raise ValueError("Activation should be `nn.HSwish`")
...@@ -989,7 +989,7 @@ class HSigmoidQuant(_QuantActivation): ...@@ -989,7 +989,7 @@ class HSigmoidQuant(_QuantActivation):
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range, narrow_range=narrow_range,
quant_delay=quant_delay) quant_delay=quant_delay)
if isinstance(activation, nn.HSwish): if issubclass(activation, nn.HSwish):
self.act = activation() self.act = activation()
else: else:
raise ValueError("Activation should be `nn.HSigmoid`") raise ValueError("Activation should be `nn.HSigmoid`")
......
...@@ -18,6 +18,7 @@ import time ...@@ -18,6 +18,7 @@ import time
import argparse import argparse
import random import random
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
...@@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode ...@@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_group_size from mindspore.communication.management import init, get_group_size, get_rank
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend from src.config import config_gpu, config_ascend
...@@ -60,9 +62,14 @@ if args_opt.platform == "Ascend": ...@@ -60,9 +62,14 @@ if args_opt.platform == "Ascend":
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU": elif args_opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False) device_target="GPU",
save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupported device target.")
class CrossEntropyWithLabelSmooth(_Loss): class CrossEntropyWithLabelSmooth(_Loss):
...@@ -155,12 +162,8 @@ class Monitor(Callback): ...@@ -155,12 +162,8 @@ class Monitor(Callback):
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.platform == "GPU": if args_opt.platform == "GPU":
# train on gpu # train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu) print("train args: ", args_opt)
print("cfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net # define net
net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU")
...@@ -201,13 +204,13 @@ if __name__ == '__main__': ...@@ -201,13 +204,13 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale) loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())] cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint: if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max) keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint( ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
# begine train # begin train
model.train(epoch_size, dataset, callbacks=cb) model.train(epoch_size, dataset, callbacks=cb)
elif args_opt.platform == "Ascend": elif args_opt.platform == "Ascend":
# train on ascend # train on ascend
......
...@@ -18,6 +18,7 @@ import time ...@@ -18,6 +18,7 @@ import time
import argparse import argparse
import random import random
import numpy as np import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
...@@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback ...@@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from mindspore.communication.management import init, get_group_size from mindspore.communication.management import init, get_group_size, get_rank
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu, config_ascend from src.config import config_gpu, config_ascend
...@@ -57,10 +59,16 @@ if args_opt.platform == "Ascend": ...@@ -57,10 +59,16 @@ if args_opt.platform == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id,
save_graphs=False)
elif args_opt.platform == "GPU": elif args_opt.platform == "GPU":
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False) device_target="GPU",
save_graphs=False)
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
else: else:
raise ValueError("Unsupport platform.") raise ValueError("Unsupport platform.")
...@@ -155,12 +163,8 @@ class Monitor(Callback): ...@@ -155,12 +163,8 @@ class Monitor(Callback):
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.platform == "GPU": if args_opt.platform == "GPU":
# train on gpu # train on gpu
print("train args: ", args_opt, "\ncfg: ", config_gpu) print("train args: ", args_opt)
print("cfg: ", config_gpu)
init('nccl')
context.set_auto_parallel_context(parallel_mode="data_parallel",
mirror_mean=True,
device_num=get_group_size())
# define net # define net
net = mobilenet_v3_large(num_classes=config_gpu.num_classes) net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
...@@ -201,11 +205,11 @@ if __name__ == '__main__': ...@@ -201,11 +205,11 @@ if __name__ == '__main__':
loss_scale_manager=loss_scale) loss_scale_manager=loss_scale)
cb = [Monitor(lr_init=lr.asnumpy())] cb = [Monitor(lr_init=lr.asnumpy())]
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
if config_gpu.save_checkpoint: if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max) keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint( ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb] cb += [ckpt_cb]
# begine train # begine train
model.train(epoch_size, dataset, callbacks=cb) model.train(epoch_size, dataset, callbacks=cb)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册