提交 bd3b9cb6 编写于 作者: M Megvii Engine Team

fix(mge/oprmm): fix grad for collective comm

GitOrigin-RevId: 8e28f46c905002857431bf725ef414afc4a48704
上级 8d02d104
......@@ -49,8 +49,26 @@ cg::OperatorNodeBase* apply_on_var_node(
dev_buffer_arr, config, disable));
}
std::tuple<std::string, std::string> split_address(const std::string& address_and_port){
auto index = address_and_port.find_last_of(':');
mgb_assert(index != std::string::npos, "missing ':' in server address");
return {address_and_port.substr(0, index), address_and_port.substr(index+1)};
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) {
auto&& comm = node->cast_final_safe<opr::CollectiveComm>();
auto&& group_client = comm.group_client();
auto [addr, port] = split_address(group_client->get_addr());
auto comp_node = node->config().get_single_comp_node().to_string_logical();
return std::make_shared<CollectiveComm>(
comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(),
comm.local_grad(), addr, std::stoi(port), comm.param().mode,
comm.dtype(), comm.backend(), comp_node);
}
OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
.apply_on_var_node(apply_on_var_node)
.make_from_op_node(make_from_op_node)
.fallback();
} // anonymous namespace
......
......@@ -13,6 +13,8 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class GetVarShape : public OpDefImplBase<GetVarShape> {
......@@ -41,6 +43,33 @@ public:
std::vector<dt_int32> offsets;
std::vector<std::vector<size_t>> shapes;
size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
for (auto&& shape : shapes) {
for (auto&& dim_len : shape) {
builder.update(&dim_len, sizeof(dim_len));
}
auto&& dim_cnt = shape.size();
builder.update(&dim_cnt, sizeof(dim_cnt));
}
auto&& shape_cnt = shapes.size();
builder.update(&shape_cnt, sizeof(shape_cnt));
return builder.digest();
}
bool is_same_st(const Hashable& rhs) const override {
auto* pps = rhs.try_cast_final<ParamPackSplit>();
if(pps == nullptr){
return false;
}
return offsets == pps->offsets && shapes == pps->shapes;
}
};
class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
......@@ -53,6 +82,24 @@ public:
: offsets(offsets_) {}
std::vector<dt_int32> offsets;
size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
return builder.digest();
}
bool is_same_st(const Hashable& rhs) const override {
auto* ppc = rhs.try_cast_final<ParamPackConcat>();
if(ppc == nullptr){
return false;
}
return offsets == ppc->offsets;
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册