提交 c78630d7 编写于 作者: L lichenever

support multiple subgraphs

上级 d5c002f7
......@@ -399,7 +399,12 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) {
ret_ = ret;
root_graph_ = ret_->func_graph();
MS_EXCEPTION_IF_NULL(root_graph_);
auto forward_graph = ForwardGraph(root_graph_);
auto graph_set = ForwardGraph(root_graph_);
if (graph_set.size() > 1) {
MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now.";
return SUCCESS;
}
auto forward_graph = *(graph_set.begin());
MS_EXCEPTION_IF_NULL(forward_graph);
forward_ret_ = forward_graph->get_return();
MS_EXCEPTION_IF_NULL(forward_ret_);
......
......@@ -1607,72 +1607,79 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
}
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
bool IsGradSensNode(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->size() < 2) {
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
}
AnfNodePtr pre_node = return_node->input(1);
MS_EXCEPTION_IF_NULL(pre_node);
// cnode(sens)-->cnode(tuple_getitem)
auto cnode = node->cast<CNodePtr>();
AnfNodePtr expect_tuple_getitem = cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
if (!expect_tuple_getitem->isa<CNode>()) {
return false;
}
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode);
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) {
return false;
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
}
ValueNodePtr expect_tuple_getitem_value_node = expect_tuple_getitem_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_value_node);
PrimitivePtr expect_tuple_getitem_prim = expect_tuple_getitem_value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_prim);
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) {
return false;
// notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(INFO) << "The loss is: " << current_prim->name();
return pre_cnode;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
MS_EXCEPTION_IF_NULL(expect_anonymous);
if (!expect_anonymous->isa<CNode>()) {
return false;
// size of common cnode is larger than 1
if (pre_cnode->size() < 2) {
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode);
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_j);
if (!expect_j->isa<CNode>()) {
return false;
// return -> tuple_getitem -> loss
if (current_prim->name() == TUPLE_GETITEM) {
AnfNodePtr pre_pre_node = pre_cnode->input(1);
MS_EXCEPTION_IF_NULL(pre_pre_node);
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value);
PrimitivePtr prim = value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(DEBUG) << "The loss name is " << prim->name();
return pre_pre_cnode;
}
auto expect_j_cnode = expect_j->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_j_cnode);
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) {
return false;
// return -> make_tuple
if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported";
}
ValueNodePtr expect_j_value_node = expect_j_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(expect_j_value_node);
PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(expect_j_prim);
return (expect_j_prim->name() == J);
// return -> loss
MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
return pre_cnode;
}
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
TensorLayouts ret;
if (!IsValueNode<FuncGraph>(cnode->input(1))) {
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
}
auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
auto loss_cnode = FindLossCNode(func_graph);
MS_EXCEPTION_IF_NULL(loss_cnode);
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(node);
LossNodeInfo node_info = GetLossNodeInfo(node);
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
TensorLayouts ret;
if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now";
return ret;
......@@ -1680,7 +1687,6 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
OperatorInfoPtr operator_info = loss_cnode->operator_info();
MS_EXCEPTION_IF_NULL(operator_info);
TensorInfo loss_grad_tensor_info;
size_t op_output_size = operator_info->outputs_tensor_info().size();
MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is "
......@@ -1805,6 +1811,100 @@ void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePt
HandleDropoutNode(distribute_operator, cnode);
}
std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
// J->CNode->Graph
std::set<FuncGraphPtr> graph_set;
for (auto &node : root_all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (expect_j_prim->name() != J) {
continue;
}
if (IsValueNode<FuncGraph>(cnode->input(1))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
MS_LOG(DEBUG) << "Find the forward graph success";
graph_set.insert(graph);
}
}
return graph_set;
}
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
void StepSplitSens(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)
auto cnode = node->cast<CNodePtr>();
AnfNodePtr expect_tuple_getitem = cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
if (!expect_tuple_getitem->isa<CNode>()) {
return;
}
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode);
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) {
return;
}
auto expect_tuple_getitem_prim = GetValueNode<PrimitivePtr>(expect_tuple_getitem_cnode->input(0));
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
MS_EXCEPTION_IF_NULL(expect_anonymous);
if (!expect_anonymous->isa<CNode>()) {
return;
}
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode);
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
MS_EXCEPTION_IF_NULL(expect_j);
if (!expect_j->isa<CNode>()) {
return;
}
auto expect_j_cnode = expect_j->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(expect_j_cnode);
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) {
return;
}
auto expect_j_prim = GetValueNode<PrimitivePtr>(expect_j_cnode->input(0));
if (expect_j_prim->name() == J) {
auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode);
if (!loss_grad_layout.empty()) {
SplitSens(node, loss_grad_layout[0]);
}
}
}
std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
AnfNodePtr root_return_node = root->get_return();
MS_EXCEPTION_IF_NULL(root_return_node);
std::vector<CNodePtr> loss_node;
const auto &all_nodes = root->nodes();
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
if (graph_set.empty()) {
loss_node.push_back(FindLossCNode(root));
}
(void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node),
[](const FuncGraphPtr &graph) { return FindLossCNode(graph); });
return loss_node;
}
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(root);
......@@ -1812,18 +1912,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
TensorRedistribution tensor_redistribution;
AnfNodePtr grad_sens_node = nullptr;
CNodePtr loss_cnode = FindLossCNodeFromRoot(root);
MS_EXCEPTION_IF_NULL(loss_cnode);
// get output layout of loss must before inserting the operators below
TensorLayouts loss_layout = GetLossNodeGradOutputLayout(loss_cnode);
std::vector<CNodePtr> loss_cnode = FindLossCNodeFromRoot(root);
// split sens must before inserting the operators.
for (auto &node : all_nodes) {
// find sens node
if ((grad_sens_node == nullptr) && IsGradSensNode(node)) {
grad_sens_node = node;
MS_LOG(INFO) << "Find the sens node success";
}
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
StepSplitSens(node);
}
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
......@@ -1837,7 +1934,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
}
bool is_loss_cnode = false;
if (cnode == loss_cnode) {
auto iter = std::find(loss_cnode.begin(), loss_cnode.end(), cnode);
if (iter != loss_cnode.end()) {
is_loss_cnode = true;
}
// insert forward ops
......@@ -1857,12 +1955,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
StepSplitTensor(node, manager);
}
}
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
if (grad_sens_node && !loss_layout.empty()) {
SplitSens(grad_sens_node, loss_layout[0]);
}
}
namespace {
......@@ -2003,134 +2095,57 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) {
}
}
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
}
AnfNodePtr pre_node = return_node->input(1);
MS_EXCEPTION_IF_NULL(pre_node);
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto current_value = pre_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(current_value);
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(current_prim);
// return -> cast
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
}
// notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(INFO) << "The loss is: " << current_prim->name();
return pre_cnode;
}
// size of common cnode is larger than 1
if (pre_cnode->inputs().size() < 2) {
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
}
// return -> tuple_getitem -> loss
if (current_prim->name() == TUPLE_GETITEM) {
AnfNodePtr pre_pre_node = pre_cnode->input(1);
MS_EXCEPTION_IF_NULL(pre_pre_node);
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value);
PrimitivePtr prim = value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(INFO) << "The loss name is " << prim->name();
return pre_pre_cnode;
} else if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported";
}
// return -> loss
MS_LOG(INFO) << "The loss name is " << current_prim->name();
return pre_cnode;
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
const auto &all_nodes = root->nodes();
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
return graph_set;
}
FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) {
for (auto &node : root_all_nodes) {
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
MS_EXCEPTION_IF_NULL(graph);
auto loss_cnode = FindLossCNode(graph);
MS_EXCEPTION_IF_NULL(loss_cnode);
auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy();
std::vector<AnfNodePtr> root_forward_nodes;
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if ((cnode->inputs().size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
ValueNodePtr expect_j_value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(expect_j_value_node);
PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(expect_j_prim);
if (expect_j_prim->name() != J) {
continue;
}
MS_LOG(DEBUG) << "Find J prim: " << expect_j_value_node->DebugString() << ".";
if (IsValueNode<FuncGraph>(cnode->input(1))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
MS_LOG(INFO) << "Find the forward graph success";
return graph;
auto root_node_id = node->UniqueIdThroughCopy();
if (loss_cnode_id == root_node_id) {
root_forward_nodes = DeepLinkedGraphSearch(cnode);
break;
}
}
return nullptr;
}
CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
AnfNodePtr root_return_node = root->get_return();
MS_EXCEPTION_IF_NULL(root_return_node);
const auto &all_nodes = root->nodes();
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes);
if (func_graph == nullptr) {
return FindLossCNode(root);
} else {
return FindLossCNode(func_graph);
}
}
FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) {
FuncGraphPtr forward_graph = root;
MS_EXCEPTION_IF_NULL(root);
AnfNodePtr root_return_node = root->get_return();
MS_EXCEPTION_IF_NULL(root_return_node);
const auto &all_nodes = root->nodes();
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes);
if (func_graph != nullptr) {
forward_graph = func_graph;
}
return forward_graph;
return root_forward_nodes;
}
void MarkForwardCNode(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
AnfNodePtr root_return_node = root->get_return();
MS_EXCEPTION_IF_NULL(root_return_node);
auto &all_nodes = root->nodes();
FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes);
auto all_nodes = root->nodes();
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
if (func_graph == nullptr) {
// Can not find the forward graph, so the ops in root graph are forward.
if (graph_set.empty()) {
MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
SetForwardFlag(all_nodes);
} else {
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
AnfNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> all_dfs_nodes = DeepLinkedGraphSearch(return_node);
SetForwardFlag(all_dfs_nodes);
for (auto &func_graph : graph_set) {
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
auto return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
SetForwardFlag(all_dfs_nodes);
auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes);
if (root_forward_nodes.empty()) {
continue;
}
// Mark forward flag for the nodes in root graph.
SetForwardFlag(root_forward_nodes);
}
}
}
......
......@@ -24,6 +24,7 @@
#include <string>
#include <unordered_map>
#include <utility>
#include <set>
#include "./common.h"
#include "optimizer/opt.h"
......@@ -142,13 +143,13 @@ bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optim
int32_t GetTupleGetItemIndex(const CNodePtr &cnode);
CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root);
std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root);
Status ParallelInit();
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
FuncGraphPtr ForwardGraph(const FuncGraphPtr &root);
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
} // namespace parallel
} // namespace mindspore
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import mindspore as ms
from mindspore import Tensor, Parameter, ParameterTuple, context
from mindspore import nn
from mindspore.common.api import _executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
import numpy as np
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul = P.Mul()
self.relu = P.ReLU()
self.param1 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide")
self.param2 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="deep")
def construct(self, x):
out = self.mul(x, self.param1)
out = self.mul(out, self.param2)
out = self.relu(out)
return out
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.sum = P.ReduceSum(keep_dims=False).set_strategy(strategy=((4, 1, 1, 1),))
self.mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=((8, 1, 1, 1),))
self.net = network
def construct(self, x):
net_out = self.net(x)
loss1 = self.sum(net_out, -1)
loss2 = self.mean(net_out, -1)
return loss1, loss2
class IthOutputCell(nn.Cell):
def __init__(self, network, output_index):
super(IthOutputCell, self).__init__()
self.network = network
self.output_index = output_index
def construct(self, x1):
predict = self.network(x1)[self.output_index]
return predict
class TrainStepWrap(nn.Cell):
def __init__(self, network, sens=1000.0):
super(TrainStepWrap, self).__init__()
self.network = network
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
weights_d = []
for params in self.trainable_params:
weights_w.append(params)
weights_d.append(params)
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0)
self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8,
loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True,
sens_param=True)
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
def construct(self, x):
weights_w = self.weights_w
weights_d = self.weights_d
loss_w, loss_d = self.network(x)
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
def test_two_subgraphs():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
net = TrainStepWrap(NetWithLoss(Net()))
input_x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
_executor.compile(net, input_x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册