提交 07280ea3 编写于 作者: L liuqi

Add multiple input and output API.

上级 5e228887
......@@ -545,6 +545,50 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type) :
net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type));
}
}
MaceEngine::MaceEngine(const NetDef *net_def,
DeviceType device_type,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes) :
op_registry_(new OperatorRegistry()), device_type_(device_type),
ws_(new Workspace()), net_(nullptr), hexagon_controller_(nullptr) {
for (auto input_name : input_nodes) {
ws_->CreateTensor(MakeString("mace_input_node_", input_name, ":0"),
GetDeviceAllocator(device_type_),
DT_FLOAT);
}
for (auto output_name : output_nodes) {
ws_->CreateTensor(MakeString("mace_output_node_", output_name, ":0"),
GetDeviceAllocator(device_type_),
DT_FLOAT);
}
if (device_type == HEXAGON) {
hexagon_controller_.reset(new HexagonControlWrapper());
MACE_CHECK(hexagon_controller_->Config(), "hexagon config error");
MACE_CHECK(hexagon_controller_->Init(), "hexagon init error");
hexagon_controller_->SetDebugLevel(
static_cast<int>(mace::logging::LogMessage::MinVLogLevel()));
int dsp_mode = ArgumentHelper::GetSingleArgument<NetDef, int>(
*net_def, "dsp_mode", 0);
hexagon_controller_->SetGraphMode(dsp_mode);
MACE_CHECK(hexagon_controller_->SetupGraph(*net_def),
"hexagon setup graph error");
if (VLOG_IS_ON(2)) {
hexagon_controller_->PrintGraph();
}
} else {
ws_->LoadModelTensor(*net_def, device_type);
// Init model
auto net = CreateNet(op_registry_, *net_def, ws_.get(),
device_type, NetMode::INIT);
if (!net->Run()) {
LOG(FATAL) << "Net init run failed";
}
ws_->RemoveUnsedTensor();
net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type));
}
}
MaceEngine::~MaceEngine() {
if (device_type_ == HEXAGON) {
if (VLOG_IS_ON(2)) {
......@@ -597,4 +641,40 @@ bool MaceEngine::Run(const float *input,
}
}
bool MaceEngine::Run(const std::vector<MaceInputInfo> &inputs,
std::map<std::string, float *> &outputs,
RunMetadata *run_metadata) {
MACE_CHECK(device_type_ != HEXAGON, "HEXAGON not supports multiple outputs now");
for (auto input : inputs) {
Tensor *input_tensor = ws_->GetTensor(MakeString("mace_input_node_", input.name, ":0"));
input_tensor->Resize(input.shape);
{
Tensor::MappingGuard input_guard(input_tensor);
float *input_data = input_tensor->mutable_data<float>();
memcpy(input_data, input.data, input_tensor->size() * sizeof(float));
}
}
if (!net_->Run(run_metadata)) {
LOG(FATAL) << "Net run failed";
}
for (auto output : outputs) {
Tensor *output_tensor = ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0"));
// save output
if (output_tensor != nullptr && output.second != nullptr) {
Tensor::MappingGuard output_guard(output_tensor);
auto shape = output_tensor->shape();
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0";
// TODO: check for overflow exception.
std::memcpy(output.second, output_tensor->data<float>(),
output_size * sizeof(float));
} else {
return false;
}
}
return true;
}
} // namespace mace
......@@ -62,6 +62,7 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
int idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
if (activation_ == PRELU) {
MACE_CHECK(alpha != nullptr) << "PReLU's alpha parameter shouldn't be null";
kernel_.setArg(idx++, *(alpha->opencl_image()));
}
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
......
......@@ -9,6 +9,7 @@
#include <vector>
#include <string>
#include <memory>
#include <map>
namespace mace {
......@@ -364,18 +365,36 @@ class NetBase;
class OperatorRegistry;
class HexagonControlWrapper;
struct MaceInputInfo {
std::string name;
std::vector<int64_t> shape;
const float *data;
};
class MaceEngine {
public:
// Single input and output
explicit MaceEngine(const NetDef *net_def,
DeviceType device_type);
// Multiple input or output
explicit MaceEngine(const NetDef *net_def,
DeviceType device_type,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes);
~MaceEngine();
// Single input and output
bool Run(const float *input,
const std::vector<int64_t> &input_shape,
float *output);
// Single input and output for benchmark
bool Run(const float *input,
const std::vector<int64_t> &input_shape,
float *output,
RunMetadata *run_metadata);
// Multiple input or output
bool Run(const std::vector<MaceInputInfo> &input,
std::map<std::string, float *> &output,
RunMetadata *run_metadata=nullptr);
MaceEngine(const MaceEngine &) = delete;
MaceEngine &operator=(const MaceEngine &) = delete;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册