提交 b85977e8 编写于 作者: L liuqi

Refactor the core BUILD file for better dependencies.

上级 ba4ca883
......@@ -21,26 +21,46 @@ cc_library(
]),
copts = ["-std=c++11"],
deps = [
"core",
":logging",
"@opencl_headers//:opencl20_headers",
],
alwayslink = 1,
)
cc_library(
name = "core",
srcs = glob([
"*.cc",
]),
hdrs = glob([
"*.h",
]),
name = "logging",
srcs = [
"logging.cc",
],
hdrs = [
"logging.h",
],
copts = ["-std=c++11"],
linkopts = if_android([
"-llog",
]),
)
cc_library(
name = "core",
srcs = glob(
["*.cc",],
exclude=[
"logging.cc"
]),
hdrs = glob(
["*.h"],
exclude=[
"logging.h"
]),
copts = ["-std=c++11"],
linkopts = if_android([
"-pie",
]),
deps = [
":logging",
":opencl_runtime",
"//mace/proto:cc_proto",
"//mace/proto:stats_proto",
"//mace/utils",
......
......@@ -3,6 +3,7 @@
//
#include "mace/core/allocator.h"
#include "mace/core/opencl_allocator.h"
namespace mace {
......@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) {
MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace
......@@ -4,6 +4,7 @@
#include "mace/core/net.h"
#include "mace/utils/utils.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
......@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type)
: NetBase(net_def, ws, type) {
: NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx);
......@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false;
}
if (device_type_ == DeviceType::OPENCL)
OpenCLRuntime::Get()->command_queue().finish();
if (op_stats) {
op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
......
......@@ -40,6 +40,7 @@ class SimpleNet : public NetBase {
protected:
vector<unique_ptr<OperatorBase> > operators_;
DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SimpleNet);
};
......
......@@ -3,7 +3,7 @@
//
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
......@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) {
bool OpenCLAllocator::OnHost() { return false; }
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace
......@@ -20,7 +20,6 @@ cc_library(
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
"//mace/core:opencl_runtime",
"//mace/utils",
"//mace/utils:tuner",
],
......
......@@ -17,7 +17,6 @@ cc_library(
],
deps = [
"//mace/core",
"//mace/core:opencl_runtime",
"@gtest//:gtest",
],
)
......
......@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str,
tmp = tmp.substr(next_offset + 1);
}
}
return true;
}
} // namespace str_util
......@@ -254,6 +255,10 @@ int Main(int argc, char **argv) {
stats_options.show_summary = show_summary;
stats.reset(new StatSummarizer(stats_options));
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
// load model
std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary);
if (!model_file_stream.is_open()) {
......@@ -265,29 +270,30 @@ int Main(int argc, char **argv) {
model_file_stream.close();
Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU);
ws.LoadModelTensor(net_def, device_type);
// Load inputs
for (size_t i = 0; i < inputs_count; ++i) {
Tensor *input_tensor =
ws.CreateTensor(input_layers[i], GetDeviceAllocator(DeviceType::CPU), DT_FLOAT);
ws.CreateTensor(input_layers[i], GetDeviceAllocator(device_type), DT_FLOAT);
vector<index_t> shapes;
str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes);
input_tensor->Resize(shapes);
float *input_data = input_tensor->mutable_data<float>();
// load input
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i],
std::ios::in | std::ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
{
Tensor::MappingGuard input_guard(input_tensor);
float *input_data = input_tensor->mutable_data<float>();
// load input
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i],
std::ios::in | std::ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
}
}
}
// create net
DeviceType device_type;
DeviceType_Parse(device, &device_type);
auto net = CreateNet(net_def, &ws, device_type);
int64_t warmup_time_us = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册