提交 5d9ac970 编写于 作者: M Megvii Engine Team

fix(mgb): fix fastrun compnode

GitOrigin-RevId: 8db93facb94b8358237e5a7e273a1371a0d531c1
上级 56c1b626
...@@ -853,7 +853,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo( ...@@ -853,7 +853,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo(
src.to_string().c_str()); src.to_string().c_str());
param.dtypes[i] = src.dtype.enumv(); param.dtypes[i] = src.dtype.enumv();
} }
param.comp_node_loc = m_cn.locator(); param.comp_node_physical = m_cn.locator();
param.comp_node_logical = m_cn.locator_logical();
mgb_assert(param.shapes.size() == m_fastrun_layouts.size()); mgb_assert(param.shapes.size() == m_fastrun_layouts.size());
for (size_t i = 0; i < param.shapes.size(); ++i) for (size_t i = 0; i < param.shapes.size(); ++i)
param.shapes[i] = m_fastrun_layouts[i]; param.shapes[i] = m_fastrun_layouts[i];
......
...@@ -222,7 +222,8 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( ...@@ -222,7 +222,8 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
mgb_assert(miopen_algo_search_enabled, "MIOpen algo search not enabled"); mgb_assert(miopen_algo_search_enabled, "MIOpen algo search not enabled");
#endif #endif
auto&& param = raw_param.as_single_pod<Param>(); auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); CompNode cn =
CompNode::load(param.comp_node_physical, param.comp_node_logical);
auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn); auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn);
std::array<TensorLayout, arity> layouts; std::array<TensorLayout, arity> layouts;
...@@ -395,7 +396,8 @@ void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) { ...@@ -395,7 +396,8 @@ void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) {
megcore::enableMIOpenAlgoSearch(true); megcore::enableMIOpenAlgoSearch(true);
#endif #endif
auto&& param = raw_param.as_single_pod<Param>(); auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); CompNode cn =
CompNode::load(param.comp_node_physical, param.comp_node_logical);
// wait for cuda init, so its time does not get accounted in timeout // wait for cuda init, so its time does not get accounted in timeout
cn.sync(); cn.sync();
MIDOUT_E MIDOUT_E
......
...@@ -122,7 +122,7 @@ public: ...@@ -122,7 +122,7 @@ public:
ExecutionPolicyBlob execution_policy; ExecutionPolicyBlob execution_policy;
size_t workspace; size_t workspace;
megdnn::DTypeEnum dtypes[arity]; megdnn::DTypeEnum dtypes[arity];
CompNode::Locator comp_node_loc; CompNode::Locator comp_node_physical, comp_node_logical;
TensorShapeArray shapes; TensorShapeArray shapes;
typename Opr::Param opr_param; typename Opr::Param opr_param;
bool allow_weight_preprocess; bool allow_weight_preprocess;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册