提交 7786adc3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5722 fix semi auto parallel parameter of reshape has another user

Merge pull request !5722 from yao_yf/semi_auto_parallel_reshape_parameter_has_another_user
......@@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return nullptr;
}
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) {
FuncGraphManagerPtr manager = node->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[node];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) {
continue;
}
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
}
return nullptr;
}
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
// Create DataParallel tensor layout for parameter(support WideDeep).
auto next_layout = FindParameterNextLayout(node);
if (next_layout != nullptr) {
return next_layout;
}
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
TensorLayout input_tensor_layout;
......
......@@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeI
RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode);
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node);
ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &));
} // namespace parallel
} // namespace mindspore
......
......@@ -292,3 +292,25 @@ def test_reshape_auto_6():
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_reshape_auto_7():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
def construct(self, x):
weight = self.reshape(self.mul_weight, (1, 128, 96))
out = self.mul(weight, self.mul_weight)
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([128, 28]), dtype=ms.float32)
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册