未验证 提交 92756ead 编写于 作者: J Juncheng 提交者: GitHub

Optimize sbp equality (#4435)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 7da5f9b6
......@@ -18,7 +18,17 @@ limitations under the License.
namespace oneflow {
bool operator==(const SbpParallel& lhs, const SbpParallel& rhs) { return PbMd().Equals(lhs, rhs); }
bool operator==(const SbpParallel& lhs, const SbpParallel& rhs) {
if (lhs.parallel_type_case() == rhs.parallel_type_case()) {
if (lhs.has_split_parallel()) {
return lhs.split_parallel().axis() == rhs.split_parallel().axis();
} else {
return true;
}
} else {
return false;
}
}
bool operator!=(const SbpParallel& lhs, const SbpParallel& rhs) { return !(lhs == rhs); }
......@@ -46,7 +56,11 @@ SbpParallel GetDualSbpParallel(const SbpParallel& sbp_parallel) {
}
bool operator==(const ParallelDistribution& lhs, const ParallelDistribution& rhs) {
return PbMd().Equals(lhs, rhs);
if (lhs.sbp_parallel().size() != rhs.sbp_parallel().size()) { return false; }
for (int i = 0; i < lhs.sbp_parallel().size(); ++i) {
if (lhs.sbp_parallel().Get(i) != rhs.sbp_parallel().Get(i)) { return false; }
}
return true;
}
bool operator!=(const ParallelDistribution& lhs, const ParallelDistribution& rhs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册