提交 3620a940 编写于 作者: M Megvii Engine Team

fix(megdnn): fix test with workspace_limits to zero

GitOrigin-RevId: c4ec323361594550e9ed46fff2f7bf42b8301ed8
上级 0d720653
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megdnn/oprs/base.h"
using namespace mgb; using namespace mgb;
...@@ -26,7 +27,8 @@ SymbolVar make_conv(SymbolVar inp, SymbolVar kern) { ...@@ -26,7 +27,8 @@ SymbolVar make_conv(SymbolVar inp, SymbolVar kern) {
using Conv = opr::Convolution; using Conv = opr::Convolution;
Conv::ExecutionPolicy poly; Conv::ExecutionPolicy poly;
poly.workspace_limit = 0; poly.workspace_limit = 0;
return Conv::make(inp, kern, {}, poly); SymbolVar conv = Conv::make(inp, kern, {}, poly);
return conv;
} }
// used for test NO_SYS_MEM_ALLOC // used for test NO_SYS_MEM_ALLOC
...@@ -74,9 +76,12 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensorDirect); ...@@ -74,9 +76,12 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensorDirect);
TEST(TestMemReuse, PureMLP0) { TEST(TestMemReuse, PureMLP0) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_inp = gen({256, 1, 64, 64}), CompNode cn = CompNode::load("cpu0");
host_kern0 = gen({32, 1, 1, 1}), //! FIXME currently recursive chooser does not support workspace_limit in
host_kern1 = gen({32, 32, 1, 1}); //! heuristic
auto host_inp = gen({256, 1, 64, 64}, cn),
host_kern0 = gen({32, 1, 1, 1}, cn),
host_kern1 = gen({32, 32, 1, 1}, cn);
auto inp = opr::SharedDeviceTensor::make(*graph, *host_inp, {"inp"}), auto inp = opr::SharedDeviceTensor::make(*graph, *host_inp, {"inp"}),
kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}), kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}),
kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"}); kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"});
...@@ -102,9 +107,12 @@ TEST(TestMemReuse, PureMLP0) { ...@@ -102,9 +107,12 @@ TEST(TestMemReuse, PureMLP0) {
TEST(TestMemReuse, PureMLP1) { TEST(TestMemReuse, PureMLP1) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto host_inp = gen({256, 1, 64, 64}), CompNode cn = CompNode::load("cpu0");
host_kern0 = gen({32, 1, 1, 1}), //! FIXME currently recursive chooser does not support workspace_limit in
host_kern1 = gen({32, 32, 1, 1}); //! heuristic
auto host_inp = gen({256, 1, 64, 64}, cn),
host_kern0 = gen({32, 1, 1, 1}, cn),
host_kern1 = gen({32, 32, 1, 1}, cn);
auto inp = opr::Host2DeviceCopy::make(*graph, host_inp, {"inp"}), auto inp = opr::Host2DeviceCopy::make(*graph, host_inp, {"inp"}),
kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}), kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}),
kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"}), kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"}),
...@@ -338,4 +346,3 @@ TEST(TestMemReuse, FwdNoSysMemAlloc) { ...@@ -338,4 +346,3 @@ TEST(TestMemReuse, FwdNoSysMemAlloc) {
} }
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#include "megbrain/opr/search_policy/algo_chooser.h" #include "megbrain/opr/search_policy/algo_chooser.h"
#include <limits>
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h" #include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h" #include "megbrain/opr/search_policy/profiler.h"
...@@ -473,6 +474,13 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( ...@@ -473,6 +474,13 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const { AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) {
mgb_log_warn(
"workspace_limit should not be setted if choose algo by "
"heuristic");
}
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册