提交 b33ec46e 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix layout gtrans when graph partition has opr with different format

GitOrigin-RevId: 326fdebb0cdf5ec0d950d3c8e9c451bf8f262d7e
上级 fe15239a
...@@ -221,6 +221,8 @@ void DynamicProgrammingSolver::Impl::analyze_edges( ...@@ -221,6 +221,8 @@ void DynamicProgrammingSolver::Impl::analyze_edges(
edges[cur].push_back(ov); edges[cur].push_back(ov);
edge2idx[cur].emplace(ov, idx++); edge2idx[cur].emplace(ov, idx++);
} }
if (cur == 0)
return;
cur--; cur--;
for (const auto& opr : reverse_adaptor(topo)) { for (const auto& opr : reverse_adaptor(topo)) {
for (const auto& i : opr->input()) { for (const auto& i : opr->input()) {
......
...@@ -19,19 +19,41 @@ using namespace mgb; ...@@ -19,19 +19,41 @@ using namespace mgb;
using namespace gopt; using namespace gopt;
using namespace opr; using namespace opr;
namespace {
using OprFormat = SolverBase::OprFormat;
template <typename Opr>
bool check_format_aware_opr_valid(const OperatorNodeBase* opr_, OprFormat opr_format) {
auto&& opr = opr_->cast_final_safe<Opr>();
return opr.param().format == opr_format;
}
} // namespace
/* =================== ProfilingBasedSolverSolver ======================*/ /* =================== ProfilingBasedSolverSolver ======================*/
ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler) ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler)
: m_profiler{std::move(profiler)} { : m_profiler{std::move(profiler)} {
static const ThinHashSet<Typeinfo*> format_aware_oprs = { static const ThinHashMap<
#define cb(_Opr) _Opr::typeinfo() Typeinfo*,
cb(Convolution), cb(ConvBiasForward), cb(ConvolutionBackwardData), thin_function<bool(const OperatorNodeBase*, OprFormat opr_format)>>
cb(PoolingForward), cb(WarpPerspective), cb(Resize), format_aware_opr_validators = {
}; #define cb(t) \
{opr::t::typeinfo(), std::bind( \
check_format_aware_opr_valid<opr::t>, \
std::placeholders::_1, std::placeholders::_2)}
cb(Convolution),
cb(ConvBiasForward),
cb(ConvolutionBackwardData),
cb(PoolingForward),
cb(WarpPerspective),
cb(Resize),
};
m_graph_partition_filter = [](const GraphPartition& partition) { m_problem_filter = [](const Problem& problem) {
auto&& base_opr_format = problem.attribute().base_opr_format;
bool has_format_aware_opr = false; bool has_format_aware_opr = false;
for (auto&& opr : partition.all_oprs()) { for (auto&& opr : problem.graph_partition().all_oprs()) {
if (!has_format_aware_opr && format_aware_oprs.count(opr->dyn_typeinfo())) { auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo());
if (iter != format_aware_opr_validators.end() &&
iter->second(opr, base_opr_format)) {
has_format_aware_opr = true; has_format_aware_opr = true;
break; break;
} }
...@@ -42,8 +64,7 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile ...@@ -42,8 +64,7 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile
ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( ProfilingBasedSolver::Solution ProfilingBasedSolver::solve(
const Problem& problem) const { const Problem& problem) const {
const auto& partition = problem.graph_partition(); if (!m_problem_filter(problem))
if (!m_graph_partition_filter(partition))
return Solution{}; return Solution{};
return do_solve(problem); return do_solve(problem);
} }
......
...@@ -49,18 +49,16 @@ public: ...@@ -49,18 +49,16 @@ public:
*/ */
class ProfilingBasedSolver : public SolverBase { class ProfilingBasedSolver : public SolverBase {
public: public:
using GraphPartitionFilter = using ProblemFilter = thin_function<bool(const Problem&)>;
thin_function<bool(const GraphPartition& graph_partition)>;
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler); ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler);
/*! /*!
* \note some graph partition (for example, graph partition without format * \note some graph partition (for example, graph partition without format
* aware operators like conv, deconv, warp, resize etc.) will be filtered by * aware operators like conv, deconv, warp, resize etc.) will be filtered by
* the GraphPartitionFilter, which can reduce the profiling time. */ * the ProblemFilter, which can reduce the profiling time. */
ProfilingBasedSolver( ProfilingBasedSolver(
std::unique_ptr<ProfilerBase> profiler, std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter)
GraphPartitionFilter graph_partition_filter)
: m_profiler{std::move(profiler)}, : m_profiler{std::move(profiler)},
m_graph_partition_filter{std::move(graph_partition_filter)} {} m_problem_filter{std::move(problem_filter)} {}
virtual ~ProfilingBasedSolver() = default; virtual ~ProfilingBasedSolver() = default;
Solution solve(const Problem& problem) const override; Solution solve(const Problem& problem) const override;
virtual Solution do_solve(const Problem& problem) const = 0; virtual Solution do_solve(const Problem& problem) const = 0;
...@@ -69,7 +67,7 @@ protected: ...@@ -69,7 +67,7 @@ protected:
std::unique_ptr<ProfilerBase> m_profiler; std::unique_ptr<ProfilerBase> m_profiler;
private: private:
GraphPartitionFilter m_graph_partition_filter; ProblemFilter m_problem_filter;
}; };
/*! /*!
...@@ -81,10 +79,8 @@ public: ...@@ -81,10 +79,8 @@ public:
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler) DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler)
: ProfilingBasedSolver(std::move(profiler)){}; : ProfilingBasedSolver(std::move(profiler)){};
DynamicProgrammingSolver( DynamicProgrammingSolver(
std::unique_ptr<ProfilerBase> profiler, std::unique_ptr<ProfilerBase> profiler, ProblemFilter problem_filter)
GraphPartitionFilter graph_partition_filter) : ProfilingBasedSolver(std::move(profiler), std::move(problem_filter)){};
: ProfilingBasedSolver(
std::move(profiler), std::move(graph_partition_filter)){};
~DynamicProgrammingSolver() noexcept = default; ~DynamicProgrammingSolver() noexcept = default;
Solution do_solve(const Problem& problem) const override; Solution do_solve(const Problem& problem) const override;
bool can_solve(const Problem& problem) const override; bool can_solve(const Problem& problem) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册