提交 3620a940 编写于 作者: M Megvii Engine Team

fix(megdnn): fix test with workspace_limits to zero

GitOrigin-RevId: c4ec323361594550e9ed46fff2f7bf42b8301ed8
上级 0d720653
......@@ -17,6 +17,7 @@
#include "megbrain/opr/blas.h"
#include "megbrain/test/helper.h"
#include "megdnn/oprs/base.h"
using namespace mgb;
......@@ -26,7 +27,8 @@ SymbolVar make_conv(SymbolVar inp, SymbolVar kern) {
using Conv = opr::Convolution;
Conv::ExecutionPolicy poly;
poly.workspace_limit = 0;
return Conv::make(inp, kern, {}, poly);
SymbolVar conv = Conv::make(inp, kern, {}, poly);
return conv;
}
// used for test NO_SYS_MEM_ALLOC
......@@ -74,9 +76,12 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensorDirect);
TEST(TestMemReuse, PureMLP0) {
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto host_inp = gen({256, 1, 64, 64}),
host_kern0 = gen({32, 1, 1, 1}),
host_kern1 = gen({32, 32, 1, 1});
CompNode cn = CompNode::load("cpu0");
//! FIXME currently recursive chooser does not support workspace_limit in
//! heuristic
auto host_inp = gen({256, 1, 64, 64}, cn),
host_kern0 = gen({32, 1, 1, 1}, cn),
host_kern1 = gen({32, 32, 1, 1}, cn);
auto inp = opr::SharedDeviceTensor::make(*graph, *host_inp, {"inp"}),
kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}),
kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"});
......@@ -102,9 +107,12 @@ TEST(TestMemReuse, PureMLP0) {
TEST(TestMemReuse, PureMLP1) {
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto host_inp = gen({256, 1, 64, 64}),
host_kern0 = gen({32, 1, 1, 1}),
host_kern1 = gen({32, 32, 1, 1});
CompNode cn = CompNode::load("cpu0");
//! FIXME currently recursive chooser does not support workspace_limit in
//! heuristic
auto host_inp = gen({256, 1, 64, 64}, cn),
host_kern0 = gen({32, 1, 1, 1}, cn),
host_kern1 = gen({32, 32, 1, 1}, cn);
auto inp = opr::Host2DeviceCopy::make(*graph, host_inp, {"inp"}),
kern0 = opr::SharedDeviceTensor::make(*graph, *host_kern0, {"kern0"}),
kern1 = opr::SharedDeviceTensor::make(*graph, *host_kern1, {"kern1"}),
......@@ -338,4 +346,3 @@ TEST(TestMemReuse, FwdNoSysMemAlloc) {
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -11,6 +11,7 @@
*/
#include "megbrain/opr/search_policy/algo_chooser.h"
#include <limits>
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"
......@@ -473,6 +474,13 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) {
mgb_log_warn(
"workspace_limit should not be setted if choose algo by "
"heuristic");
}
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
ImplExecutionPolicy policy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册