提交 5240b1f6 编写于 作者: L lichenever

fix refkey bug for auto parallel

上级 a44b5293
......@@ -49,6 +49,9 @@ namespace mindspore {
namespace parallel {
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
if (new_node_input.empty()) {
......@@ -1085,11 +1088,19 @@ std::vector<Shapes> ExtractShape(const CNodePtr& node) {
std::vector<AnfNodePtr> all_inputs = node->inputs();
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
for (auto& input : node_inputs) {
size_t inputs_size = all_inputs.size();
for (size_t i = 1; i < inputs_size; ++i) {
Shapes input_shapes;
AnfNodePtr input = all_inputs[i];
if (IsValueNode<RefKey>(input)) {
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
if (parameters.size() != 1) {
MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
}
std::pair<AnfNodePtr, int> node_pair = std::make_pair(node, SizeToInt(i));
g_RefMap[parameters[0]] = node_pair;
input_shapes = GetRefKeyNodeShape(input, func_graph);
} else if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
input_shapes = GetNodeShape(input);
......@@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) {
auto parameters = root->parameters();
for (auto& parameter : parameters) {
MS_EXCEPTION_IF_NULL(parameter->Shape());
auto iter = g_RefMap.find(parameter);
if (iter != g_RefMap.end()) {
SetParallelShape(parameter, g_RefMap[parameter]);
continue;
}
std::pair<AnfNodePtr, int> res = FindSubGraph(root, parameter);
if (res.first == nullptr) {
MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape";
} else {
SetParallelShape(parameter, res);
MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
}
}
g_RefMap.clear();
}
bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) {
......
......@@ -13,14 +13,13 @@
# limitations under the License.
import numpy as np
from mindspore import context
import mindspore as ms
from mindspore import Parameter, Tensor, context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from tests.ut.python.ops.test_math_ops import VirtualLoss
import mindspore as ms
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.common.api import _executor
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
......@@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2():
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
def test_assign_sub():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignSub()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
def construct(self, x, y, z):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out
context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册