提交 c538b837 编写于 作者: Z zjun

remove enbale hccl

上级 92bb8a7c
......@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num)
......
......@@ -56,17 +56,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num)
......
......@@ -51,17 +51,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num)
......
......@@ -54,17 +54,11 @@ context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = config.epoch_size
net = resnet50(class_num=config.class_num)
......
......@@ -37,7 +37,7 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
context.set_context(enable_mem_reuse=True)
net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
......@@ -66,7 +66,7 @@ if __name__ == '__main__':
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True, enable_hccl=False)
context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1:
......
......@@ -90,13 +90,11 @@ if __name__ == '__main__':
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
init()
rank = args_opt.device_id % device_num
else:
context.set_context(enable_hccl=False)
rank = 0
device_num = 1
......
......@@ -115,8 +115,6 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.")
.def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.")
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("set_hccl_flag", &mindspore::MsContext::set_enable_hccl, "Set enable hccl.")
.def("get_hccl_flag", &mindspore::MsContext::enable_hccl, "Get whether to enable hccl.")
.def("set_task_sink_flag", &mindspore::MsContext::set_enable_task_sink, "Set enable task sink.")
.def("get_task_sink_flag", &mindspore::MsContext::enable_task_sink, "Get whether to enable task sink.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
......
......@@ -773,7 +773,7 @@ void InitHccl() {
(void)ms_context->OpenTsd();
uint32_t device_id = ms_context->device_id();
std::string device_name = ms_context->device_target();
ms_context->set_enable_hccl(true);
if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
......
......@@ -225,14 +225,6 @@ class _Context:
if not success:
raise RuntimeError("Device id set failed!!!")
@property
def enable_hccl(self):
return self._context_handle.get_hccl_flag()
@enable_hccl.setter
def enable_hccl(self, hccl):
self._context_handle.set_hccl_flag(hccl)
@property
def enable_ir_fusion(self):
return self._context_handle.get_ir_fusion_flag()
......@@ -482,7 +474,7 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str,
device_id=int, enable_ir_fusion=bool, save_graphs=bool, enable_hccl=bool,
device_id=int, enable_ir_fusion=bool, save_graphs=bool,
enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool,
enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool,
enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
......@@ -515,7 +507,6 @@ def set_context(**kwargs):
while device_num_per_host should no more than 4096. Default: 0.
enable_ir_fusion (bool): Whether to enable ir fusion. Default: True.
save_graphs (bool): Whether to save graphs. Default: False.
enable_hccl (bool): Whether to enable hccl. Default: False.
enable_loop_sink (bool): Whether to enable loop sink. Default: True.
enable_task_sink (bool): Whether to enable task sink. Default: True.
enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
......
......@@ -130,7 +130,7 @@ class DistributedGradReducer(Cell):
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> device_id=int(device_id))
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
......
......@@ -33,7 +33,6 @@ def setup_module():
global rank_id
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True,
device_id=device_id)
context.set_context(enable_ir_fusion=True)
......
......@@ -46,7 +46,6 @@ def setup_module():
global rank_id
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True,
device_id=device_id)
context.set_context(enable_ir_fusion=True)
......
......@@ -31,7 +31,6 @@ from mindspore.train.callback import Callback
from mindspore.parallel import set_algo_parameters
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID')))
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False)
......
......@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size
......
......@@ -123,16 +123,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
context.set_auto_parallel_context(all_reduce_fusion_split_indices=[140])
init()
context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size
......
......@@ -122,16 +122,10 @@ class CrossEntropyLoss(nn.Cell):
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
context.set_context(mode=context.GRAPH_MODE)
epoch_size = args_opt.epoch_size
......
......@@ -153,7 +153,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
context.set_context(enable_hccl=enable_hccl)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
os.environ['RANK_ID'] = str(device_id)
os.environ['RANK_SIZE'] = str(device_num)
......
......@@ -19,6 +19,7 @@ from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim.momentum import Momentum
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.train.model import Model, ParallelMode
from mindspore import context
import os
......@@ -31,10 +32,10 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel import _cost_model_context as cost_model_context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_hccl=True)
context.set_context(enable_task_sink=True, device_id= 0)
context.set_context(enable_ir_fusion=True)
context.set_context(enable_loop_sink=False)
init()
def weight_variable(shape, factor=0.1):
return TruncatedNormal(0.02)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册