提交 c538b837 编写于 作者: Z zjun

remove enbale hccl

上级 92bb8a7c
...@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True) ...@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
else: mirror_mean=True, parameter_broadcast=True)
if args_opt.run_distribute: auto_parallel_context().set_all_reduce_fusion_split_indices([140])
context.set_context(enable_hccl=True) init()
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num) net = resnet101(class_num=config.class_num)
......
...@@ -56,17 +56,11 @@ context.set_context(enable_loop_sink=True) ...@@ -56,17 +56,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
else: mirror_mean=True, parameter_broadcast=True)
if args_opt.run_distribute: auto_parallel_context().set_all_reduce_fusion_split_indices([140])
context.set_context(enable_hccl=True) init()
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num) net = resnet101(class_num=config.class_num)
......
...@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True) ...@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
else: mirror_mean=True)
if args_opt.run_distribute: auto_parallel_context().set_all_reduce_fusion_split_indices([140])
context.set_context(enable_hccl=True) init()
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num) net = resnet50(class_num=config.class_num)
......
...@@ -54,17 +54,11 @@ context.set_context(enable_loop_sink=True) ...@@ -54,17 +54,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
else: mirror_mean=True)
if args_opt.run_distribute: auto_parallel_context().set_all_reduce_fusion_split_indices([140])
context.set_context(enable_hccl=True) init()
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num) net = resnet50(class_num=config.class_num)
......
...@@ -37,7 +37,7 @@ if __name__ == '__main__': ...@@ -37,7 +37,7 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False) context.set_context(enable_mem_reuse=True)
net = vgg16(num_classes=cfg.num_classes) net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
...@@ -66,7 +66,7 @@ if __name__ == '__main__': ...@@ -66,7 +66,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink=True) context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True, enable_hccl=False) context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1: if device_num > 1:
......
...@@ -90,13 +90,11 @@ if __name__ == '__main__': ...@@ -90,13 +90,11 @@ if __name__ == '__main__':
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num) device_num=device_num)
init() init()
rank = args_opt.device_id % device_num rank = args_opt.device_id % device_num
else: else:
context.set_context(enable_hccl=False)
rank = 0 rank = 0
device_num = 1 device_num = 1
......
...@@ -115,8 +115,6 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -115,8 +115,6 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.")
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("set_hccl_flag", &mindspore::MsContext::set_enable_hccl, "Set enable hccl.")
.def("get_hccl_flag", &mindspore::MsContext::enable_hccl, "Get whether to enable hccl.")
.def("set_task_sink_flag", &mindspore::MsContext::set_enable_task_sink, "Set enable task sink.") .def("set_task_sink_flag", &mindspore::MsContext::set_enable_task_sink, "Set enable task sink.")
.def("get_task_sink_flag", &mindspore::MsContext::enable_task_sink, "Get whether to enable task sink.") .def("get_task_sink_flag", &mindspore::MsContext::enable_task_sink, "Get whether to enable task sink.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
......
...@@ -773,7 +773,7 @@ void InitHccl() { ...@@ -773,7 +773,7 @@ void InitHccl() {
(void)ms_context->OpenTsd(); (void)ms_context->OpenTsd();
uint32_t device_id = ms_context->device_id(); uint32_t device_id = ms_context->device_id();
std::string device_name = ms_context->device_target(); std::string device_name = ms_context->device_target();
ms_context->set_enable_hccl(true);
if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
......
...@@ -225,14 +225,6 @@ class _Context: ...@@ -225,14 +225,6 @@ class _Context:
if not success: if not success:
raise RuntimeError("Device id set failed!!!") raise RuntimeError("Device id set failed!!!")
@property
def enable_hccl(self):
return self._context_handle.get_hccl_flag()
@enable_hccl.setter
def enable_hccl(self, hccl):
self._context_handle.set_hccl_flag(hccl)
@property @property
def enable_ir_fusion(self): def enable_ir_fusion(self):
return self._context_handle.get_ir_fusion_flag() return self._context_handle.get_ir_fusion_flag()
...@@ -482,7 +474,7 @@ def reset_auto_parallel_context(): ...@@ -482,7 +474,7 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, @args_type_check(mode=int, precompile_only=bool, device_target=str,
device_id=int, enable_ir_fusion=bool, save_graphs=bool, enable_hccl=bool, device_id=int, enable_ir_fusion=bool, save_graphs=bool,
enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool, enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool,
enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool, enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool,
enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str, enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
...@@ -515,7 +507,6 @@ def set_context(**kwargs): ...@@ -515,7 +507,6 @@ def set_context(**kwargs):
while device_num_per_host should no more than 4096. Default: 0. while device_num_per_host should no more than 4096. Default: 0.
enable_ir_fusion (bool): Whether to enable ir fusion. Default: True. enable_ir_fusion (bool): Whether to enable ir fusion. Default: True.
save_graphs (bool): Whether to save graphs. Default: False. save_graphs (bool): Whether to save graphs. Default: False.
enable_hccl (bool): Whether to enable hccl. Default: False.
enable_loop_sink (bool): Whether to enable loop sink. Default: True. enable_loop_sink (bool): Whether to enable loop sink. Default: True.
enable_task_sink (bool): Whether to enable task sink. Default: True. enable_task_sink (bool): Whether to enable task sink. Default: True.
enable_mem_reuse (bool): Whether to enable memory reuse. Default: True. enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
......
...@@ -130,7 +130,7 @@ class DistributedGradReducer(Cell): ...@@ -130,7 +130,7 @@ class DistributedGradReducer(Cell):
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True) >>> device_id=int(device_id))
>>> init() >>> init()
>>> context.reset_auto_parallel_context() >>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
......
...@@ -33,7 +33,6 @@ def setup_module(): ...@@ -33,7 +33,6 @@ def setup_module():
global rank_id global rank_id
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, context.set_context(enable_task_sink=True,
device_id=device_id) device_id=device_id)
context.set_context(enable_ir_fusion=True) context.set_context(enable_ir_fusion=True)
......
...@@ -46,7 +46,6 @@ def setup_module(): ...@@ -46,7 +46,6 @@ def setup_module():
global rank_id global rank_id
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, context.set_context(enable_task_sink=True,
device_id=device_id) device_id=device_id)
context.set_context(enable_ir_fusion=True) context.set_context(enable_ir_fusion=True)
......
...@@ -31,7 +31,6 @@ from mindspore.train.callback import Callback ...@@ -31,7 +31,6 @@ from mindspore.train.callback import Callback
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID'))) context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID')))
context.set_context(enable_ir_fusion=True) context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
......
...@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell): ...@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else: context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
if args_opt.run_distribute: init()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
else:
context.set_context(enable_hccl=False)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size epoch_size = args_opt.epoch_size
......
...@@ -123,16 +123,10 @@ class CrossEntropyLoss(nn.Cell): ...@@ -123,16 +123,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else: context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
if args_opt.run_distribute: init()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
else:
context.set_context(enable_hccl=False)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size epoch_size = args_opt.epoch_size
......
...@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell): ...@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__': if __name__ == '__main__':
if args_opt.do_eval: if not args_opt.do_eval and args_opt.run_distribute:
context.set_context(enable_hccl=False) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
else: auto_parallel_context().set_all_reduce_fusion_split_indices([140])
if args_opt.run_distribute: init()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size epoch_size = args_opt.epoch_size
......
...@@ -153,7 +153,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, ...@@ -153,7 +153,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
context.set_context(enable_hccl=enable_hccl)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
os.environ['RANK_ID'] = str(device_id) os.environ['RANK_ID'] = str(device_id)
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
......
...@@ -19,6 +19,7 @@ from mindspore import Tensor ...@@ -19,6 +19,7 @@ from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore import context from mindspore import context
import os import os
...@@ -31,10 +32,10 @@ from mindspore.parallel import set_algo_parameters ...@@ -31,10 +32,10 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, device_id= 0) context.set_context(enable_task_sink=True, device_id= 0)
context.set_context(enable_ir_fusion=True) context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
init()
def weight_variable(shape, factor=0.1): def weight_variable(shape, factor=0.1):
return TruncatedNormal(0.02) return TruncatedNormal(0.02)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册