diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.cc new file mode 100644 index 0000000000000000000000000000000000000000..64cecdeef96553507875347a6734ab8c4b9250b5 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +DistTensorSpec::DistTensorSpec(const std::vector& shape, + const TensorDistAttr& dist_attr) { + shape_.assign(shape.begin(), shape.end()); + // we should merge the new distributed attributes with the original one + // after inferencing, thus we get a copy of the original one + dist_attr_.copy_from(dist_attr); +} + +DistTensorSpec::~DistTensorSpec() {} + +const std::vector& DistTensorSpec::get_dims_mapping() { + return dist_attr_.dims_mapping(); +} + +void DistTensorSpec::set_dims_mapping( + const std::vector& dims_mapping) { + dist_attr_.set_dims_mapping(dims_mapping); +} + +const ProcessMesh& DistTensorSpec::get_process_mesh() { + return dist_attr_.process_mesh(); +} + +void DistTensorSpec::set_process_mesh(const ProcessMesh& process_mesh) { + dist_attr_.set_process_mesh(process_mesh); +} + +const std::vector& DistTensorSpec::get_shape() { return shape_; } + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..21fe7c41caddc8e9d983220071d2f2b1d9dfb3a6 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +/** + * A unified data class for inferring distributed attributes + * in both dygraph mode and static mode + */ +class DistTensorSpec { + public: + DistTensorSpec(const std::vector& shape, + const TensorDistAttr& dist_attr); + + ~DistTensorSpec(); + + // get dims_mapping from dist_attr_ + const std::vector& get_dims_mapping(); + + // set dims_mapping in dist_attr_ + void set_dims_mapping(const std::vector& dims_mapping); + + // get process_mesh from dist_attr_ + const ProcessMesh& get_process_mesh(); + + // set process_mesh in dist_attr_ + void set_process_mesh(const ProcessMesh& process_mesh); + + const std::vector& get_shape(); + + private: + std::vector shape_; + // distributed attributes of the corresponding tensor + TensorDistAttr dist_attr_; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle