diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index 35fb67938ee223965390b79ced273340c8d61269..f5a49ab0a9f18c2c1a932ca8c4dd38407b8e451e 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -77,7 +77,7 @@ class SPMDRuleBase { PADDLE_ENFORCE_NE(iter, attrs.end(), paddle::platform::errors::NotFound( - "(%s) is not found in AttributeMap.")); + "(%s) is not found in AttributeMap.", name)); return iter->second; } }; diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index bba4339198021ae0b5260b9c433052ea7a29cd9b..713a52770926de50fd1a66f461d599df85daecdf 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -24,6 +24,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" // TODO(ljz) Automatic this process in cmake file. namespace paddle { @@ -155,6 +156,9 @@ REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule); REGISTER_SPMD_RULE(split, SplitSPMDRule); REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); +// transpose rule +REGISTER_SPMD_RULE(transpose, TransposeSPMDRule); + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe567e70fa01992f6165ec2ba42b2f4a6e423cbc --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { +using phi::distributed::auto_parallel::str_join; +std::pair, std::vector> +TransposeSPMDRule::InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Transpose Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in transpose must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "transpose"); + + // step1: Build Einsum Notation + std::vector perm_dims = + ExtractAttr>("perm", attrs); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + // get einsum notation for input + int64_t ndim = input_specs[0].shape().size(); + std::vector input_axes_vec; + std::string input_axes = alphabet.substr(0, ndim); + input_axes_vec.emplace_back(input_axes); + + // get einsum notation for output + for (int64_t i = 0, n = perm_dims.size(); i < n; ++i) { + // convert the negative dim value to normal dim value + if (perm_dims[i] < 0) { + perm_dims[i] = ndim + perm_dims[i]; + } + } + std::string output_axes = ""; + for (int64_t i = 0; i < ndim; i++) { + output_axes.append(1, input_axes[perm_dims[i]]); + } + + // step2: Sharding Propogation + // step2.1: merge input shardings + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2: infer output dimsmapping from merged input dimsmapping + std::vector output_dims_mapping = + GetDimsMappingForAxes(output_axes, axis_to_dim_map); + + // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr output_dist_attr = + CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); + output_dist_attr.set_dims_mapping(output_dims_mapping); + + // Step2.3 handle input tensor partial (TODO) + VLOG(4) << "TransposeSPMDRule InferForward:"; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) + << "] " + << "perm: [" << str_join(perm_dims) << "] " + << "dst_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) + << "]"; + } + VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "]\n\n"; + + return {{input_specs[0].dist_attr()}, {output_dist_attr}}; +} + +std::pair, std::vector> +TransposeSPMDRule::InferBackward( + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW(phi::errors::Unimplemented( + "InferBackward of TransposeSPMDRule is NOT implemented yet.")); + + return {}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h new file mode 100644 index 0000000000000000000000000000000000000000..b047932036a718e4c47f6d5681abe2ee31464e58 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class TransposeSPMDRule : public SPMDRuleBase { + public: + std::pair, std::vector> + InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) override; + + std::pair, std::vector> + InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) override; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index 1da9d4674c381c166e7f285f4957a2d64f5d7690..43afd9aed75e7b6322ab1ec550997130b358ea1f 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_matmul_rule MODULES test_replicated_rule) py_test_modules(test_matmul_rule MODULES test_softmax_rule) py_test_modules(test_split_rule MODULES test_split_rule) + py_test_modules(test_transpose_rule MODULES test_transpose_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_transpose_rule.py b/test/auto_parallel/spmd_rules/test_transpose_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..62c86c3cf3f3856601de93d83bbe0dffbc53d208 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_transpose_rule.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.completion import get_spmd_rule +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto + + +class TestTransposeSPMDRule(unittest.TestCase): + """ + Unit tests for reduction spmd rule. + """ + + def setUp(self): + self.rule = get_spmd_rule("transpose") + + x_shape = [64, 36] + process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [1, 0] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + self.attrs = { + 'perm': [0, 1, 2, 3], + } + + def test_single_mesh_dim(self): + # perm = [1, 0] + # [0, -1] --> [0, -1], [-1, 0] + self.attrs['perm'] = [1, 0] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + + # perm = [0, 1] + # [0, -1] --> [0, -1], [0, -1] + self.attrs['perm'] = [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # perm = [0, 2, 3, 1] + # [-1, -1, 0, -1] --> [-1, -1, 0, -1], [-1, 0, -1, -1] + self.x_dist_tensor_spec.shape = [64, 48, 36, 24] + self.attrs['perm'] = [0, 2, 3, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, -1] + ) + + def test_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [64, 48, 36, 24] + + # perm = [0, 2, 3, 1] + # [-1, 0, 1, -1] --> [-1, 0, 1, -1], [-1, 1, -1, 0] + self.attrs['perm'] = [0, 2, 3, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0] + ) + + # perm = [0, 2, 3, 1] + # [-1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1] + self.attrs['perm'] = [0, 2, 3, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + + # perm = [-1, 0, -2, 1] + # [-1, -1, 0, 1] --> [-1, -1, 0, 1], [1, -1, 0, -1] + self.attrs['perm'] = [-1, 0, -2, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1] + ) + + +if __name__ == "__main__": + unittest.main()