提交 c17d6038 编写于 作者: M Megvii Engine Team

feat(mgb): allow output tensor's ptr change when record

GitOrigin-RevId: c610c8bf9a6d97885d9da3f2e447b57f25747dee
上级 26634db7
...@@ -207,6 +207,9 @@ void NetworkImplDft::use_tensorrt() { ...@@ -207,6 +207,9 @@ void NetworkImplDft::use_tensorrt() {
//! set the callback in async model //! set the callback in async model
void NetworkImplDft::set_async_callback(const AsyncCallback& callback) { void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode"); LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
LITE_ASSERT(
m_user_config->options.comp_node_seq_record_level == 0,
"record mode not support async mode");
LITE_ASSERT( LITE_ASSERT(
m_user_config->device_type == LiteDeviceType::LITE_CPU || m_user_config->device_type == LiteDeviceType::LITE_CPU ||
m_user_config->device_type == LiteDeviceType::LITE_CUDA, m_user_config->device_type == LiteDeviceType::LITE_CUDA,
......
...@@ -659,4 +659,21 @@ void CompNode::ImplBase::add_callback(megdnn::thin_function<void()>&&) { ...@@ -659,4 +659,21 @@ void CompNode::ImplBase::add_callback(megdnn::thin_function<void()>&&) {
locator().to_string().c_str()); locator().to_string().c_str());
} }
void CompNode::ImplBase::enable_dispatch() {
mgb_throw(
MegBrainError,
"Unsupported add callback to "
"comp node %s",
locator().to_string().c_str());
}
void CompNode::ImplBase::disable_dispatch(bool* flag) {
MGB_MARK_USED_VAR(flag);
mgb_throw(
MegBrainError,
"Unsupported add callback to "
"comp node %s",
locator().to_string().c_str());
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -810,6 +810,12 @@ public: ...@@ -810,6 +810,12 @@ public:
task(); task();
} }
} }
void enable_dispatch() override { m_env.cpu_env().enable_dispatch(); }
void disable_dispatch(bool* flag) override {
m_env.cpu_env().disable_dispatch(flag);
}
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompNodeRecorderImpl);
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
......
...@@ -474,4 +474,35 @@ void CompNodeEnv::on_bad_device_type(DeviceType expected) const { ...@@ -474,4 +474,35 @@ void CompNodeEnv::on_bad_device_type(DeviceType expected) const {
MGB_VERSION_SYMBOL3(MEGDNN, MEGDNN_MAJOR, MEGDNN_MINOR, MEGDNN_PATCH); MGB_VERSION_SYMBOL3(MEGDNN, MEGDNN_MAJOR, MEGDNN_MINOR, MEGDNN_PATCH);
void CompNodeEnv::CpuEnv::enable_dispatch() {
do_task_inplace = nullptr;
}
void CompNodeEnv::CpuEnv::disable_dispatch(bool* flag) {
do_task_inplace = flag;
}
void CompNodeEnv::CpuEnv::dispatch(Task&& task) const {
if (do_task_inplace && *do_task_inplace) {
task();
} else {
dispatcher->dispatch(std::move(task));
}
}
void CompNodeEnv::CpuEnv::dispatch(
MultiThreadingTask&& task, size_t parallelism) const {
if (do_task_inplace && *do_task_inplace) {
for (size_t i = 0; i < parallelism; ++i) {
task(i, 0);
}
} else {
dispatcher->dispatch(std::move(task), parallelism);
}
}
#if MGB_HAVE_THREAD
MGB_THREAD_LOCAL_PTR(bool) CompNodeEnv::CpuEnv::do_task_inplace = nullptr;
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -168,15 +168,35 @@ MGB_DEFINE_OPR_CLASS( ...@@ -168,15 +168,35 @@ MGB_DEFINE_OPR_CLASS(
ComputingGraphImpl::CallbackCaller, SingleCNOperatorNodeBase) // { ComputingGraphImpl::CallbackCaller, SingleCNOperatorNodeBase) // {
std::vector<std::vector<ComputingGraph::Callback>> m_cb; std::vector<std::vector<ComputingGraph::Callback>> m_cb;
//! CallbackCaller supports change memory address in output tensor record mode(only
//! on CPU). The whole callback will be dispatched(like dispatching tensor copy
//! instead of dispatching memcpy).
//! Side effect: sync() is not supported in callback anymore. Users should call
//! func->wait() instead out of callback to sync data from Device to Host.
//! Note : only record level 1 supports change memory address in output tensor.
//! HostTensor captured in callback should not on cpu default.
void scn_do_execute() override { void scn_do_execute() override {
for (size_t i = 0; i < input().size(); ++i) { for (size_t i = 0; i < input().size(); ++i) {
auto&& in = input(i)->dev_tensor(); auto&& in = input(i)->dev_tensor();
for (auto&& callback : m_cb[i]) { for (auto&& callback : m_cb[i]) {
if (this->owner_graph()->options().comp_node_seq_record_level == 1 &&
in.comp_node().device_type() == CompNode::DeviceType::CPU &&
in.comp_node() != CompNode::default_cpu()) {
auto record_cb = [&in, &callback]() {
auto comp_node = in.comp_node();
bool do_task_inplace = true;
comp_node.disable_dispatch(&do_task_inplace);
callback(const_cast<DeviceTensorND&>(in));
comp_node.enable_dispatch();
};
in.comp_node().add_callback(record_cb);
} else {
// const cast for backward API compatibility // const cast for backward API compatibility
callback(const_cast<DeviceTensorND&>(in)); callback(const_cast<DeviceTensorND&>(in));
} }
} }
} }
}
void init_output_static_infer_desc() override { void init_output_static_infer_desc() override {
using namespace cg::static_infer; using namespace cg::static_infer;
......
...@@ -412,6 +412,16 @@ public: ...@@ -412,6 +412,16 @@ public:
return m_impl->add_callback(std::move(cb)); return m_impl->add_callback(std::move(cb));
} }
/*!
* enable dispatcher
*/
void enable_dispatch() { m_impl->enable_dispatch(); }
/*!
* disable dispatcher so that task will be done inplace
*/
void disable_dispatch(bool* flag) { m_impl->disable_dispatch(flag); }
enum class Flag : uint32_t { enum class Flag : uint32_t {
//! Whether computing recorder is supported on this comp node (i.e. //! Whether computing recorder is supported on this comp node (i.e.
//! whether non-zero comp_node_seq_record_level is allowed) //! whether non-zero comp_node_seq_record_level is allowed)
...@@ -532,6 +542,10 @@ protected: ...@@ -532,6 +542,10 @@ protected:
virtual void add_callback(megdnn::thin_function<void()>&&); virtual void add_callback(megdnn::thin_function<void()>&&);
virtual void enable_dispatch();
virtual void disable_dispatch(bool* flag);
virtual uint64_t get_uid() { virtual uint64_t get_uid() {
mgb_throw(MegBrainError, "get_uid is not impl yet"); mgb_throw(MegBrainError, "get_uid is not impl yet");
}; };
......
...@@ -503,12 +503,19 @@ public: ...@@ -503,12 +503,19 @@ public:
using AffinityCallBack = thin_function<void(size_t)>; using AffinityCallBack = thin_function<void(size_t)>;
std::shared_ptr<CPUDispatcher> dispatcher; std::shared_ptr<CPUDispatcher> dispatcher;
#if MGB_HAVE_THREAD
static MGB_THREAD_LOCAL_PTR(bool) do_task_inplace;
#else
bool* do_task_inplace = nullptr;
#endif
void dispatch(Task&& task) const { dispatcher->dispatch(std::move(task)); } void enable_dispatch();
void dispatch(MultiThreadingTask&& task, size_t parallelism) const { void disable_dispatch(bool* flag);
dispatcher->dispatch(std::move(task), parallelism);
} void dispatch(Task&& task) const;
void dispatch(MultiThreadingTask&& task, size_t parallelism) const;
void set_affinity(AffinityCallBack&& cb) const { void set_affinity(AffinityCallBack&& cb) const {
dispatcher->set_affinity(std::move(cb)); dispatcher->set_affinity(std::move(cb));
...@@ -521,6 +528,12 @@ public: ...@@ -521,6 +528,12 @@ public:
return m_cpu_env; return m_cpu_env;
} }
CpuEnv& cpu_env() {
if (mgb_unlikely(m_property.type != DeviceType::CPU))
on_bad_device_type(DeviceType::CPU);
return m_cpu_env;
}
//! init this as a cpu env //! init this as a cpu env
void init_cpu(const CpuEnv& env, CompNode comp_node); void init_cpu(const CpuEnv& env, CompNode comp_node);
......
...@@ -44,7 +44,7 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) { ...@@ -44,7 +44,7 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) {
graph->options().fake_next_exec = true; graph->options().fake_next_exec = true;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
} }
auto func = graph->compile({make_callback_copy(z, host_z)}); auto func = graph->compile({make_callback_copy(z, host_z, false)});
if (fake_first) { if (fake_first) {
func->execute(); // first exec func->execute(); // first exec
} }
...@@ -55,6 +55,8 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) { ...@@ -55,6 +55,8 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) {
} }
host_x->copy_from_fixlayout(*gen(host_x->shape(), cn)); host_x->copy_from_fixlayout(*gen(host_x->shape(), cn));
func->execute(); func->execute();
func->wait();
host_z.sync();
auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param); auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param);
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter; MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter;
} }
...@@ -70,6 +72,28 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) { ...@@ -70,6 +72,28 @@ void run_comp_seq_rec_basic(CompNode cn, bool fake_first) {
ASSERT_EQ(executed[2], change); ASSERT_EQ(executed[2], change);
// create new recorder, exec with recorder // create new recorder, exec with recorder
ASSERT_EQ(executed[3], change + 1); ASSERT_EQ(executed[3], change + 1);
//! then we change host_z's ptr each time and check result
HostTensorND host_iter;
host_iter.copy_from(host_z);
std::vector<std::shared_ptr<HostTensorND>> m_hosts(10);
for (size_t i = 0; i < 10; i++) {
m_hosts[i] = gen(host_z.shape(), host_z.comp_node());
}
iter = 0;
for (; iter < 10; ++iter) {
auto host_tmp = m_hosts[iter];
auto host_z_storage = host_z.storage();
auto origin_ptr = host_z_storage.raw_storage();
host_z_storage.reset(
host_z.comp_node(), host_z_storage.size(),
host_tmp->storage().raw_storage());
auto changed_ptr = host_z_storage.raw_storage();
ASSERT_TRUE(origin_ptr != changed_ptr);
func->execute();
func->wait();
MGB_ASSERT_TENSOR_NEAR(host_iter, host_z, 1e-3) << "iter " << iter;
}
} }
void run_comp_seq_rec_basic_level2(CompNode cn) { void run_comp_seq_rec_basic_level2(CompNode cn) {
...@@ -154,7 +178,7 @@ void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) { ...@@ -154,7 +178,7 @@ void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {
w = opr::Elemwise::make({x, y, z}, opr::Elemwise::Mode::FUSE_MUL_ADD3); w = opr::Elemwise::make({x, y, z}, opr::Elemwise::Mode::FUSE_MUL_ADD3);
HostTensorND host_w; HostTensorND host_w;
auto func = graph->compile({make_callback_copy(w, host_w)}); auto func = graph->compile({make_callback_copy(w, host_w, false)});
if (fake_first) { if (fake_first) {
func->execute(); func->execute();
} }
...@@ -166,9 +190,30 @@ void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) { ...@@ -166,9 +190,30 @@ void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {
} }
host_x->copy_from(*gen(host_x->shape(), cn)); host_x->copy_from(*gen(host_x->shape(), cn));
func->execute(); func->execute();
func->wait();
auto expect = check(); auto expect = check();
MGB_ASSERT_TENSOR_EQ(expect, host_w) << "iter " << i; MGB_ASSERT_TENSOR_EQ(expect, host_w) << "iter " << i;
} }
//! then we change host_z's ptr each time and check result
HostTensorND host_iter;
host_iter.copy_from(host_w);
std::vector<std::shared_ptr<HostTensorND>> m_hosts(10);
for (size_t i = 0; i < 10; i++) {
m_hosts[i] = gen(host_w.shape(), host_w.comp_node());
}
for (size_t iter = 0; iter < 10; ++iter) {
auto host_tmp = m_hosts[iter];
auto host_w_storage = host_w.storage();
auto origin_ptr = host_w_storage.raw_storage();
host_w_storage.reset(
host_w.comp_node(), host_w_storage.size(),
host_tmp->storage().raw_storage());
auto changed_ptr = host_w_storage.raw_storage();
ASSERT_TRUE(origin_ptr != changed_ptr);
func->execute();
func->wait();
MGB_ASSERT_TENSOR_EQ(host_iter, host_w) << "iter " << iter;
}
} }
void run_level2(CompNode cn, bool use_multi_holder) { void run_level2(CompNode cn, bool use_multi_holder) {
...@@ -381,6 +426,9 @@ void run<sync_from_func>(CompNode cn) { ...@@ -381,6 +426,9 @@ void run<sync_from_func>(CompNode cn) {
HostTensorND host_y; HostTensorND host_y;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
graph->options().comp_node_seq_record_level = level; graph->options().comp_node_seq_record_level = level;
if (level == 1) {
sync = false;
}
auto cb = [&](const DeviceTensorND& dv) { auto cb = [&](const DeviceTensorND& dv) {
host_y.copy_from(dv); host_y.copy_from(dv);
if (sync) { if (sync) {
...@@ -418,6 +466,9 @@ void run<cb_non_contig>(CompNode cn) { ...@@ -418,6 +466,9 @@ void run<cb_non_contig>(CompNode cn) {
HostTensorND host_y; HostTensorND host_y;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
graph->options().comp_node_seq_record_level = level; graph->options().comp_node_seq_record_level = level;
if (level == 1) {
sync = false;
}
auto cb = [&](const DeviceTensorND& dv) { auto cb = [&](const DeviceTensorND& dv) {
host_y.copy_from(dv); host_y.copy_from(dv);
if (sync) { if (sync) {
...@@ -428,8 +479,8 @@ void run<cb_non_contig>(CompNode cn) { ...@@ -428,8 +479,8 @@ void run<cb_non_contig>(CompNode cn) {
if (level == 2) { if (level == 2) {
ComputingGraph::assert_destroy(graph); ComputingGraph::assert_destroy(graph);
} }
for (int i = 0; i < 3; ++i) { for (int k = 0; k < 3; ++k) {
host_x->copy_from(*gen(host_x->shape())); host_x->copy_from(*gen(host_x->shape(), cn));
HostTensorND expect{host_x->comp_node(), {5, 4}}; HostTensorND expect{host_x->comp_node(), {5, 4}};
auto px = host_x->ptr<float>(), py = expect.ptr<float>(); auto px = host_x->ptr<float>(), py = expect.ptr<float>();
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
...@@ -504,14 +555,16 @@ void run<multi_recorder_run>(CompNode cn) { ...@@ -504,14 +555,16 @@ void run<multi_recorder_run>(CompNode cn) {
y = opr::Host2DeviceCopy::make(*graph, host_y), y = opr::Host2DeviceCopy::make(*graph, host_y),
z = opr::Convolution::make(x, y, param); z = opr::Convolution::make(x, y, param);
graph->options().comp_node_seq_record_level = 1; graph->options().comp_node_seq_record_level = 1;
return graph->compile({make_callback_copy(z, host_z_v[graph_id])}); return graph->compile({make_callback_copy(z, host_z_v[graph_id], false)});
}; };
funcs.push_back(gen_graph(0)); funcs.push_back(gen_graph(0));
funcs.push_back(gen_graph(1)); funcs.push_back(gen_graph(1));
for (int iter = 0; iter < 10; ++iter) { for (int iter = 0; iter < 10; ++iter) {
host_x->copy_from_fixlayout(*gen(host_x->shape(), cn)); host_x->copy_from_fixlayout(*gen(host_x->shape(), cn));
funcs[0]->execute(); funcs[0]->execute();
funcs[0]->wait();
funcs[1]->execute(); funcs[1]->execute();
funcs[1]->wait();
auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param); auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param);
MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[0], 1e-3) << "iter " << iter; MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[0], 1e-3) << "iter " << iter;
MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[1], 1e-3) << "iter " << iter; MGB_ASSERT_TENSOR_NEAR(expect, host_z_v[1], 1e-3) << "iter " << iter;
......
...@@ -1375,13 +1375,17 @@ TEST(TestGraph, CompNodeFinalize) { ...@@ -1375,13 +1375,17 @@ TEST(TestGraph, CompNodeFinalize) {
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
graph->options().comp_node_seq_record_level = rec; graph->options().comp_node_seq_record_level = rec;
} }
auto func = graph->compile({make_callback_copy(z, host_z)}); auto sync = (rec != 1);
auto func = graph->compile({make_callback_copy(z, host_z, sync)});
if (rec == 2) { if (rec == 2) {
ComputingGraph::assert_destroy(graph); ComputingGraph::assert_destroy(graph);
} }
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
host_x->copy_from(*gen({1}, cn)); host_x->copy_from(*gen({1}, cn));
func->execute(); func->execute();
if (!sync) {
func->wait();
}
MGB_ASSERT_FLOAT_EQ( MGB_ASSERT_FLOAT_EQ(
host_x->ptr<float>()[0] + host_y->ptr<float>()[0], host_x->ptr<float>()[0] + host_y->ptr<float>()[0],
host_z.ptr<float>()[0]); host_z.ptr<float>()[0]);
...@@ -1933,6 +1937,7 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { ...@@ -1933,6 +1937,7 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
#endif #endif
graph->options().graph_opt.weight_preprocess = true; graph->options().graph_opt.weight_preprocess = true;
graph->options().comp_node_seq_record_level = record_level; graph->options().comp_node_seq_record_level = record_level;
auto sync = (record_level != 1);
auto mkvar = [&](const char* name, const TensorShape& shp) { auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
}; };
...@@ -1970,11 +1975,17 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) { ...@@ -1970,11 +1975,17 @@ void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
}); });
HostTensorND host_y; HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)}); auto func = graph->compile({make_callback_copy(y, host_y, sync)});
//! flag the no need memory of var //! flag the no need memory of var
func->execute(); func->execute();
if (!sync) {
func->wait();
}
//! free the no need memory of var //! free the no need memory of var
func->execute(); func->execute();
if (!sync) {
func->wait();
}
auto check = [&](SymbolVar v) { auto check = [&](SymbolVar v) {
ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED)); ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED));
ASSERT_TRUE(v.node()->dev_tensor().empty()); ASSERT_TRUE(v.node()->dev_tensor().empty());
......
...@@ -213,9 +213,13 @@ TEST(TestGraph, MultiThreadRecorder) { ...@@ -213,9 +213,13 @@ TEST(TestGraph, MultiThreadRecorder) {
z = opr::Convolution::make(x, y, param); z = opr::Convolution::make(x, y, param);
graph->options().comp_node_seq_record_level = record_level; graph->options().comp_node_seq_record_level = record_level;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, host_z)}); auto sync = (record_level != 1);
auto func = graph->compile({make_callback_copy(z, host_z, sync)});
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
func->execute(); func->execute();
if (!sync) {
func->wait();
}
} }
auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param); auto expect = eval_conv_cpu<opr::Convolution>(*host_x, *host_y, param);
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3); MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3);
......
...@@ -50,6 +50,7 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) { ...@@ -50,6 +50,7 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) {
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
graph->options().comp_node_seq_record_level = record; graph->options().comp_node_seq_record_level = record;
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
auto sync = (record != 1);
auto plug = plugin_maker(graph.get(), record); auto plug = plugin_maker(graph.get(), record);
// make a non-contiguous value, also introduce some shape dependencies // make a non-contiguous value, also introduce some shape dependencies
...@@ -76,11 +77,14 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) { ...@@ -76,11 +77,14 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) {
cg::DepOprIter{cb_rename}.add(y); cg::DepOprIter{cb_rename}.add(y);
HostTensorND host_y; HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)}); auto func = graph->compile({make_callback_copy(y, host_y, sync)});
if (record == 2) { if (record == 2) {
ComputingGraph::assert_destroy(graph); ComputingGraph::assert_destroy(graph);
} }
func->execute(); func->execute();
if (!sync) {
func->wait();
}
plug->flush_lazy(); plug->flush_lazy();
MGB_ASSERT_TENSOR_EQ(make_expect(), host_y); MGB_ASSERT_TENSOR_EQ(make_expect(), host_y);
...@@ -91,10 +95,16 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) { ...@@ -91,10 +95,16 @@ void run_test(CompNode cn, const PluginMaker& plugin_maker) {
*host_x = *gen(host_x->shape(), cn); *host_x = *gen(host_x->shape(), cn);
} }
func->execute(); func->execute();
if (!sync) {
func->wait();
}
MGB_ASSERT_TENSOR_EQ(make_expect(), host_y); MGB_ASSERT_TENSOR_EQ(make_expect(), host_y);
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
host_x->copy_from(*gen(host_x->shape(), cn)); host_x->copy_from(*gen(host_x->shape(), cn));
func->execute(); func->execute();
if (!sync) {
func->wait();
}
MGB_ASSERT_TENSOR_EQ(make_expect(), host_y); MGB_ASSERT_TENSOR_EQ(make_expect(), host_y);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册