未验证 提交 89bced5e 编写于 作者: 沉潜的鱼儿's avatar 沉潜的鱼儿 提交者: GitHub

Dist op compatible (#37994)

* dist matmul op compatible

* dist op unittest

* modify dist matmul

* modify dist reshape

* modify dist reshape

* add a space

* add a space

* delete dist matmul op

* modify reshape

* add dist op unittest

* modify dist op unittest
上级 698fca80
...@@ -57,6 +57,9 @@ class DistributedOperatorImpl: ...@@ -57,6 +57,9 @@ class DistributedOperatorImpl:
return self.is_input_compatible(dist_op) and \ return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op) self.is_output_compatible(dist_op)
def is_auto_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
......
...@@ -80,6 +80,32 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -80,6 +80,32 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if w_dims_mapping[-1] != out_dims_mapping[-1]:
return False
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -74,6 +74,36 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -74,6 +74,36 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
for idx, item in enumerate(out_dims_mapping[:-2]):
if x_dims_mapping[idx] != item:
return False
if out_dims_mapping[-2] != x_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
...@@ -201,6 +231,43 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -201,6 +231,43 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
if len(x_dims_mapping) == len(out_dims_mapping) + 2:
if out_dims_mapping[0] != x_dims_mapping[0]:
return False
if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1:
return False
elif len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
for idx, item in enumerate(x_dims_mapping[:-2]):
if out_dims_mapping[idx] != item:
return False
if x_dims_mapping[-2] != out_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -71,6 +71,25 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -71,6 +71,25 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
if x_dims_mapping != out_dims_mapping:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
...@@ -47,6 +47,35 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -47,6 +47,35 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
return True return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
perm = op_desc.attr('axis')
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
new_dims_mapping = [-1 for i in range(len(x_dims_mapping))]
for i in range(len(x_dims_mapping)):
new_dims_mapping[i] = x_dims_mapping[perm[i]]
if len(x_dims_mapping) != len(out_dims_mapping):
return False
if new_dims_mapping != out_dims_mapping:
return False
if x_shape_dims_mapping[0] != -1:
return False
if x_shape_dims_mapping[1:] != x_dims_mapping[:]:
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import copy
import numpy as np
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.fluid.core as core
from paddle.fluid import layers
from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImplContainer
from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImpl
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
paddle.enable_static()
device = "gpu" if core.is_compiled_with_cuda() else "cpu"
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sqrt_hidden_size = 32
double_hidden_size = 64
input = static.data(name="input", shape=[8, 8, 16], dtype='int32')
input = paddle.reshape(input, [hidden_size])
input = paddle.reshape(input, [sqrt_hidden_size, sqrt_hidden_size])
embedding = paddle.nn.Embedding(2, batch_size, sparse=True)
input = embedding(input)
input = paddle.reshape(input, [hidden_size, batch_size])
input = paddle.transpose(input, perm=[1, 0])
matmulinput = static.data(
name="matmulinput",
shape=[hidden_size, hidden_size],
dtype='float32')
input = layers.matmul(x=input, y=matmulinput)
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
m = paddle.nn.Softmax()
loss = m(loss)
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
def test_raise_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'transpose2':
op_dist_attr = OperatorDistributedAttribute()
dist_op = DistributedOperator(op, op_dist_attr)
impls = DistributedOperatorImpl()
try:
impls.is_auto_compatible(dist_op)
except NotImplementedError:
e = False
self.assertTrue(e == False)
def test_reshape_remove_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[0, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, 1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[0, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 0, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_remove_two_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 0])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[0, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, 1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_add_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1])
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 0])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[0, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[0, -1])
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_transpose_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'transpose2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 0, 0])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[0, 0, 0])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 0, 0])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[0, -1, -1])
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[0, 1, 1])
self.assertFalse(impls[0].is_auto_compatible(dist_op))
def test_softmax_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'softmax':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op.all_attrs()['axis'] = 2
self.assertFalse(impls[0].is_auto_compatible(dist_op))
def test_embedding_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'c_embedding' or op.type == 'lookup_table_v2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
op_dist_attr.set_input_dims_mapping(op.input_arg_names[1],
[1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, 0, 0])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[1],
[-1, 1])
dist_op = DistributedOperator(op, op_dist_attr)
op_dist_attr.set_input_dims_mapping(op.input_arg_names[1],
[1, 1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[1],
[1, 1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[1, 1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertFalse(impls[0].is_auto_compatible(dist_op))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册