提交 8f7aa5bd 编写于 作者: Y yao_yf

auto parallel context modify

上级 042ac51f
......@@ -42,15 +42,12 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_;
}
ParallelContext::ParallelContext() {
communication_backend_ = HCCL_BACKEND;
Reset();
}
ParallelContext::ParallelContext() { Reset(); }
void ParallelContext::Reset() {
mirror_mean_ = false;
full_batch_ = false;
cast_before_mirror_ = true;
gradient_fp32_sync_ = true;
loss_repeated_mean_ = true;
device_num_ = 1;
global_rank_ = 0;
......@@ -81,14 +78,10 @@ void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; }
void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
void ParallelContext::set_communication_backend(const std::string &communication_backend) {
communication_backend_ = communication_backend;
}
bool ParallelContext::set_parallel_mode(const std::string &parallel_mode) {
auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode);
if (iter == PARALLEL_MODE_LIST.end()) {
......
......@@ -58,8 +58,8 @@ class ParallelContext {
void set_full_batch(bool full_batch);
bool full_batch() const { return full_batch_; }
void set_cast_before_mirror(bool cast_before_mirror);
bool cast_before_mirror() const { return cast_before_mirror_; }
void set_gradient_fp32_sync(bool gradient_fp32_sync);
bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
void set_loss_repeated_mean(bool loss_repeated_mean);
bool loss_repeated_mean() const { return loss_repeated_mean_; }
......@@ -70,9 +70,6 @@ class ParallelContext {
void set_global_rank(int32_t global_rank);
int32_t global_rank() const { return global_rank_; }
void set_communication_backend(const std::string &communication_backend);
std::string communication_backend() const { return communication_backend_; }
bool set_parallel_mode(const std::string &parallel_mode);
std::string parallel_mode() const { return parallel_mode_; }
......@@ -112,11 +109,10 @@ class ParallelContext {
static std::shared_ptr<ParallelContext> inst_context_;
bool mirror_mean_;
bool full_batch_;
bool cast_before_mirror_;
bool gradient_fp32_sync_;
bool loss_repeated_mean_;
int32_t device_num_;
int32_t global_rank_;
std::string communication_backend_;
std::string parallel_mode_;
std::string strategy_search_mode_;
bool parameter_broadcast_;
......
......@@ -43,6 +43,7 @@
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "utils/comm_manager.h"
#include "utils/symbolic.h"
#include "utils/ms_context.h"
using mindspore::tensor::Tensor;
......@@ -869,8 +870,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
}
bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
// only if cast_before_mirror is true, pre node is cast and type is not float32 return true
if (!ParallelContext::GetInstance()->cast_before_mirror()) {
// only if gradient_fp32_sync is true, pre node is cast and type is not float32 return true
if (!ParallelContext::GetInstance()->gradient_fp32_sync()) {
return false;
}
auto pre_node = node->input(index);
......@@ -2421,13 +2422,17 @@ Status ParallelInit() {
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
int32_t device_num = ParallelContext::GetInstance()->device_num();
int32_t global_rank = ParallelContext::GetInstance()->global_rank();
std::string backend = ParallelContext::GetInstance()->communication_backend();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
std::string world_group;
if (backend == HCCL_BACKEND) {
std::string communication_backend;
if (backend == kAscendDevice || backend == kDavinciDevice) {
world_group = HCCL_WORLD_GROUP;
} else if (backend == NCCL_BACKEND) {
communication_backend = HCCL_BACKEND;
} else if (backend == kGPUDevice) {
world_group = NCCL_WORLD_GROUP;
communication_backend = NCCL_BACKEND;
} else {
MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend;
}
......@@ -2450,14 +2455,14 @@ Status ParallelInit() {
MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank;
}
if (!InitDevice(device_num, global_rank, backend)) {
if (!InitDevice(device_num, global_rank, communication_backend)) {
MS_LOG(ERROR) << "Init device failed";
return FAILED;
}
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
<< ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean()
<< ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror();
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
return SUCCESS;
}
......
......@@ -209,12 +209,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
.def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.")
.def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.")
.def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.")
.def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.")
.def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
.def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
.def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
.def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.")
.def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.")
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
......
......@@ -15,7 +15,6 @@
"""Communication management API"""
import os
from mindspore import context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
......@@ -86,9 +85,6 @@ def init(backend_name=None):
else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
auto_parallel_context().set_communication_backend(backend_name)
def release():
"""
Release distributed resource. e.g., hccl/nccl.
......
......@@ -434,7 +434,7 @@ def _context():
return _k_context
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def set_auto_parallel_context(**kwargs):
......@@ -454,9 +454,9 @@ def set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True..
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror. Default: True.
gradient_fp32_sync. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
......@@ -492,7 +492,7 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(device_num=8)
>>> context.set_auto_parallel_context(global_rank=0)
>>> context.set_auto_parallel_context(mirror_mean=True)
>>> context.set_auto_parallel_context(cast_before_mirror=False)
>>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
......@@ -524,7 +524,7 @@ def reset_auto_parallel_context():
- device_num: 1.
- global_rank: 0.
- mirror_mean: False.
- cast_before_mirror: True.
- gradient_fp32_sync: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: "".
......
......@@ -113,24 +113,24 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_mirror_mean()
def set_cast_before_mirror(self, cast_before_mirror):
def set_gradient_fp32_sync(self, gradient_fp32_sync):
"""
Set cast_before_mirror.
Set gradient_fp32_sync.
Note:
If cast_before_mirror is true,
If gradient_fp32_sync is true,
it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
Args:
cast_before_mirror (bool): The cast_before_mirror flag.
gradient_fp32_sync (bool): The gradient_fp32_sync flag.
"""
self.check_context_handle()
self._context_handle.set_cast_before_mirror(cast_before_mirror)
self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
def get_cast_before_mirror(self):
"""Get cast_before_mirror flag."""
def get_gradient_fp32_sync(self):
"""Get gradient_fp32_sync flag."""
self.check_context_handle()
return self._context_handle.get_cast_before_mirror()
return self._context_handle.get_gradient_fp32_sync()
def set_loss_repeated_mean(self, loss_repeated_mean):
"""
......@@ -152,21 +152,6 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_loss_repeated_mean()
def set_communication_backend(self, communication_backend):
"""
Set communication backend.
Args:
communication_backend (str): The communication backend.
"""
self.check_context_handle()
self._context_handle.set_communication_backend(communication_backend)
def get_communication_backend(self):
"""Get communication backend."""
self.check_context_handle()
return self._context_handle.get_communication_backend()
def set_parallel_mode(self, parallel_mode):
"""
Set parallel mode for auto parallel.
......@@ -469,7 +454,7 @@ _set_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().set_device_num,
"global_rank": auto_parallel_context().set_global_rank,
"mirror_mean": auto_parallel_context().set_mirror_mean,
"cast_before_mirror": auto_parallel_context().set_cast_before_mirror,
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
......@@ -484,7 +469,7 @@ _get_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().get_device_num,
"global_rank": auto_parallel_context().get_global_rank,
"mirror_mean": auto_parallel_context().get_mirror_mean,
"cast_before_mirror": auto_parallel_context().get_cast_before_mirror,
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode,
"auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
......@@ -495,7 +480,7 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
......@@ -512,8 +497,9 @@ def _set_auto_parallel_context(**kwargs):
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
calculations. Default: True.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True.
calculations. Default: True.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
......@@ -577,7 +563,7 @@ def _reset_auto_parallel_context():
- device_num: 1.
- global_rank: 0.
- mirror_mean: False.
- cast_before_mirror: True.
- gradient_fp32_sync: True.
- parallel_mode: "stand_alone".
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
......
......@@ -61,7 +61,7 @@ def get_rank_id(group=None):
def get_rank_size(group=None):
hccl = Hccl()
if group is None:
if group is None or "nccl_world_group" in group:
return hccl.rank_size
if isinstance(group, str):
return int(group.split("-")[0])
......
......@@ -830,7 +830,7 @@ def test_matmul_cast():
compile_net(net, x, y, b)
def test_cast_before_mirror():
def test_gradient_fp32_sync():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
......@@ -843,7 +843,7 @@ def test_cast_before_mirror():
out = self.matmul(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True)
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......@@ -854,7 +854,7 @@ def test_cast_before_mirror():
compile_net(net, x, y, b)
def test_cast_before_mirror1():
def test_gradient_fp32_sync1():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
......@@ -867,7 +867,7 @@ def test_cast_before_mirror1():
out = self.matmul(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=True)
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=True)
strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......@@ -878,7 +878,7 @@ def test_cast_before_mirror1():
compile_net(net, x, y, b)
def test_cast_before_mirror2():
def test_gradient_fp32_sync2():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
......@@ -891,7 +891,7 @@ def test_cast_before_mirror2():
out = self.matmul(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0, cast_before_mirror=False)
context.set_auto_parallel_context(device_num=8, global_rank=0, gradient_fp32_sync=False)
strategy1 = ((2, 2), (2, 2))
net = GradWrap(NetWithLoss(Net(strategy1)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......@@ -902,7 +902,7 @@ def test_cast_before_mirror2():
compile_net(net, x, y, b)
def test_cast_before_mirror3():
def test_gradient_fp32_sync3():
class Net(nn.Cell):
def __init__(self, strategy1):
super().__init__()
......
......@@ -20,25 +20,21 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def test_set_auto_parallel_context():
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False,
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False,
parallel_mode="auto_parallel", parameter_broadcast=False)
device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean")
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror")
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
assert device_num == 4
assert global_rank == 3
assert mirror_mean
assert not cast_before_mirror
assert not gradient_fp32_sync
assert parallel_mode == "auto_parallel"
assert not parameter_broadcast
auto_parallel_context().set_communication_backend("hccl")
backend = auto_parallel_context().get_communication_backend()
assert backend == "hccl"
auto_parallel_context().set_device_num(4)
device_num = auto_parallel_context().get_device_num()
device_num_is_set = auto_parallel_context().get_device_num_is_set()
......@@ -53,9 +49,9 @@ def test_set_auto_parallel_context():
mirror_mean = auto_parallel_context().get_mirror_mean()
assert mirror_mean
auto_parallel_context().set_cast_before_mirror(False)
cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
assert not cast_before_mirror
auto_parallel_context().set_gradient_fp32_sync(False)
gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync()
assert not gradient_fp32_sync
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert parameter_broadcast_is_set
......@@ -91,7 +87,7 @@ def test_reset_auto_parallel_context():
device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean")
cast_before_mirror = context.get_auto_parallel_context("cast_before_mirror")
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
device_num_is_set = auto_parallel_context().get_device_num_is_set()
......@@ -99,7 +95,7 @@ def test_reset_auto_parallel_context():
assert device_num == 1
assert global_rank == 0
assert not mirror_mean
assert cast_before_mirror
assert gradient_fp32_sync
assert parallel_mode == "stand_alone"
assert not parameter_broadcast
assert not device_num_is_set
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册