diff --git a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt index 9bffd1a7fb0814a0397435232e8acf21f9c8a559..9b2f7c237bf1b5b1510d931de71ddfa2c599043a 100644 --- a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt @@ -1,9 +1,10 @@ +add_subdirectory(spmd_rules) + cc_library( op_dist_attr SRCS dist_attr.cc DEPS phi auto_parallel_proto proto_desc) -cc_library(auto_parallel DEPS op_dist_attr spmd_rule) +cc_library(auto_parallel DEPS op_dist_attr spmd_rules) add_subdirectory(test) -add_subdirectory(spmd_rules) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt index d044a390f44f835da28ed48ece752ab26c69cbf9..42fde81693429c715d023534d00e90d3c6101115 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt @@ -1,5 +1,6 @@ -file(GLOB SPMD_SRCS "*.cc") +file(GLOB spmd_srcs *.cc) + cc_library( - spmd_rule - SRCS ${SPMD_SRCS} + spmd_rules + SRCS ${spmd_srcs} DEPS phi) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index f957412fc69388448bd7ea27591bc735ca2cb35c..7e17d2d34db081982492fdbfaf2cb1f7217f5f1a 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -39,7 +39,7 @@ SPMDRuleBase::InferBackward(const std::vector& output_specs, } std::unordered_map ShardingMergeForTensors( - const std::vector>>& + const std::vector>>& tensor_axes_to_dim_pairs) { std::unordered_map axis_to_dim_map; std::unordered_map dim_to_axis_map; @@ -168,6 +168,56 @@ TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) { return replicated_dist_attr; } +void VerifySpecs(const std::vector& specs, + const std::string& op_name) { + for (size_t i = 0, n = specs.size(); i < n; ++i) { + std::vector shape = specs[i].shape(); + std::vector dims_mapping = specs[i].dims_mapping(); + PADDLE_ENFORCE_EQ(shape.size(), + dims_mapping.size(), + phi::errors::InvalidArgument( + "Mismatch in %s, spec[%d]'s tensor size: [%d] and " + "spec[%d]'s dims_mapping size [%d].", + op_name, + i, + shape.size(), + i, + dims_mapping.size())); + } +} + +std::vector>> +GetAxesDimsMappingPair(const std::vector& tensor_axes, + const std::vector& specs) { + std::vector>> res; + size_t ntensor = specs.size(); + for (size_t i = 0; i < ntensor; ++i) { + res.emplace_back(std::pair>( + tensor_axes[i], specs[i].dims_mapping())); + } + return res; +} + +std::vector GetDimsMappingForAxes( + const std::string& axes, + const std::unordered_map& axis_to_dim_map) { + std::vector dims_mapping; + for (int64_t i = 0, n = axes.size(); i < n; i++) { + std::string axis = axes.substr(i, 1); + if (axis == "1") { + dims_mapping.emplace_back(-1); + } else { + auto iter = axis_to_dim_map.find(axis); + if (iter == axis_to_dim_map.end()) { + phi::errors::InvalidArgument( + "Tensor axis [%s] of not in axis_to_dim_map.", axis); + } + dims_mapping.emplace_back(iter->second); + } + } + return dims_mapping; +} + // SPMDRuleMap SPMDRuleMap& SPMDRuleMap::Instance() { static SPMDRuleMap g_spmd_rule_map; diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h index a3dcbe0cce7cee7c73b49b696ca649c4dc835a1b..59a24cebae9e153de878051d4ab4a8971b03d68c 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h @@ -67,22 +67,12 @@ class SPMDRuleBase { inline const T ExtractAttr( const std::string& name, const paddle::framework::AttributeMap& attrs) const { - auto& attr = GetAttr(name, attrs); - - // In order to get bool attr properly - framework::proto::AttrType attr_type = - static_cast(attr.index() - 1); - if (attr_type == framework::proto::AttrType::INT) { - if (std::is_same::value) { - return static_cast(PADDLE_GET_CONST(int, attr)); - } - } - - return PADDLE_GET_CONST(T, attr); + auto attr = GetAttr(name, attrs); + return *paddle::framework::ExtractAttribute(name)(attr); } - const Attribute& GetAttr(const std::string& name, - const paddle::framework::AttributeMap& attrs) const { + Attribute GetAttr(const std::string& name, + const paddle::framework::AttributeMap& attrs) const { auto iter = attrs.find(name); PADDLE_ENFORCE_NE(iter, attrs.end(), @@ -95,7 +85,7 @@ class SPMDRuleBase { // Merge sharding specification (dims mapping) of given tensors. // The same axes of different tensors will be merged. std::unordered_map ShardingMergeForTensors( - const std::vector>>& + const std::vector>>& tensor_axes_to_dim_pairs); // Merge the sharding specification (dims mapping) for one tensor Axis. @@ -133,6 +123,27 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, // (unsharded). TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr); +// Check whether the given DistTensorSpec objects are valid. For each +// DistTensorSpec, the rank of its dimsmapping must be equal to the rank of its +// corresponding tensor shape. the parameter op_name is used for logging error +// message. +void VerifySpecs(const std::vector& specs, + const std::string& op_name); + +// Get dimsmapping for the given tensors. Return the pair of each +// tensor's einsum notation and the corresponding dimsmapping. +std::vector>> +GetAxesDimsMappingPair(const std::vector& tensor_axes, + const std::vector& specs); + +// Get dims mapping for the given axes according to sharding information of +// the annotated axes after inferring forward or backward. The parameter axis +// stores the axes of the tensor. "1" is a special axis, for the axis "1", set +// its dims mapping to -1. +std::vector GetDimsMappingForAxes( + const std::string& axes, + const std::unordered_map& axis_to_dim_map); + // The static map that stores and initializes all the registered SPMD rules. class SPMDRuleMap { public: diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a19aed1f7fbc6bced6ab653ae25d5eadcd7b20a --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.cc @@ -0,0 +1,140 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { +using phi::distributed::auto_parallel::str_join; + +std::pair, std::vector> +ElementwiseSPMDRule::InferForward( + const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_GT( + ninputs, + 0, + phi::errors::InvalidArgument("The size of InputSpec in elementwise must " + "be greater than 0, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "elementwise"); + + // step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + std::vector input_axes_vec; + int64_t max_ndim = 0; + for (int64_t i = 0; i < ninputs; ++i) { + int64_t ndim = input_specs[i].shape().size(); + if (ndim > max_ndim) { + max_ndim = ndim; + } + } + + // get einsum notation for each input, deal with broadcast + std::vector broadcast_axis_count(max_ndim, 0); + for (int64_t i = 0; i < ninputs; ++i) { + std::vector shape = input_specs[i].shape(); + int64_t ndim = shape.size(); + int64_t start_dim = max_ndim - ndim; + std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet); + if (ninputs > 1) { + for (int64_t idim = 0; idim < max_ndim; idim++) { + // deal with the broadcast axes, record the + // input number at each broadcast axis + if (idim < start_dim) { + broadcast_axis_count[idim] += 1; + } else if (shape[idim - start_dim] == 1) { + broadcast_axis_count[idim] += 1; + // mark the broadcast axis to a special "1" + axes_notation[idim - start_dim] = '1'; + } + } + } + input_axes_vec.emplace_back(axes_notation); + } + + // get einsum notation for output + std::string output_axes = GetBroadcastAxes(max_ndim, max_ndim, alphabet); + for (int64_t idim = 0; idim < max_ndim; idim++) { + // if all inputs broadcast at this dimension, + // mark this axis in output as broadcast + if (broadcast_axis_count[idim] == ninputs) { + output_axes[idim] = '1'; + } + } + + // step2: Sharding Propogation + // step2.1: merge input shardings + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2: infer output dimsmapping from merged input dimsmapping + std::vector output_dims_mapping = + GetDimsMappingForAxes(output_axes, axis_to_dim_map); + + // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr output_dist_attr = + CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); + output_dist_attr.set_dims_mapping(output_dims_mapping); + + std::vector new_input_dist_attrs; + std::vector output_dist_attrs; + + // step2.3: update inputs' dims mapping with merged one. + for (int64_t i = 0; i < ninputs; i++) { + const DistTensorSpec& spec = input_specs[i]; + TensorDistAttr dist_attr(spec.dist_attr()); + std::vector new_dims_mapping = + GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map); + dist_attr.set_dims_mapping(new_dims_mapping); + new_input_dist_attrs.emplace_back(dist_attr); + } + + // step2.4: handle partial + // Step2.3.2 handle input tensor partial (TODO) + VLOG(4) << "ElementwiseSPMDRule InferForward:"; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) + << "] " + << "dst_dims_mapping: [" + << str_join(new_input_dist_attrs[i].dims_mapping()) << "]"; + } + VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "]\n\n"; + + output_dist_attrs.emplace_back(output_dist_attr); + return {new_input_dist_attrs, output_dist_attrs}; +} + +std::pair, std::vector> +ElementwiseSPMDRule::InferBackward( + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW(phi::errors::Unimplemented( + "InferBackward of ElementwiseSPMDRule is NOT implemented yet.")); + + return {}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h new file mode 100644 index 0000000000000000000000000000000000000000..113c34e4f43ab99b9ab98b9aadbfe01d8aff9905 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class ElementwiseSPMDRule : public SPMDRuleBase { + public: + std::pair, std::vector> + InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) override; + + std::pair, std::vector> + InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) override; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index ae3c767c99fdd5083b081d94e97f03d9c850cecb..f75379cc18950c4019628bee29d76bef38a16396 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" @@ -27,6 +28,94 @@ namespace auto_parallel { // matmul rule REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); +// elementwise rule +REGISTER_SPMD_RULE(add, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(assign_out_, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(divide, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(elementwise_pow, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(exponential_, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(floor_divide, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(fmin, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(hardswish, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(heaviside, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(maximum, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(minimum, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(mish, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(multiply, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(relu6, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(remainder, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(subtract, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(swish, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(acos, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(acosh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(asin, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(asinh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(atan, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(atanh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(bernoulli, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(bitwise_and, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(bitwise_not, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(bitwise_or, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(bitwise_xor, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(ceil, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(celu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(clip, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(conj, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(cos, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(cosh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(det, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(digamma, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(elu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(erf, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(erfinv, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(exp, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(expm1, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(fill, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(floor, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(fmax, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(gelu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(hardshrink, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(hardsigmoid, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(hardtanh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(label_smooth, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(leaky_relu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(lgamma, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(log, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(log10, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(log1p, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(log2, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logical_and, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logical_not, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logical_or, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logical_xor, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logit, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(logsigmoid, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(poisson, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(pow, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(reciprocal, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(relu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(round, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(rsqrt, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(scale, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(selu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(sigmoid, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(sign, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(silu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(sin, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(sinh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(softplus, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(softshrink, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(softsign, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(sqrt, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(square, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(stanh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(tan, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(tanh, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(tanh_shrink, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(thresholded_relu, ElementwiseSPMDRule); +REGISTER_SPMD_RULE(trunc, ElementwiseSPMDRule); + // matmul rule REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule); diff --git a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt index fc370f2a512f83d78703d87d387211cc7501d451..b0beaad0f6b1fd4d2995f4835b42dc36c048ab07 100644 --- a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt @@ -15,4 +15,4 @@ cc_test( cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) -cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule) +cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) diff --git a/test/auto_parallel/spmd_rules/test_elementwise_rule.py b/test/auto_parallel/spmd_rules/test_elementwise_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..34e3194410cc18916adbde9c9d60d2353d1efb06 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_elementwise_rule.py @@ -0,0 +1,314 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.completion import get_spmd_rule +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto + + +class TestElementwiseSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = get_spmd_rule("add") + + x_shape = [64, 36] + y_shape = [64, 36] + process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [1, 0] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + y_tensor_dist_attr = TensorDistAttr() + y_tensor_dist_attr.dims_mapping = [0, -1] + y_tensor_dist_attr.process_mesh = process_mesh + self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr) + + self.attrs = {} + + def test_single_mesh_dim(self): + # [0, -1], [-1, -1] --> [0, -1], [0, -1], [0, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # [0, -1], [-1, 0] --> [0, -1], [0, -1], [0, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, 0]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # [-1, -1], [-1, -1] --> [-1, -1], [-1, -1], [-1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # [-1, 0]--> [-1, 0], [-1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0]) + + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + + def test_single_mesh_dim_broadcast(self): + self.x_dist_tensor_spec.shape = [64, 36, 12] + self.y_dist_tensor_spec.shape = [12] + + # [0, -1, -1], [-1] --> [0, -1, -1], [-1], [0, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + + # [-1, 0, -1], [-1] --> [-1, 0, -1], [-1], [-1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1]) + self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1]) + + # [-1, -1, 0], [-1] --> [-1, -1, 0], [0], [-1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) + self.y_dist_tensor_spec.set_dims_mapping([-1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + # [-1, -1, -1], [0] --> [-1, -1, 0], [0], [-1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([0]) + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + self.x_dist_tensor_spec.shape = [64, 36, 12] + self.y_dist_tensor_spec.shape = [1, 12] + # [-1, 0, -1], [-1, -1] --> [-1, 0, -1], [-1, -1], [-1, 0, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1]) + + self.x_dist_tensor_spec.shape = [64, 1, 1, 12] + self.y_dist_tensor_spec.shape = [64, 32, 12] + # [0, -1, -1, -1], [-1, -1, -1] --> [0, -1, -1, -1], [-1, -1, -1], [0, -1, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + + # [-1, -1, -1, -1], [0, -1, -1] --> [-1, -1, -1, -1], [0, -1, -1], [-1, 0, -1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, -1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -0, -1, -1] + ) + + def test_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.y_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + self.y_dist_tensor_spec.shape = [96, 24, 48] + + # [0, 1, -1], [-1, -1, -1] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, -1, -1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # [0, -1, -1], [-1, 1, 0] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([-1, 1, 0]) + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + def test_multi_mesh_dim_broadcast(self): + process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.y_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + self.y_dist_tensor_spec.shape = [48] + + # [0, -1, -1], [1] --> [0, -1, 1], [1], [0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(len(resulted_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # [0, 1, -1], [0] --> [0, 1, -1], [-1], [0, 1, -1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.y_dist_tensor_spec.set_dims_mapping([0]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + self.x_dist_tensor_spec.shape = [96, 1, 1, 48] + self.y_dist_tensor_spec.shape = [96, 24, 48] + # [-1, -1, -1, 1], [0, -1, 1] --> [-1, -1, -1, 1], [0, -1, 1], [-1, 0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, 1]) + self.y_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + + resulted_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = resulted_dist_attrs[0] + infered_output_dist_attrs = resulted_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 1] + ) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, 1]) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + + +if __name__ == "__main__": + unittest.main()