From cf76e7ae7cf6d26fc340ff0d3677870182688cd1 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Mon, 24 Jul 2023 10:46:37 +0800 Subject: [PATCH] [Semi-Auto] add split spmd rule (#55397) * add split spmd rule * add pytest in cmake file * small fix --- .../auto_parallel/spmd_rules/common.cc | 4 +- .../auto_parallel/spmd_rules/rules.h | 5 + .../spmd_rules/split_spmd_rule.cc | 126 +++++++++++ .../spmd_rules/split_spmd_rule.h | 40 ++++ test/auto_parallel/spmd_rules/CMakeLists.txt | 1 + .../spmd_rules/test_split_rule.py | 205 ++++++++++++++++++ 6 files changed, 379 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc create mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h create mode 100644 test/auto_parallel/spmd_rules/test_split_rule.py diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index 47c0d9a683f..a0f46e1c462 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -182,8 +182,8 @@ TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) { void VerifySpecs(const std::vector& specs, const std::string& op_name) { for (size_t i = 0, n = specs.size(); i < n; ++i) { - std::vector shape = specs[i].shape(); - std::vector dims_mapping = specs[i].dims_mapping(); + const std::vector& shape = specs[i].shape(); + const std::vector& dims_mapping = specs[i].dims_mapping(); PADDLE_ENFORCE_EQ(shape.size(), dims_mapping.size(), phi::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index c58333d0fb7..bba43391980 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -23,6 +23,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" // TODO(ljz) Automatic this process in cmake file. namespace paddle { @@ -150,6 +151,10 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule); REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule); REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule); +// split rule +REGISTER_SPMD_RULE(split, SplitSPMDRule); +REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); + } // namespace auto_parallel } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc new file mode 100644 index 00000000000..59c962dab89 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc @@ -0,0 +1,126 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" +#include +#include +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +using phi::distributed::auto_parallel::str_join; + +std::pair, std::vector> +SplitSPMDRule::InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in split must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "split"); + + // step1: Build Einsum Notation + int64_t ndim = input_specs[0].shape().size(); + int64_t noutput = 0; + // split api uses num or sections as attribute + if (attrs.find("num") != attrs.end()) { + noutput = ExtractAttr("num", attrs); + } else if (attrs.find("sections") != attrs.end()) { + std::vector sections = + ExtractAttr>("sections", attrs); + noutput = sections.size(); + } + int64_t axis = ExtractAttr("axis", attrs); + if (axis < 0) { + axis += ndim; + } + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + + // get einsum notation for input, use a special + // notation 'k' to mark the splitted axis in input + std::vector input_axes_vec; + std::string input_axes = alphabet.substr(0, ndim); + input_axes[axis] = 'k'; + input_axes_vec.emplace_back(input_axes); + + // get einsum notation for output + std::string output_axes(input_axes); + // the splitted axis cannot be sharded, set its notation + // with the special '1' to set its dim mapping to -1. + output_axes[axis] = '1'; + + // step2: Sharding Propogation + // step2.1: merge input shardings + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2: infer output dimsmapping from merged input dimsmapping + std::vector output_dims_mapping = + GetDimsMappingForAxes(output_axes, axis_to_dim_map); + + // get the dist attributes for all outputs, the + // dist attributes are same for all outputs. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutput; i++) { + output_dist_attrs.emplace_back( + CopyTensorDistAttrForOutput(input_specs[0].dist_attr())); + output_dist_attrs[i].set_dims_mapping(output_dims_mapping); + } + + // step2.3 get new dist attribute for input. the splitted + // cannot be sharded, if it is sharded, set it to replicated. + std::vector new_input_dist_attrs; + new_input_dist_attrs.emplace_back(input_specs[0].dist_attr()); + std::vector new_input_dims_mapping(input_specs[0].dims_mapping()); + new_input_dims_mapping[axis] = -1; + new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping); + + // Step2.4 handle input tensor partial (TODO) + VLOG(4) << "SplitSPMDRule InferForward: "; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "einsum_notation: " << input_axes << " src_dims_mapping: [" + << str_join(input_specs[i].dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(new_input_dist_attrs[i].dims_mapping()) << "]"; + } + for (int64_t i = 0; i < noutput; i++) { + VLOG(4) << "Output" << std::to_string(i) << " dims_mapping: [" + << str_join(output_dims_mapping) << "]"; + } + + return {new_input_dist_attrs, output_dist_attrs}; +} + +std::pair, std::vector> +SplitSPMDRule::InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW(phi::errors::Unimplemented( + "InferBackward of SplitPMDRule is NOT implemented yet.")); + + return {}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h new file mode 100644 index 00000000000..f974e4cccce --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class SplitSPMDRule : public SPMDRuleBase { + public: + std::pair, std::vector> + InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) override; + + std::pair, std::vector> + InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) override; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index ed1cf37389e..1da9d4674c3 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -8,6 +8,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_matmul_rule MODULES test_embedding_rule) py_test_modules(test_matmul_rule MODULES test_replicated_rule) py_test_modules(test_matmul_rule MODULES test_softmax_rule) + py_test_modules(test_split_rule MODULES test_split_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_split_rule.py b/test/auto_parallel/spmd_rules/test_split_rule.py new file mode 100644 index 00000000000..1cd32d1bcf2 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_split_rule.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.completion import get_spmd_rule +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto + + +class TestReductionSPMDRule(unittest.TestCase): + """ + Unit tests for split spmd rule. + """ + + def setUp(self): + self.rule = get_spmd_rule("split") + + x_shape = [64, 32, 48] + process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [1, 0] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + self.attrs = { + 'num_or_sections': 2, + 'axis': 1, + } + + def test_single_mesh_dim(self): + # num_or_sections = 2, axis = 1 + # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] + self.rule = get_spmd_rule("split_with_num") + self.attrs = {} + self.attrs['num'] = 2 + self.attrs['axis'] = 1 + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 2) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1]) + + # num_or_sections = [15, 16, 17], axis = 2 + # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1], [0, -1, -1] + self.rule = get_spmd_rule("split") + self.attrs = {} + self.attrs['sections'] = [15, 16, 17] + self.attrs['axis'] = 2 + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 3) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, -1, -1]) + + # num_or_sections = [15, 16, 17], axis = 2 + # [-1, -1, 0] --> [-1, -1, -1], [-1, -1, -1], [-1 -1, -1], [-1, -1, -1] + self.attrs = {} + self.attrs['sections'] = [15, 16, 17] + self.attrs['axis'] = 2 + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 3) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[2].dims_mapping, [-1, -1, -1] + ) + + # num_or_sections = 2, axis = -2 + # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] + self.rule = get_spmd_rule("split_with_num") + self.attrs = {} + self.attrs['num'] = 2 + self.attrs['axis'] = -2 + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 2) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1]) + + def test_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 32, 48, 24] + + # num_or_sections = 3, axis = -1 + # [0, 1, -1, -1] --> [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1] + self.rule = get_spmd_rule("split_with_num") + self.attrs = {} + self.attrs['num'] = 3 + self.attrs['axis'] = -1 + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 3) + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[2].dims_mapping, [0, 1, -1, -1] + ) + + # num_or_sections = [32, 32, 32], axis = 0 + # [0, 1, -1, -1] --> [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1] + self.rule = get_spmd_rule("split") + self.attrs = {} + self.attrs['sections'] = [32, 32, 32] + self.attrs['axis'] = 0 + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 3) + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [-1, 1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[2].dims_mapping, [-1, 1, -1, -1] + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab