From 89bced5e8a89928cfabe29628a47e569f2b65a7f Mon Sep 17 00:00:00 2001 From: June Weng Date: Sun, 12 Dec 2021 22:29:24 +0800 Subject: [PATCH] Dist op compatible (#37994) * dist matmul op compatible * dist op unittest * modify dist matmul * modify dist reshape * modify dist reshape * add a space * add a space * delete dist matmul op * modify reshape * add dist op unittest * modify dist op unittest --- .../auto_parallel/operators/common.py | 3 + .../auto_parallel/operators/dist_embedding.py | 26 ++ .../auto_parallel/operators/dist_reshape.py | 67 +++ .../auto_parallel/operators/dist_softmax.py | 19 + .../auto_parallel/operators/dist_transpose.py | 29 ++ .../unittests/test_auto_search_dist_op.py | 431 ++++++++++++++++++ 6 files changed, 575 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 678f4e7fdc..3ebda4694c 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -57,6 +57,9 @@ class DistributedOperatorImpl: return self.is_input_compatible(dist_op) and \ self.is_output_compatible(dist_op) + def is_auto_compatible(self, dist_op): + raise NotImplementedError("Please Implement this method in Subclass.") + def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index 3df04a70a5..20722cdf60 100755 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -80,6 +80,32 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return False return True + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + ids_name = op_desc.input('Ids')[0] + w_name = op_desc.input('W')[0] + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) + w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) + if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[ + -1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in ids_dims_mapping[1:]: + if is_dim_shard(mapping): + return False + for mapping in out_dims_mapping[1:]: + if is_dim_shard(mapping): + return False + if w_dims_mapping[-1] != out_dims_mapping[-1]: + return False + if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: + return False + + return True + def update_dims_mapping(self, dist_op): changed = False op_desc = dist_op.serial_op.desc diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index 8821f3bc65..d72d13803f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -74,6 +74,36 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return True + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if len(x_dims_mapping) != len(out_dims_mapping) - 1: + return False + + if is_dim_shard(out_dims_mapping[-1]): + return False + + for idx, item in enumerate(out_dims_mapping[:-2]): + if x_dims_mapping[idx] != item: + return False + if out_dims_mapping[-2] != x_dims_mapping[-1]: + return False + + if x_shape_dims_mapping[0] != -1: + return False + + if x_shape_dims_mapping[1:] != x_dims_mapping[:]: + return False + + return True + def update_dims_mapping(self, dist_op): changed = False op_desc = dist_op.serial_op.desc @@ -201,6 +231,43 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return True + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + + if len(x_dims_mapping) == len(out_dims_mapping) + 2: + if out_dims_mapping[0] != x_dims_mapping[0]: + return False + if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1: + return False + elif len(x_dims_mapping) != len(out_dims_mapping) + 1: + return False + + if is_dim_shard(x_dims_mapping[-1]): + return False + + for idx, item in enumerate(x_dims_mapping[:-2]): + if out_dims_mapping[idx] != item: + return False + + if x_dims_mapping[-2] != out_dims_mapping[-1]: + return False + + if x_shape_dims_mapping[0] != -1: + return False + + if x_shape_dims_mapping[1:] != x_dims_mapping[:]: + return False + + return True + def update_dims_mapping(self, dist_op): changed = False op_desc = dist_op.serial_op.desc diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index c90fc7da89..de2d0ba62e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -71,6 +71,25 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return True + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + axis = op_desc.attr('axis') + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if axis != -1 and axis != len(x_dims_mapping) - 1: + return False + + if is_dim_shard(x_dims_mapping[axis]): + return False + + if x_dims_mapping != out_dims_mapping: + return False + + return True + def update_dims_mapping(self, dist_op): changed = False op_desc = dist_op.serial_op.desc diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 0bfc7d9f4c..98c4681051 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -47,6 +47,35 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): def is_output_compatible(self, dist_op): return True + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + perm = op_desc.attr('axis') + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + new_dims_mapping = [-1 for i in range(len(x_dims_mapping))] + for i in range(len(x_dims_mapping)): + new_dims_mapping[i] = x_dims_mapping[perm[i]] + + if len(x_dims_mapping) != len(out_dims_mapping): + return False + + if new_dims_mapping != out_dims_mapping: + return False + + if x_shape_dims_mapping[0] != -1: + return False + + if x_shape_dims_mapping[1:] != x_dims_mapping[:]: + return False + + return True + def update_dims_mapping(self, dist_op): changed = False op_desc = dist_op.serial_op.desc diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py new file mode 100644 index 0000000000..8f53a0c765 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py @@ -0,0 +1,431 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function +import unittest +import copy + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.fluid.core as core +from paddle.fluid import layers +from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImplContainer +from paddle.distributed.auto_parallel.operators.common import DistributedOperatorImpl +from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container +from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from paddle.distributed.auto_parallel.dist_op import DistributedOperator +paddle.enable_static() +device = "gpu" if core.is_compiled_with_cuda() else "cpu" + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sqrt_hidden_size = 32 + double_hidden_size = 64 + + input = static.data(name="input", shape=[8, 8, 16], dtype='int32') + input = paddle.reshape(input, [hidden_size]) + input = paddle.reshape(input, [sqrt_hidden_size, sqrt_hidden_size]) + embedding = paddle.nn.Embedding(2, batch_size, sparse=True) + input = embedding(input) + input = paddle.reshape(input, [hidden_size, batch_size]) + input = paddle.transpose(input, perm=[1, 0]) + matmulinput = static.data( + name="matmulinput", + shape=[hidden_size, hidden_size], + dtype='float32') + input = layers.matmul(x=input, y=matmulinput) + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + m = paddle.nn.Softmax() + loss = m(loss) + return loss, train_program, start_program + + +class Testcompatible(unittest.TestCase): + def test_raise_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'transpose2': + op_dist_attr = OperatorDistributedAttribute() + dist_op = DistributedOperator(op, op_dist_attr) + impls = DistributedOperatorImpl() + try: + impls.is_auto_compatible(dist_op) + except NotImplementedError: + e = False + self.assertTrue(e == False) + + def test_reshape_remove_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'reshape2': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, -1, -1]) + self.assertTrue(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, -1, 1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [0, -1, -1, 1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, 1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [0, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [0, -1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 0, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1]) + + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + def test_reshape_remove_two_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'reshape2': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, -1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertTrue(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 1, 0]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [0, 1, 1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [1, -1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 1, 1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, -1, 1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, 1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [1, -1, -1]) + self.assertFalse(impls[1].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + def test_reshape_add_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'reshape2': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1]) + self.assertTrue(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 0]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [0, -1]) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [1, -1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [1, 1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1, 1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [0, -1]) + self.assertFalse(impls[0].is_auto_compatible( + DistributedOperator(op, op_dist_attr))) + + def test_transpose_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'transpose2': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, -1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertTrue(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 0, 0]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [0, 0, 0]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [-1, 0, 0]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [0, -1, -1]) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1]) + + op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], + [0, 1, 1]) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + def test_softmax_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'softmax': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertTrue(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op.all_attrs()['axis'] = 2 + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + def test_embedding_compatible(self): + valid_op_dist_attr_list = [] + program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, program, start_program = mlp_forward(program, startup_program) + ops = program.global_block().ops + for idx, op in enumerate(ops): + if op.type == 'c_embedding' or op.type == 'lookup_table_v2': + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + impls = dist_op_impl_container.get_impls() + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1]) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[1], + [1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertTrue(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, 0, 0]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[1], + [-1, 1]) + dist_op = DistributedOperator(op, op_dist_attr) + + op_dist_attr.set_input_dims_mapping(op.input_arg_names[1], + [1, 1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, 1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[1], + [1, 1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [-1, -1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], + [-1, -1]) + op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], + [1, 1, -1]) + dist_op = DistributedOperator(op, op_dist_attr) + self.assertFalse(impls[0].is_auto_compatible(dist_op)) + + +if __name__ == "__main__": + unittest.main() -- GitLab