提交 dc0ab9b6 编写于 作者: M Megvii Engine Team

feat(lite): replace warp when src is discrete input

GitOrigin-RevId: 2bf7980ac6373b691081ab7be9975ec6fa57f8ae
上级 58b682ca
...@@ -117,6 +117,9 @@ struct LITE_API Options { ...@@ -117,6 +117,9 @@ struct LITE_API Options {
* *
* @param auto_optimize_inference lite will detect the device information add * @param auto_optimize_inference lite will detect the device information add
* set the options heuristically * set the options heuristically
*
* @param discrete_input_name configure which input is composed of discrete
* multiple tensors
*/ */
struct LITE_API Config { struct LITE_API Config {
bool has_compression = false; bool has_compression = false;
...@@ -126,6 +129,7 @@ struct LITE_API Config { ...@@ -126,6 +129,7 @@ struct LITE_API Config {
std::string bare_model_cryption_name = {}; std::string bare_model_cryption_name = {};
Options options = {}; Options options = {};
bool auto_optimize_inference = false; bool auto_optimize_inference = false;
std::string discrete_input_name = {};
}; };
/*! /*!
...@@ -289,9 +293,22 @@ public: ...@@ -289,9 +293,22 @@ public:
std::shared_ptr<Tensor> get_io_tensor( std::shared_ptr<Tensor> get_io_tensor(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO); std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO);
/** @brief get the network input tensors which input consists of discrete multiple
* tensors, layout (1, c, h, w)
*
* @param io_name the name of the tensor
* @param phase indicate the tensor is input tensor
*/
std::vector<std::shared_ptr<Tensor>> get_io_tensors(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT);
//! get the network input tensor by index //! get the network input tensor by index
std::shared_ptr<Tensor> get_input_tensor(size_t index); std::shared_ptr<Tensor> get_input_tensor(size_t index);
//! get the network input tensors which input consists of discrete multiple tensors
//! by index
std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index);
//! get the network output tensor by index //! get the network output tensor by index
std::shared_ptr<Tensor> get_output_tensor(size_t index); std::shared_ptr<Tensor> get_output_tensor(size_t index);
......
...@@ -103,6 +103,9 @@ extern LITE_API const LiteOptions default_option; ...@@ -103,6 +103,9 @@ extern LITE_API const LiteOptions default_option;
*\param auto_optimize_inference lite will detect the device information add *\param auto_optimize_inference lite will detect the device information add
* set the options heuristically * set the options heuristically
*
* \param discrete_input_name configure which input is composed of discrete
* multiple tensors
*/ */
typedef struct LiteConfig { typedef struct LiteConfig {
int has_compression; int has_compression;
...@@ -112,6 +115,7 @@ typedef struct LiteConfig { ...@@ -112,6 +115,7 @@ typedef struct LiteConfig {
const char* bare_model_cryption_name; const char* bare_model_cryption_name;
LiteOptions options; LiteOptions options;
int auto_optimize_inference; int auto_optimize_inference;
const char* discrete_input_name;
} LiteConfig; } LiteConfig;
//! get default config //! get default config
...@@ -298,6 +302,19 @@ LITE_API int LITE_get_io_tensor( ...@@ -298,6 +302,19 @@ LITE_API int LITE_get_io_tensor(
LiteNetwork network, const char* io_name, LiteTensorPhase phase, LiteNetwork network, const char* io_name, LiteTensorPhase phase,
LiteTensor* tensor); LiteTensor* tensor);
/**
* \brief get the n'th tensor in the network input tensors whose input
* consists of discrete multiple tensors and name is io_name, layout (1, c, h, w)
* \param[in] network The loaded model
* \param[in] io_name The input name
* \param[in] n_idx The index of tensor
* \param[in] phase The tensor phase
* \param[out] tensor The IO tensor get from the network
*/
LITE_API int LITE_get_io_tensors(
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase,
LiteTensor* tensor);
/** /**
* \brief get the input tensor name in the order in loaded model * \brief get the input tensor name in the order in loaded model
* \param[in] network The loaded model * \param[in] network The loaded model
......
...@@ -43,7 +43,8 @@ LiteConfig default_config_t = { ...@@ -43,7 +43,8 @@ LiteConfig default_config_t = {
.backend = LiteBackend::LITE_DEFAULT, .backend = LiteBackend::LITE_DEFAULT,
.bare_model_cryption_name = nullptr, .bare_model_cryption_name = nullptr,
.options = default_option, .options = default_option,
.auto_optimize_inference = false}; .auto_optimize_inference = false,
.discrete_input_name = nullptr};
LiteConfig* default_config() { LiteConfig* default_config() {
return &default_config_t; return &default_config_t;
} }
...@@ -135,6 +136,9 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) { ...@@ -135,6 +136,9 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
lite_config.options.enable_nchw64 = c_config.options.enable_nchw64; lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
lite_config.auto_optimize_inference = c_config.auto_optimize_inference; lite_config.auto_optimize_inference = c_config.auto_optimize_inference;
if (c_config.discrete_input_name) {
lite_config.discrete_input_name = c_config.discrete_input_name;
}
return lite_config; return lite_config;
} }
...@@ -274,6 +278,20 @@ int LITE_get_io_tensor( ...@@ -274,6 +278,20 @@ int LITE_get_io_tensor(
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_get_io_tensors(
LiteNetwork network, const char* io_name, size_t n_idx, LiteTensorPhase phase,
LiteTensor* tensor) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
auto io_tensors =
static_cast<lite::Network*>(network)->get_io_tensors(io_name, phase);
LITE_ASSERT(
n_idx < io_tensors.size(), "n_idx should be less than %zu",
io_tensors.size());
*tensor = io_tensors[n_idx].get();
LITE_CAPI_END();
}
int LITE_get_input_name(const LiteNetwork network, size_t index, const char** name) { int LITE_get_input_name(const LiteNetwork network, size_t index, const char** name) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network && name, "The network pass to LITE api is null"); LITE_ASSERT(network && name, "The network pass to LITE api is null");
......
...@@ -173,6 +173,8 @@ class LiteConfig(Structure): ...@@ -173,6 +173,8 @@ class LiteConfig(Structure):
auto_optimize_inference: lite will detect the device information add set the options heuristically auto_optimize_inference: lite will detect the device information add set the options heuristically
discrete_input_name: configure which input is composed of discrete multiple tensors
Examples: Examples:
.. code-block:: .. code-block::
...@@ -193,6 +195,7 @@ class LiteConfig(Structure): ...@@ -193,6 +195,7 @@ class LiteConfig(Structure):
("_bare_model_cryption_name", c_char_p), ("_bare_model_cryption_name", c_char_p),
("options", LiteOptions), ("options", LiteOptions),
("auto_optimize_inference", c_int), ("auto_optimize_inference", c_int),
("discrete_input_name", c_char_p),
] ]
def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None): def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
...@@ -207,6 +210,7 @@ class LiteConfig(Structure): ...@@ -207,6 +210,7 @@ class LiteConfig(Structure):
self.has_compression = 0 self.has_compression = 0
self.backend = LiteBackend.LITE_DEFAULT self.backend = LiteBackend.LITE_DEFAULT
self.auto_optimize_inference = 0 self.auto_optimize_inference = 0
self.discrete_input_name = c_char_p(b"")
@property @property
def bare_model_cryption_name(self): def bare_model_cryption_name(self):
...@@ -229,6 +233,7 @@ class LiteConfig(Structure): ...@@ -229,6 +233,7 @@ class LiteConfig(Structure):
"bare_model_cryption_name": self.bare_model_cryption_name, "bare_model_cryption_name": self.bare_model_cryption_name,
"options": self.options, "options": self.options,
"auto_optimize_inference": self.auto_optimize_inference, "auto_optimize_inference": self.auto_optimize_inference,
"discrete_input_name": self.discrete_input_name,
} }
return data.__repr__() return data.__repr__()
...@@ -536,6 +541,10 @@ class _NetworkAPI(_LiteCObjBase): ...@@ -536,6 +541,10 @@ class _NetworkAPI(_LiteCObjBase):
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
), ),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]), ("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
(
"LITE_get_io_tensors",
[_Cnetwork, c_char_p, c_size_t, c_int, POINTER(_Ctensor)],
),
] ]
...@@ -736,6 +745,30 @@ class LiteNetwork(object): ...@@ -736,6 +745,30 @@ class LiteNetwork(object):
tensor.update() tensor.update()
return tensor return tensor
def get_io_tensors(self, name, n_idx, phase=LiteTensorPhase.LITE_INPUT):
"""
get the n_idx'th tensor in the network input tensors whose
input consists of discrete multiple tensors and tensor name is name
Args:
name: the name of input tensor
n_idx: the tensor index
phase: the type of LiteTensor, this is useful to separate input tensor with the same name
Returns:
the tensors with given name and type
"""
if type(name) == str:
c_name = c_char_p(name.encode("utf-8"))
else:
c_name = c_char_p(name)
tensor = LiteTensor(physic_construct=False)
self._api.LITE_get_io_tensors(
self._network, c_name, n_idx, phase, byref(tensor._tensor)
)
tensor.update()
return tensor
def get_input_name(self, index): def get_input_name(self, index):
""" """
get the input name by the index in the network get the input name by the index in the network
......
...@@ -500,3 +500,45 @@ class TestNetwork(TestShuffleNet): ...@@ -500,3 +500,45 @@ class TestNetwork(TestShuffleNet):
os.remove(fast_run_cache) os.remove(fast_run_cache)
os.remove(global_layout_transform_model) os.remove(global_layout_transform_model)
class TestDiscreteInputNet(unittest.TestCase):
source_dir = os.getenv("LITE_TEST_RESOURCE")
data0_path = os.path.join(source_dir, "data0.npy")
data1_path = os.path.join(source_dir, "data1.npy")
data2_path = os.path.join(source_dir, "data2.npy")
model_path = os.path.join(source_dir, "test_discrete_input.mge")
data0 = np.load(data0_path)
data1 = np.load(data1_path)
data2 = np.load(data2_path)
def do_forward(self, network, times=3):
data_name = network.get_input_name(1)
datas = []
datas.append(network.get_io_tensors(data_name, 0))
datas.append(network.get_io_tensors(data_name, 1))
datas.append(network.get_io_tensors(data_name, 2))
datas[0].set_data_by_copy(self.data0)
datas[1].set_data_by_copy(self.data1)
datas[2].set_data_by_copy(self.data2)
for i in range(times):
network.forward()
network.wait()
class TestDiscreteInput(TestDiscreteInputNet):
def test_discrete_input(self):
config = LiteConfig()
config.discrete_input_name = "data".encode("utf-8")
input_io = LiteIO(
"data",
is_host=True,
io_type=LiteIOType.LITE_IO_VALUE,
layout=LiteLayout([3, 3, 224, 224]),
)
ios = LiteNetworkIO()
ios.add_input(input_io)
network = LiteNetwork(config, ios)
network.load(self.model_path)
self.do_forward(network)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
...@@ -259,6 +260,88 @@ void NetworkImplDft::make_output_spec() { ...@@ -259,6 +260,88 @@ void NetworkImplDft::make_output_spec() {
} }
} }
void NetworkImplDft::replace_src_discrete_input_opr_pass() {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto dest_with_extra_deps =
get_dest_vars_with_extra_deps(m_load_result.output_var_list);
gopt::SubGraph graph{dest_with_extra_deps};
auto rewriter = graph.make_rewriter();
auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
if (opr->same_type<mgb::opr::WarpPerspective>()) {
bool is_h2d = true;
if (opr->input(0)->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>())
is_h2d = true;
else if (opr->input(0)
->owner_opr()
->same_type<mgb::opr::VolatileSharedDeviceTensor>())
is_h2d = false;
else
return;
SymbolVarArray srcs;
if (is_h2d) {
auto h2d = opr->input(0)->owner_opr();
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) {
auto val = TensorHelper::implement(inp)
->cast_final_safe<TensorImplDft>()
.m_host_tensor;
LITE_ASSERT(val);
srcs.push_back(mgb::opr::Host2DeviceCopy::make(
*m_load_result.graph, val, h2d->config()));
}
} else {
auto volatiled = opr->input(0)->owner_opr();
for (auto&& inp : get_io_tensors(m_user_config->discrete_input_name)) {
auto val = TensorHelper::implement(inp)
->cast_final_safe<TensorImplDft>()
.m_dev_tensor;
LITE_ASSERT(val);
srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make(
*m_load_result.graph, val, volatiled->config()));
}
}
auto& warp = opr->cast_final<mgb::opr::WarpPerspective>();
SymbolVar new_out;
if (opr->input().size() == 3) {
new_out = mgb::opr::WarpPerspective::make(
srcs, warp.input(1), warp.input(2), warp.param(),
warp.config());
} else {
LITE_ASSERT(opr->input().size() == 4);
new_out = mgb::opr::WarpPerspective::make(
srcs, warp.input(1), warp.input(2), warp.input(3), warp.param(),
warp.config());
}
rewriter.replace_var(
warp.output(0), new_out.node(),
"replace WarpPerspective to WarpPerspective multi src version.");
} else {
rewriter.auto_replace_outputs(opr);
}
};
graph.iter(on_opr);
rewriter.apply_inplace();
auto new_ovar = graph.endpoint_vars();
new_ovar.resize(m_load_result.output_var_list.size());
for (size_t i = 0; i < new_ovar.size(); ++i) {
out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
}
for (auto&& i : m_load_result.output_var_map) {
i.second = out_var_map.at(i.second);
}
for (auto&& i : m_load_result.output_var_map_id) {
i.second = out_var_map.at(i.second);
}
for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
}
m_load_result.output_var_list = std::move(new_ovar);
}
void NetworkImplDft::replace_dev_input_pass() { void NetworkImplDft::replace_dev_input_pass() {
mgb::CompNode::Locator locator; mgb::CompNode::Locator locator;
m_load_config.comp_node_mapper(locator); m_load_config.comp_node_mapper(locator);
...@@ -528,6 +611,8 @@ void NetworkImplDft::configure_after_loaded() { ...@@ -528,6 +611,8 @@ void NetworkImplDft::configure_after_loaded() {
void NetworkImplDft::compile_graph() { void NetworkImplDft::compile_graph() {
replace_dev_input_pass(); replace_dev_input_pass();
if (!m_user_config->discrete_input_name.empty())
replace_src_discrete_input_opr_pass();
make_output_spec(); make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec); m_execute_func = m_load_result.graph_compile(m_output_spec);
} }
...@@ -691,6 +776,11 @@ void NetworkImplDft::update_input() { ...@@ -691,6 +776,11 @@ void NetworkImplDft::update_input() {
m_network_io->inputs.push_back(io_in); m_network_io->inputs.push_back(io_in);
} }
} }
if (!m_user_config->discrete_input_name.empty()) {
update_input_lite_tensors();
}
//! delete the IO that is not the network //! delete the IO that is not the network
for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) { for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
if (it->lite_tensor == nullptr) { if (it->lite_tensor == nullptr) {
...@@ -702,6 +792,79 @@ void NetworkImplDft::update_input() { ...@@ -702,6 +792,79 @@ void NetworkImplDft::update_input() {
} }
} }
void NetworkImplDft::update_input_lite_tensors() {
auto device_type = m_user_config->device_type;
auto device_id = m_compnode_locator.device;
auto stream_id = m_compnode_locator.stream;
for (auto&& in_tensor_iter : m_load_result.tensor_map) {
if (in_tensor_iter.first != m_user_config->discrete_input_name) {
continue;
}
bool found = false;
for (auto&& config_in : m_network_io->inputs) {
if (in_tensor_iter.first == config_in.name) {
found = true;
size_t bs = in_tensor_iter.second->shape(0);
auto shape = in_tensor_iter.second->shape();
shape.shape[0] = 1;
if (config_in.config_layout.ndim) {
bs = config_in.config_layout.shapes[0];
shape.shape[1] = config_in.config_layout.shapes[1];
shape.shape[2] = config_in.config_layout.shapes[2];
shape.shape[3] = config_in.config_layout.shapes[3];
}
HostTensorND tensor(
in_tensor_iter.second->comp_node(), shape,
in_tensor_iter.second->dtype(),
in_tensor_iter.second->format());
for (size_t i = 0; i < bs; ++i) {
if (config_in.is_host) {
config_in.lite_tensors.push_back(std::make_shared<Tensor>(
device_id, stream_id, device_type, true));
TensorHelper::implement(config_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_host_tensor = std::make_shared<HostTensorND>(tensor);
config_in.lite_tensors[i]->update_from_implement();
} else {
config_in.lite_tensors.push_back(std::make_shared<Tensor>(
device_id, stream_id, device_type));
config_in.lite_tensors[i]->set_layout(
to_lite_layout(tensor.layout()));
}
TensorHelper::implement(config_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_record_reset =
m_user_config->options.comp_node_seq_record_level > 0;
}
}
}
if (!found) {
size_t bs = in_tensor_iter.second->shape(0);
auto shape = in_tensor_iter.second->shape();
shape.shape[0] = 1;
HostTensorND tensor(
in_tensor_iter.second->comp_node(), shape,
in_tensor_iter.second->dtype(), in_tensor_iter.second->format());
IOInner io_in;
io_in.name = in_tensor_iter.first;
for (size_t i = 0; i < bs; ++i) {
io_in.lite_tensors.push_back(std::make_shared<Tensor>(
device_id, stream_id, device_type, true));
TensorHelper::implement(io_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_host_tensor = std::make_shared<HostTensorND>(tensor);
TensorHelper::implement(io_in.lite_tensors[i])
->cast_final_safe<TensorImplDft>()
.m_record_reset =
m_user_config->options.comp_node_seq_record_level > 0;
io_in.lite_tensors[i]->update_from_implement();
}
m_network_io->inputs.push_back(io_in);
}
}
}
void NetworkImplDft::update_output() { void NetworkImplDft::update_output() {
auto device_type = m_user_config->device_type; auto device_type = m_user_config->device_type;
auto device_id = m_compnode_locator.device; auto device_id = m_compnode_locator.device;
...@@ -855,10 +1018,29 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor( ...@@ -855,10 +1018,29 @@ std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
return nullptr; return nullptr;
} }
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_io_tensors(
std::string io_name, LiteTensorPhase phase) {
if (phase == LiteTensorPhase::LITE_INPUT) {
for (auto&& config_in : m_network_io->inputs) {
if (io_name == config_in.name &&
config_in.name == m_user_config->discrete_input_name) {
return config_in.lite_tensors;
}
}
}
LITE_THROW(mgb::ssprintf(
"tensor name must be %s input tensor name.", io_name.c_str()));
return {};
}
std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) { std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
return get_io_tensor(get_input_name(index)); return get_io_tensor(get_input_name(index));
} }
std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) {
return get_io_tensors(get_input_name(index));
}
std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) { std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
return get_io_tensor(get_output_name(index)); return get_io_tensor(get_output_name(index));
} }
......
...@@ -57,9 +57,19 @@ public: ...@@ -57,9 +57,19 @@ public:
std::string io_name, std::string io_name,
LiteTensorPhase phase = LiteTensorPhase::LITE_IO) override; LiteTensorPhase phase = LiteTensorPhase::LITE_IO) override;
//! get the network input tensors which input consists of discrete multiple tensors,
//! layout (1, c, h, w)
std::vector<std::shared_ptr<Tensor>> get_io_tensors(
std::string io_name,
LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override;
//! get the input tensor by index in the load_result tensormap //! get the input tensor by index in the load_result tensormap
std::shared_ptr<Tensor> get_input_tensor(size_t index) override; std::shared_ptr<Tensor> get_input_tensor(size_t index) override;
//! get the network input tensors which input consists of discrete multiple tensors
//! by index
std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) override;
//! get the output tensor by index in the load_result output_var_list //! get the output tensor by index in the load_result output_var_list
std::shared_ptr<Tensor> get_output_tensor(size_t index) override; std::shared_ptr<Tensor> get_output_tensor(size_t index) override;
...@@ -190,6 +200,11 @@ private: ...@@ -190,6 +200,11 @@ private:
//! VolatileSharedDeviceTensor Opr //! VolatileSharedDeviceTensor Opr
void replace_dev_input_pass(); void replace_dev_input_pass();
//! if the input to the network is a list of tensors, this pass will replace
//! the opr that supports the input of a list of tensors with the corresponding
//! version, current support WarpPerspective
void replace_src_discrete_input_opr_pass();
//! check whether the model is cross compnode //! check whether the model is cross compnode
void cross_compnode_model_detect(); void cross_compnode_model_detect();
...@@ -199,6 +214,8 @@ private: ...@@ -199,6 +214,8 @@ private:
void update_input(); void update_input();
void update_output(); void update_output();
//! initialization lite_tensors when input is composed of discrete multiple tensors
void update_input_lite_tensors();
//! when the model info have loaded, update the config according the model //! when the model info have loaded, update the config according the model
//! info, finaly use it in compute graph //! info, finaly use it in compute graph
......
...@@ -127,6 +127,15 @@ std::shared_ptr<Tensor> Network::get_io_tensor( ...@@ -127,6 +127,15 @@ std::shared_ptr<Tensor> Network::get_io_tensor(
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
std::vector<std::shared_ptr<Tensor>> Network::get_io_tensors(
std::string name, LiteTensorPhase phase) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_io_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_io_tensors(name, phase);
LITE_ERROR_HANDLER_END
}
std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) { std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) {
LITE_ERROR_HANDLER_BEGIN LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded."); LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded.");
...@@ -135,6 +144,14 @@ std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) { ...@@ -135,6 +144,14 @@ std::shared_ptr<Tensor> Network::get_input_tensor(size_t index) {
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
std::vector<std::shared_ptr<Tensor>> Network::get_input_tensors(size_t index) {
LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_input_tensor should be used after model loaded.");
LITE_CHECK_NON_NULL_POINTER(m_impl);
return m_impl->get_input_tensors(index);
LITE_ERROR_HANDLER_END
}
std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) { std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) {
LITE_ERROR_HANDLER_BEGIN LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(m_loaded, "get_output_tensor should be used after model loaded."); LITE_ASSERT(m_loaded, "get_output_tensor should be used after model loaded.");
......
...@@ -42,6 +42,9 @@ public: ...@@ -42,6 +42,9 @@ public:
bool have_sync = false; bool have_sync = false;
//! Real input and output data location //! Real input and output data location
std::shared_ptr<Tensor> lite_tensor = nullptr; std::shared_ptr<Tensor> lite_tensor = nullptr;
//! If the input is consists of discrete multiple tensors, lite_tensors is real
//! input data location
std::vector<std::shared_ptr<Tensor>> lite_tensors;
IOInner() = default; IOInner() = default;
IOInner(const IO& io) { IOInner(const IO& io) {
...@@ -86,9 +89,22 @@ public: ...@@ -86,9 +89,22 @@ public:
virtual std::shared_ptr<Tensor> get_io_tensor( virtual std::shared_ptr<Tensor> get_io_tensor(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO) = 0; std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO) = 0;
//! get the network input tensors which input consists of discrete multiple tensors,
//! layout (1, c, h, w)
virtual std::vector<std::shared_ptr<Tensor>> get_io_tensors(
std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) {
return {};
}
//! get the input tensor by index in the load_result tensormap //! get the input tensor by index in the load_result tensormap
virtual std::shared_ptr<Tensor> get_input_tensor(size_t index) = 0; virtual std::shared_ptr<Tensor> get_input_tensor(size_t index) = 0;
//! get the network input tensors which input consists of discrete multiple tensors
//! by index
virtual std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) {
return {};
}
//! get the output tensor by index in the load_result output_var_list //! get the output tensor by index in the load_result output_var_list
virtual std::shared_ptr<Tensor> get_output_tensor(size_t index) = 0; virtual std::shared_ptr<Tensor> get_output_tensor(size_t index) = 0;
......
...@@ -1387,6 +1387,96 @@ TEST(TestNetWork, DeviceAsyncExec) { ...@@ -1387,6 +1387,96 @@ TEST(TestNetWork, DeviceAsyncExec) {
} }
#endif #endif
TEST(TestNetWork, Discrete_Input) {
auto data = get_input_data("./data_b3.npy");
auto data_0 = get_input_data("./data0.npy");
auto data_1 = get_input_data("./data1.npy");
auto data_2 = get_input_data("./data2.npy");
std::string model_path = "./test_discrete_input.mge";
Config config;
config.device_type = LiteDeviceType::LITE_CUDA;
std::shared_ptr<Network> network0 = std::make_shared<Network>(config);
network0->load_model(model_path);
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data");
data_tensor->share_memory_with(*data);
network0->forward();
network0->wait();
std::shared_ptr<Tensor> output_tensor0 = network0->get_output_tensor(0);
config.discrete_input_name = "data";
NetworkIO ios;
bool is_host = true;
Layout d_ly{{3, 3, 224, 224}, 4, LiteDataType::LITE_FLOAT};
ios.inputs.push_back({"data", is_host, LiteIOType::LITE_IO_VALUE, d_ly});
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios);
network1->load_model(model_path);
std::vector<std::shared_ptr<Tensor>> data_tensors =
network1->get_io_tensors("data");
data_tensors[0]->share_memory_with(*data_0);
data_tensors[1]->share_memory_with(*data_1);
data_tensors[2]->share_memory_with(*data_2);
network1->forward();
network1->wait();
std::shared_ptr<Tensor> output_tensor1 = network1->get_output_tensor(0);
compare_lite_tensor<float>(output_tensor0, output_tensor1);
}
TEST(TestNetWork, Discrete_Input_Device) {
auto data = get_input_data("./data_b3.npy");
auto data_0 = get_input_data("./data0.npy");
auto data_1 = get_input_data("./data1.npy");
auto data_2 = get_input_data("./data2.npy");
std::string model_path = "./test_discrete_input.mge";
Config config;
config.device_type = LiteDeviceType::LITE_CUDA;
std::shared_ptr<Network> network0 = std::make_shared<Network>(config);
network0->load_model(model_path);
std::shared_ptr<Tensor> data_tensor = network0->get_io_tensor("data");
data_tensor->share_memory_with(*data);
network0->forward();
network0->wait();
std::shared_ptr<Tensor> output_tensor0 = network0->get_output_tensor(0);
config.discrete_input_name = "data";
NetworkIO ios;
bool is_host = false;
Layout d_ly{{3, 3, 224, 224}, 4, LiteDataType::LITE_FLOAT};
ios.inputs.push_back({"data", is_host, LiteIOType::LITE_IO_VALUE, d_ly});
std::shared_ptr<Network> network1 = std::make_shared<Network>(config, ios);
network1->load_model(model_path);
std::vector<std::shared_ptr<Tensor>> data_tensors =
network1->get_io_tensors("data");
auto d0_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
auto d1_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
auto d2_cuda = Tensor(LiteDeviceType::LITE_CUDA, d_ly);
d0_cuda.copy_from(*data_0);
d1_cuda.copy_from(*data_1);
d2_cuda.copy_from(*data_2);
data_tensors[0]->share_memory_with(d0_cuda);
data_tensors[1]->share_memory_with(d1_cuda);
data_tensors[2]->share_memory_with(d2_cuda);
network1->forward();
network1->wait();
std::shared_ptr<Tensor> output_tensor1 = network1->get_output_tensor(0);
compare_lite_tensor<float>(output_tensor0, output_tensor1);
}
#endif #endif
#if MGB_ATLAS || MGB_CAMBRICON #if MGB_ATLAS || MGB_CAMBRICON
......
...@@ -290,6 +290,48 @@ TEST(TestCapiNetWork, GetAllNameAhead) { ...@@ -290,6 +290,48 @@ TEST(TestCapiNetWork, GetAllNameAhead) {
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000); ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000);
} }
TEST(TestCapiNetWork, Discrete_Input) {
std::vector<std::shared_ptr<lite::Tensor>> datas;
datas.push_back(lite::get_input_data("./data0.npy"));
datas.push_back(lite::get_input_data("./data1.npy"));
datas.push_back(lite::get_input_data("./data2.npy"));
size_t data_length_in_byte = datas[0]->get_tensor_total_size_in_byte();
LiteIO input_io = default_io;
input_io.is_host = true;
input_io.name = "data";
LiteLayout d_ly;
d_ly.ndim = 4;
d_ly.data_type = LiteDataType::LITE_FLOAT;
std::vector<size_t> input_shape = {3, 3, 224, 224};
for (size_t i = 0; i < d_ly.ndim; i++) {
d_ly.shapes[i] = input_shape[i];
}
input_io.config_layout = d_ly;
LiteNetworkIO network_io = *default_network_io();
network_io.inputs = &input_io;
network_io.input_size = 1;
LiteConfig c_config = *default_config();
c_config.discrete_input_name = "data";
LiteNetwork c_network;
LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, network_io));
std::string model_path = "./test_discrete_input.mge";
LITE_CAPI_CHECK(LITE_load_model_from_path(c_network, model_path.c_str()));
std::vector<LiteTensor> c_data_tensors(3, nullptr);
for (size_t i = 0; i < 3; i++) {
LITE_CAPI_CHECK(LITE_get_io_tensors(
c_network, "data", i, LITE_INPUT, &c_data_tensors[i]));
LITE_CAPI_CHECK(LITE_reset_tensor_memory(
c_data_tensors[i], datas[i]->get_memory_ptr(), data_length_in_byte));
}
ForwardNetwork;
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
#if LITE_BUILD_WITH_RKNPU #if LITE_BUILD_WITH_RKNPU
static int GetTop( static int GetTop(
......
...@@ -381,7 +381,7 @@ public: ...@@ -381,7 +381,7 @@ public:
}; };
//! shortcut for calling ExtraDependencyMerger //! shortcut for calling ExtraDependencyMerger
SymbolVarArray get_dest_vars_with_extra_deps( MGE_WIN_DECLSPEC_FUC SymbolVarArray get_dest_vars_with_extra_deps(
const SymbolVarArray& dest_vars, SpecialOprStat* sopr_stat = nullptr); const SymbolVarArray& dest_vars, SpecialOprStat* sopr_stat = nullptr);
} // namespace cg } // namespace cg
......
...@@ -44,13 +44,14 @@ public: ...@@ -44,13 +44,14 @@ public:
//! rewrite vars in a graph //! rewrite vars in a graph
class Rewriter; class Rewriter;
SubGraph(const SymbolVarArray& endpoint_vars); MGE_WIN_DECLSPEC_FUC SubGraph(const SymbolVarArray& endpoint_vars);
//! get the associated ComputingGraph //! get the associated ComputingGraph
ComputingGraph* comp_graph() const { return m_comp_graph; } ComputingGraph* comp_graph() const { return m_comp_graph; }
//! iterate in topology order //! iterate in topology order
void iter(const Callback& cb, std::shared_ptr<ExtraDep> = nullptr) const; MGE_WIN_DECLSPEC_FUC void iter(
const Callback& cb, std::shared_ptr<ExtraDep> = nullptr) const;
//! make a Rewriter bound to this graph //! make a Rewriter bound to this graph
inline Rewriter make_rewriter(); inline Rewriter make_rewriter();
...@@ -99,7 +100,7 @@ public: ...@@ -99,7 +100,7 @@ public:
* \return new operator that uses new inputs; it would be * \return new operator that uses new inputs; it would be
* opr if no input is changed * opr if no input is changed
*/ */
OperatorNodeBase* auto_replace_outputs(OperatorNodeBase* opr); MGE_WIN_DECLSPEC_FUC OperatorNodeBase* auto_replace_outputs(OperatorNodeBase* opr);
//! get current var: if var has been replaced, return its //! get current var: if var has been replaced, return its
//! new value; otherwise return var itself //! new value; otherwise return var itself
...@@ -119,11 +120,11 @@ public: ...@@ -119,11 +120,11 @@ public:
* *
* \param msg see OptState::on_var_replaced * \param msg see OptState::on_var_replaced
*/ */
void replace_var(VarNode* src, VarNode* dst, const char* msg); MGE_WIN_DECLSPEC_FUC void replace_var(VarNode* src, VarNode* dst, const char* msg);
//! apply this rewriter to the owner graph and modify owner //! apply this rewriter to the owner graph and modify owner
//! SubGraph inplace //! SubGraph inplace
void apply_inplace() const; MGE_WIN_DECLSPEC_FUC void apply_inplace() const;
}; };
SubGraph::Rewriter SubGraph::make_rewriter() { SubGraph::Rewriter SubGraph::make_rewriter() {
return {this}; return {this};
......
...@@ -160,18 +160,6 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape( ...@@ -160,18 +160,6 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
"out2d=%s", "out2d=%s",
imgshp.to_string().c_str(), matshp.to_string().c_str(), imgshp.to_string().c_str(), matshp.to_string().c_str(),
oshp2d.to_string().c_str()); oshp2d.to_string().c_str());
if (input().size() - m_srcs_size == 2) {
mgb_assert(
m_srcs_size == matshp[0], "batchsize mismatch: img=%zu mat=%zu",
m_srcs_size, matshp[0]);
} else {
mgb_assert(input().size() - m_srcs_size == 3);
mat_idx_shp = shpinfo.shape_inp_shp.at(m_srcs_size + 1);
mgb_assert(
mat_idx_shp[0] == matshp[0] && mat_idx_shp.ndim == 1,
"invalid mat_idx shape: mat=%zu mat_idx=%s", matshp[0],
mat_idx_shp.to_string().c_str());
}
size_t height_idx = 0; size_t height_idx = 0;
if (param().format == Param::Format::NCHW) { if (param().format == Param::Format::NCHW) {
height_idx = 2; height_idx = 2;
......
...@@ -22,7 +22,7 @@ namespace opr { ...@@ -22,7 +22,7 @@ namespace opr {
* Impl note: this operator might have 3 or 4 inputs depending on whether * Impl note: this operator might have 3 or 4 inputs depending on whether
* \p mat_idx is given * \p mat_idx is given
*/ */
MGB_DEFINE_OPR_CLASS( MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
WarpPerspectiveForward, WarpPerspectiveForward,
intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr< intl::WorkspaceSizeInfer<intl::OutshapeBySymvarSCNOpr<
mixin::MegDNNOprHolderImpl<megdnn::WarpPerspectiveForward>>>) // { mixin::MegDNNOprHolderImpl<megdnn::WarpPerspectiveForward>>>) // {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册