提交 8fe02dcc 编写于 作者: M Megvii Engine Team

fix(opr): fix no profile on shape change

GitOrigin-RevId: 560ab738badaa1ea33192b55f739fbd15e487f7d
上级 c6844dd9
......@@ -54,6 +54,13 @@ size_t AlgoChooser<Opr>::setup_algo(
layouts, megdnn_opr, param_str, mgb_opr->comp_node(),
mgb_opr->execution_policy(), allow_weight_preprocess, desc);
bool no_profiling_on_shape_change = cg->options().no_profiling_on_shape_change;
//! if no profile on shape change is set and the algo policy is valid,
//! get the workspace directly
if (no_profiling_on_shape_change && megdnn_opr->execution_policy().algo.valid()) {
return helper.get_workspace_size_bytes(megdnn_opr->execution_policy(), layouts);
}
ImplExecutionPolicy policy;
if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
policy = algo_choose_hook(mgb_opr);
......
......@@ -309,6 +309,56 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) {
{TensorShape{4, 6, 8}, TensorShape{4, 8, 4}});
}
TEST(TestOprDNN, NoProfileWhenShapeChange) {
using CacheMem = std::pair<const void*, size_t>;
using Policy = opr::ConvBias::ExecutionPolicy;
using S = Policy::Strategy;
auto on_get = [](const std::string&, const void*, size_t, const void*, size_t) {};
std::vector<std::pair<CacheMem, CacheMem>> cache_set_history;
auto on_set = [&cache_set_history](
const std::string&, const void* key, size_t key_size,
const void* val, size_t val_size) {
cache_set_history.emplace_back(
std::make_pair(key, key_size), std::make_pair(val, val_size));
};
PersistentCacheHook cache_hook{on_get, on_set};
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
auto graph = ComputingGraph::make();
graph->options().no_profiling_on_shape_change = true;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({1, 4, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::ConvBias::Param param_conv;
Policy policy;
policy.strategy = S::PROFILE;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 4, 3, 3}), b1 = mkcvar("w1", {1, 8, 1, 1}),
conv1 = opr::ConvBias::make(
x, w1, b1, param_conv, policy, OperatorNodeConfig("conv1"));
auto w2 = mkcvar("w2", {8, 8, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}),
out = opr::ConvBias::make(
conv1, w2, b2, param_conv, policy, OperatorNodeConfig("conv2"));
std::unique_ptr<cg::AsyncExecutable> func = graph->compile({{out, {}}});
func->execute().wait();
//! there are two convbias, so there should have two algo cache.
ASSERT_EQ(cache_set_history.size(), 2);
host_x->resize({5, 4, 32, 32});
//! no profile when input shape changed
ASSERT_EQ(cache_set_history.size(), 2);
}
#endif // MGB_ENABLE_FASTRUN
#endif // MGB_CUDA
......
......@@ -508,15 +508,6 @@ AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper(
m_fastrun_layouts, m_dnn_opr->param(), fastrun_batch_size);
}
if (m_desc.no_profiling_on_shape_change) {
for (size_t i = 0; i < m_incache_layouts.size(); i++) {
for (size_t j = 0; j < m_incache_layouts.at(i).ndim; j++) {
m_incache_layouts.at(i)[j] = 0;
}
m_incache_layouts.at(i).init_contiguous_stride();
}
}
mgb_assert(m_fastrun_layouts.size() == layouts.size());
static_assert(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册