diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 783a30243407653cc660375b542f2c8f896ac52e..e220d34463212887bbaaf927288a15ad9549ba32 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -11,7 +11,7 @@ namespace kernels { // [(c+3)/4*W, N * H] void CalInOutputImageShape(const std::vector &shape, /* NHWC */ - std::vector &image_shape) { + std::vector &image_shape) { MACE_CHECK(shape.size() == 4); image_shape.resize(2); image_shape[0] = RoundUpDiv4(shape[3]) * shape[2]; @@ -40,41 +40,30 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ const BufferType type, std::vector &image_shape) { switch (type) { - case FILTER: - CalFilterImageShape(shape, image_shape); + case FILTER:CalFilterImageShape(shape, image_shape); break; - case IN_OUT: - CalInOutputImageShape(shape, image_shape); + case IN_OUT:CalInOutputImageShape(shape, image_shape); break; - case ARGUMENT: - CalArgImageShape(shape, image_shape); + case ARGUMENT:CalArgImageShape(shape, image_shape); break; - default: - LOG(FATAL) << "Mace not supported yet."; + default:LOG(FATAL) << "Mace not supported yet."; } } - std::string DtToCLDt(const DataType dt) { switch (dt) { - case DT_FLOAT: - return "float"; - case DT_HALF: - return "half"; - default: - LOG(FATAL) << "Unsupported data type"; + case DT_FLOAT:return "float"; + case DT_HALF:return "half"; + default:LOG(FATAL) << "Unsupported data type"; return ""; } } std::string DtToCLCMDDt(const DataType dt) { switch (dt) { - case DT_FLOAT: - return "f"; - case DT_HALF: - return "h"; - default: - LOG(FATAL) << "Not supported data type for opencl cmd data type"; + case DT_FLOAT:return "f"; + case DT_HALF:return "h"; + default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; return ""; } } @@ -82,10 +71,8 @@ std::string DtToCLCMDDt(const DataType dt) { std::string DtToUpstreamCLDt(const DataType dt) { switch (dt) { case DT_FLOAT: - case DT_HALF: - return "float"; - default: - LOG(FATAL) << "Unsupported data type"; + case DT_HALF:return "float"; + default:LOG(FATAL) << "Unsupported data type"; return ""; } } @@ -93,15 +80,12 @@ std::string DtToUpstreamCLDt(const DataType dt) { std::string DtToUpstreamCLCMDDt(const DataType dt) { switch (dt) { case DT_FLOAT: - case DT_HALF: - return "f"; - default: - LOG(FATAL) << "Not supported data type for opencl cmd data type"; + case DT_HALF:return "f"; + default:LOG(FATAL) << "Not supported data type for opencl cmd data type"; return ""; } } - void TuningOrRun3DKernel(cl::Kernel &kernel, const std::string tuning_key, const uint32_t *gws, @@ -137,10 +121,13 @@ void TuningOrRun3DKernel(cl::Kernel &kernel, }; }; cl::Event event; - auto func = [&](std::vector ¶ms, Timer *timer) -> cl_int { + auto func = [&](const std::vector ¶ms, + Timer *timer, + std::vector *tuning_result) -> cl_int { + MACE_CHECK(params.size() == 4) << "Tuning parameters of 3D kernel must be 4D"; cl_int error = CL_SUCCESS; if (timer == nullptr) { - uint32_t num_blocks = params.back(); + uint32_t num_blocks = params[3]; const uint32_t block_size = gws[2] / num_blocks; if (gws[2] % num_blocks > 0) num_blocks++; for (uint32_t i = 0; i < num_blocks; ++i) { @@ -153,27 +140,31 @@ void TuningOrRun3DKernel(cl::Kernel &kernel, MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; } } else { - timer->StartTiming(); + timer->ClearTiming(); error = runtime->command_queue().enqueueNDRangeKernel( kernel, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), cl::NDRange(params[0], params[1], params[2]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; - timer->StopTiming(); - double elapse_time = timer->ElapsedMicros(); - timer->ClearTiming(); - uint32_t num_blocks = std::min(static_cast(elapse_time / kMaxKernelExeTime) + 1, gws[2]); - params.back() = num_blocks; - const uint32_t block_size = gws[2] / num_blocks; - if (gws[2] % num_blocks > 0) num_blocks++; - for (uint32_t i = 0; i < num_blocks; ++i) { - uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size; - error = runtime->command_queue().enqueueNDRangeKernel( - kernel, - cl::NDRange(0, 0, i * block_size), - cl::NDRange(gws[0], gws[1], gws2), - cl::NDRange(params[0], params[1], params[2]), nullptr, &event); - MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; - timer->AccumulateTiming(); + timer->AccumulateTiming(); + tuning_result->assign(params.begin(), params.end()); + + if (LimitKernelTime()) { + double elapse_time = timer->AccumulatedMicros(); + timer->ClearTiming(); + uint32_t num_blocks = std::min(static_cast(elapse_time / kMaxKernelExeTime) + 1, gws[2]); + (*tuning_result)[3] = num_blocks; + const uint32_t block_size = gws[2] / num_blocks; + if (gws[2] % num_blocks > 0) num_blocks++; + for (uint32_t i = 0; i < num_blocks; ++i) { + uint32_t gws2 = (i == num_blocks - 1) ? (gws[2] - (i * block_size)) : block_size; + error = runtime->command_queue().enqueueNDRangeKernel( + kernel, + cl::NDRange(0, 0, i * block_size), + cl::NDRange(gws[0], gws[1], gws2), + cl::NDRange(params[0], params[1], params[2]), nullptr, &event); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + timer->AccumulateTiming(); + } } } return error; @@ -217,10 +208,13 @@ void TuningOrRun2DKernel(cl::Kernel &kernel, }; }; cl::Event event; - auto func = [&](std::vector ¶ms, Timer *timer) -> cl_int { + auto func = [&](const std::vector ¶ms, + Timer *timer, + std::vector *tuning_result) -> cl_int { + MACE_CHECK(params.size() == 3) << "Tuning parameters of 2D kernel must be 3d"; cl_int error = CL_SUCCESS; if (timer == nullptr) { - uint32_t num_blocks = params.back(); + uint32_t num_blocks = params[2]; const uint32_t block_size = gws[1] / num_blocks; if (gws[1] % num_blocks > 0) num_blocks++; for (uint32_t i = 0; i < num_blocks; ++i) { @@ -234,28 +228,32 @@ void TuningOrRun2DKernel(cl::Kernel &kernel, MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; } } else { - timer->StartTiming(); + timer->ClearTiming(); error = runtime->command_queue().enqueueNDRangeKernel( kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]), cl::NDRange(params[0], params[1]), nullptr, &event); MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; - timer->StopTiming(); - double elapse_time = timer->ElapsedMicros(); - timer->ClearTiming(); - uint32_t num_blocks = std::min(static_cast(elapse_time / kMaxKernelExeTime) + 1, gws[1]); - params.back() = num_blocks; - const uint32_t block_size = gws[1] / num_blocks; - if (gws[1] % num_blocks > 0) num_blocks++; - for (uint32_t i = 0; i < num_blocks; ++i) { - uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size; - error = runtime->command_queue().enqueueNDRangeKernel( - kernel, - cl::NDRange(0, i * block_size), - cl::NDRange(gws[0], gws1), - cl::NDRange(params[0], params[1]), nullptr, &event); - MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; - timer->AccumulateTiming(); + timer->AccumulateTiming(); + tuning_result->assign(params.begin(), params.end()); + + if (LimitKernelTime()) { + double elapse_time = timer->AccumulatedMicros(); + timer->ClearTiming(); + uint32_t num_blocks = std::min(static_cast(elapse_time / kMaxKernelExeTime) + 1, gws[1]); + (*tuning_result)[2] = num_blocks; + const uint32_t block_size = gws[1] / num_blocks; + if (gws[1] % num_blocks > 0) num_blocks++; + for (uint32_t i = 0; i < num_blocks; ++i) { + uint32_t gws1 = (i == num_blocks - 1) ? (gws[1] - (i * block_size)) : block_size; + error = runtime->command_queue().enqueueNDRangeKernel( + kernel, + cl::NDRange(0, i * block_size), + cl::NDRange(gws[0], gws1), + cl::NDRange(params[0], params[1]), nullptr, &event); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + timer->AccumulateTiming(); + } } } return error; diff --git a/mace/kernels/opencl/helper.h b/mace/kernels/opencl/helper.h index cfbef59f7038285d462d635d172b892bd6de56a1..466064b6d8b6ab98a09ec001fb46cace22447b78 100644 --- a/mace/kernels/opencl/helper.h +++ b/mace/kernels/opencl/helper.h @@ -58,6 +58,11 @@ inline void SetFuture(StatsFuture *future, const cl::Event &event) { } } +inline bool LimitKernelTime() { + const char *flag = getenv("MACE_LIMIT_OPENCL_KERNEL_TIME"); + return flag != nullptr && strlen(flag) == 1 && flag[0] == '1'; +} + namespace { template void AppendToStream(std::stringstream *ss, const std::string &delimiter, T v) { diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h index e2797fa9bb8086563c44c588cdc609fafe013100..369152819afb67c554c8c057777fc91d9b3e1349 100644 --- a/mace/utils/tuner.h +++ b/mace/utils/tuner.h @@ -44,7 +44,7 @@ class Tuner { std::vector &default_param, const std::function>()> ¶m_generator, - const std::function &, Timer *)> &func, + const std::function &, Timer *, std::vector *)> &func, Timer *timer) { std::string obfucated_param_key = MACE_OBFUSCATE_SYMBOL(param_key); if (IsTuning() && param_generator != nullptr) { @@ -60,12 +60,12 @@ class Tuner { if (param_table_.find(obfucated_param_key) != param_table_.end()) { VLOG(1) << param_key << ": " << internal::MakeString(param_table_[obfucated_param_key]); - return func(param_table_[obfucated_param_key], nullptr); + return func(param_table_[obfucated_param_key], nullptr, nullptr); } else { #ifndef MACE_DISABLE_NO_TUNING_WARNING LOG(WARNING) << "Fallback to default parameter: " << param_key; #endif - return func(default_param, nullptr); + return func(default_param, nullptr, nullptr); } } } @@ -119,15 +119,16 @@ class Tuner { template inline RetType Run( - const std::function &, Timer *)> &func, + const std::function &, Timer *, std::vector *)> &func, std::vector ¶ms, Timer *timer, int num_runs, - double *time_us) { + double *time_us, + std::vector *tuning_result) { RetType res; int64_t total_time_us = 0; for (int i = 0; i < num_runs; ++i) { - res = func(params, timer); + res = func(params, timer, tuning_result); total_time_us += timer->AccumulatedMicros(); } @@ -139,24 +140,25 @@ class Tuner { inline RetType Tune( const std::function>()> ¶m_generator, - const std::function &, Timer *)> &func, + const std::function &, Timer *, std::vector *)> &func, Timer *timer, std::vector *opt_params) { RetType res; double opt_time = std::numeric_limits::max(); auto params = param_generator(); + std::vector tuning_result; for (auto param : params) { double tmp_time = 0.0; // warm up - Run(func, param, timer, 2, &tmp_time); + Run(func, param, timer, 2, &tmp_time, &tuning_result); // run - RetType tmp_res = Run(func, param, timer, 10, &tmp_time); + RetType tmp_res = Run(func, param, timer, 10, &tmp_time, &tuning_result); // Check the execution time if (tmp_time < opt_time) { opt_time = tmp_time; - *opt_params = param; + *opt_params = tuning_result; res = tmp_res; } } diff --git a/tools/export_lib.sh b/tools/export_lib.sh index 446330a50e389c4808106660dc6bb044988de407..55b2b3a392670bd6d36a0556fa061386e5ba5a1b 100755 --- a/tools/export_lib.sh +++ b/tools/export_lib.sh @@ -65,7 +65,6 @@ build_target() --copt="-D_GLIBCXX_USE_C99_MATH_TR1" \ --copt="-Werror=return-type" \ --copt="-DMACE_OBFUSCATE_LITERALS" \ - $TUNING_MODE_BUILD_FLAGS \ $DSP_MODE_BUILD_FLAGS || exit -1 } diff --git a/tools/wino_conv.py b/tools/wino_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cdf3d8e88586b10dd3256de3670978c2a2e5f2 --- /dev/null +++ b/tools/wino_conv.py @@ -0,0 +1,141 @@ +import numpy as np +import math +import tensorflow as tf + +A_T = np.array([[1, 1, 1, 0], [0, 1, -1, -1]]).astype(np.float32) +A = np.transpose(A_T) +B_T = np.array([ + [1, 0, -1, 0], + [0, 1, 1, 0], + [0, -1, 1, 0], + [0, 1, 0, -1] +]).astype(np.float32) +B = np.transpose(B_T) +G = np.array([ + [1, 0, 0], + [0.5, 0.5, 0.5], + [0.5, -0.5, 0.5], + [0, 0, 1], +]).astype(np.float32) +G_T = np.transpose(G) + + +def output_shape(input_shape, filter_shape): + out_shape = np.zeros(4).astype(np.int32) + out_shape[0] = input_shape[0] + out_shape[1] = filter_shape[0] + out_shape[2] = input_shape[2] - 2 + out_shape[3] = input_shape[3] - 2 + return out_shape + + +def winog_conv(input, filter): + m = 2 + r = 3 + alpha = m + r - 1 + input_shape = input.shape + filter_shape = filter.shape + out_shape = output_shape(input_shape, filter_shape) + + K = filter_shape[0] + C = input_shape[1] + U = np.zeros((K * 16, C)) + + for k in range(K): + for c in range(C): + u = np.dot(np.dot(G, filter[k, c, :, :]), G_T) + for i in range(4): + for j in range(4) : + U[(i * 4 + j) * K + k, c] = u[i, j] + + print 'filter out: ', U.shape + print U[0, 0] + U.astype(np.float32).tofile("filter_out") + + rounded_h = int(math.ceil(out_shape[2] / 2.0)) + rounded_w = int(math.ceil(out_shape[3] / 2.0)) + P = input_shape[0] * rounded_h * rounded_w + V = np.zeros((C * 16, P)) + for p in range(P): + for c in range(C): + n = p / (rounded_w * rounded_h) + t = p % (rounded_h * rounded_w) + h_idx = t / rounded_w + w_idx = t % rounded_w + h_start = h_idx * 2 + w_start = w_idx * 2 + h_end = min(h_start+4, input_shape[2]) + w_end = min(w_start+4, input_shape[3]) + d = np.zeros((4, 4)) + d[0:h_end-h_start, 0:w_end-w_start] = input[n, c, h_start:h_end, w_start:w_end] + v = np.dot(np.dot(B_T, d), B) + for i in range(4): + for j in range(4): + V[(i*4+j)*C + c, p] = v[i, j] + + tmp = V.reshape(16, C, P, 1) + print 'input out: ', tmp.shape + tmp.astype(np.float32).tofile("C") + M = np.zeros((16 * K, P)) + for i in range(alpha * alpha): + u = U[i * K : (i+1) * K, :] + v = V[i * C : (i+1) * C, :] + M[i * K : (i+1) * K, :] = np.dot(u, v) + + print 'M shape: ', M.shape + M.astype(np.float32).tofile("gemm") + res = np.zeros((out_shape[0], out_shape[2], out_shape[3], out_shape[1])) + for k in range(K): + for b in range(P): + m = np.zeros((4, 4)) + for i in range(4): + for j in range(4): + m[i][j] = M[(i*4+j) * K + k, b] + y = np.dot(np.dot(A_T, m), A) + for i in range(2): + for j in range(2): + n = b / (rounded_h * rounded_w) + t = b % (rounded_h * rounded_w) + p = (t / rounded_w) * 2 + i + q = (t % rounded_w) * 2 + j + if p >= out_shape[2] or q >= out_shape[3]: + continue + res[n, p, q, k] = y[i, j] + + print 'Res shape: ', res.shape + res.astype(np.float32).tofile("res") + + return res + +def tf_conv(input, filter): + conv_op = tf.nn.conv2d(input, filter, [1, 1, 1, 1], 'VALID') + with tf.Session() as sess: + res = sess.run(conv_op) + return res + + +def main(): + input = np.random.random([7, 61, 71, 31]).astype(np.float32) + # input = np.fromfile(file="A", dtype=np.float32) + # input = input.reshape(1, 3, 3, 5) + print 'input shape: ', input.shape + input.tofile("A") + filter = np.random.random([3, 3, 31, 31]).astype(np.float32) + tf_out = tf_conv(input, filter) + input = input.transpose((0, 3, 1, 2)) + filter = filter.transpose((3, 2, 0, 1)) + print 'filter shape: ', filter.shape + filter.tofile("filter_in") + winog_out = winog_conv(input, filter) + res = np.allclose(tf_out, winog_out) + if res: + print "=========Pass=========" + else: + print "=========Failed=========" + print "TF: ", tf_out + print "Winograd: ", winog_out + + +if __name__ == '__main__': + main() +