提交 84d5e4f9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!643 [AutoParallel]Support reshape parameter

Merge pull request !643 from lichen/support_reshape_parameter
...@@ -1530,9 +1530,32 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n ...@@ -1530,9 +1530,32 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
return nullptr; return nullptr;
} }
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) {
// Create DataParallel tensor layout for parameter(support WideDeep).
CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
TensorLayout input_tensor_layout;
// create input_shape
Shapes inputs_shape = GetNodeShape(node);
Shape input_shape_array = inputs_shape[0];
if (input_shape_array.empty()) {
MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter.";
}
// create tensor_map
size_t shape_size = input_shape_array.size();
TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1);
input_tensor_map_array.insert(input_tensor_map_array.begin(), 0);
// create dev_matrix
Shape dev_matrix_array = {dev_num};
if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) {
MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed.";
}
return std::make_shared<TensorLayout>(input_tensor_layout);
}
std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "Failure: parameter before reshape is not supported temporary"; return CreateParameterLayout(node);
} }
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return nullptr; return nullptr;
......
...@@ -415,8 +415,11 @@ def set_auto_parallel_context(**kwargs): ...@@ -415,8 +415,11 @@ def set_auto_parallel_context(**kwargs):
Args: Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True. Default: True. "stand_alone" do not support mirror_mean. Default: False.
cast_before_mirror (bool): Insert Mirror Op after the cast if this flag is True.
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
cast_before_mirror. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import numpy as np
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, strategy):
super().__init__()
self.reshape = P.Reshape()
self.mul = P.Mul().set_strategy(strategy)
self.relu = P.ReLU()
def construct(self, x, y):
out = self.reshape(x, (10000, 36, 1))
out = self.mul(out, y)
out = self.relu(out)
return out
def test_reshape_parameter_data_parallel():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy = ((8, 1, 1), (8, 1, 1))
net = GradWrap(NetWithLoss(Net(strategy)))
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_reshape_parameter_model_parallel():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy = ((4, 2, 1), (4, 2, 1))
net = GradWrap(NetWithLoss(Net(strategy)))
x = Tensor(np.ones([10000, 36]), dtype=ms.float32)
y = Tensor(np.ones([10000, 36, 1]), dtype=ms.float32)
_executor.compile(net, x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册