提交 c9eb1a89 编写于 作者: L liuqi

Add env variable for opening opencl kernel tuning.

上级 faa8459b
......@@ -11,7 +11,7 @@ namespace kernels {
// [(c+3)/4*W, N * H]
void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
std::vector<size_t> &image_shape) {
std::vector<size_t> &image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape.resize(2);
image_shape[0] = RoundUpDiv4(shape[3]) * shape[2];
......@@ -40,41 +40,30 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
const BufferType type,
std::vector<size_t> &image_shape) {
switch (type) {
case FILTER:
CalFilterImageShape(shape, image_shape);
case FILTER:CalFilterImageShape(shape, image_shape);
break;
case IN_OUT:
CalInOutputImageShape(shape, image_shape);
case IN_OUT:CalInOutputImageShape(shape, image_shape);
break;
case ARGUMENT:
CalArgImageShape(shape, image_shape);
case ARGUMENT:CalArgImageShape(shape, image_shape);
break;
default:
LOG(FATAL) << "Mace not supported yet.";
default:LOG(FATAL) << "Mace not supported yet.";
}
}
std::string DtToCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
return "float";
case DT_HALF:
return "half";
default:
LOG(FATAL) << "Unsupported data type";
case DT_FLOAT:return "float";
case DT_HALF:return "half";
default:LOG(FATAL) << "Unsupported data type";
return "";
}
}
std::string DtToCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
return "f";
case DT_HALF:
return "h";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_FLOAT:return "f";
case DT_HALF:return "h";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
......@@ -82,10 +71,8 @@ std::string DtToCLCMDDt(const DataType dt) {
std::string DtToUpstreamCLDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:
return "float";
default:
LOG(FATAL) << "Unsupported data type";
case DT_HALF:return "float";
default:LOG(FATAL) << "Unsupported data type";
return "";
}
}
......@@ -93,15 +80,12 @@ std::string DtToUpstreamCLDt(const DataType dt) {
std::string DtToUpstreamCLCMDDt(const DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_HALF:
return "f";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
case DT_HALF:return "f";
default:LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
}
}
void TuningOrRun3DKernel(cl::Kernel &kernel,
const std::string tuning_key,
const uint32_t *gws,
......@@ -137,10 +121,13 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
};
};
cl::Event event;
auto func = [&](std::vector<uint32_t> &params, Timer *timer) -> cl_int {
auto func = [&](const std::vector<uint32_t> &params,
Timer *timer,
std::vector<uint32_t> *tuning_result) -> cl_int {
MACE_CHECK(params.size() == 4) << "Tuning parameters of 3D kernel must be 4D";
cl_int error = CL_SUCCESS;
if (timer == nullptr) {
uint32_t num_blocks = params.back();
uint32_t num_blocks = params[3];
const uint32_t block_size = gws[2] / num_blocks;
if (gws[2] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
......@@ -153,27 +140,31 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
}
} else {
timer->StartTiming();
timer->ClearTiming();
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->StopTiming();
double elapse_time = timer->ElapsedMicros();
timer->ClearTiming();
uint32_t num_blocks = std::min(static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[2]);
params.back() = num_blocks;
const uint32_t block_size = gws[2] / num_blocks;
if (gws[2] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
error = runtime->command_queue().enqueueNDRangeKernel(
kernel,
cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
if (LimitKernelTime()) {
double elapse_time = timer->AccumulatedMicros();
timer->ClearTiming();
uint32_t num_blocks = std::min(static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[2]);
(*tuning_result)[3] = num_blocks;
const uint32_t block_size = gws[2] / num_blocks;
if (gws[2] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size;
error = runtime->command_queue().enqueueNDRangeKernel(
kernel,
cl::NDRange(0, 0, i * block_size),
cl::NDRange(gws[0], gws[1], gws2),
cl::NDRange(params[0], params[1], params[2]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
}
}
}
return error;
......@@ -217,10 +208,13 @@ void TuningOrRun2DKernel(cl::Kernel &kernel,
};
};
cl::Event event;
auto func = [&](std::vector<uint32_t> &params, Timer *timer) -> cl_int {
auto func = [&](const std::vector<uint32_t> &params,
Timer *timer,
std::vector<uint32_t> *tuning_result) -> cl_int {
MACE_CHECK(params.size() == 3) << "Tuning parameters of 2D kernel must be 3d";
cl_int error = CL_SUCCESS;
if (timer == nullptr) {
uint32_t num_blocks = params.back();
uint32_t num_blocks = params[2];
const uint32_t block_size = gws[1] / num_blocks;
if (gws[1] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
......@@ -234,28 +228,32 @@ void TuningOrRun2DKernel(cl::Kernel &kernel,
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
}
} else {
timer->StartTiming();
timer->ClearTiming();
error = runtime->command_queue().enqueueNDRangeKernel(
kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->StopTiming();
double elapse_time = timer->ElapsedMicros();
timer->ClearTiming();
uint32_t num_blocks = std::min(static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[1]);
params.back() = num_blocks;
const uint32_t block_size = gws[1] / num_blocks;
if (gws[1] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
error = runtime->command_queue().enqueueNDRangeKernel(
kernel,
cl::NDRange(0, i * block_size),
cl::NDRange(gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
timer->AccumulateTiming();
tuning_result->assign(params.begin(), params.end());
if (LimitKernelTime()) {
double elapse_time = timer->AccumulatedMicros();
timer->ClearTiming();
uint32_t num_blocks = std::min(static_cast<uint32_t>(elapse_time / kMaxKernelExeTime) + 1, gws[1]);
(*tuning_result)[2] = num_blocks;
const uint32_t block_size = gws[1] / num_blocks;
if (gws[1] % num_blocks > 0) num_blocks++;
for (uint32_t i = 0; i < num_blocks; ++i) {
uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size;
error = runtime->command_queue().enqueueNDRangeKernel(
kernel,
cl::NDRange(0, i * block_size),
cl::NDRange(gws[0], gws1),
cl::NDRange(params[0], params[1]), nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
timer->AccumulateTiming();
}
}
}
return error;
......
......@@ -58,6 +58,11 @@ inline void SetFuture(StatsFuture *future, const cl::Event &event) {
}
}
inline bool LimitKernelTime() {
const char *flag = getenv("MACE_LIMIT_OPENCL_KERNEL_TIME");
return flag != nullptr && strlen(flag) == 1 && flag[0] == '1';
}
namespace {
template<typename T>
void AppendToStream(std::stringstream *ss, const std::string &delimiter, T v) {
......
......@@ -44,7 +44,7 @@ class Tuner {
std::vector<param_type> &default_param,
const std::function<std::vector<std::vector<param_type>>()>
&param_generator,
const std::function<RetType(std::vector<param_type> &, Timer *)> &func,
const std::function<RetType(const std::vector<param_type> &, Timer *, std::vector<param_type> *)> &func,
Timer *timer) {
std::string obfucated_param_key = MACE_OBFUSCATE_SYMBOL(param_key);
if (IsTuning() && param_generator != nullptr) {
......@@ -60,12 +60,12 @@ class Tuner {
if (param_table_.find(obfucated_param_key) != param_table_.end()) {
VLOG(1) << param_key << ": "
<< internal::MakeString(param_table_[obfucated_param_key]);
return func(param_table_[obfucated_param_key], nullptr);
return func(param_table_[obfucated_param_key], nullptr, nullptr);
} else {
#ifndef MACE_DISABLE_NO_TUNING_WARNING
LOG(WARNING) << "Fallback to default parameter: " << param_key;
#endif
return func(default_param, nullptr);
return func(default_param, nullptr, nullptr);
}
}
}
......@@ -119,15 +119,16 @@ class Tuner {
template <typename RetType>
inline RetType Run(
const std::function<RetType(std::vector<param_type> &, Timer *)> &func,
const std::function<RetType(const std::vector<param_type> &, Timer *, std::vector<param_type> *)> &func,
std::vector<param_type> &params,
Timer *timer,
int num_runs,
double *time_us) {
double *time_us,
std::vector<param_type> *tuning_result) {
RetType res;
int64_t total_time_us = 0;
for (int i = 0; i < num_runs; ++i) {
res = func(params, timer);
res = func(params, timer, tuning_result);
total_time_us += timer->AccumulatedMicros();
}
......@@ -139,24 +140,25 @@ class Tuner {
inline RetType Tune(
const std::function<std::vector<std::vector<param_type>>()>
&param_generator,
const std::function<RetType(std::vector<param_type> &, Timer *)> &func,
const std::function<RetType(const std::vector<param_type> &, Timer *, std::vector<param_type> *)> &func,
Timer *timer,
std::vector<param_type> *opt_params) {
RetType res;
double opt_time = std::numeric_limits<double>::max();
auto params = param_generator();
std::vector<param_type> tuning_result;
for (auto param : params) {
double tmp_time = 0.0;
// warm up
Run<RetType>(func, param, timer, 2, &tmp_time);
Run<RetType>(func, param, timer, 2, &tmp_time, &tuning_result);
// run
RetType tmp_res = Run<RetType>(func, param, timer, 10, &tmp_time);
RetType tmp_res = Run<RetType>(func, param, timer, 10, &tmp_time, &tuning_result);
// Check the execution time
if (tmp_time < opt_time) {
opt_time = tmp_time;
*opt_params = param;
*opt_params = tuning_result;
res = tmp_res;
}
}
......
......@@ -65,7 +65,6 @@ build_target()
--copt="-D_GLIBCXX_USE_C99_MATH_TR1" \
--copt="-Werror=return-type" \
--copt="-DMACE_OBFUSCATE_LITERALS" \
$TUNING_MODE_BUILD_FLAGS \
$DSP_MODE_BUILD_FLAGS || exit -1
}
......
import numpy as np
import math
import tensorflow as tf
A_T = np.array([[1, 1, 1, 0], [0, 1, -1, -1]]).astype(np.float32)
A = np.transpose(A_T)
B_T = np.array([
[1, 0, -1, 0],
[0, 1, 1, 0],
[0, -1, 1, 0],
[0, 1, 0, -1]
]).astype(np.float32)
B = np.transpose(B_T)
G = np.array([
[1, 0, 0],
[0.5, 0.5, 0.5],
[0.5, -0.5, 0.5],
[0, 0, 1],
]).astype(np.float32)
G_T = np.transpose(G)
def output_shape(input_shape, filter_shape):
out_shape = np.zeros(4).astype(np.int32)
out_shape[0] = input_shape[0]
out_shape[1] = filter_shape[0]
out_shape[2] = input_shape[2] - 2
out_shape[3] = input_shape[3] - 2
return out_shape
def winog_conv(input, filter):
m = 2
r = 3
alpha = m + r - 1
input_shape = input.shape
filter_shape = filter.shape
out_shape = output_shape(input_shape, filter_shape)
K = filter_shape[0]
C = input_shape[1]
U = np.zeros((K * 16, C))
for k in range(K):
for c in range(C):
u = np.dot(np.dot(G, filter[k, c, :, :]), G_T)
for i in range(4):
for j in range(4) :
U[(i * 4 + j) * K + k, c] = u[i, j]
print 'filter out: ', U.shape
print U[0, 0]
U.astype(np.float32).tofile("filter_out")
rounded_h = int(math.ceil(out_shape[2] / 2.0))
rounded_w = int(math.ceil(out_shape[3] / 2.0))
P = input_shape[0] * rounded_h * rounded_w
V = np.zeros((C * 16, P))
for p in range(P):
for c in range(C):
n = p / (rounded_w * rounded_h)
t = p % (rounded_h * rounded_w)
h_idx = t / rounded_w
w_idx = t % rounded_w
h_start = h_idx * 2
w_start = w_idx * 2
h_end = min(h_start+4, input_shape[2])
w_end = min(w_start+4, input_shape[3])
d = np.zeros((4, 4))
d[0:h_end-h_start, 0:w_end-w_start] = input[n, c, h_start:h_end, w_start:w_end]
v = np.dot(np.dot(B_T, d), B)
for i in range(4):
for j in range(4):
V[(i*4+j)*C + c, p] = v[i, j]
tmp = V.reshape(16, C, P, 1)
print 'input out: ', tmp.shape
tmp.astype(np.float32).tofile("C")
M = np.zeros((16 * K, P))
for i in range(alpha * alpha):
u = U[i * K : (i+1) * K, :]
v = V[i * C : (i+1) * C, :]
M[i * K : (i+1) * K, :] = np.dot(u, v)
print 'M shape: ', M.shape
M.astype(np.float32).tofile("gemm")
res = np.zeros((out_shape[0], out_shape[2], out_shape[3], out_shape[1]))
for k in range(K):
for b in range(P):
m = np.zeros((4, 4))
for i in range(4):
for j in range(4):
m[i][j] = M[(i*4+j) * K + k, b]
y = np.dot(np.dot(A_T, m), A)
for i in range(2):
for j in range(2):
n = b / (rounded_h * rounded_w)
t = b % (rounded_h * rounded_w)
p = (t / rounded_w) * 2 + i
q = (t % rounded_w) * 2 + j
if p >= out_shape[2] or q >= out_shape[3]:
continue
res[n, p, q, k] = y[i, j]
print 'Res shape: ', res.shape
res.astype(np.float32).tofile("res")
return res
def tf_conv(input, filter):
conv_op = tf.nn.conv2d(input, filter, [1, 1, 1, 1], 'VALID')
with tf.Session() as sess:
res = sess.run(conv_op)
return res
def main():
input = np.random.random([7, 61, 71, 31]).astype(np.float32)
# input = np.fromfile(file="A", dtype=np.float32)
# input = input.reshape(1, 3, 3, 5)
print 'input shape: ', input.shape
input.tofile("A")
filter = np.random.random([3, 3, 31, 31]).astype(np.float32)
tf_out = tf_conv(input, filter)
input = input.transpose((0, 3, 1, 2))
filter = filter.transpose((3, 2, 0, 1))
print 'filter shape: ', filter.shape
filter.tofile("filter_in")
winog_out = winog_conv(input, filter)
res = np.allclose(tf_out, winog_out)
if res:
print "=========Pass========="
else:
print "=========Failed========="
print "TF: ", tf_out
print "Winograd: ", winog_out
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册