提交 8b7d8d29 编写于 作者: M Megvii Engine Team

fix(core): fix json dump when weight preprocess

GitOrigin-RevId: 6cd882b10ddc0405f4475d2c118e030715153e37
上级 ec65e1f9
......@@ -528,6 +528,7 @@ std::shared_ptr<json::Value> VarNode::to_json() const {
CHK(PERSISTENT_DEVICE_VALUE);
CHK(DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
CHK(DISALLOW_VAR_SANITY_CHECK);
CHK(MEMORY_NO_NEED);
#undef CHK
mgb_assert(flag_checked == static_cast<size_t>(m_flag));
......
......@@ -25,6 +25,7 @@
#include "megbrain/utils/timer.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/test/helper.h"
#include "megdnn/oprs/base.h"
......@@ -1993,6 +1994,12 @@ typename megdnn::ExecutionPolicy try_find_any_bias_preprocess_algo(
void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
#if MGB_ENABLE_JSON
std::unique_ptr<GraphProfiler> profiler;
if(!record_level){
profiler = std::make_unique<GraphProfiler>(graph.get());
}
#endif
graph->options().graph_opt.weight_preprocess = true;
graph->options().comp_node_seq_record_level = record_level;
auto mkvar = [&](const char* name, const TensorShape& shp) {
......@@ -2055,6 +2062,13 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
if (wp2.val()) {
check(w2);
}
#if MGB_ENABLE_JSON
if (profiler) {
func->wait();
profiler->to_json_full(func.get())
->writeto_fpath(output_file("weight_preprocess.json"));
}
#endif
}
} // anonymous namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册