提交 156c128d 编写于 作者: 刘琦

Merge branch 'master' into 'master'

Finish conv 1x1  with stride 2 (opencl kernel)

See merge request !103
......@@ -21,26 +21,46 @@ cc_library(
]),
copts = ["-std=c++11"],
deps = [
"core",
":logging",
"@opencl_headers//:opencl20_headers",
],
alwayslink = 1,
)
cc_library(
name = "core",
srcs = glob([
"*.cc",
]),
hdrs = glob([
"*.h",
]),
name = "logging",
srcs = [
"logging.cc",
],
hdrs = [
"logging.h",
],
copts = ["-std=c++11"],
linkopts = if_android([
"-llog",
]),
)
cc_library(
name = "core",
srcs = glob(
["*.cc",],
exclude=[
"logging.cc"
]),
hdrs = glob(
["*.h"],
exclude=[
"logging.h"
]),
copts = ["-std=c++11"],
linkopts = if_android([
"-pie",
]),
deps = [
":logging",
":opencl_runtime",
"//mace/proto:cc_proto",
"//mace/proto:stats_proto",
"//mace/utils",
......
......@@ -3,6 +3,7 @@
//
#include "mace/core/allocator.h"
#include "mace/core/opencl_allocator.h"
namespace mace {
......@@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) {
MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace
......@@ -4,6 +4,7 @@
#include "mace/core/net.h"
#include "mace/utils/utils.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
......@@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type)
: NetBase(net_def, ws, type) {
: NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx);
......@@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
return false;
}
if (device_type_ == DeviceType::OPENCL)
OpenCLRuntime::Get()->command_queue().finish();
if (op_stats) {
op_stats->set_op_end_rel_micros(NowInMicroSec() -
op_stats->all_start_micros());
......
......@@ -40,6 +40,7 @@ class SimpleNet : public NetBase {
protected:
vector<unique_ptr<OperatorBase> > operators_;
DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SimpleNet);
};
......
......@@ -3,7 +3,7 @@
//
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
......@@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) {
bool OpenCLAllocator::OnHost() { return false; }
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace
......@@ -20,7 +20,6 @@ cc_library(
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
"//mace/core:opencl_runtime",
"//mace/utils",
"//mace/utils:tuner",
],
......
......@@ -24,33 +24,87 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */
}
}
#define vec_conv_2d_1x1_s1 \
float4 in0 = vload4(0, input_ptr); \
float4 in1 = vload4(0, input_ptr + in_pixel); \
float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \
float4 in3 = vload4(0, input_ptr + 3 * in_pixel);
#define vec_conv_2d_1x1_s2 \
float4 in00 = vload4(0, input_ptr); \
float3 in01 = vload3(0, input_ptr + 4); \
float4 in10 = vload4(0, input_ptr + in_pixel); \
float3 in11 = vload3(0, input_ptr + in_pixel + 4); \
float4 in20 = vload4(0, input_ptr + 2 * in_pixel); \
float3 in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\
float4 in30 = vload4(0, input_ptr + 3 * in_pixel); \
float3 in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \
float4 in0 = (float4)(in00.s02, in01.s02); \
float4 in1 = (float4)(in10.s02, in11.s02); \
float4 in2 = (float4)(in20.s02, in21.s02); \
float4 in3 = (float4)(in30.s02, in31.s02);
#define vec_conv_2d_1x1_compute_loop \
for (int oc = 0; oc < 4; ++oc) { \
float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \
float4 out = vload4(0, output_ptr + oc * out_pixel); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr + oc * out_pixel); \
}
#define vec_conv_2d_1x1_compute \
float4 weights = vload4(0, filter_ptr); \
float4 out = vload4(0, output_ptr); \
out += in0 * weights.x; \
out += in1 * weights.y; \
out += in2 * weights.z; \
out += in3 * weights.w; \
vstore4(out, 0, output_ptr);
__kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
__global const float *filter, /* o, i, kh, kw */
__global const float *bias, /* o */
__global float *output, /* n, c, h, w */
__private const int in_chan_num,
__private const int out_chan_num,
__private const int pixel_num) {
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int out_width,
__private const int stride) {
int batch = get_global_id(0);
int out_chan_blk = get_global_id(1);
int out_pixel_blk = get_global_id(2);
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int round_out_width = (out_width + 3) / 4;
const int out_pixel_height = out_pixel_blk / round_out_width;
const int out_pixel_width = out_pixel_blk % round_out_width;
const int out_chan_begin = out_chan_blk * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int out_pixel_begin = out_pixel_blk * 4;
const int out_pixel_end = min(out_pixel_begin + 4, pixel_num);
const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4;
const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width);
const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4;
const int in_offset = batch * in_chan_num * pixel_num;
const int out_offset = batch * out_chan_num * pixel_num;
const int in_offset = batch * in_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * out_pixel;
const float *input_base = input + in_offset + out_pixel_begin;
const float *input_base = input + in_offset + in_pixel_begin;
float *output_base = output + out_offset + out_pixel_begin;
int out_chan_len = out_chan_end - out_chan_begin;
int pixel_len = out_pixel_end - out_pixel_begin;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float *output_ptr = output_base + out_chan * pixel_num;
float *output_ptr = output_base + out_chan * out_pixel;
float bias_value = bias == NULL ? 0 : bias[out_chan];
for (int p = 0; p < pixel_len; ++p) {
output_ptr[p] = bias_value;
......@@ -59,53 +113,51 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */
int in_chan = 0;
if (pixel_len == 4) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * pixel_num;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * pixel_num;
float4 in0 = vload4(0, input_ptr);
float4 in1 = vload4(0, input_ptr + pixel_num);
float4 in2 = vload4(0, input_ptr + 2 * pixel_num);
float4 in3 = vload4(0, input_ptr + 3 * pixel_num);
#pragma unroll
for (int oc = 0; oc < 4; ++oc) {
float4 weights = vload4(0, filter_ptr + oc * in_chan_num);
float4 out = vload4(0, output_ptr + oc * pixel_num);
out += in0 * weights.x;
out += in1 * weights.y;
out += in2 * weights.z;
out += in3 * weights.w;
vstore4(out, 0, output_ptr + oc * pixel_num);
if (stride == 1) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1;
vec_conv_2d_1x1_compute_loop;
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s1;
vec_conv_2d_1x1_compute;
}
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * pixel_num;
float4 weights = vload4(0, filter_ptr);
float4 in0 = vload4(0, input_ptr);
float4 in1 = vload4(0, input_ptr + pixel_num);
float4 in2 = vload4(0, input_ptr + 2 * pixel_num);
float4 in3 = vload4(0, input_ptr + 3 * pixel_num);
float4 out = vload4(0, output_ptr);
out += in0 * weights.x;
out += in1 * weights.y;
out += in2 * weights.z;
out += in3 * weights.w;
vstore4(out, 0, output_ptr);
} else if (stride == 2) {
for (; in_chan + 3 < in_chan_num; in_chan += 4) {
const float *input_ptr = input_base + in_chan * in_pixel;
int out_chan = out_chan_begin;
for (; out_chan + 3 < out_chan_end; out_chan += 4) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute_loop;
}
for (; out_chan < out_chan_end; ++out_chan) {
const float* filter_ptr = filter + out_chan * in_chan_num + in_chan;
float *output_ptr = output_base + out_chan * out_pixel;
vec_conv_2d_1x1_s2;
vec_conv_2d_1x1_compute;
}
}
}
}
for (; in_chan < in_chan_num; ++in_chan) {
const float *input_ptr = input_base + in_chan * pixel_num;
const float *input_ptr = input_base + in_chan * in_pixel;
for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) {
float weights = filter[out_chan * in_chan_num + in_chan];
float *output_ptr = output_base + out_chan * pixel_num;
float *output_ptr = output_base + out_chan * out_pixel;
for (int p = 0; p < pixel_len; ++p) {
float in = input_ptr[p];
float in = input_ptr[p*stride];
output_ptr[p] += in * weights;
}
}
......
......@@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input,
if (pixels == 4) {
float4 res = bias == NULL ? 0 : (float4)bias[i];
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
if (stride_w == 1) {
if (stride_w == 1) {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
} else {
}
} else {
for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) {
const float* input_ptr = input_base + in_chan_idx * in_pixel;
const float* filter_ptr = filter_base + in_chan_idx * 9;
res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3);
res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3);
res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3);
......
......@@ -10,6 +10,9 @@ namespace kernels {
extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
......@@ -24,7 +27,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
const Tensor *bias, Tensor *output);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, nullptr},
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr},
{Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr},
......
......@@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input,
void Conv1x1V2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
Tensor *output) {
const index_t batch = output->dim(0);
const index_t channels = output->dim(1);
......@@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input,
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
const index_t pixels = height * width;
const index_t channel_blocks = (channels + 3) / 4;
const index_t pixel_blocks = (pixels + 3) / 4;
const index_t pixel_blocks = (width + 3) / 4 * height;
// TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy
// TODO check wired clReleaseCommandQueue latency
......@@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input,
conv_2d_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(channels));
conv_2d_kernel.setArg(idx++, static_cast<int>(pixels));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(3)));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, stride);
auto command_queue = runtime->command_queue();
cl_int error = command_queue.enqueueNDRangeKernel(
......@@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input,
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
Conv1x1V2(input, filter, bias, output);
Conv1x1V2(input, filter, bias, 1, output);
};
extern void Conv2dOpenclK1x1S2(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
MACE_CHECK(input->dim(0) == output->dim(0));
Conv1x1V2(input, filter, bias, 2, output);
};
} // namespace kernels
......
......@@ -17,7 +17,6 @@ cc_library(
],
deps = [
"//mace/core",
"//mace/core:opencl_runtime",
"@gtest//:gtest",
],
)
......
......@@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str,
tmp = tmp.substr(next_offset + 1);
}
}
return true;
}
} // namespace str_util
......@@ -254,6 +255,10 @@ int Main(int argc, char **argv) {
stats_options.show_summary = show_summary;
stats.reset(new StatSummarizer(stats_options));
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
// load model
std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary);
if (!model_file_stream.is_open()) {
......@@ -265,29 +270,30 @@ int Main(int argc, char **argv) {
model_file_stream.close();
Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU);
ws.LoadModelTensor(net_def, device_type);
// Load inputs
for (size_t i = 0; i < inputs_count; ++i) {
Tensor *input_tensor =
ws.CreateTensor(input_layers[i], GetDeviceAllocator(DeviceType::CPU), DT_FLOAT);
ws.CreateTensor(input_layers[i], GetDeviceAllocator(device_type), DT_FLOAT);
vector<index_t> shapes;
str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes);
input_tensor->Resize(shapes);
float *input_data = input_tensor->mutable_data<float>();
// load input
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i],
std::ios::in | std::ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
{
Tensor::MappingGuard input_guard(input_tensor);
float *input_data = input_tensor->mutable_data<float>();
// load input
if (i < input_layer_files.size()) {
std::ifstream in_file(input_layer_files[i],
std::ios::in | std::ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
}
}
}
// create net
DeviceType device_type;
DeviceType_Parse(device, &device_type);
auto net = CreateNet(net_def, &ws, device_type);
int64_t warmup_time_us = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册