提交 e9c036cc 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix op specializations and tuple hash

GitOrigin-RevId: e3df93fd5b1303aeab76019f4a75596a051a5bf1
上级 b5ec83c5
......@@ -18,23 +18,28 @@ using namespace megdnn;
// FIXME: remove this when mgb::hash support tuple_hash
namespace mgb {
namespace {
template<typename T, size_t ...Ns>
auto tail(T t, std::index_sequence<Ns...>) {
return std::make_tuple(std::get<Ns+1>(t)...);
struct HashWrapper {
size_t hash;
constexpr operator size_t() {return hash;}
constexpr HashWrapper operator+(HashWrapper rhs) {
// NOTE: use a + b + c + d, not a + (b + (c + d)) !!!
return {hash * 20141203 + rhs.hash};
}
};
template <typename... Args>
constexpr size_t hash_many(const Args&... args) {
return (... + HashWrapper{mgb::hash(args)});
}
} // anonymous namespace
template<typename T, typename ...Args>
class HashTrait<std::tuple<T, Args...>> {
constexpr static size_t length = sizeof...(Args);
public:
struct HashTrait<std::tuple<T, Args...>> {
static size_t eval(const std::tuple<T, Args...> &t) {
const T& val = std::get<0>(t);
if constexpr (!length) {
return mgb::hash(val);
} else {
return mgb::hash_pair_combine(mgb::hash(val),
mgb::hash(tail(t, std::make_index_sequence<length - 1>{})));
}
return std::apply(hash_many<T, Args...>, t);
}
};
} // namespace mgb
......@@ -43,4 +48,4 @@ namespace mgb::imperative {
#include "./opdef.cpp.inl"
} // namespace mgb::imperative
\ No newline at end of file
} // namespace mgb::imperative
......@@ -158,7 +158,13 @@ auto apply_on_var_node(
}
}
OP_TRAIT_REG(Reduce, Reduce)
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Reduce>();
return Reduce::make(node->param());
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reduce
......@@ -439,12 +445,13 @@ OP_TRAIT_REG(GaussianRNG, GaussianRNG)
}} // gaussian_rng
namespace { namespace roi_align {
auto apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
return opr::ROIAlign::make(inputs[0], inputs[1], op.param());
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
......@@ -496,12 +503,13 @@ OP_TRAIT_REG(Eye, Eye)
}} // eye
namespace { namespace roi_pooling {
auto apply_on_var_node(
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param());
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr();
return {opr->output(0), opr->output(1)};
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
.apply_on_var_node(apply_on_var_node)
......@@ -620,11 +628,11 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param());
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output();
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // svd
} // namespace mgb::imperative
\ No newline at end of file
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册