提交 ef71ae94 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!698 [Auto parallel] Support multi-subgraphs in auto-parallel

Merge pull request !698 from Xiaoda/support-wide-deep-in-auto-parallel
......@@ -44,6 +44,7 @@ namespace parallel {
#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false
class CostGraph;
using CostGraphPtr = std::shared_ptr<CostGraph>;
......
......@@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD;
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
......@@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) {
void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; }
void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; }
void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) {
costmodel_allreduce_fusion_algorithm_ = algorithm;
}
......
......@@ -67,6 +67,9 @@ class CostModelContext {
void set_costmodel_communi_bias(double);
double costmodel_communi_bias() const { return costmodel_communi_bias_; }
void set_multi_subgraphs(bool);
bool is_multi_subgraphs() const { return is_multi_subgraphs_; }
void set_costmodel_allreduce_fusion_algorithm(int32_t);
int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; }
......@@ -138,6 +141,8 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS
double costmodel_communi_bias_;
bool is_multi_subgraphs_;
int32_t costmodel_allreduce_fusion_algorithm_;
int32_t costmodel_allreduce_fusion_times_;
......
......@@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
return operator_info;
}
Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
// Using CNode's UniqueIds to construct nodes
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
bool new_operator = true, first_operator = true;
std::string first_operator_cnode;
size_t current_op_index = 0;
// The map from CNode's UniqueId to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
// Step 1
for (auto &node : all_nodes) {
......@@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
// When visiting the second subgraph, use the corresponding operatorInfo which already created
bool modify_new_operator = (new_operator) && (!first_operator) && (cnode->UniqueId() == first_operator_cnode);
if (modify_new_operator) {
new_operator = false;
}
if (new_operator) {
auto search_cnode = from_cnode_to_info.find(cnode->UniqueId());
if (search_cnode == from_cnode_to_info.end()) {
auto operator_info = CreateTheOperatorInfo(prim, cnode);
if (operator_info == nullptr) {
return FAILED;
......@@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
if (first_operator) {
first_operator_cnode = cnode->UniqueId();
first_operator = false;
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
// Needed by rec_parser
entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
} else {
// Two CNODEs' UniqueIds should not be equal
MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name();
}
}
MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
return SUCCESS;
}
// Using CNode's UniqueIdThroughCopys to construct nodes
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) {
MS_LOG(INFO) << "Constructing nodes for cost graph begins.";
entire_costgraph = std::make_shared<CostGraph>();
entire_costgraph->SetDeviceMemoryAndCostParameter();
// The map from CNode's UniqueIdThroughCopy to its operatorInfo
std::map<std::string, OperatorInfoPtr> from_cnode_to_info;
for (auto &node : all_nodes) {
// NOTE: we only care about splittable Primitive operators
auto cnode = node->cast<CNodePtr>();
bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0)));
if (bool_result) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
// Find the operatorInfo if it exists
auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy());
if (search_cnode == from_cnode_to_info.end()) {
// In this case, the corresponding OperatorInfo is not created, create the new one.
auto operator_info = CreateTheOperatorInfo(prim, cnode);
if (operator_info == nullptr) {
return FAILED;
}
// Needed by rec_parser
operator_info->set_type(prim->name());
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
(void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info));
// Needed by rec_parser
entire_costgraph->add_inputs_tensor_name(inputs_tensor_name);
} else {
auto current_op_ptr = entire_costgraph->FindOperatorByIndex(current_op_index);
auto current_op_ptr = search_cnode->second;
if (current_op_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed.";
} else {
......@@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F
<< " does not match the Prim: " << prim->name();
}
(void)cnode->set_operator_info(current_op_ptr);
current_op_index++;
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
}
}
}
if ((!new_operator) && (current_op_index != entire_costgraph->GetOperators().size())) {
MS_LOG(EXCEPTION) << "The second subgraph's operator number: " << current_op_index
<< " does not match the first ones: " << entire_costgraph->GetOperators().size();
}
MS_LOG(INFO) << "Constructing nodes for cost graph ends.";
return SUCCESS;
......@@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// OUTPUT: the determined strategy for each operator.
// Step 1
if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) {
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
<< " operators.";
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
<< entire_costgraph->GetOperators().size() << " operators.";
} else {
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
}
} else {
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are "
<< entire_costgraph->GetOperators().size() << " operators.";
} else {
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
}
}
// Step 2
......@@ -916,7 +972,7 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
}
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) {
if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) {
MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
<< " operators.";
} else {
......
......@@ -43,7 +43,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);
std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node);
Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root);
void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes);
......
......@@ -24,6 +24,7 @@
#include <functional>
#include "ir/func_graph_cloner.h"
#include "parallel/costmodel_context.h"
#include "pipeline/pass.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/data_converter.h"
......@@ -341,7 +342,10 @@ static std::vector<ActionItem> CommonPipeline() {
// Resolve the python func
actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
if (!multi_graphs) {
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
......
......@@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Set the parameter cost_model_communi_bias of the DP algorithm.")
.def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias,
"Get the parameter cost_model_communi_bias of the DP algorithm.")
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
"Set the parameter gradient AllReduce fusion algorithm.")
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
......
......@@ -214,6 +214,31 @@ class _CostModelContext:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_costmodel_communi_bias()
def set_multi_subgraphs(self, multi_subgraph):
"""
Set the flag of ANF graph containing multiple subgraphs.
Args:
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
self._context_handle.set_multi_subgraphs(multi_subgraph)
def get_multi_subgraphs(self):
"""
Get the flag of ANF graph containing multiple subgraphs.
Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_multi_subgraphs()
def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
"""
Set costmodel allreduce fusion algorithm.
......@@ -427,6 +452,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
......@@ -447,6 +473,7 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
......@@ -461,6 +488,7 @@ get_cost_model_context_func_map = {
@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
multi_subgraphs=bool,
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
costmodel_allreduce_fusion_allreduce_inherent_time=float,
......@@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
1: only use backward computation time to group allreduce;
......
import numpy as np
from mindspore import context
import mindspore as ms
import mindspore.nn as nn
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.common.api import _executor
from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters
from mindspore.parallel._utils import _reset_op_id as reset_op_id
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul = P.Mul()
self.relu = P.ReLU()
self.wd = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide")
self.wt = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="l")
def construct(self, x):
out = self.mul(x, self.wd)
out = self.mul(out, self.wt)
out = self.relu(out)
return out
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.sum = P.ReduceSum()
self.mean = P.ReduceMean()
self.net = network
def construct(self, x):
predict = self.net(x)
loss1 = self.sum(predict, -1)
loss2 = self.mean(predict, -1)
return loss1, loss2
class IthOutputCell(nn.Cell):
def __init__(self, network, output_index):
super(IthOutputCell, self).__init__()
self.network = network
self.output_index = output_index
def construct(self, x):
predict = self.network(x)[self.output_index]
return predict
class TrainStepWarp(nn.Cell):
def __init__(self, network, sens=1000.0):
super(TrainStepWarp, self).__init__()
self.network = network
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
weights_d = []
for params in self.trainable_params:
weights_w.append(params)
weights_d.append(params)
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, l1=1e-8,
l2=1e-8, initial_accum=1.0)
self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8,
loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True, sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True, sens_param=True)
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
def construct(self, x):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(x)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
def test_double_subgraphs():
cost_model_context.set_cost_model_context(multi_subgraphs=True)
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = TrainStepWarp(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
reset_op_id()
_executor.compile(net, x, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
assert strategies == expected_strategies
import numpy as np
from mindspore import context
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as reset_op_id
import re
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x):
predict = self.network(x)
return self.loss(predict)
class Blockcell(nn.Cell):
def __init__(self):
super(Blockcell, self).__init__()
self.bn = nn.BatchNorm2d(64, momentum=0.9)
def construct(self, x):
out = self.bn(x)
return out
def getBlock():
return Blockcell()
def test_two_bn():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.block1 = getBlock()
self.block2 = getBlock()
self.relu = P.ReLU()
self.add = P.TensorAdd()
self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32)
def construct(self, x):
out = self.block1(x)
out = self.relu(out)
out = self.add(out, self.bias)
out = self.block2(out)
return out
net = NetWithLoss(Net())
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_algo_parameters(elementwise_op_strategy_follow=True)
reset_op_id()
_executor.compile(net, x, phase='train')
strategies = _executor._get_strategy(net)
assert len(strategies) == 4
for (k, v) in strategies.items():
if re.search('BatchNorm-op', k) is not None:
assert v == [[8, 1], [1], [1], [1], [1]]
elif re.search('TensorAdd-op', k) is not None:
assert v == [[8, 1], [8, 1]]
elif re.search('ReLU-op', k) is not None:
assert v == [[8, 1]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册