diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 58591a78cc3401e07a93d0d73f522f6d9565c449..55b2449623e2a0a4bf2d81da1891912c919dfe09 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -545,6 +545,50 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type) : net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type)); } } +MaceEngine::MaceEngine(const NetDef *net_def, + DeviceType device_type, + const std::vector &input_nodes, + const std::vector &output_nodes) : + op_registry_(new OperatorRegistry()), device_type_(device_type), + ws_(new Workspace()), net_(nullptr), hexagon_controller_(nullptr) { + for (auto input_name : input_nodes) { + ws_->CreateTensor(MakeString("mace_input_node_", input_name, ":0"), + GetDeviceAllocator(device_type_), + DT_FLOAT); + } + for (auto output_name : output_nodes) { + ws_->CreateTensor(MakeString("mace_output_node_", output_name, ":0"), + GetDeviceAllocator(device_type_), + DT_FLOAT); + } + if (device_type == HEXAGON) { + hexagon_controller_.reset(new HexagonControlWrapper()); + MACE_CHECK(hexagon_controller_->Config(), "hexagon config error"); + MACE_CHECK(hexagon_controller_->Init(), "hexagon init error"); + hexagon_controller_->SetDebugLevel( + static_cast(mace::logging::LogMessage::MinVLogLevel())); + int dsp_mode = ArgumentHelper::GetSingleArgument( + *net_def, "dsp_mode", 0); + hexagon_controller_->SetGraphMode(dsp_mode); + MACE_CHECK(hexagon_controller_->SetupGraph(*net_def), + "hexagon setup graph error"); + if (VLOG_IS_ON(2)) { + hexagon_controller_->PrintGraph(); + } + } else { + ws_->LoadModelTensor(*net_def, device_type); + + // Init model + auto net = CreateNet(op_registry_, *net_def, ws_.get(), + device_type, NetMode::INIT); + if (!net->Run()) { + LOG(FATAL) << "Net init run failed"; + } + ws_->RemoveUnsedTensor(); + net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type)); + } + +} MaceEngine::~MaceEngine() { if (device_type_ == HEXAGON) { if (VLOG_IS_ON(2)) { @@ -597,4 +641,40 @@ bool MaceEngine::Run(const float *input, } } +bool MaceEngine::Run(const std::vector &inputs, + std::map &outputs, + RunMetadata *run_metadata) { + + MACE_CHECK(device_type_ != HEXAGON, "HEXAGON not supports multiple outputs now"); + for (auto input : inputs) { + Tensor *input_tensor = ws_->GetTensor(MakeString("mace_input_node_", input.name, ":0")); + input_tensor->Resize(input.shape); + { + Tensor::MappingGuard input_guard(input_tensor); + float *input_data = input_tensor->mutable_data(); + memcpy(input_data, input.data, input_tensor->size() * sizeof(float)); + } + } + if (!net_->Run(run_metadata)) { + LOG(FATAL) << "Net run failed"; + } + for (auto output : outputs) { + Tensor *output_tensor = ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0")); + // save output + if (output_tensor != nullptr && output.second != nullptr) { + Tensor::MappingGuard output_guard(output_tensor); + auto shape = output_tensor->shape(); + int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0"; + // TODO: check for overflow exception. + std::memcpy(output.second, output_tensor->data(), + output_size * sizeof(float)); + } else { + return false; + } + } + return true; +} + } // namespace mace diff --git a/mace/kernels/activation.h b/mace/kernels/activation.h index c9f8dac5873a3cedcac4ff4cddd92c89b91ef9a3..72e52b67cfef1e3a230c78fb94edc72fd5ca397f 100644 --- a/mace/kernels/activation.h +++ b/mace/kernels/activation.h @@ -116,7 +116,7 @@ class ActivationFunctor { const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); if (activation_ == PRELU) { - MACE_CHECK(alpha != nullptr) << "PReLU's alpha parameter shouldn't be null"; + MACE_CHECK_NOTNULL(alpha); const T *alpha_ptr = alpha->data(); PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr, output_ptr); } else { diff --git a/mace/kernels/opencl/activation_opencl.cc b/mace/kernels/opencl/activation_opencl.cc index 99b8a6bc80bed3d92c8649155b7aca1210cbd0a7..75922a9eb17af3b9790a283efaf3b1d9581c2f8c 100644 --- a/mace/kernels/opencl/activation_opencl.cc +++ b/mace/kernels/opencl/activation_opencl.cc @@ -62,6 +62,7 @@ void ActivationFunctor::operator()(const Tensor *input, int idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); if (activation_ == PRELU) { + MACE_CHECK_NOTNULL(alpha); kernel_.setArg(idx++, *(alpha->opencl_image())); } kernel_.setArg(idx++, static_cast(relux_max_limit_)); diff --git a/mace/public/mace.h b/mace/public/mace.h index 591987cb9eb0740b790dffb8f2b519adb17e887f..d5fd7a52d837a5d9376f98aa4799806d0c581795 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -9,6 +9,7 @@ #include #include #include +#include namespace mace { @@ -364,18 +365,36 @@ class NetBase; class OperatorRegistry; class HexagonControlWrapper; +struct MaceInputInfo { + std::string name; + std::vector shape; + const float *data; +}; + class MaceEngine { public: + // Single input and output explicit MaceEngine(const NetDef *net_def, DeviceType device_type); + // Multiple input or output + explicit MaceEngine(const NetDef *net_def, + DeviceType device_type, + const std::vector &input_nodes, + const std::vector &output_nodes); ~MaceEngine(); + // Single input and output bool Run(const float *input, const std::vector &input_shape, float *output); + // Single input and output for benchmark bool Run(const float *input, const std::vector &input_shape, float *output, RunMetadata *run_metadata); + // Multiple input or output + bool Run(const std::vector &input, + std::map &output, + RunMetadata *run_metadata=nullptr); MaceEngine(const MaceEngine &) = delete; MaceEngine &operator=(const MaceEngine &) = delete;