提交 56362287 编写于 作者: L lichenever

update

上级 e32d539b
......@@ -28,9 +28,14 @@ namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " don't have op:" << op_name;
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
}
return inner_ops_module;
}
return ops_module;
}
......
......@@ -56,6 +56,12 @@ Status GatherV2PInfo::GetAttrs() {
}
}
// target=CPU, axis must be 0
if (target_ == "CPU" && axis_ != 0) {
MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_;
return FAILED;
}
return SUCCESS;
}
......@@ -279,6 +285,11 @@ Status GatherV2PInfo::InferBias() {
int32_t rank = g_device_manager->global_rank();
auto input_shape = inputs_shape_.at(0);
auto params_strategy = strategy_->GetInputDim().at(0);
// axis don't split
if (params_strategy.at(axis_) == 1) {
bias_ = 0;
return SUCCESS;
}
// params_size=1, axis=0
if ((input_shape.size() == 1) && (axis_ == 0)) {
slice_size_ = input_shape.at(0) / params_strategy.at(0);
......@@ -353,26 +364,35 @@ Status GatherV2PInfo::InferForwardCommunication() {
}
auto group_size = group_.GetDevNum();
Attr attr_group;
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
if (host_reduce_scatter_) {
// group size <= 8
std::vector<int32_t> rank_list;
if (group_size <= 8) {
reduce_scatter_flag_ = false;
operator_name = HOST_REDUCE_SCATTER;
rank_list = GetRankFromGroup(group_);
attr_group = std::make_pair(GROUP, MakeValue(rank_list));
} else {
// group size > 8, don't support host reduce_scatter
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
}
} else {
// group size > 8
reduce_scatter_flag_ = true;
split_num_ = SizeToInt(group_size / 8);
CheckGlobalDeviceManager();
operator_name = REDUCE_SCATTER;
int32_t rank = g_device_manager->global_rank();
size_t repeat = group_size / 8;
for (size_t i = 0; i < repeat; ++i) {
rank_list.push_back(rank + SizeToInt(i * 8));
if (InferGroup() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Group failed.";
return FAILED;
}
Group g = g_device_manager->CreateGroup(rank_list);
attr_group = std::make_pair(GROUP, MakeValue(g.name()));
attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
}
Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
OperatorAttrs attrs = {attr_op, attr_group};
......@@ -446,8 +466,8 @@ Status GatherV2PInfo::ComputeReplaceOp() {
Attr param_offset = std::make_pair("offset", MakeValue(bias_));
Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_));
Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_));
OperatorParams params = {std::make_pair(param_offset, 4), std::make_pair(param_flag, 5),
std::make_pair(param_split_num, 6)};
OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4),
std::make_pair(param_split_num, 5)};
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(op_name, args);
replace_op_.push_back(op);
......
......@@ -70,6 +70,7 @@ class GatherV2PInfo : public OperatorInfo {
Group group_;
bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1;
bool host_reduce_scatter_ = false;
};
class SparseGatherV2Info : public GatherV2PInfo {
......
......@@ -55,6 +55,7 @@ constexpr char REDUCE_OP_SUM[] = "sum";
constexpr char REDUCE_OP_MAX[] = "max";
constexpr char REDUCE_OP_MIN[] = "min";
constexpr char OP_PATH[] = "mindspore.ops.operations";
constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops";
constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils";
constexpr char GET_OP_FUNCTION[] = "_get_python_op";
constexpr char KEEP_DIMS[] = "keep_dims";
......
......@@ -536,7 +536,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
auto prim = GetValueNode<PrimitivePtr>(node->input(0));
if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) {
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2), node->input(3)};
replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
}
if (!params.empty()) {
Param param_first = *(params.begin());
......
......@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu0():
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......@@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu1():
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......@@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu2():
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......
......@@ -184,7 +184,7 @@ def test_gatherv2_auto1():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu0():
def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......@@ -196,7 +196,7 @@ def need_fix_test_gatherv2_cpu0():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu1():
def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......@@ -208,7 +208,7 @@ def need_fix_test_gatherv2_cpu1():
_executor.compile(net, x, y)
def need_fix_test_gatherv2_cpu2():
def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册