diff --git a/src/gopt/impl/dynamic_programming_solver.cpp b/src/gopt/impl/dynamic_programming_solver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..847ea63b7a0156a78e91834162f298e39e285e2f --- /dev/null +++ b/src/gopt/impl/dynamic_programming_solver.cpp @@ -0,0 +1,547 @@ +/** + * \file src/gopt/impl/dynamic_programming_solver.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "./utils.h" +#include "megbrain/gopt/global_layout_transform.h" + +using namespace mgb; +using namespace gopt; +using namespace cg; + +/* ================= DynamicProgrammingSolver::Impl ==================*/ +class DynamicProgrammingSolver::Impl { +public: + Impl(size_t max_states) : m_max_states{max_states} {} + ~Impl() = default; + Solution solve(const ProfilerBase* profiler, const Problem& problem); + +private: + using TensorFormatsBitSet = uint32_t; + using State = SmallVector; + static constexpr uint32_t MAX_TENSOR_FORMATS = sizeof(TensorFormatsBitSet); + TensorFormatsBitSet add(TensorFormatsBitSet& set, TensorFormats fmt) { + mgb_assert(static_cast(fmt) < MAX_TENSOR_FORMATS); + set |= (1 << static_cast(fmt)); + return set; + } + bool valid(const TensorFormatsBitSet& set, TensorFormats fmt) { + mgb_assert(static_cast(fmt) < MAX_TENSOR_FORMATS); + bool val = set & (1 << static_cast(fmt)); + return val; + } + struct Value { + OperatorNodeBase* opr; + const State* prev; + OprFormat opr_fmt; + float time; + ///! index in the topo order of the correspoding operator + size_t opr_idx; + }; + + struct StateHash { + size_t operator()(const State& key) const { + size_t h = 0; + for (auto&& v : key) { + h = mgb::hash_pair_combine(h, + std::hash{}(v)); + } + return h; + } + }; + struct StateEqual { + size_t operator()(const State& lhs, const State& rhs) const { + if (lhs.size() != rhs.size()) + return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) + return false; + } + return true; + } + }; + using StateTable = std::unordered_map; + struct Cut { + StateTable states; + }; + using ProfilingResult = ProfilerBase::ProfilingResult; + using OprConfigTrait = LayoutTransformContext::OprConfigTrait; + struct Context { + const std::vector& topo; + const ProfilingResult& rst; + const OprConfigTrait& opr_configs; + const SmallVector& available_tensor_formats; + }; + /*! + * \brief get the tensor formats configuration for the operator with particular op format + * \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration + * \param[in] opr given operator + * \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. + * \param[in] ctx context + */ + TensorFormats get_io_formats(ThinHashMap& var2fmts, + const OperatorNodeBase* opr, OprFormat opr_fmt, + const Context& ctx); + /*! + * \brief compute the distace of two states of the given varnode + * \param[in] from the source state + * \param[in] to the target state + * \param[in] var given varnode + * \param[in] ctx context + */ + float distance(const TensorFormatsBitSet& from, + const TensorFormatsBitSet& to, VarNode* var, + const Context& ctx); + /*! + * \brief compute the distace of two states of the given cut edges + * \param[in] from the source state + * \param[in] to the target state + * \param[in] edge a VarNodeArry, the given cut edges + * \param[in] ctx context + */ + float state_distance(const State& from, const State& to, + const VarNodeArray& edge, const Context& ctx); + /*! + * \brief analyze the edges of each cut + * \param[out] edges the return edges of the cuts + * \param[out] edge2idx hashmaps, that maps edge(varnode) to its index + * \param[in] ctx contex + */ + void analyze_edges(SmallVector& edges, + SmallVector>& edge2idx, + const Context& ctx); + /*! + * \brief prune states using the distance of states + */ + void prune(StateTable& states, const VarNodeArray& edge, + const Context& ctx); + /*! + * \brief force prune states, reserve the smallest MAX_STATES states + */ + void force_prune(StateTable& states); + +private: + size_t m_max_states; +}; + +TensorFormats DynamicProgrammingSolver::Impl::get_io_formats( + ThinHashMap& var2fmts, + const OperatorNodeBase* opr, OprFormat opr_fmt, const Context& ctx) { + auto&& rst = ctx.rst; + auto&& opr_configs = ctx.opr_configs; + + auto iter = opr_configs.find(opr->dyn_typeinfo()); + Maybe fmtcfg = None; + if (iter != opr_configs.end()) { + fmtcfg = (*iter->second.at(opr_fmt))(opr); + } + TensorFormats out_fmt; + if (fmtcfg.valid()) + out_fmt = fmtcfg.val().output_tensor_formats[0]; + else + out_fmt = opr_format_to_tensor_formats(opr_fmt); + for (size_t i = 0; i < opr->input().size(); ++i) { + auto&& var = opr->input(i); + auto iter = rst.var_record.find(var); + if (iter != rst.var_record.end()) { + if (fmtcfg.valid()) + var2fmts[var] = fmtcfg.val().input_tensor_formats[i]; + else + var2fmts[var] = opr_format_to_tensor_formats(opr_fmt); + } + } + return out_fmt; +} + +float DynamicProgrammingSolver::Impl::distance(const TensorFormatsBitSet& from, + const TensorFormatsBitSet& to, + VarNode* var, + const Context& ctx) { + auto&& costs = ctx.rst.var_record.at(var).costs; + auto&& available_tensor_formats = ctx.available_tensor_formats; + + float dist = 0.f; + if ((from & to) == to) + return dist; + auto to_set = ((from | to) ^ from); + for (auto o : available_tensor_formats) { + if (valid(to_set, o)) { + float o_cost = std::numeric_limits::max(); + for (auto i : available_tensor_formats) { + if (valid(from, i)) { + float cost = costs.at({i, o}); + o_cost = std::min(o_cost, cost); + } + } + dist += o_cost; + } + } + return dist; +} + +float DynamicProgrammingSolver::Impl::state_distance(const State& from, + const State& to, + const VarNodeArray& edge, + const Context& ctx) { + float dist = 0.f; + mgb_assert(from.size() == to.size() && from.size() == edge.size()); + for (size_t i = 0; i < edge.size(); ++i) { + dist += distance(from[i], to[i], edge[i], ctx); + } + return dist; +} + +void DynamicProgrammingSolver::Impl::analyze_edges( + SmallVector& edges, + SmallVector>& edge2idx, + const Context& ctx) { + auto&& topo = ctx.topo; + auto&& rst = ctx.rst; + + size_t nr_oprs = topo.size(); + + edges.resize(nr_oprs); + edge2idx.resize(nr_oprs); + + ThinHashSet cur_edge; + size_t cur = nr_oprs - 1; + int idx = 0; + for (auto&& ov : topo[cur]->usable_output()) { + edges[cur].push_back(ov); + edge2idx[cur].emplace(ov, idx++); + } + cur--; + for (const auto& opr : reverse_adaptor(topo)) { + for (const auto& i : opr->input()) { + if (rst.var_record.count(i) > 0) { + cur_edge.insert(i); + } + } + for (auto&& ov : opr->usable_output()) { + cur_edge.erase(ov); + } + edges[cur].insert(edges[cur].begin(), cur_edge.begin(), cur_edge.end()); + int i = 0; + for (auto&& e : edges[cur]) { + edge2idx[cur][e] = i++; + } + if (cur == 0) + break; + cur--; + } +} + +void DynamicProgrammingSolver::Impl::prune(StateTable& states, + const VarNodeArray& edge, + const Context& ctx) { + struct Item { + decltype(states.begin()) iter; + }; + std::list list; + for (auto it = states.begin(); it != states.end(); ++it) { + list.emplace_back(Item{it}); + } + SmallVector removed_states; + for (auto i = list.begin(); i != list.end();) { + bool advanced_i = false; + for (auto j = std::next(i, 1); j != list.end();) { + if (i->iter->second.time > j->iter->second.time && + state_distance(j->iter->first, i->iter->first, edge, ctx) < + i->iter->second.time - j->iter->second.time) { + removed_states.push_back(i->iter->first); + i = list.erase(i); + advanced_i = true; + break; + } else if (i->iter->second.time < j->iter->second.time && + state_distance(i->iter->first, j->iter->first, edge, + ctx) < + j->iter->second.time - i->iter->second.time) { + removed_states.push_back(j->iter->first); + j = list.erase(j); + } else { + j = std::next(j, 1); + } + } + if (!advanced_i) + i = std::next(i, 1); + } + for (auto&& state : removed_states) + states.erase(state); +} + +void DynamicProgrammingSolver::Impl::force_prune(StateTable& states) { + if (states.size() < m_max_states) + return; + struct Item { + decltype(states.begin()) iter; + }; + auto cmp = [](Item lhs, Item rhs) { + return lhs.iter->second.time < rhs.iter->second.time; + }; + std::priority_queue, decltype(cmp)> pq(cmp); + for (auto it = states.begin(); it != states.end(); ++it) { + if (pq.size() < m_max_states) + pq.push(Item{it}); + else { + auto i = pq.top(); + if (it->second.time < i.iter->second.time) { + pq.pop(); + pq.push(Item{it}); + } + } + } + StateTable active_state; + while (!pq.empty()) { + auto i = pq.top(); + active_state.insert(*i.iter); + pq.pop(); + } + states.swap(active_state); +} + +DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( + const ProfilerBase* profiler, const Problem& problem) { + const auto rst = profiler->profile(problem); + const auto& partition = problem.graph_partition(); + const auto& opr_configs = problem.opr_configs(); + const auto& base_fmt = problem.base_format(); + const auto& available_tensor_formats = problem.available_tensor_formats(); + const auto& topo = partition.all_oprs(); + Context ctx{topo, rst, opr_configs, available_tensor_formats}; + + SmallVector edges; + SmallVector> edge2idx; + /// analyze edges of each cuts + analyze_edges(edges, edge2idx, ctx); + + SmallVector cuts; + size_t cur = 0; + + /// initialize states + auto init = [&, this](OperatorNodeBase* opr) { + auto it = rst.opr_record.find(opr); + if (it == rst.opr_record.end()) + return; + ThinHashSet ovar_set; + for (auto&& ov : opr->usable_output()) { + ovar_set.insert(ov); + } + const auto& records = it->second.costs; + cuts.emplace_back(Cut{}); + auto& states = cuts.back().states; + for (const auto& record : records) { + auto opr_fmt = record.first; + float opr_time = record.second; + ThinHashMap ivar2fmts; + auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); + const auto& edge = edges[cur]; + State state(edge.size(), 0); + Value value{opr, nullptr, opr_fmt, 0.f, cur}; + float ovar_time = 0.f; + for (size_t i = 0; i < edge.size(); ++i) { + auto&& var = edge[i]; + auto&& costs = rst.var_record.at(var).costs; + if (ovar_set.count(var) > 0) { + add(state[i], out_fmt); + if (partition.output().count(var) > 0 && + out_fmt != base_fmt) { + ovar_time += costs.at({out_fmt, base_fmt}); + add(state[i], base_fmt); + } + } else { + add(state[i], base_fmt); + } + } + float ivar_time = 0.f; + for (const auto& kv : ivar2fmts) { + auto&& v = kv.first; + auto&& costs = rst.var_record.at(v).costs; + auto to = kv.second; + float min_time = std::numeric_limits::max(); + if (base_fmt == to) { + min_time = 0.f; + } else { + min_time = costs.at({base_fmt, to}); + if (edge2idx[cur].count(v) > 0) { + add(state[edge2idx[cur][v]], to); + } + } + ivar_time += min_time; + } + value.time = opr_time + ivar_time + ovar_time; + states[state] = value; + } + }; + + /// update the states + auto body = [&, this](OperatorNodeBase* opr) { + auto it = rst.opr_record.find(opr); + if (it == rst.opr_record.end()) + return; + ThinHashSet ovar_set; + for (auto&& ov : opr->usable_output()) { + ovar_set.insert(ov); + } + const auto& records = it->second.costs; + StateTable states; + for (const auto& record : records) { + auto opr_fmt = record.first; + float opr_time = record.second; + ThinHashMap ivar2fmts; + auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx); + for (const auto& kv : cuts.back().states) { + auto&& prev_state = kv.first; + float prev_time = kv.second.time; + const auto& edge = edges[cur]; + State state(edge.size(), 0); + Value value{opr, &prev_state, opr_fmt, 0.f, cur}; + float ovar_time = 0.f; + for (size_t i = 0; i < edge.size(); ++i) { + auto&& var = edge[i]; + auto&& costs = rst.var_record.at(var).costs; + auto iter = edge2idx[cur - 1].find(var); + if (iter != edge2idx[cur - 1].end()) { + state[i] = prev_state[iter->second]; + } else { + mgb_assert(ovar_set.count(var) > 0); + add(state[i], out_fmt); + if (partition.output().count(var) > 0 && + out_fmt != base_fmt) { + ovar_time += costs.at({out_fmt, base_fmt}); + + add(state[i], base_fmt); + } + } + } + float ivar_time = 0.f; + for (const auto& kv : ivar2fmts) { + auto&& v = kv.first; + auto&& costs = rst.var_record.at(v).costs; + auto to = kv.second; + auto it1 = edge2idx[cur - 1].find(v); + float min_time = std::numeric_limits::max(); + if (valid(prev_state[it1->second], to)) { + min_time = 0.f; + } else { + for (auto&& from : available_tensor_formats) { + if (valid(prev_state[it1->second], from)) { + float cost = costs.at({from, to}); + min_time = std::min(min_time, cost); + } + } + } + auto it2 = edge2idx[cur].find(v); + if (it2 != edge2idx[cur].end()) { + add(state[it2->second], to); + } + ivar_time += min_time; + } + value.time = prev_time + opr_time + ivar_time + ovar_time; + auto iter = states.find(state); + if (iter == states.end()) { + states[state] = value; + } else { + float time = iter->second.time; + if (value.time < time) { + iter->second = value; + } + } + } + } + cuts.emplace_back(Cut{}); + cuts.back().states.swap(states); + }; + + /// forward pass to generate all states + for (auto&& opr : topo) { + if (cuts.empty()) { + init(opr); + } else { + body(opr); + } + if (!cuts.empty()) { + auto& states = cuts.back().states; + prune(states, edges[cur], ctx); + force_prune(states); + } + cur++; + } + + Solution solution; + + /// backward pass to generate the solution + float min_time = std::numeric_limits::max(); + OperatorNodeBase* cur_opr; + OprFormat min_fmt; + const State* pstate = nullptr; + for (auto&& kv : cuts.back().states) { + auto&& v = kv.second; + if (v.time < min_time) { + cur_opr = v.opr; + pstate = v.prev; + min_time = v.time; + min_fmt = v.opr_fmt; + ///! just to check the tensor formats of the output varnode + auto&& k = kv.first; + size_t opr_idx = v.opr_idx; + for (size_t i = 0; i < k.size(); ++i) { + auto&& fmt_set = k[i]; + auto&& var = edges[opr_idx][i]; + if (partition.output().count(var)) { + mgb_assert(valid(fmt_set, base_fmt)); + } + } + } + } + mgb_log_debug("opr:%s;format:%s;time:%f", cur_opr->cname(), + opr_format_to_string(min_fmt), min_time); + + solution.insert({cur_opr, min_fmt}); + cur = cuts.size() - 2; + while (pstate) { + auto val = cuts[cur].states[*pstate]; + ///! just to check the tensor formats of the output varnode + size_t opr_idx = val.opr_idx; + for (size_t i = 0; i < pstate->size(); ++i) { + auto&& fmt_set = pstate->operator[](i); + auto&& var = edges[opr_idx][i]; + if (partition.output().count(var)) { + mgb_assert(valid(fmt_set, base_fmt)); + } + } + mgb_log_debug("opr:%s;format:%s;time:%f", val.opr->cname(), + opr_format_to_string(val.opr_fmt), val.time); + solution.insert({val.opr, val.opr_fmt}); + pstate = val.prev; + cur--; + } + return solution; +} + +/* =================== DynamicProgrammingSolver ======================*/ +DynamicProgrammingSolver::Solution DynamicProgrammingSolver::do_solve( + const Problem& problem) const { + constexpr size_t MAX_STATES = 1024; + Impl impl(MAX_STATES); + return impl.solve(m_profiler.get(), problem); +} + +bool DynamicProgrammingSolver::can_solve(const Problem& problem) const { + auto&& available_tensor_formats = problem.available_tensor_formats(); + for (auto&& tensor_format : available_tensor_formats) { + if (static_cast(tensor_format) >= 32) + return false; + } + return true; +} + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/layout_transform_context.cpp b/src/gopt/impl/layout_transform_context.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98881203b7f84510764ad60f1289f18baa9f7c3a --- /dev/null +++ b/src/gopt/impl/layout_transform_context.cpp @@ -0,0 +1,40 @@ +/** + * \file src/gopt/impl/layout_transform_context.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./utils.h" +#include "megbrain/gopt/global_layout_transform.h" + +using namespace mgb; +using namespace gopt; + +/* ================= LayoutTransformContext ==================*/ +LayoutTransformContext& LayoutTransformContext::add_opr_config( + Typeinfo* opr, OprFormat opr_format) { + auto& dispatchers = m_opr_configs[opr]; + dispatchers[opr_format] = + OprTensorFormatsConfiguration::find_dispatcher_by_type_format( + opr, opr_format); + return *this; +} + +LayoutTransformContext& LayoutTransformContext::add_opr_config( + Typeinfo* opr, SmallVector opr_formats) { + auto& dispatchers = m_opr_configs[opr]; + for (auto opr_fmt : opr_formats) { + dispatchers[opr_fmt] = + OprTensorFormatsConfiguration::find_dispatcher_by_type_format( + opr, opr_fmt); + } + return *this; +} + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/profiler_impl.cpp b/src/gopt/impl/profiler_impl.cpp index aa95296443f2ddb542be59dbc88ef0a2e3aa85e3..38f7e6ca90946def2fefcd39fa8086cd51ee87f0 100644 --- a/src/gopt/impl/profiler_impl.cpp +++ b/src/gopt/impl/profiler_impl.cpp @@ -17,6 +17,7 @@ #include "megbrain/graph/event.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" +#include "megbrain/opr/nn_int.h" #include "megbrain/opr/io.h" #include "megbrain/plugin/base.h" #include "megbrain/serialization/sereg.h" @@ -265,6 +266,10 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( record.opr = opr; auto& costs = record.costs; for (auto&& i : available_configs) { + /// XXXX remove later + if (i.opr_format == OprFormat::NCHW && + opr->input(0)->dtype().enumv() != DTypeEnum::Float32) + continue; costs[i.opr_format] = profile_operator(opr, base_config, i); } return record; @@ -414,37 +419,42 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( cb(Resize, 1), #undef cb }; + static const ThinHashSet skip_opr_types = { + TypeCvt::typeinfo(), Elemwise::typeinfo(), + ElemwiseMultiType::typeinfo()}; ThinHashSet vars; ThinHashSet oprs; - { - auto cb = [&cvprop, &vars, &oprs](OperatorNodeBase* opr) { - if (cvprop.is_const(opr)) - return; - oprs.insert(opr); - auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); - if (find == format_aware_input_tensors.end()) { - for (auto&& i : opr->input()) { - if (!cvprop.is_const(i)) { - vars.insert(i); - } + ThinHashSet skip_oprs; + for (auto&& opr : problem.graph_partition().all_oprs()) { + if (cvprop.is_const(opr)) + continue; + bool skip = true; + for (auto&& i : opr->input()) { + skip &= problem.graph_partition().input().count(i) > 0 || + skip_oprs.count(i->owner_opr()) > 0; + } + skip &= skip_opr_types.count(opr->dyn_typeinfo()); + if (skip) + skip_oprs.insert(opr); + oprs.insert(opr); + auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); + if (find == format_aware_input_tensors.end()) { + for (auto&& i : opr->input()) { + if (!cvprop.is_const(i)) { + vars.insert(i); } - } else { - size_t nr_input_tensor = - std::min(find->second, opr->input().size()); - for (size_t i = 0; i < nr_input_tensor; ++i) { - if (!cvprop.is_const(opr->input(i))) { - vars.insert(opr->input(i)); - } + } + } else { + size_t nr_input_tensor = + std::min(find->second, opr->input().size()); + for (size_t i = 0; i < nr_input_tensor; ++i) { + if (!cvprop.is_const(opr->input(i))) { + vars.insert(opr->input(i)); } } - vars.insert(opr->output(0)); - }; - DepOprIter iter{cb}; - for (auto&& i : problem.graph_partition().input()) { - iter.set_visited(i->owner_opr()); } - for (auto&& o : problem.graph_partition().output()) { - iter.add(o->owner_opr()); + for (auto&& ov : opr->usable_output()) { + vars.insert(ov); } } @@ -462,8 +472,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( auto&& opr_configs = problem.opr_configs(); auto find = opr_configs.find(opr->dyn_typeinfo()); if (find == opr_configs.end()) { - opr_record[opr] = profile_operator(opr, base_format, - available_tensor_formats); + if (skip_oprs.count(opr) > 0) { + SmallVector tensor_formats = {base_format}; + opr_record[opr] = + profile_operator(opr, base_format, tensor_formats); + } else { + opr_record[opr] = profile_operator(opr, base_format, + available_tensor_formats); + } } else { auto&& dispatchers = find->second; SmallVector configs; diff --git a/src/gopt/impl/profiling_based_solver.cpp b/src/gopt/impl/profiling_based_solver.cpp new file mode 100644 index 0000000000000000000000000000000000000000..760d70b97df0cbcf7b56172ca30f220b6a4eade8 --- /dev/null +++ b/src/gopt/impl/profiling_based_solver.cpp @@ -0,0 +1,56 @@ +/** + * \file src/gopt/impl/profiling_based_solver.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/imgproc.h" + +using namespace mgb; +using namespace gopt; +using namespace opr; + +/* =================== ProfilingBasedSolverSolver ======================*/ +ProfilingBasedSolver::ProfilingBasedSolver( + std::unique_ptr profiler) + : m_profiler{std::move(profiler)} { + static const ThinHashSet format_aware_oprs = { +#define cb(_Opr) _Opr::typeinfo() + cb(Convolution), + cb(ConvBiasForward), + cb(ConvolutionBackwardData), + cb(PoolingForward), + cb(WarpPerspective), + cb(Resize), + }; + + m_graph_partition_filter = [](const GraphPartition& partition) { + bool has_format_aware_opr = false; + for (auto&& opr : partition.all_oprs()) { + if (!has_format_aware_opr && + format_aware_oprs.count(opr->dyn_typeinfo())) { + has_format_aware_opr = true; + break; + } + } + return has_format_aware_opr; + }; +} + +ProfilingBasedSolver::Solution ProfilingBasedSolver::solve( + const Problem& problem) const { + const auto& partition = problem.graph_partition(); + if (!m_graph_partition_filter(partition)) + return Solution{}; + return do_solve(problem); +} + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/reformat_manager.cpp b/src/gopt/impl/reformat_manager.cpp index 79a4e2c819f89cbf61765e235faf941f675bd405..d69cae29be09d275438e714152ccea8d7ef27c65 100644 --- a/src/gopt/impl/reformat_manager.cpp +++ b/src/gopt/impl/reformat_manager.cpp @@ -11,9 +11,9 @@ */ #include "megbrain/gopt/reformat_manager.h" +#include "./utils.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/utils/arith_helper.h" -#include "./utils.h" using namespace mgb; using namespace gopt; @@ -87,21 +87,6 @@ bool ReformatManager::ReformatKey::Equal::operator()( lhs.attribute == rhs.attribute; } -ReformatManager::ReformatKey& -ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { - static const ThinHashSet> set = { - {TensorFormats::NCHW, TensorFormats::NCHWc64}, - {TensorFormats::NCHWc64, TensorFormats::NCHW}, - {TensorFormats::NCHW, TensorFormats::NHWC}, - {TensorFormats::NHWC, TensorFormats::NCHW}}; - if (set.count({input_format, output_format}) > 0 && - (dt.enumv() == DTypeEnum::QuantizedS4 || - dt.enumv() == DTypeEnum::Quantized4Asymm)) { - input_dtype = output_dtype = dt.enumv(); - } - return *this; -} - // =================== ReformatManager ====================*/ ReformatManager::ReformatManager() { using Attribute = ReformatKey::Attribute; @@ -378,7 +363,7 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( divup(orig_channel, input_alignment) * input_alignment; size_t aligned_out_channel = divup(orig_channel, output_alignment) * output_alignment; - size_t common_alignment = input_alignment * output_alignment / + size_t common_alignment = input_alignment * output_alignment / gcd(input_alignment, output_alignment); size_t aligned_channel = divup(orig_channel, common_alignment) * common_alignment; @@ -427,11 +412,11 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( for (size_t i = 0; i < input_shape.ndim; ++i) { if (input_shape[i].name() == Dimension::Name::C && input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { - in_channels = orig_var->shape()[i]; + in_channels = orig_var->shape()[i] * input_shape[i].stride(); input_channel_idx = i; - mgb_assert(input_shape[i].stride() == 1, - "unsupport weight format(got:%s)", - input_shape.to_string().c_str()); +// mgb_assert(input_shape[i].stride() == 1, +// "unsupport weight format(got:%s)", +// input_shape.to_string().c_str()); } else if ((input_shape[i].name() == Dimension::Name::K || input_shape[i].name() == Dimension::Name::N) && input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { @@ -536,7 +521,8 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, "formats(var:%s;shp:%s;fmt:%s)", var->cname(), oshp.to_string().c_str(), orig_shape.to_string().c_str()); - if (oshp.is_scalar()) return oshp; + if (oshp.is_scalar()) + return oshp; TensorShape tshp; ThinHashMap name2dominant; for (size_t i = 0; i < orig_shape.ndim; ++i) { @@ -597,4 +583,32 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, return tshp; } +ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( + TensorFormats weight_format, TensorFormats out_feature_format) { + using AlignmentDesc = ReformatManager::AlignmentDesc; + using Name = Dimension::Name; + auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); + auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); + size_t out_channel_alignment = 1; + for (size_t i = 0; i < out_shape.ndim; ++i) { + auto name = out_shape[i].name(); + auto extent = out_shape[i].extent(); + if ((name == Name::C || name == Name::K) && + extent == Dimension::UNDETERMINED_EXTENT) { + out_channel_alignment = out_shape[i].stride(); + break; + } + } + Name out_channel_name; + for (size_t i = 0; i < weight_shape.ndim; ++i) { + auto name = weight_shape[i].name(); + auto extent = weight_shape[i].extent(); + if ((name == Name::N || name == Name::K) && + extent == Dimension::UNDETERMINED_EXTENT) { + out_channel_name = name; + } + } + return AlignmentDesc{out_channel_name, out_channel_alignment}; +} + // vim: syntax=cpp.doxygen diff --git a/src/gopt/impl/subgraph_extractor.cpp b/src/gopt/impl/subgraph_extractor.cpp index 5e0d88f520a917907ec767e30f72e6368fc6696c..e7c9cbb1d1bf1b704a6a9cf17ea29a5e1527aa80 100644 --- a/src/gopt/impl/subgraph_extractor.cpp +++ b/src/gopt/impl/subgraph_extractor.cpp @@ -304,10 +304,15 @@ std::vector SubGraphExtractor::extract( } } partition->opr_set().insert(opr); + partition->all_oprs().push_back(opr); for (const auto& i : opr->input()) partition->input().insert(i); } } + for (auto&& partition : partitions) { + auto& all_oprs = partition.all_oprs(); + std::reverse(all_oprs.begin(), all_oprs.end()); + } return partitions; } diff --git a/src/gopt/impl/utils.h b/src/gopt/impl/utils.h index 335302081b4a3783f61560f7c7b0a3c8b3a0a228..325f2c7ed8fe5d72eebf21bfe4ae0ba82b8b4e79 100644 --- a/src/gopt/impl/utils.h +++ b/src/gopt/impl/utils.h @@ -36,6 +36,28 @@ static inline const char* opr_format_to_string( #undef cb } +static inline TensorFormats opr_format_to_tensor_formats( + OprTensorFormatsConfiguration::OprFormat opr_format) { + using OprFormat = OprTensorFormatsConfiguration::OprFormat; + switch (opr_format) { + case OprFormat::NCHW: + return TensorFormats::NCHW; + case OprFormat::NHWC: + return TensorFormats::NHWC; + case OprFormat::NCHW4: + return TensorFormats::NCHWc4; + case OprFormat::NCHW32: + return TensorFormats::NCHWc32; + case OprFormat::NCHW64: + return TensorFormats::NCHWc64; + case OprFormat::CHWN4: + return TensorFormats::CHWNc4; + default: + mgb_throw(AssertionError, "format(%s) is not supported", + opr_format_to_string(opr_format)); + }; +} + static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( TensorFormats format) { switch (format) { diff --git a/src/gopt/include/megbrain/gopt/global_layout_transform.h b/src/gopt/include/megbrain/gopt/global_layout_transform.h index 98432204d3eaebec7bf719ae5c21de89c7deba5f..9a851869a42ed3483ed55adce0a3720a2134a280 100644 --- a/src/gopt/include/megbrain/gopt/global_layout_transform.h +++ b/src/gopt/include/megbrain/gopt/global_layout_transform.h @@ -11,6 +11,7 @@ */ #pragma once +#include "megbrain/gopt/framework.h" #include "megbrain/gopt/reformat_manager.h" #include "megbrain/gopt/subgraph_extractor.h" #include "megbrain/opr/dnn/convolution.h" @@ -41,14 +42,16 @@ struct OprTensorFormatsConfiguration { /*! * \brief A structure that describes the global layout transform problem */ -class Problem { +class LayoutTransformContext { public: + using OprList = SubGraphExtractor::OprList; using OprFormat = OprTensorFormatsConfiguration::OprFormat; using OprTensorFormatsDispatcher = OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; using OprConfigTrait = ThinHashMap>; + using ReformatAttribute = ReformatManager::ReformatKey::Attribute; struct Attribute { OprFormat base_opr_format; /// the base opr format indicates that the /// network to be optimized is constructed @@ -62,58 +65,110 @@ public: /// (like elemwise, elemwise multi type, /// typecvt etc.) are built in the base /// tensor format. + ReformatAttribute + reformat_attribute; /// additional reformat attribute, which + /// indicates whether to pad nhwc layout + /// automatically or to enable nhwcd4 format + /// on opencl platform to use image object }; - Problem(const GraphPartition& graph_partition, - const SmallVector& available_tensor_formats, - const OprConfigTrait& opr_config, const Attribute& attribute) - : m_graph_partition{graph_partition}, - m_available_tensor_formats{available_tensor_formats}, - m_opr_configs{opr_config}, + LayoutTransformContext() = delete; + LayoutTransformContext(OprList opr_list, + SmallVector available_tensor_formats, + Attribute attribute) + : m_opr_list{std::move(opr_list)}, + m_available_tensor_formats{std::move(available_tensor_formats)}, + m_attribute{attribute} {} + LayoutTransformContext(OprList opr_list, + SmallVector available_tensor_formats, + OprConfigTrait opr_configs, Attribute attribute) + : m_opr_list{std::move(opr_list)}, + m_available_tensor_formats{std::move(available_tensor_formats)}, + m_opr_configs{std::move(opr_configs)}, m_attribute{attribute} {} + const OprList& opr_list() const { return m_opr_list; } + const SmallVector& available_tensor_formats() const { + return m_available_tensor_formats; + } + const OprConfigTrait& opr_configs() const { return m_opr_configs; } + Attribute attribute() const { return m_attribute; } + /*! + * \brief add an op format configuration for a particular operator type + * \param opr runtime typeinfo of operator + * \param opr_format op format configuration which to be enabled in the + * layout transform problem + */ + LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format); + /*! + * \brief add a vector of op format configurations for a particular operator + * type + * \param opr runtime typeinfo of operator + * \param opr_format op format configuration which to be enabled in the + * layout transform problem + */ + LayoutTransformContext& add_opr_config(Typeinfo* opr, + SmallVector opr_formats); + +private: + OprList m_opr_list; /// supported operator list + SmallVector + m_available_tensor_formats; /// the available tensor formats, used + /// for format agnostic operators (like + /// elemwise, elemwise multi type, + /// typecvt, etc. + OprConfigTrait m_opr_configs; /// the available opr format configurations, + /// used for format aware operators (like + /// conv, deconv, conv_bias, etc. + Attribute m_attribute; /// the extra attributes to describe the problem +}; + +class Problem { +public: + using OprFormat = OprTensorFormatsConfiguration::OprFormat; + using OprTensorFormatsDispatcher = + OprTensorFormatsConfiguration::OprTensorFormatsDispatcher; + using OprConfigTrait = LayoutTransformContext::OprConfigTrait; + using Attribute = LayoutTransformContext::Attribute; + + Problem(const GraphPartition& graph_partition, + const LayoutTransformContext& ctx) + : m_graph_partition{graph_partition}, m_ctx{ctx} {} ~Problem() noexcept = default; const GraphPartition& graph_partition() const { return m_graph_partition; } - const OprConfigTrait& opr_configs() const { return m_opr_configs; } + const OprConfigTrait& opr_configs() const { return m_ctx.opr_configs(); } const SmallVector& available_tensor_formats() const { - return m_available_tensor_formats; + return m_ctx.available_tensor_formats(); } TensorFormats base_format() const { - return m_attribute.base_tensor_formats; + return m_ctx.attribute().base_tensor_formats; } + /*! + * \brief return the tensor formats configuration of an operator in the + * default op format + */ OprTensorFormatsConfiguration base_config( const cg::OperatorNodeBase* opr) const { auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format( - opr->dyn_typeinfo(), m_attribute.base_opr_format); + opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format); auto rst = (*_)(opr); if (rst.valid()) return rst.val(); OprTensorFormatsConfiguration config; config.typeinfo = opr->dyn_typeinfo(); - config.opr_format = m_attribute.base_opr_format; + config.opr_format = m_ctx.attribute().base_opr_format; for (const auto& i : opr->input()) { config.input_dtypes.emplace_back(i->dtype().enumv()); - config.input_tensor_formats.emplace_back( - m_attribute.base_tensor_formats); + config.input_tensor_formats.emplace_back(base_format()); config.input_tensor_types.emplace_back(TensorType::FEATURE); } config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); - config.output_tensor_formats.emplace_back( - m_attribute.base_tensor_formats); + config.output_tensor_formats.emplace_back(base_format()); return config; } private: const GraphPartition& m_graph_partition; /// the graph partition - const SmallVector& - m_available_tensor_formats; /// the available tensor formats, used - /// for format agnostic operators (like - /// elemwise, elemwise multi type, - /// typecvt, etc. - const OprConfigTrait& - m_opr_configs; /// the available opr format configurations, used - /// for format aware operators (like conv, deconv, - /// conv_bias, etc. - Attribute m_attribute; /// the extra attributes to describe the problem + const LayoutTransformContext& m_ctx; }; /*! @@ -170,6 +225,92 @@ public: static std::unique_ptr make_profiler(); }; +/*! + * \brief abstract solver + */ +class SolverBase { +public: + using OprFormat = Problem::OprFormat; + using Solution = ThinHashMap; + SolverBase() = default; + virtual ~SolverBase() = default; + /*! + * \brief solve the given problem + */ + virtual Solution solve(const Problem& problem) const = 0; + /*! + * \brief check whether the given problem can be solved by the + * algorithm(i.e. solver). + */ + virtual bool can_solve(const Problem& problem) const = 0; +}; + +/*! + * \brief solvers that will first collect the costs of operators in different op + * format and the costs of layout transform of varnode with a user provided + * profiler on the target device. This will lead to time consuming. + */ +class ProfilingBasedSolver : public SolverBase { +public: + using GraphPartitionFilter = + thin_function; + ProfilingBasedSolver(std::unique_ptr profiler); + /*! + * \note some graph partition (for example, graph partition without format + * aware operators like conv, deconv, warp, resize etc.) will be filtered by + * the GraphPartitionFilter, which can reduce the profiling time. */ + ProfilingBasedSolver(std::unique_ptr profiler, + GraphPartitionFilter graph_partition_filter) + : m_profiler{std::move(profiler)}, + m_graph_partition_filter{std::move(graph_partition_filter)} {} + virtual ~ProfilingBasedSolver() = default; + Solution solve(const Problem& problem) const override; + virtual Solution do_solve(const Problem& problem) const = 0; + +protected: + std::unique_ptr m_profiler; + +private: + GraphPartitionFilter m_graph_partition_filter; +}; + +/*! + * \brief A solver that solves the layout selection problem using dynamic + * programming algorithm (Markov decision process). + */ +class DynamicProgrammingSolver final : public ProfilingBasedSolver { +public: + DynamicProgrammingSolver(std::unique_ptr profiler) + : ProfilingBasedSolver(std::move(profiler)){}; + DynamicProgrammingSolver(std::unique_ptr profiler, + GraphPartitionFilter graph_partition_filter) + : ProfilingBasedSolver(std::move(profiler), + std::move(graph_partition_filter)){}; + ~DynamicProgrammingSolver() noexcept = default; + Solution do_solve(const Problem& problem) const override; + bool can_solve(const Problem& problem) const override; + +private: + class Impl; +}; + +/*! + * \brief A layout transform pass, which convert the operator's format to the + * optimal format using the results of the solver. + */ +class LayoutTransformPass final : public Pass { +public: + const char* name() const override { return "layout assignment pass"; } + void apply(OptState& opt) const override; + LayoutTransformPass(std::unique_ptr ctx, + std::unique_ptr solver) + : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} + +private: + std::unique_ptr m_ctx; + std::unique_ptr m_solver; +}; + } // namespace gopt } // namespace mgb diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h index 9b1c2652af70149fc59741cd693b495b56dbf5bf..7464dea2a76cf8c5343120ed2dd01d30c82185e1 100644 --- a/src/gopt/include/megbrain/gopt/reformat_manager.h +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -84,7 +84,7 @@ public: output_dtype{DTypeEnum::Float32}, attribute{Attribute::DEFAULT} {} ReformatKey(TensorFormats input_format_, TensorFormats output_format_, - Attribute attribute_ = Attribute::DEFAULT, + Attribute attribute_, DTypeEnum input_dtype_ = DTypeEnum::Float32, DTypeEnum output_dtype_ = DTypeEnum::Float32) : input_format{input_format_}, @@ -92,6 +92,15 @@ public: input_dtype{input_dtype_}, output_dtype{output_dtype_}, attribute{attribute_} {} + ReformatKey(TensorFormats input_format_, TensorFormats output_format_, + DTypeEnum input_dtype_ = DTypeEnum::Float32, + DTypeEnum output_dtype_ = DTypeEnum::Float32, + Attribute attribute_ = Attribute::DEFAULT) + : input_format{input_format_}, + output_format{output_format_}, + input_dtype{input_dtype_}, + output_dtype{output_dtype_}, + attribute{attribute_} {} struct Hash { size_t operator()(const ReformatKey& key) const; }; @@ -99,7 +108,6 @@ public: bool operator()(const ReformatKey& lhs, const ReformatKey& rhs) const; }; - ReformatKey& deduce_reformat_dtype_enum(const DType& dt); }; using ReformatCache = std::unordered_map; using OperatorNodeSet = ThinHashSet; + using OperatorNodeList = std::vector; class InputPlaceholder; @@ -32,15 +33,18 @@ public: const OperatorNodeSet& opr_set() const { return m_opr_set; } const VarNodeSet& input() const { return m_inputs; } const VarNodeSet& output() const { return m_outputs; } + const OperatorNodeList& all_oprs() const { return m_oprs; } OperatorNodeSet& opr_set() { return m_opr_set; } + OperatorNodeList& all_oprs() { return m_oprs; } VarNodeSet& input() { return m_inputs; } VarNodeSet& output() { return m_outputs; } private: + std::pair replace_graph_by_placeholder() const; OperatorNodeSet m_opr_set; + OperatorNodeList m_oprs; VarNodeSet m_inputs; VarNodeSet m_outputs; - std::pair replace_graph_by_placeholder() const; }; class SubGraphExtractor { diff --git a/src/gopt/test/profiler.cpp b/src/gopt/test/profiler.cpp index b3be17e2e6e36cd4532997121c410a8bbf2bd1af..686dd677896c472dc900ae633d5f96e68e8148df 100644 --- a/src/gopt/test/profiler.cpp +++ b/src/gopt/test/profiler.cpp @@ -10,6 +10,7 @@ * implied. */ +#include "megbrain/plugin/profiler.h" #include "./helper.h" #include "megbrain/gopt/global_layout_transform.h" #include "megbrain/gopt/inference.h" @@ -22,123 +23,59 @@ using namespace mgb; using namespace gopt; using namespace serialization; +#if MGB_CUDA namespace { -class LayoutTransformContext : public NonCopyableObj { -public: - using OprList = SubGraphExtractor::OprList; - using OprFormat = Problem::OprFormat; - using OprConfigTrait = Problem::OprConfigTrait; - - LayoutTransformContext() = delete; - LayoutTransformContext(OprList opr_list, - SmallVector available_tensor_formats, - OprConfigTrait opr_configs) - : m_opr_list{std::move(opr_list)}, - m_available_tensor_formats{std::move(available_tensor_formats)}, - m_opr_configs{std::move(opr_configs)} {} - const OprList& opr_list() const { return m_opr_list; } - const SmallVector& available_tensor_formats() const { - return m_available_tensor_formats; - } - const OprConfigTrait& opr_configs() const { return m_opr_configs; } - static std::unique_ptr make() { - OprList opr_list = { - opr::ConvBiasForward::typeinfo(), - opr::ConvolutionForward::typeinfo(), - opr::ConvolutionBackwardData::typeinfo(), - opr::ElemwiseMultiType::typeinfo(), - opr::Elemwise::typeinfo(), - opr::TypeCvt::typeinfo(), - opr::PoolingForward::typeinfo(), - opr::WarpPerspectiveForward::typeinfo(), - }; - OprConfigTrait opr_configs; - { - auto& dispatchers = opr_configs[opr::ConvBias::typeinfo()]; -#define cb(_fmt) \ - dispatchers[OprFormat::_fmt] = \ - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ - opr::ConvBias::typeinfo(), OprFormat::_fmt); - cb(NCHW4); - cb(NCHW32); - cb(NHWC); - cb(NCHW64); - cb(CHWN4); -#undef cb - } - { - auto& dispatchers = - opr_configs[opr::ConvolutionBackwardData::typeinfo()]; -#define cb(_fmt) \ - dispatchers[OprFormat::_fmt] = \ - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ - opr::ConvolutionBackwardData::typeinfo(), \ - OprFormat::_fmt); - cb(NCHW4); -#undef cb - } - - { - auto& dispatchers = - opr_configs[opr::ConvolutionForward::typeinfo()]; -#define cb(_fmt) \ - dispatchers[OprFormat::_fmt] = \ - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ - opr::ConvolutionForward::typeinfo(), OprFormat::_fmt); - cb(NCHW4); -#undef cb - } - - { - auto& dispatchers = opr_configs[opr::PoolingForward::typeinfo()]; -#define cb(_fmt) \ - dispatchers[OprFormat::_fmt] = \ - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ - opr::PoolingForward::typeinfo(), OprFormat::_fmt); - cb(NCHW4); - cb(NCHW32); - cb(NHWC); - cb(NCHW64); - cb(CHWN4); -#undef cb - } - - { - auto& dispatchers = - opr_configs[opr::WarpPerspectiveForward::typeinfo()]; -#define cb(_fmt) \ - dispatchers[OprFormat::_fmt] = \ - OprTensorFormatsConfiguration::find_dispatcher_by_type_format( \ - opr::WarpPerspectiveForward::typeinfo(), OprFormat::_fmt); - cb(NHWC); - cb(NCHW4); - cb(NCHW64); -#undef cb - } - - SmallVector available_tensor_formats = { - TensorFormats::NHWC, TensorFormats::NCHWc4, - TensorFormats::NCHWc32, TensorFormats::NCHWc64}; - return std::make_unique( - std::move(opr_list), std::move(available_tensor_formats), - std::move(opr_configs)); - } +std::unique_ptr make_ctx() { + using OprFormat = LayoutTransformContext::OprFormat; + using OprList = LayoutTransformContext::OprList; + using ReformatAttribute = LayoutTransformContext::ReformatAttribute; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ConvolutionBackwardData::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + }; -private: - OprList m_opr_list; - SmallVector m_available_tensor_formats; - OprConfigTrait m_opr_configs; -}; -}; // namespace + SmallVector available_tensor_formats = { + TensorFormats::NCHW, TensorFormats::NHWC, + TensorFormats::NCHWc4, TensorFormats::NCHWc32, + TensorFormats::NCHWc64, TensorFormats::CHWNc4}; + Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, + ReformatAttribute::DEFAULT}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), + attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, + OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config(opr::ConvolutionForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config( + opr::PoolingForward::typeinfo(), + {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, + OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config( + opr::WarpPerspectiveForward::typeinfo(), + {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + return ctx; +} +} // namespace -#if MGB_CUDA #if CUDA_VERSION >= 10020 TEST(TestProfiler, Conv) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); cn.activate(); REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); - auto ctx = LayoutTransformContext::make(); + auto ctx = make_ctx(); HostTensorGenerator gen; auto graph = ComputingGraph::make(); @@ -177,14 +114,10 @@ TEST(TestProfiler, Conv) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({c2}, strategy); - using OprFormat = OprTensorFormatsConfiguration::OprFormat; SubGraphExtractor extractor(ctx->opr_list()); auto partitions = extractor.extract({c2}); ASSERT_EQ(partitions.size(), 1u); - using Attribute = Problem::Attribute; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; - Problem problem(partitions[0], ctx->available_tensor_formats(), - ctx->opr_configs(), attribute); + Problem problem(partitions[0], *ctx); auto profiler = ProfilerBase::make_profiler(); auto rst = profiler->profile(problem); const auto& opr_rst = rst.opr_record; @@ -204,7 +137,7 @@ TEST(TestProfiler, Deconv) { auto cn = CompNode::load("gpu0"); cn.activate(); REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); - auto ctx = LayoutTransformContext::make(); + auto ctx = make_ctx(); HostTensorGenerator gen; auto graph = ComputingGraph::make(); @@ -238,14 +171,10 @@ TEST(TestProfiler, Deconv) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({c2}, strategy); - using OprFormat = OprTensorFormatsConfiguration::OprFormat; SubGraphExtractor extractor(ctx->opr_list()); auto partitions = extractor.extract({c2}); ASSERT_EQ(partitions.size(), 1u); - using Attribute = Problem::Attribute; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; - Problem problem(partitions[0], ctx->available_tensor_formats(), - ctx->opr_configs(), attribute); + Problem problem(partitions[0], *ctx); auto profiler = ProfilerBase::make_profiler(); auto rst = profiler->profile(problem); const auto& opr_rst = rst.opr_record; @@ -262,7 +191,7 @@ TEST(TestProfiler, Warp) { auto cn = CompNode::load("gpu0"); cn.activate(); REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); - auto ctx = LayoutTransformContext::make(); + auto ctx = make_ctx(); constexpr size_t INP_H = 10, INP_W = 10, N = 16; @@ -307,14 +236,9 @@ TEST(TestProfiler, Warp) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({w1}, strategy); - using OprFormat = OprTensorFormatsConfiguration::OprFormat; SubGraphExtractor extractor(ctx->opr_list()); auto partitions = extractor.extract({w1}); - ASSERT_EQ(partitions.size(), 1u); - using Attribute = Problem::Attribute; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; - Problem problem(partitions[0], ctx->available_tensor_formats(), - ctx->opr_configs(), attribute); + Problem problem(partitions[0], *ctx); auto profiler = ProfilerBase::make_profiler(); auto rst = profiler->profile(problem); const auto& opr_rst = rst.opr_record; @@ -330,7 +254,7 @@ TEST(TestProfiler, Pooling) { auto cn = CompNode::load("gpu0"); cn.activate(); REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); - auto ctx = LayoutTransformContext::make(); + auto ctx = make_ctx(); HostTensorGenerator gen; auto graph = ComputingGraph::make(); @@ -353,14 +277,10 @@ TEST(TestProfiler, Pooling) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::PROFILE; gopt::modify_opr_algo_strategy_inplace({p2}, strategy); - using OprFormat = OprTensorFormatsConfiguration::OprFormat; SubGraphExtractor extractor(ctx->opr_list()); auto partitions = extractor.extract({p2}); ASSERT_EQ(partitions.size(), 1u); - using Attribute = Problem::Attribute; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; - Problem problem(partitions[0], ctx->available_tensor_formats(), - ctx->opr_configs(), attribute); + Problem problem(partitions[0], *ctx); auto profiler = ProfilerBase::make_profiler(); auto rst = profiler->profile(problem); const auto& opr_rst = rst.opr_record; @@ -373,8 +293,7 @@ TEST(TestProfiler, Elemwise) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); cn.activate(); - REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); - auto ctx = LayoutTransformContext::make(); + auto ctx = make_ctx(); HostTensorGenerator gen; auto graph = ComputingGraph::make(); @@ -403,14 +322,10 @@ TEST(TestProfiler, Elemwise) { OperatorNodeConfig( dtype::Quantized4Asymm(13.f, static_cast(4)))); - using OprFormat = OprTensorFormatsConfiguration::OprFormat; SubGraphExtractor extractor(ctx->opr_list()); auto partitions = extractor.extract({q4e}); ASSERT_EQ(partitions.size(), 1u); - using Attribute = Problem::Attribute; - Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW}; - Problem problem(partitions[0], ctx->available_tensor_formats(), - ctx->opr_configs(), attribute); + Problem problem(partitions[0], *ctx); auto profiler = ProfilerBase::make_profiler(); auto rst = profiler->profile(problem); const auto& opr_rst = rst.opr_record; @@ -423,7 +338,6 @@ TEST(TestProfiler, Elemwise) { EXPECT_TRUE(var_rst.count(q8a.node()) > 0); EXPECT_TRUE(var_rst.count(q8b.node()) > 0); } - #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/reformat_manager.cpp b/src/gopt/test/reformat_manager.cpp index d639f5240153270dc7ff3de4cf100ccd181cdaa8..43b16f9cb136664dc10392e7b74eed7bba2180f6 100644 --- a/src/gopt/test/reformat_manager.cpp +++ b/src/gopt/test/reformat_manager.cpp @@ -447,6 +447,7 @@ TEST(TestReformatManager, AutoAlignedFeatureProfiling) { for (size_t i = 0; i < RUNS; ++i) func->execute(); double time_profiler = profiler->duration() * 1e6; + printf("time: %f, %f\n", time_cuda_evt, time_profiler); MGB_CUDA_CHECK(cudaEventDestroy(evt0)); MGB_CUDA_CHECK(cudaEventDestroy(evt1)); }