未验证 提交 6863e2ae 编写于 作者: J JZ-LIANG 提交者: GitHub

[Semi-Auto] SPMD Parallel Rule Base (#53863)

* base rule

* add sharidng merge

* add sharidng axis merge

* define unified data class for inferencing dist_attr

* test wrap DistTensorSpec in dygraph mode

* matmul main logic done

* define unified data class for inferencing dist_attr

---------
Co-authored-by: NYichen Zhang <zhangyichen03@baidu.com>
上级 689e27af
......@@ -3,4 +3,7 @@ cc_library(
SRCS dist_attr.cc
DEPS phi auto_parallel_proto proto_desc)
cc_library(auto_parallel DEPS op_dist_attr spmd_rule)
add_subdirectory(test)
add_subdirectory(spmd_rules)
cc_library(
spmd_rule
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc
DEPS phi)
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include <glog/logging.h>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferForward should be called from a "
"derived class of SPMDRuleBase !"));
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferBackward should be called from a "
"derived class of SPMDRuleBase !"));
}
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_axes_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;
for (auto& pair : tensor_axes_to_dim_pairs) {
for (size_t i = 0; i < pair.second.size(); ++i) {
auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i];
if (axis_to_dim_map.count(tensor_axis) == 0) {
merge_dim = mesh_dim;
} else {
merge_dim = ShardingMergeForAxis(
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]);
}
axis_to_dim_map[tensor_axis] = merge_dim;
if (merge_dim != -1) {
if (dim_to_axis_map.count(merge_dim) == 0) {
dim_to_axis_map.insert({merge_dim, tensor_axis});
} else if (dim_to_axis_map[merge_dim].find(tensor_axis) ==
std::string::npos) {
dim_to_axis_map[merge_dim] += tensor_axis;
}
}
}
}
// Resolute "mesh_dim shard by more than one axis" confict.
// Now we just naive pick the first axis naively.
// (TODO) use local cost model to pick the axis with lowest cost(in concern of
// memory or communication or computation).
for (auto& it : dim_to_axis_map) {
if (it.second.size() > 1) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (size_t i = 1; i < it.second.size(); ++i) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
}
}
}
return axis_to_dim_map;
}
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2) {
if (mesh_dim1 != mesh_dim2) {
if (mesh_dim1 == -1) {
return mesh_dim2;
} else if (mesh_dim2 == -1) {
return mesh_dim1;
} else {
// (TODO) local cost model here.
PADDLE_THROW(
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two "
"different mesh dimension [%d] and [%d].",
axis,
mesh_dim1,
mesh_dim2));
}
} else {
return mesh_dim1;
}
}
TensorDistAttr CopyTensorDistAttrForOutput(
const TensorDistAttr& src_dist_attr) {
TensorDistAttr new_dist_attr = TensorDistAttr();
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh());
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim());
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims());
// new_dist_attr.set_annotated(false); TODO unset field is false by default.
return new_dist_attr;
}
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes) {
std::vector<int64_t> partial_on_dims;
for (auto& it : axis_to_dim_map) {
if (tensor_axes.find(it.first) == std::string::npos) {
if (it.second > -1) {
partial_on_dims.push_back(it.second);
}
}
}
return partial_on_dims;
}
std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet) {
PADDLE_ENFORCE_GE(
alphabet.size(),
broadcast_ndim,
phi::errors::InvalidArgument(
"size of alphabet [%d] is less than broadcast ndim [%d]",
alphabet.size(),
broadcast_ndim));
PADDLE_ENFORCE_GE(broadcast_ndim,
tenosr_ndim,
phi::errors::InvalidArgument(
"broadcast ndim [%d] is less than tenosr ndim [%d]",
broadcast_ndim,
tenosr_ndim));
if (tenosr_ndim <= 0) {
return std::string();
}
return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim);
}
// SPMDRuleMap
SPMDRuleMap& SPMDRuleMap::Instance() {
static SPMDRuleMap g_spmd_rule_map;
return g_spmd_rule_map;
}
// To enable default replicated spmd rule for op that are NOT registered
// which all tensors of inputs and outputs will be replicated in all ranks of
// the mesh.
SPMDRuleBase* SPMDRuleMap::Get(const std::string& op_type) const {
auto rule_ptr = GetNullable(op_type);
if (rule_ptr == nullptr) {
std::string str;
for (const auto& item : map_) {
str += item.first + ", ";
}
VLOG(4) << "Size of current map [" << map_.size() << "]";
VLOG(4) << "Keys are [" << str << "]";
}
PADDLE_ENFORCE_NOT_NULL(
rule_ptr,
platform::errors::NotFound(
"NO SPMD Rule has been registered for Operator [%s].", op_type));
return rule_ptr;
}
SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it == map_.end()) {
return nullptr;
} else {
return it->second.get();
}
}
int SPMDRuleMap::Insert(const std::string& op_type,
std::unique_ptr<SPMDRuleBase> rule) {
VLOG(4) << "Call SPMDRuleMap::Insert!";
PADDLE_ENFORCE_NE(
Has(op_type),
true,
platform::errors::AlreadyExists(
"SPMD Rule for Operator [%s] has been registered.", op_type));
map_.insert({op_type, std::move(rule)});
return 1;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/utils/flat_hash_map.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using paddle::framework::Attribute;
class SPMDRuleBase {
public:
virtual ~SPMDRuleBase() {}
// Based on the information of Input Tensors and Op Attribute:
// 1. Merge the Sharding (dims_mapping) among Input Tensors.
// 2. Infer the Sharding (dims_mapping) for Output Tensors.
// The Info of input tensors (Shape and DistAttr) are wrapped as
// DistTensorSpec, and op attribtue should be given as AttributeMap. The
// Output is a pair consist of two vectors:
// 1. The first vector: the merged DistAttr of input tensors.
// 2. The infered DistAttr of output tensors.
// The Merged DistAttr might be different from the original Intput DistAttrs,
// which means that the corressponding input tensor need to be reshard.
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs);
// Based on the information of Output Tensors and Op Attribute:
// 1. Merge the Sharding (dims_mapping) among Output Tensors.
// 2. Infer the Sharding (dims_mapping) for Input Tensors.
// The Info of output tensors (Shape and DistAttr) are wrapped as
// DistTensorSpec, and op attribtue should be given as AttributeMap. The
// Output is a pair consist of two vectors:
// 1. The first vector: the merged DistAttr of output tensors.
// 2. The infered DistAttr of Input tensors.
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs);
template <typename T>
inline const T ExtractAttr(
const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto& attr = GetAttr(name, attrs);
// In order to get bool attr properly
framework::proto::AttrType attr_type =
static_cast<framework::proto::AttrType>(attr.index() - 1);
if (attr_type == framework::proto::AttrType::INT) {
if (std::is_same<bool, T>::value) {
return static_cast<bool>(PADDLE_GET_CONST(int, attr));
}
}
return PADDLE_GET_CONST(T, attr);
}
const Attribute& GetAttr(const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto iter = attrs.find(name);
PADDLE_ENFORCE_NE(iter,
attrs.end(),
paddle::platform::errors::NotFound(
"(%s) is not found in AttributeMap."));
return iter->second;
}
};
// Merge sharding specification (dims mapping) of given tensors.
// The same axes of different tensors will be merged.
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_axes_to_dim_pairs);
// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2);
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);
// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes);
// Generate the axis notation of tensor for the einsum notation of a broadcast
// operation(alignment star from the rightmost axis). tenosr_ndim: the size of
// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast
// operation. alphabet: the characters used to represent the axes of tensor.
// length of alphabet should >= broadcast_ndim.
std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet);
// The static map that stores and initializes all the registered SPMD rules.
class SPMDRuleMap {
public:
~SPMDRuleMap() = default;
// A singleton
static SPMDRuleMap& Instance();
// Returns the spmd rule for the given op_type
SPMDRuleBase* Get(const std::string& op_type) const;
// Returns the spmd by name or nullptr if not registered
SPMDRuleBase* GetNullable(const std::string& op_type) const;
// Register a spmd for an op_type.
int Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule);
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
private:
SPMDRuleMap() = default;
paddle::flat_hash_map<std::string, std::unique_ptr<SPMDRuleBase>> map_;
DISABLE_COPY_AND_ASSIGN(SPMDRuleMap);
};
#define REGISTER_SPMD_RULE(op_type, rule_class, ...) \
UNUSED static int __spmd_rule_holder_##op_type = \
::paddle::distributed::auto_parallel::SPMDRuleMap::Instance().Insert( \
#op_type, std::make_unique<rule_class>(__VA_ARGS__))
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
DistTensorSpec::DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr) {
shape_.assign(shape.begin(), shape.end());
// we should merge the new distributed attributes with the original one
// after inferencing, thus we get a copy of the original one
dist_attr_.copy_from(dist_attr);
}
DistTensorSpec::DistTensorSpec(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.shape();
shape_.assign(spec_shape.begin(), spec_shape.end());
dist_attr_.copy_from(spec.dist_attr());
}
DistTensorSpec::~DistTensorSpec() {}
DistTensorSpec::DistTensorSpec(const Tensor& tensor) {
shape_ = tensor.shape();
}
DistTensorSpec& DistTensorSpec::operator=(const DistTensorSpec& spec) {
std::vector<int64_t> spec_shape = spec.shape();
shape_ = spec_shape;
dist_attr_.copy_from(spec.dist_attr());
return *this;
}
const std::vector<int64_t>& DistTensorSpec::dims_mapping() const {
return dist_attr_.dims_mapping();
}
void DistTensorSpec::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dist_attr_.set_dims_mapping(dims_mapping);
}
const ProcessMesh& DistTensorSpec::process_mesh() const {
return dist_attr_.process_mesh();
}
void DistTensorSpec::set_process_mesh(const ProcessMesh& process_mesh) {
dist_attr_.set_process_mesh(process_mesh);
}
const std::vector<int64_t>& DistTensorSpec::shape() const { return shape_; }
void DistTensorSpec::set_shape(const std::vector<int64_t>& shape) {
shape_ = shape;
}
const TensorDistAttr& DistTensorSpec::dist_attr() const { return dist_attr_; }
void DistTensorSpec::set_dist_attr(const TensorDistAttr& dist_attr) {
dist_attr_ = dist_attr;
}
std::string DistTensorSpec::to_string() const {
using phi::distributed::auto_parallel::str_join;
std::string spec_str = "{tensor_shape:[" + str_join(shape_) + "], ";
spec_str += "dist_attr:" + dist_attr_.to_string() + "}";
return spec_str;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::ProcessMesh;
using phi::distributed::auto_parallel::TensorDistAttr;
/**
* A unified data class for inferring distributed attributes
* in both dygraph mode and static mode
*/
class DistTensorSpec {
public:
DistTensorSpec() = default;
DistTensorSpec(const std::vector<int64_t>& shape,
const TensorDistAttr& dist_attr);
DistTensorSpec(const DistTensorSpec& spec);
// temp function, only for test in dygraph mode
explicit DistTensorSpec(const Tensor& tensor);
~DistTensorSpec();
DistTensorSpec& operator=(const DistTensorSpec& spec);
// get dims_mapping from dist_attr_
const std::vector<int64_t>& dims_mapping() const;
// set dims_mapping in dist_attr_
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
// get process_mesh from dist_attr_
const ProcessMesh& process_mesh() const;
// set process_mesh in dist_attr_
void set_process_mesh(const ProcessMesh& process_mesh);
const TensorDistAttr& dist_attr() const;
void set_dist_attr(const TensorDistAttr& dist_attr);
const std::vector<int64_t>& shape() const;
void set_shape(const std::vector<int64_t>& shape);
std::string to_string() const;
private:
std::vector<int64_t> shape_;
// distributed attributes of the corresponding tensor
TensorDistAttr dist_attr_;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on matmul logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
2,
phi::errors::InvalidArgument(
"The size of InputSpec of matmul should be 2, but got [%d].",
input_specs_size));
auto x_shape = input_specs[0].shape();
auto y_shape = input_specs[1].shape();
int x_ndim = x_shape.size();
int y_ndim = y_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
auto y_dist_attr_src = input_specs[1].dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> y_dims_mapping = y_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
PADDLE_ENFORCE_EQ(
y_ndim,
y_dims_mapping.size(),
phi::errors::InvalidArgument(
"Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].",
x_ndim,
x_dims_mapping.size()));
bool trans_x = ExtractAttr<bool>("trans_x", attrs);
bool trans_y = ExtractAttr<bool>("trans_y", attrs);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward Inputs: "
<< "X shape: [" << str_join(x_shape) << "], x_dims_mapping: ["
<< str_join(x_dims_mapping) << "]; Y shape: [" << str_join(y_shape)
<< "], y_dims_mapping: [" << str_join(y_dims_mapping)
<< "]; trans_x: "
<< "[" << (trans_x ? "true" : "false") << "]; "
<< "trans_y: "
<< "[" << (trans_y ? "true" : "false") << "]; ";
// step1: build Einsum Notation
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
int max_ndim = std::max(x_ndim, y_ndim);
std::string alphabet = "abcdefghijlopqrstuvwxyz";
std::string x_axes;
std::string y_axes;
std::string out_axes;
// Handle 4 different matmul cases in Paddle
// vector * vector = scala
if (x_ndim == 1 && y_ndim == 1) {
x_axes = "k";
y_axes = "k";
out_axes = "";
// vector * batched matrix
} else if (x_ndim == 1 && y_ndim > 1) {
x_axes = "k";
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, y_ndim - 2, alphabet);
y_axes = y_broadcast_axes + "kn";
out_axes = y_broadcast_axes + "n";
// batched matrix * vector
} else if (x_ndim > 1 && y_ndim == 1) {
y_axes = "k";
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, x_ndim - 2, alphabet);
x_axes = x_broadcast_axes + "mk";
out_axes = x_broadcast_axes + "m";
// batched matrix * batched matrix
} else if (x_ndim > 1 && y_ndim > 1) {
std::string x_broadcast_axes =
GetBroadcastAxes(x_ndim - 2, max_ndim - 2, alphabet);
std::string y_broadcast_axes =
GetBroadcastAxes(y_ndim - 2, max_ndim - 2, alphabet);
x_axes = x_broadcast_axes + "mk";
y_axes = y_broadcast_axes + "kn";
if (x_ndim > y_ndim) {
out_axes = x_broadcast_axes + "mn";
} else {
out_axes = y_broadcast_axes + "mn";
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"MatmulSPMDRule Receive Unsupported x_dim [%d] and y_dim [%d].",
x_ndim,
y_ndim));
}
VLOG(4) << "MatmulSPMDRule build Einsum notation: [" << x_axes << ","
<< y_axes << " --> " << out_axes << "].";
// step2: Sharding Propogation
if (trans_x) {
PADDLE_ENFORCE_GE(
x_ndim,
2,
phi::errors::InvalidArgument("When trans_x is True, the size of X "
"tensor should be 2, but got [%d].",
x_ndim));
std::iter_swap(x_dims_mapping.end() - 2, x_dims_mapping.end() - 1);
}
if (trans_y) {
PADDLE_ENFORCE_GE(
y_ndim,
2,
phi::errors::InvalidArgument("When trans_x is True, the size of X "
"tensor should be 2, but got [%d].",
y_ndim));
std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1);
}
// step2.1: Sharding Merge
std::pair<std::string, std::vector<int64_t>> x_pair(x_axes, x_dims_mapping);
std::pair<std::string, std::vector<int64_t>> y_pair(y_axes, y_dims_mapping);
auto axis_to_dim_map = ShardingMergeForTensors({x_pair, y_pair});
// step2.2: Infer Output's Dims Mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(x_dist_attr_src);
std::vector<int64_t> out_dims_mapping;
out_dims_mapping.reserve(out_axes.size());
for (size_t i = 0; i < out_axes.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_axes.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step2.3: Merge and get Inputs' New Dims Mapping.
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr(
x_dist_attr_src, x_shape, x_axes, axis_to_dim_map, trans_x);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr(
y_dist_attr_src, y_shape, y_axes, axis_to_dim_map, trans_y);
// step2.3: Handle Partial
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: "
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
<< str_join(x_dist_attr_src.dims_mapping())
<< "], dst_dims_mapping: ["
<< str_join(x_dist_attr_dst.dims_mapping()) << "]; Y shape: ["
<< str_join(y_shape) << "], src_dims_mapping: ["
<< str_join(y_dist_attr_src.dims_mapping())
<< "], dst_dims_mapping: ["
<< str_join(y_dist_attr_dst.dims_mapping())
<< "]; Output dims_mapping: [" << str_join(out_dims_mapping)
<< "], partial_on_dims: [" << str_join(partial_on_dims) << "]";
return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}};
}
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool trans_axis) {
TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr);
std::vector<int64_t> infered_dims_mapping;
infered_dims_mapping.reserve(tensor_axis.size());
for (size_t i = 0; i < tensor_axis.size(); ++i) {
if (shape.size() > i && shape[i] == 1) {
infered_dims_mapping.push_back(-1);
} else {
auto itr = axis_to_dim_map.find(tensor_axis.substr(i, 1));
if (itr == axis_to_dim_map.end()) {
phi::errors::InvalidArgument(
"Tensor axis [%s] of not in axis_to_dim_map.",
tensor_axis.substr(i, 1));
}
infered_dims_mapping.push_back(itr->second);
}
}
if (trans_axis) {
std::iter_swap(infered_dims_mapping.end() - 2,
infered_dims_mapping.end() - 1);
}
dist_attr_.set_dims_mapping(infered_dims_mapping);
return dist_attr_;
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
MatmulSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of MatmulSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int64_t>& shape,
const std::string& tensor_axes,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool trans_axis);
class MatmulSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
namespace distributed {
namespace auto_parallel {
// matmul rule
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -13,7 +13,6 @@ cc_test(
SRCS dist_attr_test.cc
DEPS phi proto_desc)
cc_test(
dist_mapper_test
SRCS dist_mapper_test.cc
DEPS phi)
cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi)
cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rule)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <sstream>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
TEST(MatmulSPMDRule, Ctor) {
// build input data class
std::vector<int64_t> x_shape = {64, 32};
std::vector<int64_t> y_shape = {32, 48};
std::vector<int64_t> mesh_shape = {2, 3};
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5};
std::vector<std::string> dim_names = {"x", "y"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
TensorDistAttr x_dist_attr = TensorDistAttr();
x_dist_attr.set_process_mesh(process_mesh);
x_dist_attr.set_dims_mapping(std::vector<int64_t>({1, -1}));
x_dist_attr.set_batch_dim(-1);
x_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
TensorDistAttr y_dist_attr = TensorDistAttr();
y_dist_attr.set_process_mesh(process_mesh);
y_dist_attr.set_dims_mapping(std::vector<int64_t>({-1, -1}));
y_dist_attr.set_batch_dim(-1);
y_dist_attr.set_dynamic_dims(std::vector<bool>({false, false}));
DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr);
DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr);
paddle::framework::AttributeMap attrs;
attrs["trans_x"] = false;
attrs["trans_y"] = false;
SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul");
// mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
size_t input_size = 2;
size_t output_size = 1;
EXPECT_EQ(infered_dist_attrs.first.size(), input_size);
EXPECT_EQ(infered_dist_attrs.second.size(), output_size);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1}));
VLOG(4) << "test1 done." << std::endl << std::endl << std::endl;
// mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[]
x_dist_tensor_spec.set_dims_mapping({-1, -1});
y_dist_tensor_spec.set_dims_mapping({-1, 0});
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, 0}));
VLOG(4) << "test2 done." << std::endl << std::endl << std::endl;
// mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done
x_dist_tensor_spec.set_dims_mapping({1, 0});
y_dist_tensor_spec.set_dims_mapping({-1, -1});
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({1, 0}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({0, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1}));
VLOG(4) << "test3 done." << std::endl << std::endl << std::endl;
// mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done
x_dist_tensor_spec.set_dims_mapping({-1, -1});
y_dist_tensor_spec.set_dims_mapping({1, 0});
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, 1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, 0}));
VLOG(4) << "test4 done." << std::endl << std::endl << std::endl;
// abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] =
// abcmn[1, 0, -1, -1] partial[]: done
x_dist_tensor_spec.set_shape({512, 48, 64, 32});
x_dist_tensor_spec.set_dims_mapping({0, 1, -1, -1});
y_dist_tensor_spec.set_dims_mapping({-1, -1});
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({0, 1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({0, 1, -1, -1}));
VLOG(4) << "test5 done." << std::endl << std::endl << std::endl;
// abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,
// -1, -1, -1] partial[0]: done
x_dist_tensor_spec.set_dims_mapping({1, -1, -1, 0});
y_dist_tensor_spec.set_dims_mapping({-1, -1});
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1, 0}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({0, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1, -1}));
VLOG(4) << "test6 done." << std::endl << std::endl << std::endl;
// abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] =
// abcmn[1, -1, 0, -1] partial[]: done
x_dist_tensor_spec.set_dims_mapping({1, -1, -1, 0});
y_dist_tensor_spec.set_dims_mapping({-1, -1});
attrs["trans_x"] = true;
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1, 0}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1, 0, -1}));
VLOG(4) << "test7 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmn[-1, -1, -1, 1] partial[0]: done
x_dist_tensor_spec.set_dims_mapping({-1, -1, -1, -1});
y_dist_tensor_spec.set_dims_mapping({1, 0});
attrs["trans_x"] = false;
attrs["trans_y"] = true;
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, -1, -1, 0}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, -1, 1}));
VLOG(4) << "test8 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmn[-1, -1, -1, 1] partial[0]: done
x_dist_tensor_spec.set_dims_mapping({-1, -1, 0, 1});
y_dist_tensor_spec.set_dims_mapping({1, 0});
attrs["trans_y"] = true;
attrs["trans_x"] = true;
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 0, 1}));
EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 1, -1}));
VLOG(4) << "test9 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmn[-1, -1, -1, 1] partial[0]: done
x_dist_tensor_spec.set_dims_mapping({-1, -1, 1, 0});
y_dist_tensor_spec.set_dims_mapping({1, 0});
attrs["trans_y"] = true;
attrs["trans_x"] = true;
EXPECT_ANY_THROW(infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs));
// Error
VLOG(4) << "test10 done." << std::endl << std::endl << std::endl;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -47,7 +47,8 @@ set(PYBIND_DEPS
jit_property
prim_utils
static_tensor_operants
type_info)
type_info
auto_parallel)
if(WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
......
......@@ -24,12 +24,18 @@
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::distributed::auto_parallel::DistTensorSpec;
using paddle::distributed::auto_parallel::OperatorDistAttr;
using paddle::distributed::auto_parallel::SPMDRuleBase;
using paddle::distributed::auto_parallel::SPMDRuleMap;
using paddle::framework::OpDesc;
using paddle::framework::VarDesc;
using phi::distributed::auto_parallel::Device;
......@@ -281,6 +287,29 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string);
py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")
.def("infer_forward", &SPMDRuleBase::InferForward)
.def("infer_backward", &SPMDRuleBase::InferBackward);
py::class_<DistTensorSpec>(*m, "DistTensorSpec")
.def(py::init<>())
.def(py::init<const DistTensorSpec &>())
.def(py::init<const std::vector<int64_t> &, const TensorDistAttr &>())
.def("dims_mapping", &DistTensorSpec::dims_mapping)
.def("set_dims_mapping", &DistTensorSpec::set_dims_mapping)
.def("process_mesh", &DistTensorSpec::process_mesh)
.def("set_process_mesh", &DistTensorSpec::set_process_mesh)
.def_property("shape", &DistTensorSpec::shape, &DistTensorSpec::set_shape)
.def("__str__", &DistTensorSpec::to_string)
.def("__copy__",
[](const DistTensorSpec &self) { return DistTensorSpec(self); })
.def(
"__deepcopy__",
[](const DistTensorSpec &self, py::dict) {
return DistTensorSpec(self);
},
py::arg("memo"));
py::class_<OperatorDistAttr>(*m, "OperatorDistAttr")
.def(py::init<>())
.def(py::init<const OpDesc &>())
......@@ -384,6 +413,13 @@ void BindAutoParallel(py::module *m) {
py::arg("memo"))
.def("__str__", &OperatorDistAttr::to_string);
m->def(
"get_spmd_rule",
[](const std::string op_type) {
return SPMDRuleMap::Instance().Get(op_type);
},
py::return_value_policy::reference);
// TODO(liuzhenhai): DistributedMapper is not used for now, but
// dist_mapper_test need the symbols forch DistributedMapper to be linked,
// remove it latter
......
......@@ -1278,6 +1278,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
}}
"""
def gen_dist_tensor_code(self):
# define the DistTensorSpec vector for input and output tensors
api_code = " \n std::vector<paddle::distributed::auto_parallel::DistTensorSpec> input_specs;\n"
# get DistTensorSpec for each input tensor
for tensor_name in self.inputs['names']:
api_code += f" input_specs.emplace_back(paddle::distributed::auto_parallel::DistTensorSpec({tensor_name}));\n"
api_code += "\n"
return api_code
def gene_base_api_code(self, inplace_flag=False):
api_func_name = self.get_api_func_name()
if inplace_flag and api_func_name[-1] != '_':
......@@ -1286,6 +1297,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
PADDLE_API {self.get_return_type(inplace_flag)} {api_func_name}({self.get_define_args(inplace_flag)}) {{
{self.gene_kernel_select()}
"""
# if api_func_name == 'matmul':
# api_code += self.gen_dist_tensor_code()
if len(self.kernel['func']) > 1:
kernel_dispatch_code = ''
......
......@@ -379,6 +379,8 @@ def source_include(header_file_path):
#include "paddle/phi/api/profiler/event_tracing.h"
#include "paddle/phi/api/profiler/supplement_tracing.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
DECLARE_bool(conv2d_disable_cudnn);
DECLARE_int32(low_precision_op_list);
"""
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <map>
#include <sstream>
#include <string>
#include <unordered_map>
......
......@@ -16,6 +16,7 @@ import copy
import logging
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.core import get_spmd_rule # noqa: F401
from paddle.framework import core
from ..process_mesh import ProcessMesh, compute_compatible_process_mesh
......
......@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
from paddle.fluid.core import DistTensorSpec # noqa: F401
from paddle.fluid.core import OperatorDistAttr # noqa: F401
from paddle.fluid.core import TensorDistAttr # noqa: F401
......@@ -28,7 +28,7 @@ from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable
from ..process_mesh import ProcessMesh
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .dist_attribute import DistTensorSpec, OperatorDistAttr, TensorDistAttr
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
......@@ -2380,3 +2380,66 @@ def use_new_executor():
'True',
'true',
]
def wrap_data_for_completion(
dist_op, input_names: list, output_names: list, attr_names: list
):
"""
Get data used in inferring distributed attributes, including:
1. DistTensorSpec for each input and output tensor of this dist_op.
2. Operator attributes of this dist_op, e.g. transpose_x in matmul op.
Args:
dist_op: the DistributedOperator
input_names: list, name of the dist_op's input tensors
output_names: list, name of the dist_op's output tensors
attr_names: list, attribute name of the dist_op's corresponding serial op
Returns:
input_specs: list, DistTensorSpec for each input tensor of the dist_op
output_specs: list, DistTensorSpec for each output tensor of the dist_op
attrs: dict, attribute map of the dist op
Usage:
op_desc = dist_op.serial_op.desc
input_name_list = []
output_name_list = []
input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op
input_name_list.append(op_desc.input('Y')[0])
output_name_list.append(op_desc.output('Out')[0])
attr_name_list = ['trans_x', 'trans_y']
input_specs, output_specs, attrs = wrap_data_for_completion(
dist_op,
input_name_list,
output_name_list,
attr_name_list)
"""
input_specs = []
output_specs = []
attrs = {}
serial_op = dist_op.serial_op
# Construct each input tensor's DistTensorSpec with shape and dist_attr
for name in input_names:
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name)
var = serial_op.block._var_recursive(name)
tensor_shape = var.shape
dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
input_specs.append(dist_spec)
# Construct each output tensor's DistTensorSpec with shape and dist_attr
for name in output_names:
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name)
var = serial_op.block._var_recursive(name)
tensor_shape = var.shape
dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr)
output_specs.append(dist_spec)
for attr_name in attr_names:
attrs[attr_name] = serial_op.desc.attr(attr_name)
return input_specs, output_specs, attrs
# file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
# string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
if(WITH_DISTRIBUTE AND WITH_GPU)
# NOTE(zyl): unittests WITH single card and WITHOUT timeout
py_test_modules(test_matmul_rule MODULES test_matmul_rule)
# End of unittests WITH single card WITHOUT timeout
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestMatmulSPMDRule(unittest.TestCase):
def setUp(self):
self.rule = get_spmd_rule("matmul")
x_shape = [64, 32]
y_shape = [32, 48]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [1, 0]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
y_tensor_dist_attr = TensorDistAttr()
y_tensor_dist_attr.dims_mapping = [0, -1]
y_tensor_dist_attr.process_mesh = process_mesh
self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr)
self.attrs = {
'trans_x': False,
'trans_y': False,
}
def test_matmul_infer_forward(self):
# TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
# test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1])
self.y_dist_tensor_spec.set_dims_mapping([-1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
# test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
self.x_dist_tensor_spec.set_dims_mapping([1, 0])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
# mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]:
self.x_dist_tensor_spec.set_dims_mapping([-1, -1])
self.y_dist_tensor_spec.set_dims_mapping([1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
# abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done
self.x_dist_tensor_spec.shape = [512, 48, 64, 32]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
# abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
# trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
self.y_dist_tensor_spec.set_dims_mapping([-1, -1])
self.attrs['trans_x'] = True
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)
# trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1])
self.y_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['trans_x'] = False
self.attrs['trans_y'] = True
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1]
)
# trans_y = True, trans_x = True, abcmk[-1, -1, 0, 1], kn[1, 0] --> abcmk[-1, -1, 0, 1]],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]
# multiple mesh dim shard same tensor axis
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
self.y_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['trans_x'] = True
self.attrs['trans_y'] = True
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1]
)
# trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error:
# one mesh dim shard multiple tensor axes
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0])
self.y_dist_tensor_spec.set_dims_mapping([1, 0])
self.attrs['trans_x'] = True
self.attrs['trans_y'] = True
with self.assertRaises(NotImplementedError):
self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册