提交 156c128d 编写于 作者: 刘琦

Merge branch 'master' into 'master'

Finish conv 1x1  with stride 2 (opencl kernel)

See merge request !103
...@@ -21,26 +21,46 @@ cc_library( ...@@ -21,26 +21,46 @@ cc_library(
]), ]),
copts = ["-std=c++11"], copts = ["-std=c++11"],
deps = [ deps = [
"core", ":logging",
"@opencl_headers//:opencl20_headers", "@opencl_headers//:opencl20_headers",
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library( cc_library(
name = "core", name = "logging",
srcs = glob([ srcs = [
"*.cc", "logging.cc",
]), ],
hdrs = glob([ hdrs = [
"*.h", "logging.h",
]), ],
copts = ["-std=c++11"], copts = ["-std=c++11"],
linkopts = if_android([ linkopts = if_android([
"-llog", "-llog",
]),
)
cc_library(
name = "core",
srcs = glob(
["*.cc",],
exclude=[
"logging.cc"
]),
hdrs = glob(
["*.h"],
exclude=[
"logging.h"
]),
copts = ["-std=c++11"],
linkopts = if_android([
"-pie", "-pie",
]), ]),
deps = [ deps = [
":logging",
":opencl_runtime",
"//mace/proto:cc_proto", "//mace/proto:cc_proto",
"//mace/proto:stats_proto", "//mace/proto:stats_proto",
"//mace/utils", "//mace/utils",
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
// //
#include "mace/core/allocator.h" #include "mace/core/allocator.h"
#include "mace/core/opencl_allocator.h"
namespace mace { namespace mace {
...@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) { ...@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) {
MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator()); MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator()); MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace } // namespace mace
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
...@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, ...@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) DeviceType type)
: NetBase(net_def, ws, type) { : NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx); const auto &operator_def = net_def->op(idx);
...@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false; return false;
} }
if (device_type_ == DeviceType::OPENCL)
OpenCLRuntime::Get()->command_queue().finish();
if (op_stats) { if (op_stats) {
op_stats->set_op_end_rel_micros(NowInMicroSec() - op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros()); op_stats->all_start_micros());
......
...@@ -40,6 +40,7 @@ class SimpleNet : public NetBase { ...@@ -40,6 +40,7 @@ class SimpleNet : public NetBase {
protected: protected:
vector<unique_ptr<OperatorBase> > operators_; vector<unique_ptr<OperatorBase> > operators_;
DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SimpleNet); DISABLE_COPY_AND_ASSIGN(SimpleNet);
}; };
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// //
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_allocator.h" #include "mace/core/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
...@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) { ...@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) {
bool OpenCLAllocator::OnHost() { return false; } bool OpenCLAllocator::OnHost() { return false; }
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace } // namespace mace
...@@ -20,7 +20,6 @@ cc_library( ...@@ -20,7 +20,6 @@ cc_library(
linkopts = if_android(["-lm"]), linkopts = if_android(["-lm"]),
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
"//mace/utils", "//mace/utils",
"//mace/utils:tuner", "//mace/utils:tuner",
], ],
......
...@@ -24,33 +24,87 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */ ...@@ -24,33 +24,87 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */
} }
} }
#define vec_conv_2d_1x1_s1 \
float4 in0 = vload4(0, input_ptr); \
float4 in1 = vload4(0, input_ptr + in_pixel); \
float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \
float4 in3 = vload4(0, input_ptr + 3 * in_pixel);
#define vec_conv_2d_1x1_s2 \
float4 in00 = vload4(0, input_ptr); \
float3 in01 = vload3(0, input_ptr + 4); \
float4 in10 = vload4(0, input_ptr + in_pixel); \
float3 in11 = vload3(0, input_ptr + in_pixel + 4); \
float4 in20 = vload4(0, input_ptr + 2 * in_pixel); \
float3 in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\
float4 in30 = vload4(0, input_ptr + 3 * in_pixel); \
float3 in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \
float4 in0 = (float4)(in00.s02, in01.s02); \
float4 in1 = (float4)(in10.s02, in11.s02); \
float4 in2 = (float4)(in20.s02, in21.s02); \
float4 in3 = (float4)(in30.s02, in31.s02);
#define vec_conv_2d_1x1_compute_loop \
for (int oc = 0; oc < 4; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
}
#define vec_conv_2d_1x1_compute \
float4 weights = vload4(0, filter_ptr); \
float4 out = vload4(0, output_ptr); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr);
__kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
__global const float *filter, /* o, i, kh, kw */ __global const float *filter, /* o, i, kh, kw */
__global const float *bias, /* o */ __global const float *bias, /* o */
__global float *output, /* n, c, h, w */ __global float *output, /* n, c, h, w */
__private const int in_chan_num, __private const int in_chan_num,
__private const int out_chan_num, __private const int out_chan_num,
__private const int pixel_num) { __private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int out_width,
__private const int stride) {
int batch = get_global_id(0); int batch = get_global_id(0);
int out_chan_blk = get_global_id(1); int out_chan_blk = get_global_id(1);
int out_pixel_blk = get_global_id(2); int out_pixel_blk = get_global_id(2);
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int round_out_width = (out_width + 3) / 4;
const int out_pixel_height = out_pixel_blk / round_out_width;
const int out_pixel_width = out_pixel_blk % round_out_width;
const int out_chan_begin = out_chan_blk * 4; const int out_chan_begin = out_chan_blk * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num); const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int out_pixel_begin = out_pixel_blk * 4; const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4;
const int out_pixel_end = min(out_pixel_begin + 4, pixel_num); const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width);
const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4;
const int in_offset = batch * in_chan_num * pixel_num; const int in_offset = batch * in_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * pixel_num; const int out_offset = batch * out_chan_num * out_pixel;
const float *input_base = input + in_offset + out_pixel_begin; const float *input_base = input + in_offset + in_pixel_begin;
float *output_base = output + out_offset + out_pixel_begin; float *output_base = output + out_offset + out_pixel_begin;
int out_chan_len = out_chan_end - out_chan_begin; int out_chan_len = out_chan_end - out_chan_begin;
int pixel_len = out_pixel_end - out_pixel_begin; int pixel_len = out_pixel_end - out_pixel_begin;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float *output_ptr = output_base + out_chan * pixel_num; float *output_ptr = output_base + out_chan * out_pixel;
float bias_value = bias == NULL ? 0 : bias[out_chan]; float bias_value = bias == NULL ? 0 : bias[out_chan];
for (int p = 0; p < pixel_len; ++p) { for (int p = 0; p < pixel_len; ++p) {
output_ptr[p] = bias_value; output_ptr[p] = bias_value;
...@@ -59,53 +113,51 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ ...@@ -59,53 +113,51 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int in_chan = 0; int in_chan = 0;
if (pixel_len == 4) { if (pixel_len == 4) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) { if (stride == 1) {
const float *input_ptr = input_base + in_chan * pixel_num; for (; in_chan + 3 < in_chan_num; in_chan += 4) {
int out_chan = out_chan_begin; const float *input_ptr = input_base + in_chan * in_pixel;
for (; out_chan + 3 < out_chan_end; out_chan += 4) { int out_chan = out_chan_begin;
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; for (; out_chan + 3 < out_chan_end; out_chan += 4) {
float *output_ptr = output_base + out_chan * pixel_num; const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float4 in0 = vload4(0, input_ptr); float *output_ptr = output_base + out_chan * out_pixel;
float4 in1 = vload4(0, input_ptr + pixel_num); vec_conv_2d_1x1_s1;
float4 in2 = vload4(0, input_ptr + 2 * pixel_num); vec_conv_2d_1x1_compute_loop;
float4 in3 = vload4(0, input_ptr + 3 * pixel_num); }
#pragma unroll for (; out_chan < out_chan_end; ++out_chan) {
for (int oc = 0; oc < 4; ++oc) { const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); float *output_ptr = output_base + out_chan * out_pixel;
float4 out = vload4(0, output_ptr + oc * pixel_num); vec_conv_2d_1x1_s1;
out += in0 * weights.x; vec_conv_2d_1x1_compute;
out += in1 * weights.y;
out += in2 * weights.z;
out += in3 * weights.w;
vstore4(out, 0, output_ptr + oc * pixel_num);
} }
} }
for (; out_chan < out_chan_end; ++out_chan) { } else if (stride == 2) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; for (; in_chan + 3 < in_chan_num; in_chan += 4) {
float *output_ptr = output_base + out_chan * pixel_num; const float *input_ptr = input_base + in_chan * in_pixel;
float4 weights = vload4(0, filter_ptr); int out_chan = out_chan_begin;
float4 in0 = vload4(0, input_ptr); for (; out_chan + 3 < out_chan_end; out_chan += 4) {
float4 in1 = vload4(0, input_ptr + pixel_num); const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float4 in2 = vload4(0, input_ptr + 2 * pixel_num); float *output_ptr = output_base + out_chan * out_pixel;
float4 in3 = vload4(0, input_ptr + 3 * pixel_num); vec_conv_2d_1x1_s2;
float4 out = vload4(0, output_ptr); vec_conv_2d_1x1_compute_loop;
out += in0 * weights.x; }
out += in1 * weights.y; for (; out_chan < out_chan_end; ++out_chan) {
out += in2 * weights.z; const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
out += in3 * weights.w; float *output_ptr = output_base + out_chan * out_pixel;
vstore4(out, 0, output_ptr); vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute;
}
} }
} }
} }
for (; in_chan < in_chan_num; ++in_chan) { for (; in_chan < in_chan_num; ++in_chan) {
const float *input_ptr = input_base + in_chan * pixel_num; const float *input_ptr = input_base + in_chan * in_pixel;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float weights = filter[out_chan * in_chan_num + in_chan]; float weights = filter[out_chan * in_chan_num + in_chan];
float *output_ptr = output_base + out_chan * pixel_num; float *output_ptr = output_base + out_chan * out_pixel;
for (int p = 0; p < pixel_len; ++p) { for (int p = 0; p < pixel_len; ++p) {
float in = input_ptr[p]; float in = input_ptr[p*stride];
output_ptr[p] += in * weights; output_ptr[p] += in * weights;
} }
} }
......
...@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input, ...@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input,
if (pixels == 4) { if (pixels == 4) {
float4 res = bias == NULL ? 0 : (float4)bias[i]; float4 res = bias == NULL ? 0 : (float4)bias[i];
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel; if (stride_w == 1) {
const float* filter_ptr = filter_base + in_chan_idx * 9; for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
if (stride_w == 1) { const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3); res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
} else { }
} else {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3); res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
......
...@@ -10,6 +10,9 @@ namespace kernels { ...@@ -10,6 +10,9 @@ namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output); const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output); const Tensor *bias, Tensor *output);
...@@ -24,7 +27,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -24,7 +27,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
const Tensor *bias, Tensor *output); const Tensor *bias, Tensor *output);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = { static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, nullptr}, {Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2}, {Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr}, {nullptr, nullptr},
......
...@@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input, ...@@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input,
void Conv1x1V2(const Tensor *input, void Conv1x1V2(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
const int stride,
Tensor *output) { Tensor *output) {
const index_t batch = output->dim(0); const index_t batch = output->dim(0);
const index_t channels = output->dim(1); const index_t channels = output->dim(1);
...@@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input, ...@@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input,
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
const index_t pixels = height * width;
const index_t channel_blocks = (channels + 3) / 4; const index_t channel_blocks = (channels + 3) / 4;
const index_t pixel_blocks = (pixels + 3) / 4; const index_t pixel_blocks = (width + 3) / 4 * height;
// TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy // TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy
// TODO check wired clReleaseCommandQueue latency // TODO check wired clReleaseCommandQueue latency
...@@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input, ...@@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input,
conv_2d_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer()))); conv_2d_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channels)); conv_2d_kernel.setArg(idx++, static_cast<int>(input_channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(channels)); conv_2d_kernel.setArg(idx++, static_cast<int>(channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(pixels)); conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(3)));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, stride);
auto command_queue = runtime->command_queue(); auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel( cl_int error = command_queue.enqueueNDRangeKernel(
...@@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, ...@@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input,
MACE_CHECK(input_batch == batch && input_height == height && MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width); input_width == width);
Conv1x1V2(input, filter, bias, output); Conv1x1V2(input, filter, bias, 1, output);
};
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
MACE_CHECK(input->dim(0) == output->dim(0));
Conv1x1V2(input, filter, bias, 2, output);
}; };
} // namespace kernels } // namespace kernels
......
...@@ -17,7 +17,6 @@ cc_library( ...@@ -17,7 +17,6 @@ cc_library(
], ],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
"@gtest//:gtest", "@gtest//:gtest",
], ],
) )
......
...@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str, ...@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str,
tmp = tmp.substr(next_offset + 1); tmp = tmp.substr(next_offset + 1);
} }
} }
return true;
} }
} // namespace str_util } // namespace str_util
...@@ -254,6 +255,10 @@ int Main(int argc, char **argv) { ...@@ -254,6 +255,10 @@ int Main(int argc, char **argv) {
stats_options.show_summary = show_summary; stats_options.show_summary = show_summary;
stats.reset(new StatSummarizer(stats_options)); stats.reset(new StatSummarizer(stats_options));
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
// load model // load model
std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary); std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary);
if (!model_file_stream.is_open()) { if (!model_file_stream.is_open()) {
...@@ -265,29 +270,30 @@ int Main(int argc, char **argv) { ...@@ -265,29 +270,30 @@ int Main(int argc, char **argv) {
model_file_stream.close(); model_file_stream.close();
Workspace ws; Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU); ws.LoadModelTensor(net_def, device_type);
// Load inputs // Load inputs
for (size_t i = 0; i < inputs_count; ++i) { for (size_t i = 0; i < inputs_count; ++i) {
Tensor *input_tensor = Tensor *input_tensor =
ws.CreateTensor(input_layers[i], GetDeviceAllocator(DeviceType::CPU), DT_FLOAT); ws.CreateTensor(input_layers[i], GetDeviceAllocator(device_type), DT_FLOAT);
vector<index_t> shapes; vector<index_t> shapes;
str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes); str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes);
input_tensor->Resize(shapes); input_tensor->Resize(shapes);
float *input_data = input_tensor->mutable_data<float>(); {
Tensor::MappingGuard input_guard(input_tensor);
// load input float *input_data = input_tensor->mutable_data<float>();
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i], // load input
std::ios::in | std::ios::binary); if (i < input_layer_files.size()) {
in_file.read(reinterpret_cast<char *>(input_data), std::ifstream in_file(input_layer_files[i],
input_tensor->size() * sizeof(float)); std::ios::in | std::ios::binary);
in_file.close(); in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
}
} }
} }
// create net // create net
DeviceType device_type;
DeviceType_Parse(device, &device_type);
auto net = CreateNet(net_def, &ws, device_type); auto net = CreateNet(net_def, &ws, device_type);
int64_t warmup_time_us = 0; int64_t warmup_time_us = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册