提交 0be6ca88 编写于 作者: M Megvii Engine Team

fix(src/core): fix record change ptr bug on comp node copy

GitOrigin-RevId: 0f689662113123e00862698269a0ea7aa42af825
上级 84baf3df
...@@ -306,11 +306,37 @@ public: ...@@ -306,11 +306,37 @@ public:
m_env.cpu_env().dispatch(do_copy); m_env.cpu_env().dispatch(do_copy);
} }
void copy_to_host_ref(
megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
size_t size) override {
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [host_ref_ptr, device_ref_ptr, size]() {
std::memcpy(host_ref_ptr.get_ptr(), device_ref_ptr.get_ptr(), size);
};
m_env.cpu_env().dispatch(do_copy);
}
void copy_to_device_ref(
megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
size_t size) override {
// use lambda capture to avoid memory allocation in std::bind
auto do_copy = [device_ref_ptr, host_ref_ptr, size]() {
std::memcpy(device_ref_ptr.get_ptr(), host_ref_ptr.get_ptr(), size);
};
m_env.cpu_env().dispatch(do_copy);
}
void peer_copy_to( void peer_copy_to(
Impl* dest_impl, void* dest, const void* src, size_t size) override { Impl* dest_impl, void* dest, const void* src, size_t size) override {
dest_impl->copy_to_device(dest, src, size); dest_impl->copy_to_device(dest, src, size);
} }
void peer_copy_to_ref(
Impl* dest_impl, megdnn::RefPtr& dest, megdnn::RefPtr& src,
size_t size) override {
dest_impl->copy_to_device_ref(dest, src, size);
}
size_t get_mem_addr_alignment() override { return m_env.property().mem_alignment; } size_t get_mem_addr_alignment() override { return m_env.property().mem_alignment; }
void dispatch(Task&& task) override { m_env.cpu_env().dispatch(std::move(task)); } void dispatch(Task&& task) override { m_env.cpu_env().dispatch(std::move(task)); }
...@@ -733,6 +759,24 @@ public: ...@@ -733,6 +759,24 @@ public:
CompNodeBaseImpl::copy_to_device(device_ptr, host_ptr, size); CompNodeBaseImpl::copy_to_device(device_ptr, host_ptr, size);
} }
void copy_to_host_ref(
megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_host_ref(host_ref_ptr, device_ref_ptr, size);
}
void copy_to_device_ref(
megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
size_t size) override {
if (m_worker_queue) {
m_worker_queue->check_exception();
}
CompNodeBaseImpl::copy_to_device_ref(device_ref_ptr, host_ref_ptr, size);
}
void peer_copy_to( void peer_copy_to(
Impl* dest_impl, void* dest, const void* src, size_t size) override { Impl* dest_impl, void* dest, const void* src, size_t size) override {
//! copy to default_cpu //! copy to default_cpu
...@@ -774,6 +818,48 @@ public: ...@@ -774,6 +818,48 @@ public:
dest_impl->copy_to_device(dest, src, size); dest_impl->copy_to_device(dest, src, size);
} }
void peer_copy_to_ref(
Impl* dest_impl, megdnn::RefPtr& dest, megdnn::RefPtr& src,
size_t size) override {
//! copy to default_cpu
if (dest_impl->same_type<CpuCompNode::CompNodeDefaultImpl>()) {
CompNodeBaseImpl::peer_copy_to_ref(dest_impl, dest, src, size);
return;
}
if (!dest_impl->same_type<CpuCompNode::CompNodeRecorderImpl>()) {
if (dest_impl->env().property().type == DeviceType::ATLAS) {
#if MGB_ATLAS
dest_impl->copy_to_device(dest.get_ptr(), src.get_ptr(), size);
return;
#else
mgb_throw(
MegBrainError,
"Atlas comp_node used but "
"ATLAS BUILD not enabled");
#endif
} else if (dest_impl->env().property().type == DeviceType::CAMBRICON) {
#if MGB_CAMBRICON
dest_impl->copy_to_device(dest.get_ptr(), src.get_ptr(), size);
return;
#else
mgb_throw(
MegBrainError,
"Cambricon comp_node used but "
"CAMBRICON BUILD not enabled");
#endif
}
else {
mgb_assert(
locator().device == Locator::DEVICE_CPU_DEFAULT,
"currently only peer copy from default cpu comp "
"nodes "
"is implemented");
}
}
dest_impl->copy_to_device_ref(dest, src, size);
}
std::unique_ptr<Event> create_event(size_t flags) override { std::unique_ptr<Event> create_event(size_t flags) override {
if (m_worker_queue) { if (m_worker_queue) {
m_worker_queue->check_exception(); m_worker_queue->check_exception();
......
...@@ -81,9 +81,8 @@ const DeviceTensorStorage& StaticDeviceMemoryManager::alloc( ...@@ -81,9 +81,8 @@ const DeviceTensorStorage& StaticDeviceMemoryManager::alloc(
void StaticDeviceMemoryManager::prefault() { void StaticDeviceMemoryManager::prefault() {
for (auto&& i : m_storage) { for (auto&& i : m_storage) {
if (i.first.device_type() == CompNode::DeviceType::CPU) { if (i.first.device_type() == CompNode::DeviceType::CPU) {
auto set = [ptr = i.second.ptr(), size = i.second.size()]() { auto storage = i.second;
memset(ptr, 0, size); auto set = [storage]() { memset(storage.ptr(), 0, storage.size()); };
};
CompNodeEnv::from_comp_node(i.first).cpu_env().dispatch(set); CompNodeEnv::from_comp_node(i.first).cpu_env().dispatch(set);
i.first.sync(); i.first.sync();
} }
......
...@@ -379,7 +379,9 @@ MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from( ...@@ -379,7 +379,9 @@ MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
need_sync = true; need_sync = true;
} }
} }
src.comp_node().copy_to_host(ptr(), src.ptr(), size); megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
src.comp_node().copy_to_host_ref(dst_ptr, src_ptr, size);
if (need_sync) if (need_sync)
src.comp_node().sync(); src.comp_node().sync();
} }
...@@ -390,7 +392,9 @@ template <> ...@@ -390,7 +392,9 @@ template <>
MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from( MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
const TensorStorage<HostTensorStorageTrait>& src, size_t size) const { const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
mgb_assert(size <= this->size() && size <= src.size()); mgb_assert(size <= this->size() && size <= src.size());
m_comp_node.copy_to_device(ptr(), src.ptr(), size); megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
m_comp_node.copy_to_device_ref(dst_ptr, src_ptr, size);
} }
// device to device // device to device
...@@ -417,9 +421,13 @@ MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from( ...@@ -417,9 +421,13 @@ MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
// to pin the memory of src tensor, so it does not require synchronization // to pin the memory of src tensor, so it does not require synchronization
// and is more efficient // and is more efficient
src.comp_node().sync(); src.comp_node().sync();
comp_node().copy_to_device(ptr(), src.ptr(), size); megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
comp_node().copy_to_device_ref(dst_ptr, src_ptr, size);
} else { } else {
src.comp_node().peer_copy_to(m_comp_node, ptr(), src.ptr(), size); megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
src.comp_node().peer_copy_to_ref(m_comp_node, dst_ptr, src_ptr, size);
} }
} }
...@@ -712,32 +720,34 @@ const typename TensorND<TensorStorage>::ChainReturnType& TensorND< ...@@ -712,32 +720,34 @@ const typename TensorND<TensorStorage>::ChainReturnType& TensorND<
void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) { void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
auto&& env = CompNodeEnv::from_comp_node(tensor.comp_node()); auto&& env = CompNodeEnv::from_comp_node(tensor.comp_node());
env.activate(); env.activate();
void* ptr = tensor.raw_ptr();
size_t size = tensor.layout().span().dist_byte(); size_t size = tensor.layout().span().dist_byte();
switch (env.property().type) { switch (env.property().type) {
#if MGB_CUDA #if MGB_CUDA
case CompNode::DeviceType::CUDA: case CompNode::DeviceType::CUDA:
MGB_CUDA_CHECK(cudaMemsetAsync(ptr, val, size, env.cuda_env().stream)); MGB_CUDA_CHECK(cudaMemsetAsync(
tensor.raw_ptr(), val, size, env.cuda_env().stream));
break; break;
#endif #endif
#if MGB_ATLAS #if MGB_ATLAS
case CompNode::DeviceType::ATLAS: case CompNode::DeviceType::ATLAS:
#if MGB_USE_ATLAS_ASYNC_API #if MGB_USE_ATLAS_ASYNC_API
MGB_ATLAS_CHECK( MGB_ATLAS_CHECK(aclrtMemsetAsync(
aclrtMemsetAsync(ptr, -1, val, size, env.atlas_env().stream)); tensor.raw_ptr(), -1, val, size, env.atlas_env().stream));
#else #else
MGB_ATLAS_CHECK(aclrtMemset(ptr, -1, val, size)); MGB_ATLAS_CHECK(aclrtMemset(tensor.raw_ptr(), -1, val, size));
#endif #endif
break; break;
#endif #endif
#if MGB_CAMBRICON #if MGB_CAMBRICON
case CompNode::DeviceType::CAMBRICON: case CompNode::DeviceType::CAMBRICON:
MGB_CNRT_CHECK(cnrtSyncQueue(env.cnrt_env().queue)); MGB_CNRT_CHECK(cnrtSyncQueue(env.cnrt_env().queue));
MGB_CNRT_CHECK(cnrtMemset(ptr, val, size)); MGB_CNRT_CHECK(cnrtMemset(tensor.raw_ptr(), val, size));
break; break;
#endif #endif
case CompNode::DeviceType::CPU: { case CompNode::DeviceType::CPU: {
auto fill = [ptr, size, val]() { std::memset(ptr, val, size); }; auto fill = [tensor, size, val]() {
std::memset(tensor.as_megdnn().raw_ptr(), val, size);
};
env.cpu_env().dispatch(fill); env.cpu_env().dispatch(fill);
} break; } break;
default: default:
......
...@@ -242,6 +242,20 @@ public: ...@@ -242,6 +242,20 @@ public:
return m_impl->copy_to_device(device_ptr, host_ptr, size); return m_impl->copy_to_device(device_ptr, host_ptr, size);
} }
//! copy from underlying device to host
void copy_to_host_ref(
megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
size_t size) const {
return m_impl->copy_to_host_ref(host_ref_ptr, device_ref_ptr, size);
}
//! copy from host to underlying device
void copy_to_device_ref(
megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
size_t size) const {
return m_impl->copy_to_device_ref(device_ref_ptr, host_ref_ptr, size);
}
/*! /*!
* \brief copy from this device to another device; would use the * \brief copy from this device to another device; would use the
* computing resource on dest_node * computing resource on dest_node
...@@ -253,6 +267,14 @@ public: ...@@ -253,6 +267,14 @@ public:
reinterpret_cast<Impl*>(dest_node.m_impl), dest, src, size); reinterpret_cast<Impl*>(dest_node.m_impl), dest, src, size);
} }
void peer_copy_to_ref(
CompNode dest_node, megdnn::RefPtr& dst_ref_ptr,
megdnn::RefPtr& src_ref_ptr, size_t size) const {
return m_impl->peer_copy_to_ref(
reinterpret_cast<Impl*>(dest_node.m_impl), dst_ref_ptr, src_ref_ptr,
size);
}
//! get alignment requiement in bytes; guaranteed to be power of 2 //! get alignment requiement in bytes; guaranteed to be power of 2
size_t get_mem_addr_alignment() const { return m_impl->get_mem_addr_alignment(); } size_t get_mem_addr_alignment() const { return m_impl->get_mem_addr_alignment(); }
...@@ -517,9 +539,25 @@ protected: ...@@ -517,9 +539,25 @@ protected:
void* host_ptr, const void* device_ptr, size_t size) = 0; void* host_ptr, const void* device_ptr, size_t size) = 0;
virtual void copy_to_device( virtual void copy_to_device(
void* device_ptr, const void* host_ptr, size_t size) = 0; void* device_ptr, const void* host_ptr, size_t size) = 0;
virtual void copy_to_host_ref(
megdnn::RefPtr& host_ref_ptr, megdnn::RefPtr& device_ref_ptr,
size_t size) {
copy_to_host(host_ref_ptr.get_ptr(), device_ref_ptr.get_ptr(), size);
}
virtual void copy_to_device_ref(
megdnn::RefPtr& device_ref_ptr, megdnn::RefPtr& host_ref_ptr,
size_t size) {
copy_to_device(device_ref_ptr.get_ptr(), host_ref_ptr.get_ptr(), size);
}
virtual void peer_copy_to( virtual void peer_copy_to(
Impl* dest_impl, void* dest, const void* src, size_t size) = 0; Impl* dest_impl, void* dest, const void* src, size_t size) = 0;
virtual void peer_copy_to_ref(
Impl* dest_impl, megdnn::RefPtr& dest, megdnn::RefPtr& src,
size_t size) {
peer_copy_to(dest_impl, dest.get_ptr(), src.get_ptr(), size);
}
virtual size_t get_mem_addr_alignment() = 0; virtual size_t get_mem_addr_alignment() = 0;
virtual size_t get_mem_padding(); virtual size_t get_mem_padding();
......
...@@ -100,6 +100,10 @@ SymbolVar Network::add_type_cvt(SymbolVar f, DType out_dtype) { ...@@ -100,6 +100,10 @@ SymbolVar Network::add_type_cvt(SymbolVar f, DType out_dtype) {
return opr::TypeCvt::make(f, out_dtype); return opr::TypeCvt::make(f, out_dtype);
} }
SymbolVar Network::add_concat(SymbolVar f, SymbolVar g, int axis) {
return opr::Concat::make({f, g}, axis);
}
SymbolVar mgb::create_block( SymbolVar mgb::create_block(
Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1, Network& network, SymbolVar f_in, size_t stride, size_t num_outputs1,
bool has_proj, DType out_dtype) { bool has_proj, DType out_dtype) {
......
...@@ -60,6 +60,7 @@ public: ...@@ -60,6 +60,7 @@ public:
Padding padding = {0, 0}, Padding padding = {0, 0},
opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX); opr::Pooling::Param::Mode mode = opr::Pooling::Param::Mode::MAX);
SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32()); SymbolVar add_type_cvt(SymbolVar f, DType out_dtype = dtype::Float32());
SymbolVar add_concat(SymbolVar f, SymbolVar g, int axis = 0);
}; };
SymbolVar create_block( SymbolVar create_block(
......
...@@ -41,7 +41,8 @@ struct TestGraph { ...@@ -41,7 +41,8 @@ struct TestGraph {
f = m_network->add_elemwise( f = m_network->add_elemwise(
{f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP); {f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP);
f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1}); f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1});
m_out_var = m_network->add_pooling(f, {2, 2}, {2, 2}); f = m_network->add_pooling(f, {2, 2}, {2, 2});
m_out_var = m_network->add_concat(f, -f);
} }
void create_graph_with_subtensor_forward() { void create_graph_with_subtensor_forward() {
...@@ -63,7 +64,8 @@ struct TestGraph { ...@@ -63,7 +64,8 @@ struct TestGraph {
f = m_network->add_elemwise( f = m_network->add_elemwise(
{f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP); {f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP);
f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1}); f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1});
m_out_var = m_network->add_pooling(f, {2, 2}, {2, 2}); f = m_network->add_pooling(f, {2, 2}, {2, 2});
m_out_var = m_network->add_concat(f, -f);
} }
void create_graph_with_subtensor_relayout() { void create_graph_with_subtensor_relayout() {
...@@ -86,7 +88,8 @@ struct TestGraph { ...@@ -86,7 +88,8 @@ struct TestGraph {
f = m_network->add_elemwise( f = m_network->add_elemwise(
{f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP); {f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP);
f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1}); f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1});
m_out_var = m_network->add_pooling(f, {2, 2}, {2, 2}); f = m_network->add_pooling(f, {2, 2}, {2, 2});
m_out_var = m_network->add_concat(f, -f);
} }
void create_graph_with_setsubtensor() { void create_graph_with_setsubtensor() {
...@@ -113,7 +116,8 @@ struct TestGraph { ...@@ -113,7 +116,8 @@ struct TestGraph {
f = m_network->add_elemwise( f = m_network->add_elemwise(
{f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP); {f}, dtype::Float32(), opr::Elemwise::Param::Mode::EXP);
f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1}); f = m_network->add_conv(f, 8, {3, 3}, dtype::Float32(), true, {1, 1}, {1, 1});
m_out_var = m_network->add_pooling(f, {2, 2}, {2, 2}); f = m_network->add_pooling(f, {2, 2}, {2, 2});
m_out_var = m_network->add_concat(f, -f);
} }
std::unique_ptr<cg::AsyncExecutable> compile_without_copy() { std::unique_ptr<cg::AsyncExecutable> compile_without_copy() {
...@@ -173,8 +177,8 @@ TEST(TestNoCopy, IONoCopyPtrEQ) { ...@@ -173,8 +177,8 @@ TEST(TestNoCopy, IONoCopyPtrEQ) {
test_graph.create_graph(); test_graph.create_graph();
auto func = test_graph.compile_without_copy(); auto func = test_graph.compile_without_copy();
auto&& outvar = func->get_output_vars()[0]; auto&& outvar = func->get_output_vars()[0];
DeviceTensorND dv0(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND dv0(test_graph.m_cn, {2, 8, 7, 7});
DeviceTensorND dv1(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND dv1(test_graph.m_cn, {2, 8, 7, 7});
size_t times = 10; size_t times = 10;
for (size_t i = 0; i < times; i++) { for (size_t i = 0; i < times; i++) {
auto input_tensor = test_graph.input_tensor; auto input_tensor = test_graph.input_tensor;
...@@ -229,7 +233,7 @@ TEST(TestNoCopy, IONoCopyCorrect) { ...@@ -229,7 +233,7 @@ TEST(TestNoCopy, IONoCopyCorrect) {
ptr[d] = i / 5 + 3; ptr[d] = i / 5 + 3;
} }
input_tensor->reset(storage, layout); input_tensor->reset(storage, layout);
DeviceTensorND dv(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND dv(test_graph.m_cn, {2, 8, 7, 7});
outvar->init_mem_plan(&dv); outvar->init_mem_plan(&dv);
outvar->reset_dev_tensor_from_tensor(dv); outvar->reset_dev_tensor_from_tensor(dv);
...@@ -258,7 +262,7 @@ TEST(TestNoCopy, IONoCopyRecord) { ...@@ -258,7 +262,7 @@ TEST(TestNoCopy, IONoCopyRecord) {
HostTensorND truth; HostTensorND truth;
auto func = test_graph.compile_without_copy(); auto func = test_graph.compile_without_copy();
auto&& outvar = func->get_output_vars()[0]; auto&& outvar = func->get_output_vars()[0];
DeviceTensorND tmp(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND tmp(test_graph.m_cn, {2, 8, 7, 7});
outvar->init_mem_plan(&tmp); outvar->init_mem_plan(&tmp);
size_t times = 10; size_t times = 10;
for (size_t i = 0; i < times; i++) { for (size_t i = 0; i < times; i++) {
...@@ -272,7 +276,7 @@ TEST(TestNoCopy, IONoCopyRecord) { ...@@ -272,7 +276,7 @@ TEST(TestNoCopy, IONoCopyRecord) {
ptr[d] = i / 5 + 3; ptr[d] = i / 5 + 3;
} }
input_tensor->only_reset_raw_storage(storage); input_tensor->only_reset_raw_storage(storage);
DeviceTensorND dv(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND dv(test_graph.m_cn, {2, 8, 7, 7});
dv.raw_ptr(); dv.raw_ptr();
auto& dev_tensor = outvar->mutable_dev_tensor(); auto& dev_tensor = outvar->mutable_dev_tensor();
...@@ -306,7 +310,7 @@ void test_subtensor_record(int level) { ...@@ -306,7 +310,7 @@ void test_subtensor_record(int level) {
HostTensorND truth; HostTensorND truth;
auto func = test_graph.compile_without_copy(); auto func = test_graph.compile_without_copy();
auto&& outvar = func->get_output_vars()[0]; auto&& outvar = func->get_output_vars()[0];
DeviceTensorND tmp(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND tmp(test_graph.m_cn, {2, 8, 7, 7});
outvar->init_mem_plan(&tmp); outvar->init_mem_plan(&tmp);
size_t times = 10; size_t times = 10;
for (size_t i = 0; i < times; i++) { for (size_t i = 0; i < times; i++) {
...@@ -320,7 +324,7 @@ void test_subtensor_record(int level) { ...@@ -320,7 +324,7 @@ void test_subtensor_record(int level) {
ptr[d] = i / 5 + 3; ptr[d] = i / 5 + 3;
} }
input_tensor->only_reset_raw_storage(storage); input_tensor->only_reset_raw_storage(storage);
DeviceTensorND dv(test_graph.m_cn, {1, 8, 7, 7}); DeviceTensorND dv(test_graph.m_cn, {2, 8, 7, 7});
dv.raw_ptr(); dv.raw_ptr();
auto& dev_tensor = outvar->mutable_dev_tensor(); auto& dev_tensor = outvar->mutable_dev_tensor();
......
...@@ -139,11 +139,9 @@ void NMSKeep::CPUKern::exec( ...@@ -139,11 +139,9 @@ void NMSKeep::CPUKern::exec(
// See CUDAKern::exec for more explanation on output comp nodes. // See CUDAKern::exec for more explanation on output comp nodes.
CompNode comp_node = out_idx.comp_node(); CompNode comp_node = out_idx.comp_node();
auto inp_ptr = inp.ptr<float>();
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()),
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
size_t batch = inp.shape(0), nr_boxes = inp.shape(1); size_t batch = inp.shape(0), nr_boxes = inp.shape(1);
if (nr_boxes == 0) { if (nr_boxes == 0) {
auto out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
for (size_t i = 0; i < batch; ++i) { for (size_t i = 0; i < batch; ++i) {
*(out_size_ptr + i) = 0; *(out_size_ptr + i) = 0;
} }
...@@ -157,6 +155,11 @@ void NMSKeep::CPUKern::exec( ...@@ -157,6 +155,11 @@ void NMSKeep::CPUKern::exec(
// be dispatched on a different thread // be dispatched on a different thread
auto kern = [=]() { auto kern = [=]() {
for (size_t i = 0; i < batch; ++i) { for (size_t i = 0; i < batch; ++i) {
auto inp_ptr = inp.as_megdnn().ptr<float>();
auto out_idx_ptr =
reinterpret_cast<uint32_t*>(out_idx.as_megdnn().ptr<int32_t>());
auto out_size_ptr =
reinterpret_cast<uint32_t*>(out_size.as_megdnn().ptr<int32_t>());
nms::cpu_kern( nms::cpu_kern(
nr_boxes, param.max_output, param.iou_thresh, nr_boxes, param.max_output, param.iou_thresh,
inp_ptr + i * nr_boxes * 4, out_idx_ptr + i * param.max_output, inp_ptr + i * nr_boxes * 4, out_idx_ptr + i * param.max_output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册