未验证 提交 e3b6e02f 编写于 作者: J JZ-LIANG 提交者: GitHub

[Semi AutoParall] Support Partial Semantic I (#55508)

上级 dd1379e4
......@@ -160,6 +160,7 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
// Step2.3.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, out_axes);
output_dist_attr_dst.set_partial_status(partial_on_dims);
// Step2.3.2 handle input tensor partial (TODO)
VLOG(4) << "MatmulSPMDRule InferForward: "
......
......@@ -88,13 +88,15 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping);
std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr);
// step2.4: handle partial
// Step2.4.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, output_axes);
output_dist_attr.set_partial_status(
partial_on_dims /*, handle reduce_type in future */);
std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr);
// Step2.4.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports
......
......@@ -293,7 +293,11 @@ void BindAutoParallel(py::module *m) {
return TensorDistAttr(self);
},
py::arg("memo"))
.def("__str__", &TensorDistAttr::to_string);
.def("__str__", &TensorDistAttr::to_string)
.def("_is_partial", &TensorDistAttr::is_partial)
.def("_partial_dims", &TensorDistAttr::partial_dims)
.def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims)
.def("_clean_partial_status", &TensorDistAttr::clean_partial_status);
py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")
.def("infer_forward", &SPMDRuleBase::InferForward)
......
......@@ -24,6 +24,7 @@ namespace phi {
namespace distributed {
namespace auto_parallel {
// partial is not allow annotated by user by now.
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
......@@ -44,6 +45,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
std::swap(this->partial_status_, tmp.partial_status_);
return *this;
}
......@@ -53,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
set_partial_status(dist_attr.partial_status());
}
void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
......@@ -77,6 +80,44 @@ void TensorDistAttr::set_annotated(
annotated_ = annotated;
}
const std::set<int64_t> TensorDistAttr::partial_dims() const {
std::set<int64_t> keys;
for (auto& kv : partial_status_) {
keys.emplace(kv.first);
}
return keys;
}
void TensorDistAttr::set_partial_status(
const paddle::flat_hash_map<int64_t, ReduceType>& partial_status) {
partial_status_ = partial_status;
}
void TensorDistAttr::set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type) {
for (const auto& dim : dims) {
if (partial_status_.count(dim) != 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Trying to Set dim %d as Partial which is already a Partial dim.",
dim));
}
partial_status_.emplace(dim, type);
}
}
void TensorDistAttr::clean_partial_status() { partial_status_.clear(); }
void TensorDistAttr::clean_partial_dims(const std::vector<int64_t>& dims) {
for (const auto& dim : dims) {
if (partial_status_.count(dim) == 0) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Trying to clean Partial on dim %d but it is not Partial.", dim));
} else {
partial_status_.erase(dim);
}
}
}
void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (!tensor_shape.empty()) {
......@@ -178,6 +219,20 @@ bool TensorDistAttr::verify_annotated(
return true;
}
bool TensorDistAttr::verify_partial_status() const {
VLOG(4) << "[TensorDistAttr verify_partial_status] "
<< partial_status_string();
for (auto& itr : partial_status_) {
if (itr.first < 0 || itr.first >= process_mesh_.ndim()) {
return false;
}
if (itr.second < ReduceType::SUM || itr.second <= ReduceType::ALL) {
return false;
}
}
return true;
}
bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_process_mesh(process_mesh_)) {
return false;
......@@ -194,6 +249,9 @@ bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if (!verify_annotated(annotated_)) {
return false;
}
if (!verify_partial_status()) {
return false;
}
return true;
}
......@@ -203,7 +261,8 @@ std::string TensorDistAttr::to_string() const {
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
dist_str += "annotated: [" + str_join(annotated_) + "], ";
dist_str += "partial: " + partial_status_string() + ".}";
return dist_str;
}
......@@ -267,9 +326,23 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
if (lhs.partial_status() != rhs.partial_status()) {
return false;
}
return true;
}
std::string TensorDistAttr::partial_status_string() const {
std::string partial_status_str = "[";
for (auto& itr : partial_status_) {
partial_status_str += "Partial(dims:" + std::to_string(itr.first) + ", " +
ReduceTypeStrings[static_cast<int>(itr.second)] +
"), ";
}
partial_status_str += "]";
return partial_status_str;
}
} // namespace auto_parallel
} // namespace distributed
} // namespace phi
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/flat_hash_map.h"
namespace phi {
namespace distributed {
......@@ -32,13 +33,25 @@ namespace auto_parallel {
constexpr const char* kDefault = "default";
enum class ReduceType : std::uint8_t {
SUM = 0,
AVG,
MAX,
MIN,
PRODUCT,
ANY,
ALL
};
constexpr const char* ReduceTypeStrings[] = {
"SUM", "AVG", "MAX", "MIN", "PRODUCT", "ANY", "ALL"};
class TensorDistAttr {
public:
TensorDistAttr() = default;
explicit TensorDistAttr(const std::vector<int64_t>& tensor_shape);
TensorDistAttr(const TensorDistAttr& tensor);
TensorDistAttr(const TensorDistAttr& dist_attr);
TensorDistAttr& operator=(const TensorDistAttr& dist_attr);
......@@ -52,6 +65,29 @@ class TensorDistAttr {
void set_dims_mapping(const std::vector<int64_t>& dims_mapping);
// true if tensor is partial on any mesh dim.
bool is_partial() const { return !partial_status_.empty(); }
// return vector of mesh dims on which the this tensor is partial on
const std::set<int64_t> partial_dims() const;
const paddle::flat_hash_map<int64_t, ReduceType>& partial_status() const {
return partial_status_;
}
// by map
void set_partial_status(
const paddle::flat_hash_map<int64_t, ReduceType>& partial_status);
// by each dim
void set_partial_status(const std::vector<int64_t>& dims,
const ReduceType& type = ReduceType::SUM);
// all
void clean_partial_status();
// clean by dims
void clean_partial_dims(const std::vector<int64_t>& dims);
void set_default_dims_mapping(const std::vector<int64_t>& tensor_shape);
int64_t batch_dim() const { return batch_dim_; }
......@@ -89,11 +125,17 @@ class TensorDistAttr {
bool verify_annotated(const std::map<std::string, bool>& annotated) const;
bool verify_partial_status() const;
bool verify(const std::vector<int64_t>& tensor_shape) const;
// TensorDistAttr from_string(const std::string& dist_str);
std::string to_string() const;
std::string partial_status_string() const;
// in partial-support-stage-I partial will always be a runtime attribute,
// there is not need to serialize it. support the partial serialization in
// future partial-support-stage-II.
void from_proto(const TensorDistAttrProto& proto);
TensorDistAttrProto to_proto() const;
......@@ -109,6 +151,10 @@ class TensorDistAttr {
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
// partial map would be small (less than mesh.size)
// iterate operation (copy and comparision) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
paddle::flat_hash_map<int64_t, ReduceType> partial_status_;
};
inline std::ostream& operator<<(std::ostream& os, const TensorDistAttr& obj) {
......
......@@ -60,6 +60,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
......@@ -73,6 +75,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1])
......@@ -85,6 +88,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[]
self.x_dist_tensor_spec.set_dims_mapping([-1, -1])
......@@ -97,6 +101,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
self.x_dist_tensor_spec.set_dims_mapping([1, 0])
......@@ -109,6 +114,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]:
self.x_dist_tensor_spec.set_dims_mapping([-1, -1])
......@@ -121,6 +128,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1])
# abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done
self.x_dist_tensor_spec.shape = [512, 48, 64, 32]
......@@ -138,6 +147,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
......@@ -154,6 +164,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0])
......@@ -171,6 +183,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1])
......@@ -189,6 +202,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, 1]
)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
infered_output_dist_attrs[0]._clean_partial_dims([0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# trans_y = True, trans_x = True, abcmk[-1, -1, 0, 1], kn[1, 0] --> abcmk[-1, -1, 0, 1]],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]
# multiple mesh dim shard same tensor axis
......@@ -208,6 +225,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 1, -1]
)
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error:
# one mesh dim shard multiple tensor axes
......
......@@ -62,6 +62,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# reduce on dim 0, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
......@@ -76,6 +78,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# reduce on dim 1, keep_dim = false
# [0, -1] --> [0, -1], [0], partial_on_dim:[]
......@@ -90,6 +94,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduce on dim 1, keep_dim = true
# [0, -1] --> [0, -1], [0, -1], partial_on_dim:[]
......@@ -104,6 +109,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduce on dim 0 and 1, keep_dim = false
# [0, -1] --> [0, -1], [], partial_on_dim:[0]
......@@ -118,6 +124,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
# reduce on dim 0 and 1, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
......@@ -132,6 +140,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0])
def test_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
......@@ -170,6 +180,10 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [0, 1])
infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduction on dim 1, 2, keep_dim = false
# [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[]
self.attrs['keep_dim'] = False
......@@ -183,6 +197,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduction on dim 1, 2, keep_dim = false
# [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1]
......@@ -197,6 +212,10 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1])
infered_output_dist_attrs[0]._clean_partial_status()
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), False)
# reduction on dim 1, 2, keep_dim = true
# [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1]
......@@ -211,6 +230,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), [1])
if __name__ == "__main__":
......
......@@ -70,6 +70,7 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test1 done." << std::endl << std::endl << std::endl;
// mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[]
......@@ -83,6 +84,7 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test2 done." << std::endl << std::endl << std::endl;
// mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done
......@@ -96,6 +98,9 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({0, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0}));
VLOG(4) << "test3 done." << std::endl << std::endl << std::endl;
// mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done
......@@ -109,6 +114,9 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({1}));
VLOG(4) << "test4 done." << std::endl << std::endl << std::endl;
// abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] =
......@@ -124,6 +132,7 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({0, 1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test5 done." << std::endl << std::endl << std::endl;
// abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,
......@@ -138,6 +147,9 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({0, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1, -1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0}));
VLOG(4) << "test6 done." << std::endl << std::endl << std::endl;
// abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] =
......@@ -153,6 +165,7 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({-1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({1, -1, 0, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test7 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
......@@ -169,6 +182,11 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, -1, 1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0}));
infered_dist_attrs.second[0].clean_partial_dims(std::vector<int64_t>({0}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test8 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
......@@ -185,6 +203,13 @@ TEST(MatmulSPMDRule, Ctor) {
std::vector<int64_t>({-1, 0}));
EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(),
std::vector<int64_t>({-1, -1, 1, -1}));
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0}));
VLOG(4) << infered_dist_attrs.second[0].to_string();
infered_dist_attrs.second[0].clean_partial_status();
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
infered_dist_attrs.second[0].set_partial_status(std::vector<int64_t>({1}));
EXPECT_EQ(infered_dist_attrs.second[0].verify_partial_status(), false);
VLOG(4) << "test9 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
......@@ -197,6 +222,28 @@ TEST(MatmulSPMDRule, Ctor) {
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs));
// Error
VLOG(4) << "test10 done." << std::endl << std::endl << std::endl;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmn[-1, -1, -1, 1] partial[0]:
x_dist_tensor_spec.set_dims_mapping({-1, -1, 0, 1});
y_dist_tensor_spec.set_dims_mapping({1, 0});
attrs["trans_y"] = true;
attrs["trans_x"] = true;
infered_dist_attrs = matmul_rule->InferForward(
{x_dist_tensor_spec, y_dist_tensor_spec}, attrs);
EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims(
std::vector<int64_t>({1})));
infered_dist_attrs.second[0].set_partial_status(std::vector<int64_t>({1}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true);
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0, 1}));
infered_dist_attrs.second[0].clean_partial_dims(std::vector<int64_t>({1}));
EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(),
std::set<int64_t>({0}));
infered_dist_attrs.second[0].clean_partial_dims(std::vector<int64_t>({0}));
EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false);
VLOG(4) << "test11 done." << std::endl << std::endl << std::endl;
}
TEST(LayerNormSPMDRule, Ctor) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册