提交 36a62576 编写于 作者: Y yangzhenzhang

support forward graph

上级 00191223
......@@ -345,7 +345,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
......@@ -903,9 +902,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
}
}
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) {
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(node);
bool is_loss_cnode =
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; });
MirrorOps mirror_ops = distribute_operator->mirror_ops();
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
// insert mirror op
......@@ -914,7 +919,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
InsertMirrorOps(mirror_ops, node);
}
// insert virtual div op
if (!virtual_div_op.empty() && is_loss_node) {
if (!virtual_div_op.empty() && is_loss_cnode) {
MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
InsertVirtualDivOp(virtual_div_op, node);
}
......@@ -986,10 +991,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
Dimensions dim;
if (elements[index]->isa<ValueSequeue>()) {
ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>();
if (value_tuple == nullptr) {
MS_LOG(EXCEPTION) << "Failure:value_tuple is nullptr";
}
std::vector<ValuePtr> value_vector = value_tuple->value();
(void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
[](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int>(value)); });
......@@ -1013,7 +1014,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
BaseShapePtr base_shape_ptr = node->Shape();
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<Primitive>(cnode->input(0))) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
......@@ -1190,7 +1190,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
continue;
}
CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>();
if ((graph_cnode_inp0 == nullptr) || !IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
continue;
}
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1));
......@@ -1692,14 +1692,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
return pre_cnode;
}
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
TensorLayouts ret;
if (!IsValueNode<FuncGraph>(cnode->input(1))) {
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
}
auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
auto loss_cnode = FindLossCNode(func_graph);
MS_EXCEPTION_IF_NULL(loss_cnode);
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(node);
......@@ -1735,16 +1729,16 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
return ret;
}
void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
MS_EXCEPTION_IF_NULL(grad_sens_node);
auto cnode = grad_sens_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr sens_tensor_node = cnode->input(1);
if (grad_sens_node->size() <= 1) {
MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
}
AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
MS_EXCEPTION_IF_NULL(sens_tensor_node);
Shapes sens_shapes = GetNodeShape(sens_tensor_node);
if (sens_shapes.size() != 1) {
MS_LOG(EXCEPTION) << "SplitSens: GetNodeShape for sens_tensor_node, output size is not 1";
MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
}
// If the shape of sens tensor is [] or [1], no need to split it.
Shape sens_shape = sens_shapes[0];
......@@ -1780,14 +1774,14 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
return;
}
MS_LOG(EXCEPTION) << "SplitSens: the type of sens node is not Tensor or Parameter, it is unsupported now.";
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
}
// Use _GetTensorSlice operator to split the sens tensor
FuncGraphPtr func_graph = cnode->func_graph(); // only cnode can get the graph
FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph
MS_EXCEPTION_IF_NULL(func_graph);
Operator op = CreateGetTensorSliceOp(loss_grad_layout);
InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS);
InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
}
void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
......@@ -1853,7 +1847,6 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
......@@ -1870,55 +1863,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
return graph_set;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
void StepSplitSens(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)
auto cnode = node->cast<CNodePtr>();
AnfNodePtr expect_tuple_getitem = cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
if (!expect_tuple_getitem->isa<CNode>()) {
return;
}
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode);
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) {
return;
}
auto expect_tuple_getitem_prim = GetValueNode<PrimitivePtr>(expect_tuple_getitem_cnode->input(0));
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
MS_EXCEPTION_IF_NULL(expect_anonymous);
if (!expect_anonymous->isa<CNode>()) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode);
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_j);
if (!expect_j->isa<CNode>()) {
return;
}
auto expect_j_cnode = expect_j->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_j_cnode);
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) {
return;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(expect_j_cnode->input(0));
if (expect_j_prim->name() == J) {
auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode);
if (!loss_grad_layout.empty()) {
SplitSens(node, loss_grad_layout[0]);
}
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
CNodePtr sens_node = sens_loss_pair.first;
CNodePtr loss_node = sens_loss_pair.second;
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
if (!loss_grad_layout.empty()) {
SplitSens(sens_node, loss_grad_layout[0]);
}
}
......@@ -1937,26 +1887,77 @@ std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) {
return loss_node;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs;
for (auto &node : root->nodes()) {
if (!node->isa<CNode>()) {
continue;
}
// cnode(sens)-->cnode(tuple_getitem)
auto sens_cnode = node->cast<CNodePtr>();
AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
if (!expect_tuple_getitem->isa<CNode>()) {
continue;
}
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) {
continue;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
MS_EXCEPTION_IF_NULL(expect_anonymous);
if (!expect_anonymous->isa<CNode>()) {
continue;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_j);
if (!expect_j->isa<CNode>()) {
continue;
}
auto expect_j_cnode = expect_j->cast<CNodePtr>();
if (!IsSomePrimitive(expect_j_cnode, J)) {
continue;
}
if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
}
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
auto loss_cnode = FindLossCNode(func_graph);
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode);
sens_loss_pairs.push_back(sens_loss_pair);
}
return sens_loss_pairs;
}
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(manager);
TensorRedistribution tensor_redistribution;
AnfNodePtr grad_sens_node = nullptr;
std::vector<CNodePtr> loss_cnode = FindLossCNodeFromRoot(root);
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root);
bool has_backward = !sens_loss_pairs.empty();
// split sens must before inserting the operators.
for (auto &node : all_nodes) {
for (auto &pair : sens_loss_pairs) {
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
StepSplitSens(node);
StepSplitSens(pair);
}
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
......@@ -1965,11 +1966,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
continue;
}
bool is_loss_cnode = false;
auto iter = std::find(loss_cnode.begin(), loss_cnode.end(), cnode);
if (iter != loss_cnode.end()) {
is_loss_cnode = true;
}
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
......@@ -1977,7 +1973,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
// insert backward ops
BackwardCommunication(distribute_operator, cnode, is_loss_cnode);
if (has_backward) {
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
}
// StepReplace
StepReplace(distribute_operator, cnode);
......@@ -2099,7 +2097,6 @@ void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
......@@ -2117,7 +2114,6 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
......@@ -2146,7 +2142,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto root_node_id = node->UniqueIdThroughCopy();
if (loss_cnode_id == root_node_id) {
root_forward_nodes = DeepLinkedGraphSearch(cnode);
......
......@@ -82,7 +82,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node);
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs);
// Generate and init parallel operator
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.common.api import _executor
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.neg = P.Neg().set_strategy(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.neg(out)
return out, b
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile(net):
_executor.compile(net, _x, _b)
context.reset_auto_parallel_context()
def test_forward_graph_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1), (16, 1, 1))
strategy2 = ((16, 1, 1), )
net = Net(_w1, strategy1, strategy2)
compile(net)
def test_forward_graph_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16), (1, 1, 16))
strategy2 = ((1, 1, 16), )
net = Net(_w1, strategy1, strategy2)
compile(net)
def test_forward_graph_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4), (2, 2, 4))
strategy2 = ((2, 2, 4), )
net = Net(_w1, strategy1, strategy2)
compile(net)
def test_forward_graph_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(_w1)
compile(net)
def test_forward_graph_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4), (2, 2, 4))
strategy2 = ((1, 2, 2), )
net = Net(_w1, strategy1, strategy2)
compile(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册