提交 ec234135 编写于 作者: M Megvii Engine Team

feat(lite): support discrete inputs

GitOrigin-RevId: 25ce8da275d17d50d771986d36744e7866a0094f
上级 dc0ab9b6
...@@ -554,9 +554,7 @@ void WarpPerspectiveForwardImpl::exec( ...@@ -554,9 +554,7 @@ void WarpPerspectiveForwardImpl::exec(
cuda_check(cudaMemcpyAsync( cuda_check(cudaMemcpyAsync(
bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0), bundle.get(i), workspace_cpu.get(0), workspace_cpu.get_size(0),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
cuda_check(cudaStreamAddCallback( free(workspace_cpu_raw);
stream, callback_free, static_cast<void*>(workspace_cpu_raw),
0));
warp_perspective::forward_proxy_multi_src( warp_perspective::forward_proxy_multi_src(
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
...@@ -579,9 +577,7 @@ void WarpPerspectiveForwardImpl::exec( ...@@ -579,9 +577,7 @@ void WarpPerspectiveForwardImpl::exec(
cuda_check(cudaMemcpyAsync( cuda_check(cudaMemcpyAsync(
bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0), bundle.get(0), workspace_cpu.get(0), workspace_cpu.get_size(0),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
cuda_check(cudaStreamAddCallback( free(workspace_cpu_raw);
stream, callback_free, static_cast<void*>(workspace_cpu_raw),
0));
warp_perspective::forward_proxy_multi_src( warp_perspective::forward_proxy_multi_src(
is_nhwc, srcs_gpu, mat.ptr<dt_float32>(), is_nhwc, srcs_gpu, mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
......
...@@ -299,7 +299,7 @@ public: ...@@ -299,7 +299,7 @@ public:
* @param io_name the name of the tensor * @param io_name the name of the tensor
* @param phase indicate the tensor is input tensor * @param phase indicate the tensor is input tensor
*/ */
std::vector<std::shared_ptr<Tensor>> get_io_tensors( std::vector<std::shared_ptr<Tensor>> get_discrete_tensors(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT); std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT);
//! get the network input tensor by index //! get the network input tensor by index
......
...@@ -311,7 +311,7 @@ LITE_API int LITE_get_io_tensor( ...@@ -311,7 +311,7 @@ LITE_API int LITE_get_io_tensor(
* \param[in] phase The tensor phase * \param[in] phase The tensor phase
* \param[out] tensor The IO tensor get from the network * \param[out] tensor The IO tensor get from the network
*/ */
LITE_API int LITE_get_io_tensors( LITE_API int LITE_get_discrete_tensor(
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase,
LiteTensor* tensor); LiteTensor* tensor);
......
...@@ -278,13 +278,13 @@ int LITE_get_io_tensor( ...@@ -278,13 +278,13 @@ int LITE_get_io_tensor(
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_get_io_tensors( int LITE_get_discrete_tensor(
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase, LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase,
LiteTensor* tensor) { LiteTensor* tensor) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
auto io_tensors = auto io_tensors =
static_cast<lite::Network*>(network)->get_io_tensors(io_name, phase); static_cast<lite::Network*>(network)->get_discrete_tensors(io_name, phase);
LITE_ASSERT( LITE_ASSERT(
n_idx < io_tensors.size(), "n_idx should be less than %zu", n_idx < io_tensors.size(), "n_idx should be less than %zu",
io_tensors.size()); io_tensors.size());
......
...@@ -542,7 +542,7 @@ class _NetworkAPI(_LiteCObjBase): ...@@ -542,7 +542,7 @@ class _NetworkAPI(_LiteCObjBase):
), ),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
( (
"LITE_get_io_tensors", "LITE_get_discrete_tensor",
[_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)], [_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)],
), ),
] ]
...@@ -745,7 +745,7 @@ class LiteNetwork(object): ...@@ -745,7 +745,7 @@ class LiteNetwork(object):
tensor.update() tensor.update()
return tensor return tensor
def get_io_tensors(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT): def get_discrete_tensor(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT):
""" """
get the n_idx'th tensor in the network input tensors whose get the n_idx'th tensor in the network input tensors whose
input consists of discrete multiple tensors and tensor name is name input consists of discrete multiple tensors and tensor name is name
...@@ -763,7 +763,7 @@ class LiteNetwork(object): ...@@ -763,7 +763,7 @@ class LiteNetwork(object):
else: else:
c_name = c_char_p(name) c_name = c_char_p(name)
tensor = LiteTensor(physic_construct=False) tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_io_tensors( self._api.LITE_get_discrete_tensor(
self._network, c_name, n_idx, phase, byref(tensor._tensor) self._network, c_name, n_idx, phase, byref(tensor._tensor)
) )
tensor.update() tensor.update()
......
...@@ -504,28 +504,59 @@ class TestNetwork(TestShuffleNet): ...@@ -504,28 +504,59 @@ class TestNetwork(TestShuffleNet):
class TestDiscreteInputNet(unittest.TestCase): class TestDiscreteInputNet(unittest.TestCase):
source_dir = os.getenv("LITE_TEST_RESOURCE") source_dir = os.getenv("LITE_TEST_RESOURCE")
data_path = os.path.join(source_dir, "data_b3.npy")
data0_path = os.path.join(source_dir, "data0.npy") data0_path = os.path.join(source_dir, "data0.npy")
data1_path = os.path.join(source_dir, "data1.npy") data1_path = os.path.join(source_dir, "data1.npy")
data2_path = os.path.join(source_dir, "data2.npy") data2_path = os.path.join(source_dir, "data2.npy")
roi_path = os.path.join(source_dir, "roi.npy")
model_path = os.path.join(source_dir, "test_discrete_input.mge") model_path = os.path.join(source_dir, "test_discrete_input.mge")
data = np.load(data_path)
data0 = np.load(data0_path) data0 = np.load(data0_path)
data1 = np.load(data1_path) data1 = np.load(data1_path)
data2 = np.load(data2_path) data2 = np.load(data2_path)
roi = np.load(roi_path)
def do_forward(self, network, times=3): def check_correct(self, out_data, error=1e-4):
out_data = out_data.flatten()
config = LiteConfig()
net = LiteNetwork(config)
net.load(self.model_path)
input_tensor = net.get_io_tensor("data")
input_tensor.set_data_by_share(self.data)
roi_tensor = net.get_io_tensor("roi")
roi_tensor.set_data_by_share(self.roi)
output_name = net.get_output_name(0)
output_tensor = net.get_io_tensor(output_name)
net.forward()
net.wait()
correct_data = output_tensor.to_numpy().flatten()
assert correct_data.size == out_data.size
for i in range(out_data.size):
assert abs(out_data[i] - correct_data[i]) < error
def do_forward(self, network, times=1):
data_name = network.get_input_name(1) data_name = network.get_input_name(1)
datas = [] datas = []
datas.append(network.get_io_tensors(data_name, 0)) datas.append(network.get_discrete_tensor(data_name, 0))
datas.append(network.get_io_tensors(data_name, 1)) datas.append(network.get_discrete_tensor(data_name, 1))
datas.append(network.get_io_tensors(data_name, 2)) datas.append(network.get_discrete_tensor(data_name, 2))
datas[0].set_data_by_copy(self.data0) datas[0].set_data_by_share(self.data0)
datas[1].set_data_by_copy(self.data1) datas[1].set_data_by_share(self.data1)
datas[2].set_data_by_copy(self.data2) datas[2].set_data_by_share(self.data2)
roi_tensor = network.get_io_tensor("roi")
roi_tensor.set_data_by_share(self.roi)
out_name = network.get_output_name(0)
out_tensor = network.get_io_tensor(out_name)
for i in range(times): for i in range(times):
network.forward() network.forward()
network.wait() network.wait()
out_data = out_tensor.to_numpy()
self.check_correct(out_data)
class TestDiscreteInput(TestDiscreteInputNet): class TestDiscreteInput(TestDiscreteInputNet):
def test_discrete_input(self): def test_discrete_input(self):
......
...@@ -268,57 +268,69 @@ void NetworkImplDft::replace_src_discrete_input_opr_pass() { ...@@ -268,57 +268,69 @@ void NetworkImplDft::replace_src_discrete_input_opr_pass() {
gopt::SubGraph graph{dest_with_extra_deps}; gopt::SubGraph graph{dest_with_extra_deps};
auto rewriter = graph.make_rewriter(); auto rewriter = graph.make_rewriter();
auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) { auto on_opr = [&](cg::OperatorNodeBase* opr) {
if (opr->same_type<mgb::opr::WarpPerspective>()) { bool replace_output = false;
bool is_h2d = true; for (auto inp : opr->input()) {
if (opr->input(0)->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>()) if ((inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>() ||
is_h2d = true; inp->owner_opr()->same_type<mgb::opr::VolatileSharedDeviceTensor>()) &&
else if (opr->input(0) inp->name() == m_user_config->discrete_input_name) {
->owner_opr() bool is_h2d = true;
->same_type<mgb::opr::VolatileSharedDeviceTensor>()) if (inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>()) {
is_h2d = false; is_h2d = true;
else } else {
return; is_h2d = false;
SymbolVarArray srcs;
if (is_h2d) {
auto h2d = opr->input(0)->owner_opr();
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) {
auto val = TensorHelper::implement(inp)
->cast_final_safe<TensorImplDft>()
.m_host_tensor;
LITE_ASSERT(val);
srcs.push_back(mgb::opr::Host2DeviceCopy::make(
*m_load_result.graph, val, h2d->config()));
} }
} else {
auto volatiled = opr->input(0)->owner_opr(); SymbolVarArray srcs;
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) { if (is_h2d) {
auto val = TensorHelper::implement(inp) auto h2d = inp->owner_opr();
->cast_final_safe<TensorImplDft>() for (auto&& i :
.m_dev_tensor; get_discrete_tensors(m_user_config->discrete_input_name)) {
LITE_ASSERT(val); auto val = TensorHelper::implement(i)
srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make( ->cast_final_safe<TensorImplDft>()
*m_load_result.graph, val, volatiled->config())); .m_host_tensor;
LITE_ASSERT(val);
srcs.push_back(mgb::opr::Host2DeviceCopy::make(
*m_load_result.graph, val, h2d->config()));
}
} else {
auto volatiled = inp->owner_opr();
for (auto&& i :
get_discrete_tensors(m_user_config->discrete_input_name)) {
auto val = TensorHelper::implement(i)
->cast_final_safe<TensorImplDft>()
.m_dev_tensor;
LITE_ASSERT(val);
srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make(
*m_load_result.graph, val, volatiled->config()));
}
} }
}
auto& warp = opr->cast_final<mgb::opr::WarpPerspective>(); if (opr->same_type<mgb::opr::WarpPerspective>()) {
SymbolVar new_out; auto& warp = opr->cast_final<mgb::opr::WarpPerspective>();
if (opr->input().size() == 3) { SymbolVar new_out;
new_out = mgb::opr::WarpPerspective::make( if (opr->input().size() == 3) {
srcs, warp.input(1), warp.input(2), warp.param(), new_out = mgb::opr::WarpPerspective::make(
warp.config()); srcs, warp.input(1), warp.input(2), warp.param(),
} else { warp.config());
LITE_ASSERT(opr->input().size() == 4); } else {
new_out = mgb::opr::WarpPerspective::make( LITE_ASSERT(opr->input().size() == 4);
srcs, warp.input(1), warp.input(2), warp.input(3), warp.param(), new_out = mgb::opr::WarpPerspective::make(
warp.config()); srcs, warp.input(1), warp.input(2), warp.input(3),
warp.param(), warp.config());
}
rewriter.replace_var(
warp.output(0), new_out.node(),
"replace WarpPerspective to WarpPerspective multi src "
"version.");
replace_output = true;
} else {
auto concat = mgb::opr::Concat::make(srcs, 0);
rewriter.replace_var(inp, concat.node(), "add a concat opr.");
}
} }
rewriter.replace_var( }
warp.output(0), new_out.node(), if (!replace_output) {
"replace WarpPerspective to WarpPerspective multi src version.");
} else {
rewriter.auto_replace_outputs(opr); rewriter.auto_replace_outputs(opr);
} }
}; };
...@@ -385,6 +397,10 @@ void NetworkImplDft::replace_dev_input_pass() { ...@@ -385,6 +397,10 @@ void NetworkImplDft::replace_dev_input_pass() {
inp_var_map[host_val2var.at(host_val.get())] = dev_var; inp_var_map[host_val2var.at(host_val.get())] = dev_var;
name2dev_tensor[config_in.name] = dev_val; name2dev_tensor[config_in.name] = dev_val;
} }
//! reset lite_tensor in discrete mode
if (config_in.name == m_user_config->discrete_input_name) {
config_in.lite_tensor.reset();
}
} }
auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map); auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
for (size_t i = 0; i < new_ovar.size(); ++i) { for (size_t i = 0; i < new_ovar.size(); ++i) {
...@@ -611,8 +627,9 @@ void NetworkImplDft::configure_after_loaded() { ...@@ -611,8 +627,9 @@ void NetworkImplDft::configure_after_loaded() {
void NetworkImplDft::compile_graph() { void NetworkImplDft::compile_graph() {
replace_dev_input_pass(); replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty()) if (!m_user_config->discrete_input_name.empty()) {
replace_src_discrete_input_opr_pass(); replace_src_discrete_input_opr_pass();
}
make_output_spec(); make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec); m_execute_func = m_load_result.graph_compile(m_output_spec);
} }
...@@ -792,6 +809,7 @@ void NetworkImplDft::update_input() { ...@@ -792,6 +809,7 @@ void NetworkImplDft::update_input() {
} }
} }
//! initialization lite_tensors when input is composed of discrete multiple tensors
void NetworkImplDft::update_input_lite_tensors() { void NetworkImplDft::update_input_lite_tensors() {
auto device_type = m_user_config->device_type; auto device_type = m_user_config->device_type;
auto device_id = m_compnode_locator.device; auto device_id = m_compnode_locator.device;
...@@ -801,24 +819,22 @@ void NetworkImplDft::update_input_lite_tensors() { ...@@ -801,24 +819,22 @@ void NetworkImplDft::update_input_lite_tensors() {
if (in_tensor_iter.first != m_user_config->discrete_input_name) { if (in_tensor_iter.first != m_user_config->discrete_input_name) {
continue; continue;
} }
bool found = false;
for (auto&& config_in : m_network_io->inputs) { for (auto&& config_in : m_network_io->inputs) {
if (in_tensor_iter.first == config_in.name) { if (in_tensor_iter.first == config_in.name) {
found = true;
size_t bs = in_tensor_iter.second->shape(0); size_t bs = in_tensor_iter.second->shape(0);
auto shape = in_tensor_iter.second->shape(); auto shape = in_tensor_iter.second->shape();
shape.shape[0] = 1;
if (config_in.config_layout.ndim) { if (config_in.config_layout.ndim) {
bs = config_in.config_layout.shapes[0]; bs = config_in.config_layout.shapes[0];
shape.shape[1] = config_in.config_layout.shapes[1]; for (size_t i = 0; i < config_in.config_layout.ndim; ++i) {
shape.shape[2] = config_in.config_layout.shapes[2]; shape.shape[i] = config_in.config_layout.shapes[i];
shape.shape[3] = config_in.config_layout.shapes[3]; }
} }
HostTensorND tensor( shape.shape[0] = 1;
in_tensor_iter.second->comp_node(), shape,
in_tensor_iter.second->dtype(),
in_tensor_iter.second->format());
for (size_t i = 0; i < bs; ++i) { for (size_t i = 0; i < bs; ++i) {
HostTensorND tensor(
in_tensor_iter.second->comp_node(), shape,
in_tensor_iter.second->dtype(),
in_tensor_iter.second->format());
if (config_in.is_host) { if (config_in.is_host) {
config_in.lite_tensors.push_back(std::make_shared<Tensor>( config_in.lite_tensors.push_back(std::make_shared<Tensor>(
device_id, stream_id, device_type, true)); device_id, stream_id, device_type, true));
...@@ -839,29 +855,6 @@ void NetworkImplDft::update_input_lite_tensors() { ...@@ -839,29 +855,6 @@ void NetworkImplDft::update_input_lite_tensors() {
} }
} }
} }
if (!found) {
size_t bs = in_tensor_iter.second->shape(0);
auto shape = in_tensor_iter.second->shape();
shape.shape[0] = 1;
HostTensorND tensor(
in_tensor_iter.second->comp_node(), shape,
in_tensor_iter.second->dtype(), in_tensor_iter.second->format());
IOInner io_in;
io_in.name = in_tensor_iter.first;
for (size_t i = 0; i < bs; ++i) {
io_in.lite_tensors.push_back(std::make_shared<Tensor>(
device_id, stream_id, device_type, true));
TensorHelper::implement(io_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_host_tensor = std::make_shared<HostTensorND>(tensor);
TensorHelper::implement(io_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_record_reset =
m_user_config->options.comp_node_seq_record_level > 0;
io_in.lite_tensors[i]->update_from_implement();
}
m_network_io->inputs.push_back(io_in);
}
} }
} }
...@@ -997,7 +990,15 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( ...@@ -997,7 +990,15 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) { if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
for (auto&& config_in : m_network_io->inputs) { for (auto&& config_in : m_network_io->inputs) {
if (io_name == config_in.name) { if (io_name == config_in.name) {
return config_in.lite_tensor; if (config_in.lite_tensor) {
return config_in.lite_tensor;
} else {
LITE_THROW(mgb::ssprintf(
"%s input tensor is in discrete mode, you can use "
"get_discrete_tensors to get this input.",
io_name.c_str()));
return nullptr;
}
} }
} }
} }
...@@ -1018,7 +1019,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( ...@@ -1018,7 +1019,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
return nullptr; return nullptr;
} }
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_io_tensors( std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_discrete_tensors(
std::string io_name, LiteTensorPhase phase) { std::string io_name, LiteTensorPhase phase) {
if (phase == LiteTensorPhase::LITE_INPUT) { if (phase == LiteTensorPhase::LITE_INPUT) {
for (auto&& config_in : m_network_io->inputs) { for (auto&& config_in : m_network_io->inputs) {
...@@ -1038,7 +1039,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) { ...@@ -1038,7 +1039,7 @@ std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
} }
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) { std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) {
return get_io_tensors(get_input_name(index)); return get_discrete_tensors(get_input_name(index));
} }
std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) { std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
......
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
//! get the network input tensors which input consists of discrete multiple tensors, //! get the network input tensors which input consists of discrete multiple tensors,
//! layout (1, c, h, w) //! layout (1, c, h, w)
std::vector<std::shared_ptr<Tensor>> get_io_tensors( std::vector<std::shared_ptr<Tensor>> get_discrete_tensors(
std::string io_name, std::string io_name,
LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override; LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override;
......
...@@ -127,12 +127,12 @@ std::shared_ptr<Tensor> Network::get_io_tensor( ...@@ -127,12 +127,12 @@ std::shared_ptr<Tensor> Network::get_io_tensor(
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
std::vector<std::shared_ptr<Tensor>> Network::get_io_tensors( std::vector<std::shared_ptr<Tensor>> Network::get_discrete_tensors(
std::string name, LiteTensorPhase phase) { std::string name, LiteTensorPhase phase) {
LITE_ERROR_HANDLER_BEGIN LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded."); LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl); LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_io_tensors(name, phase); return m_impl->get_discrete_tensors(name, phase);
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
......
...@@ -91,8 +91,10 @@ public: ...@@ -91,8 +91,10 @@ public:
//! get the network input tensors which input consists of discrete multiple tensors, //! get the network input tensors which input consists of discrete multiple tensors,
//! layout (1, c, h, w) //! layout (1, c, h, w)
virtual std::vector<std::shared_ptr<Tensor>> get_io_tensors( virtual std::vector<std::shared_ptr<Tensor>> get_discrete_tensors(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) { std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) {
LITE_MARK_USED_VAR(io_name);
LITE_MARK_USED_VAR(phase);
return {}; return {};
} }
...@@ -102,6 +104,7 @@ public: ...@@ -102,6 +104,7 @@ public:
//! get the network input tensors which input consists of discrete multiple tensors //! get the network input tensors which input consists of discrete multiple tensors
//! by index //! by index
virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) { virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) {
LITE_MARK_USED_VAR(index);
return {}; return {};
} }
......
...@@ -1393,6 +1393,7 @@ TEST(TestNetWork, Discrete_Input) { ...@@ -1393,6 +1393,7 @@ TEST(TestNetWork, Discrete_Input) {
auto data_0 = get_input_data("./data0.npy"); auto data_0 = get_input_data("./data0.npy");
auto data_1 = get_input_data("./data1.npy"); auto data_1 = get_input_data("./data1.npy");
auto data_2 = get_input_data("./data2.npy"); auto data_2 = get_input_data("./data2.npy");
auto roi = get_input_data("./roi.npy");
std::string model_path = "./test_discrete_input.mge"; std::string model_path = "./test_discrete_input.mge";
Config config; Config config;
...@@ -1403,6 +1404,8 @@ TEST(TestNetWork, Discrete_Input) { ...@@ -1403,6 +1404,8 @@ TEST(TestNetWork, Discrete_Input) {
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data");
data_tensor->share_memory_with(*data); data_tensor->share_memory_with(*data);
std::shared_ptr<Tensor> roi_tensor = network0->get_io_tensor("roi");
roi_tensor->share_memory_with(*roi);
network0->forward(); network0->forward();
network0->wait(); network0->wait();
...@@ -1417,8 +1420,11 @@ TEST(TestNetWork, Discrete_Input) { ...@@ -1417,8 +1420,11 @@ TEST(TestNetWork, Discrete_Input) {
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios);
network1->load_model(model_path); network1->load_model(model_path);
std::shared_ptr<Tensor> roi_tensor1 = network1->get_io_tensor("roi");
roi_tensor1->copy_from(*roi);
std::vector<std::shared_ptr<Tensor>> data_tensors = std::vector<std::shared_ptr<Tensor>> data_tensors =
network1->get_io_tensors("data"); network1->get_discrete_tensors("data");
data_tensors[0]->share_memory_with(*data_0); data_tensors[0]->share_memory_with(*data_0);
data_tensors[1]->share_memory_with(*data_1); data_tensors[1]->share_memory_with(*data_1);
data_tensors[2]->share_memory_with(*data_2); data_tensors[2]->share_memory_with(*data_2);
...@@ -1435,6 +1441,7 @@ TEST(TestNetWork, Discrete_Input_Device) { ...@@ -1435,6 +1441,7 @@ TEST(TestNetWork, Discrete_Input_Device) {
auto data_0 = get_input_data("./data0.npy"); auto data_0 = get_input_data("./data0.npy");
auto data_1 = get_input_data("./data1.npy"); auto data_1 = get_input_data("./data1.npy");
auto data_2 = get_input_data("./data2.npy"); auto data_2 = get_input_data("./data2.npy");
auto roi = get_input_data("./roi.npy");
std::string model_path = "./test_discrete_input.mge"; std::string model_path = "./test_discrete_input.mge";
Config config; Config config;
...@@ -1444,7 +1451,9 @@ TEST(TestNetWork, Discrete_Input_Device) { ...@@ -1444,7 +1451,9 @@ TEST(TestNetWork, Discrete_Input_Device) {
network0->load_model(model_path); network0->load_model(model_path);
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data"); std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data");
data_tensor->share_memory_with(*data); data_tensor->copy_from(*data);
std::shared_ptr<Tensor> roi_tensor = network0->get_io_tensor("roi");
roi_tensor->copy_from(*roi);
network0->forward(); network0->forward();
network0->wait(); network0->wait();
...@@ -1459,8 +1468,10 @@ TEST(TestNetWork, Discrete_Input_Device) { ...@@ -1459,8 +1468,10 @@ TEST(TestNetWork, Discrete_Input_Device) {
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios); std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios);
network1->load_model(model_path); network1->load_model(model_path);
std::shared_ptr<Tensor> roi_tensor1 = network1->get_io_tensor("roi");
roi_tensor1->copy_from(*roi);
std::vector<std::shared_ptr<Tensor>> data_tensors = std::vector<std::shared_ptr<Tensor>> data_tensors =
network1->get_io_tensors("data"); network1->get_discrete_tensors("data");
auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly); auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
...@@ -1477,6 +1488,48 @@ TEST(TestNetWork, Discrete_Input_Device) { ...@@ -1477,6 +1488,48 @@ TEST(TestNetWork, Discrete_Input_Device) {
compare_lite_tensor<float>(output_tensor0, output_tensor1); compare_lite_tensor<float>(output_tensor0, output_tensor1);
} }
TEST(TestNetWork, Discrete_Input_Concat) {
auto data = get_input_data("./data_b3.npy");
auto data_0 = get_input_data("./data0.npy");
auto data_1 = get_input_data("./data1.npy");
auto data_2 = get_input_data("./data2.npy");
std::string model_path = "./test_discrete_input_concat.mge";
Config config;
config.device_type = LiteDeviceType::LITE_CUDA;
std::shared_ptr<Network> network0 = std::make_shared<Network>(config);
network0->load_model(model_path);
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data");
data_tensor->copy_from(*data);
network0->forward();
network0->wait();
std::shared_ptr<Tensor> output_tensor0 = network0->get_output_tensor(0);
config.discrete_input_name = "data";
NetworkIO ios;
bool is_host = true;
Layout d_ly{{3, 3, 224, 224}, 4, LiteDataType::LITE_FLOAT};
ios.inputs.push_back({"data", is_host, LiteIOType::LITE_IO_VALUE, d_ly});
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios);
network1->load_model(model_path);
std::vector<std::shared_ptr<Tensor>> data_tensors =
network1->get_discrete_tensors("data");
data_tensors[0]->copy_from(*data_0);
data_tensors[1]->copy_from(*data_1);
data_tensors[2]->copy_from(*data_2);
network1->forward();
network1->wait();
std::shared_ptr<Tensor> output_tensor1 = network1->get_output_tensor(0);
compare_lite_tensor<float>(output_tensor0, output_tensor1);
}
#endif #endif
#if MGB_ATLAS || MGB_CAMBRICON #if MGB_ATLAS || MGB_CAMBRICON
......
...@@ -322,7 +322,7 @@ TEST(TestCapiNetWork, Discrete_Input) { ...@@ -322,7 +322,7 @@ TEST(TestCapiNetWork, Discrete_Input) {
std::vector<LiteTensor> c_data_tensors(3, nullptr); std::vector<LiteTensor> c_data_tensors(3, nullptr);
for (size_t i = 0; i < 3; i++) { for (size_t i = 0; i < 3; i++) {
LITE_CAPI_CHECK(LITE_get_io_tensors( LITE_CAPI_CHECK(LITE_get_discrete_tensor(
c_network, "data", i, LITE_INPUT, &c_data_tensors[i])); c_network, "data", i, LITE_INPUT, &c_data_tensors[i]));
LITE_CAPI_CHECK(LITE_reset_tensor_memory( LITE_CAPI_CHECK(LITE_reset_tensor_memory(
c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte)); c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册