未验证 提交 7c13645d 编写于 作者: 沉潜的鱼儿's avatar 沉潜的鱼儿 提交者: GitHub

dist matmul op compatible (#37949)

* dist matmul op compatible

* modify common dist op

* modify common

* add a space
上级 2360406d
......@@ -296,6 +296,83 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......@@ -510,6 +587,95 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......@@ -710,6 +876,59 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......@@ -777,6 +996,86 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
return False
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......@@ -985,6 +1284,94 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......@@ -1183,6 +1570,61 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(dist_op)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import copy
import numpy as np
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.fluid.core as core
from paddle.fluid import layers
from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImplContainer
from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImpl
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
paddle.enable_static()
device = "gpu" if core.is_compiled_with_cuda() else "cpu"
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sqrt_hidden_size = 32
double_hidden_size = 64
input = static.data(name="input", shape=[8, 8, 16], dtype='int32')
input = paddle.reshape(input, [hidden_size])
input = paddle.reshape(input, [sqrt_hidden_size, sqrt_hidden_size])
embedding = paddle.nn.Embedding(2, batch_size, sparse=True)
input = embedding(input)
input = paddle.reshape(input, [hidden_size, batch_size])
input = paddle.transpose(input, perm=[1, 0])
matmulinput = static.data(
name="matmulinput",
shape=[hidden_size, hidden_size],
dtype='float32')
input = layers.matmul(x=input, y=matmulinput)
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
m = paddle.nn.Softmax()
loss = m(loss)
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
def test_matmulv2_matmul_2_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
with static.program_guard(program,
start_program), utils.unique_name.guard():
matmulx3 = static.data(
name="matmulx3", shape=[6, 2, 6], dtype='float32')
matmuly3 = static.data(
name="matmuly3", shape=[6, 6], dtype='float32')
output1 = paddle.matmul(x=matmulx3, y=matmuly3)
output_1 = layers.matmul(x=matmulx3, y=matmuly3)
matmulx4 = static.data(
name="matmulx4", shape=[6, 6, 2, 6], dtype='float32')
matmuly4 = static.data(
name="matmuly4", shape=[6, 6, 6, 6], dtype='float32')
output2 = paddle.matmul(x=matmulx4, y=matmuly4)
output_2 = layers.matmul(x=matmulx4, y=matmuly4)
ops = program.global_block().ops
vars = program.global_block().vars
for idx, op in enumerate(ops):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
out = op.output_arg_names[0]
if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1])
self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, 1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1])
self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [1, -1, -1])
op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, 1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 0, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 0, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 0, -1])
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_matmulv2_matmul_1_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
with static.program_guard(program,
start_program), utils.unique_name.guard():
matmulx3 = static.data(
name="matmulx3", shape=[6, 2, 6], dtype='float32')
matmuly3 = static.data(
name="matmuly3", shape=[6, 6], dtype='float32')
output1 = paddle.matmul(x=matmulx3, y=matmuly3)
output_1 = layers.matmul(x=matmulx3, y=matmuly3)
matmulx4 = static.data(
name="matmulx4", shape=[6, 6, 6, 6], dtype='float32')
matmuly4 = static.data(
name="matmuly4", shape=[6, 6, 6, 6], dtype='float32')
output2 = paddle.matmul(x=matmulx4, y=matmuly4)
output_2 = layers.matmul(x=matmulx4, y=matmuly4)
ops = program.global_block().ops
vars = program.global_block().vars
for idx, op in enumerate(ops):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
out = op.output_arg_names[0]
if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, 1])
op_dist_attr.set_input_dims_mapping(Y, [1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
op_dist_attr.set_output_dims_mapping(out, [1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(out, [-1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, 0, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 0, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 0, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 0, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_matmulv2_matmul_0_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
with static.program_guard(program,
start_program), utils.unique_name.guard():
matmulx3 = static.data(
name="matmulx3", shape=[6, 2, 6], dtype='float32')
matmuly3 = static.data(
name="matmuly3", shape=[6, 6], dtype='float32')
output1 = paddle.matmul(x=matmulx3, y=matmuly3)
output_1 = layers.matmul(x=matmulx3, y=matmuly3)
matmulx4 = static.data(
name="matmulx4", shape=[6, 6, 2, 6], dtype='float32')
matmuly4 = static.data(
name="matmuly4", shape=[6, 6, 6, 6], dtype='float32')
output2 = paddle.matmul(x=matmulx4, y=matmuly4)
output_2 = layers.matmul(x=matmulx4, y=matmuly4)
ops = program.global_block().ops
vars = program.global_block().vars
for idx, op in enumerate(ops):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
out = op.output_arg_names[0]
if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, 1])
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, 0])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [0, -1])
op_dist_attr.set_output_dims_mapping(out, [1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1])
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 0, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4:
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, 1, 1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册