提交 180edcc2 编写于 作者: L liangjianzhong

matmul main logic done

上级 f7e39d75
......@@ -31,19 +31,11 @@ class SPMDRuleBase {
virtual std::vector<DistTensorSpec> InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferForward should be called from a "
"derived class of SPMDRuleBase !"));
}
const paddle::framework::AttributeMap& attrs);
virtual std::vector<DistTensorSpec> InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferBackward should be called from a "
"derived class of SPMDRuleBase !"));
}
const paddle::framework::AttributeMap& attrs);
template <typename T>
inline const T& ExtractAttr(
......@@ -62,54 +54,11 @@ class SPMDRuleBase {
platform::errors::NotFound("(%s) is not found in AttributeMap."));
return iter->second;
}
}
};
std::unordered_map<std::string, int64_t>
ShardingMergeForTensors(
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_notation_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;
for (auto& pair : tensor_notation_to_dim_pairs) {
for (int i = 0; i < pair.second.size(); i++) {
auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i];
if (axis_to_dim_map.count(tensor_axis) == 0) {
merge_dim = mesh_dim;
} else {
merge_dim = ShardingMergeForAxis(
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]);
}
axis_to_dim_map.insert({tensor_axis, merge_dim});
if (dim_to_axis_map.count(merge_dim) == 0) {
dim_to_axis_map.insert({merge_dim, tensor_axis});
} else {
dim_to_axis_map[merge_dim] += tensor_axis;
}
}
}
// Resolute "mesh_dim shard by more than one axis" confict.
// Now we just naive pick the first axis naively.
// (TODO) use local cost model to pick the axis with lowest cost(in concern of
// memory or communication or computation).
for (auto& it : dim_to_axis_map) {
if (it.second.size() > 1) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (int i = 1; i < it.second.size(); i++) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
}
}
}
return axis_to_dim_map;
}
tensor_notation_to_dim_pairs);
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
......@@ -117,26 +66,13 @@ ShardingMergeForTensors(
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string axis,
const int64_t mesh_dim1,
const int64_t mesh_dim2) {
if (mesh_dim1 != mesh_dim2) {
if (mesh_dim1 == -1) {
return mesh_dim2;
} else if (mesh_dim2 == -1) {
return mesh_dim1;
} else {
// (TODO) local cost model here.
PADDLE_THROW(
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two "
"different mesh dimension [%d] and [%d].",
axis,
mesh_dim1,
mesh_dim2));
}
const int64_t mesh_dim2);
} else {
return mesh_dim1;
}
}
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& in_axis_to_dim_map,
const std::string& out_axis);
} // namespace auto_parallel
} // namespace distributed
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
namespace paddle {
......@@ -129,10 +127,61 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
input_pairs;
input_pairs.push_back(x_pair);
input_pairs.push_back(y_pair);
auto dim_to_sharding = ShardingMerge(input_pairs);
auto axis_to_dim_map = ShardingMergeForTensors(input_pairs);
// step2.2: fill output's dim mapping.
TensorDistAttr output_dist_attr_dst =
CopyTensorDistAttrForOutput(input_specs[0].DistAttr) std::vector<int64_t>
out_dims_mapping;
out_dims_mapping.reserve(out_string.size());
for (int i = 0; i < out_string.size(); ++i) {
out_dims_mapping.push_back(axis_to_dim_map[out_string.substr(i, 1)]);
}
output_dist_attr_dst.set_dims_mapping(out_dims_mapping);
// step2.3: fill input's dim mapping.
TensorDistAttr x_dist_attr_dst = GetInferedDistAttr(
input_specs[0].DistAttr, input_specs[0].shape, x_string, axis_to_dim_map);
TensorDistAttr y_dist_attr_dst = GetInferedDistAttr(
input_specs[1].DistAttr, input_specs[1].shape, y_string, axis_to_dim_map);
// step2.3: Handle Broadcast
// step2.3: Handle Partial
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_string);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: "
<< "X shape: " << input_specs[0].shape
<< ", src_dims_mapping: " << x_dims_mapping
<< ", dst_dims_mapping: " << x_dist_attr_dst.dims_mapping
<< "; Y shape: " << input_specs[1].shape
<< ", src_dims_mapping: " << x_dims_mapping
<< ", dst_dims_mapping: " << y_dist_attr_dst.dims_mapping
<< "; Output dims_mapping: " << out_dims_mapping
<< ", partial_on_dims: " << partial_on_dims;
}
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map) {
TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr);
std::vector<int64_t> infered_dims_mapping;
infered_dims_mapping.reserve(tensor_string.size());
for (int i = 0; i < tensor_axis.size(); ++i) {
if (shape.size() > i && shape[i] == 1) {
infered_dims_mapping.push_back(-1);
} else {
infered_dims_mapping.push_back(axis_to_dim_map[tensor_axis.substr(i, 1)]);
}
}
dist_attr_.set_dims_mapping(infered_dims_mapping);
return dist_attr_;
}
std::vector<DistTensorSpec> MatmulSPMDRule::InferBackward(
......@@ -142,40 +191,3 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferBackward(
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/// @brief
// int max_dim = 0;
// int ndim = 0;
// std::vector<int> intput_ndims;
// for (auto& input_spec : input_specs){
// ndim = input_spec.shape().size();
// intput_ndims.push_back(ndim);
// if (ndim > max_dim) {
// max_dim = ndim;
// }
// }
// std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// std::vector<std::string> input_dim_chars;
// for (auto& intput_ndim : intput_ndims){
// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim,
// intput_ndim));
// }
// int max_dim = 0;
// int ndim = 0;
// std::vector<int> intput_ndims;
// for (auto& input_spec : input_specs){
// ndim = input_spec.shape().size();
// intput_ndims.push_back(ndim);
// if (ndim > max_dim) {
// max_dim = ndim;
// }
// }
// std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// std::vector<std::string> input_dim_chars;
// for (auto& intput_ndim : intput_ndims){
// input_dim_chars.push_back(alphabet.substr(max_dim - intput_ndim,
// intput_ndim));
// }
......@@ -25,6 +25,12 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {
TensorDistAttr GetInferedDistAttr(
const TensorDistAttr& origin_dist_attr,
const std::vector<int>& shape,
const std::string& tensor_axis,
const std::unordered_map<std::string, int64_t>& axis_to_dim_map);
class MatmulSPMDRule : public SPMDRuleBase {
public:
std::vector<DistTensorSpec> InferForward(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册