提交 b85977e8 编写于 作者: L liuqi

Refactor the core BUILD file for better dependencies.

上级 ba4ca883
...@@ -21,26 +21,46 @@ cc_library( ...@@ -21,26 +21,46 @@ cc_library(
]), ]),
copts = ["-std=c++11"], copts = ["-std=c++11"],
deps = [ deps = [
"core", ":logging",
"@opencl_headers//:opencl20_headers", "@opencl_headers//:opencl20_headers",
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library( cc_library(
name = "core", name = "logging",
srcs = glob([ srcs = [
"*.cc", "logging.cc",
]), ],
hdrs = glob([ hdrs = [
"*.h", "logging.h",
]), ],
copts = ["-std=c++11"], copts = ["-std=c++11"],
linkopts = if_android([ linkopts = if_android([
"-llog", "-llog",
]),
)
cc_library(
name = "core",
srcs = glob(
["*.cc",],
exclude=[
"logging.cc"
]),
hdrs = glob(
["*.h"],
exclude=[
"logging.h"
]),
copts = ["-std=c++11"],
linkopts = if_android([
"-pie", "-pie",
]), ]),
deps = [ deps = [
":logging",
":opencl_runtime",
"//mace/proto:cc_proto", "//mace/proto:cc_proto",
"//mace/proto:stats_proto", "//mace/proto:stats_proto",
"//mace/utils", "//mace/utils",
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// //
#include "mace/core/allocator.h" #include "mace/core/allocator.h"
#include "mace/core/opencl_allocator.h"
namespace mace { namespace mace {
...@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) { ...@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) {
MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator()); MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator()); MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace } // namespace mace
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
...@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, ...@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) DeviceType type)
: NetBase(net_def, ws, type) { : NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx); const auto &operator_def = net_def->op(idx);
...@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false; return false;
} }
if (device_type_ == DeviceType::OPENCL)
OpenCLRuntime::Get()->command_queue().finish();
if (op_stats) { if (op_stats) {
op_stats->set_op_end_rel_micros(NowInMicroSec() - op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros()); op_stats->all_start_micros());
......
...@@ -40,6 +40,7 @@ class SimpleNet : public NetBase { ...@@ -40,6 +40,7 @@ class SimpleNet : public NetBase {
protected: protected:
vector<unique_ptr<OperatorBase> > operators_; vector<unique_ptr<OperatorBase> > operators_;
DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SimpleNet); DISABLE_COPY_AND_ASSIGN(SimpleNet);
}; };
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// //
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_allocator.h" #include "mace/core/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
...@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) { ...@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) {
bool OpenCLAllocator::OnHost() { return false; } bool OpenCLAllocator::OnHost() { return false; }
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace } // namespace mace
...@@ -20,7 +20,6 @@ cc_library( ...@@ -20,7 +20,6 @@ cc_library(
linkopts = if_android(["-lm"]), linkopts = if_android(["-lm"]),
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
"//mace/utils", "//mace/utils",
"//mace/utils:tuner", "//mace/utils:tuner",
], ],
......
...@@ -17,7 +17,6 @@ cc_library( ...@@ -17,7 +17,6 @@ cc_library(
], ],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
"@gtest//:gtest", "@gtest//:gtest",
], ],
) )
......
...@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str, ...@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str,
tmp = tmp.substr(next_offset + 1); tmp = tmp.substr(next_offset + 1);
} }
} }
return true;
} }
} // namespace str_util } // namespace str_util
...@@ -254,6 +255,10 @@ int Main(int argc, char **argv) { ...@@ -254,6 +255,10 @@ int Main(int argc, char **argv) {
stats_options.show_summary = show_summary; stats_options.show_summary = show_summary;
stats.reset(new StatSummarizer(stats_options)); stats.reset(new StatSummarizer(stats_options));
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
// load model // load model
std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary); std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary);
if (!model_file_stream.is_open()) { if (!model_file_stream.is_open()) {
...@@ -265,29 +270,30 @@ int Main(int argc, char **argv) { ...@@ -265,29 +270,30 @@ int Main(int argc, char **argv) {
model_file_stream.close(); model_file_stream.close();
Workspace ws; Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU); ws.LoadModelTensor(net_def, device_type);
// Load inputs // Load inputs
for (size_t i = 0; i < inputs_count; ++i) { for (size_t i = 0; i < inputs_count; ++i) {
Tensor *input_tensor = Tensor *input_tensor =
ws.CreateTensor(input_layers[i], GetDeviceAllocator(DeviceType::CPU), DT_FLOAT); ws.CreateTensor(input_layers[i], GetDeviceAllocator(device_type), DT_FLOAT);
vector<index_t> shapes; vector<index_t> shapes;
str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes); str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes);
input_tensor->Resize(shapes); input_tensor->Resize(shapes);
float *input_data = input_tensor->mutable_data<float>(); {
Tensor::MappingGuard input_guard(input_tensor);
// load input float *input_data = input_tensor->mutable_data<float>();
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i], // load input
std::ios::in | std::ios::binary); if (i < input_layer_files.size()) {
in_file.read(reinterpret_cast<char *>(input_data), std::ifstream in_file(input_layer_files[i],
input_tensor->size() * sizeof(float)); std::ios::in | std::ios::binary);
in_file.close(); in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
}
} }
} }
// create net // create net
DeviceType device_type;
DeviceType_Parse(device, &device_type);
auto net = CreateNet(net_def, &ws, device_type); auto net = CreateNet(net_def, &ws, device_type);
int64_t warmup_time_us = 0; int64_t warmup_time_us = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册