提交 15a0c2b2 编写于 作者: L liuqi

Winograd script support multiple type.

上级 917f19e9
...@@ -32,12 +32,12 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -32,12 +32,12 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(
const index_t round_w = (output_shape[2] + 1) / 2; const index_t round_w = (output_shape[2] + 1) / 2;
const index_t out_width = input_tensor->dim(0) * round_h * round_w; const index_t out_width = input_tensor->dim(0) * round_h * round_w;
if (kernel_.get() == nullptr) {
output_shape = {16, input_tensor->dim(3), out_width, 1}; output_shape = {16, input_tensor->dim(3), out_width, 1};
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape); CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, image_shape);
output_tensor->ResizeImage(output_shape, image_shape); output_tensor->ResizeImage(output_shape, image_shape);
if (kernel_.get() == nullptr) {
std::string obfuscated_kernel_name = std::string obfuscated_kernel_name =
MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2"); MACE_OBFUSCATE_SYMBOL("winograd_transform_2x2");
std::set<std::string> built_options; std::set<std::string> built_options;
......
...@@ -2,22 +2,89 @@ import numpy as np ...@@ -2,22 +2,89 @@ import numpy as np
import math import math
import tensorflow as tf import tensorflow as tf
A_T = np.array([[1, 1, 1, 0], [0, 1, -1, -1]]).astype(np.float32) A_T = {}
A = np.transpose(A_T) A = {}
B_T = np.array([ B_T = {}
B = {}
G = {}
G_T = {}
# f(2, 3)
A_T[4] = np.array([[1, 1, 1, 0], [0, 1, -1, -1]]).astype(np.float32)
A[4] = np.transpose(A_T[4])
B_T[4] = np.array([
[1, 0, -1, 0], [1, 0, -1, 0],
[0, 1, 1, 0], [0, 1, 1, 0],
[0, -1, 1, 0], [0, -1, 1, 0],
[0, 1, 0, -1] [0, 1, 0, -1]
]).astype(np.float32) ]).astype(np.float32)
B = np.transpose(B_T) B[4] = np.transpose(B_T[4])
G = np.array([ G[4] = np.array([
[1, 0, 0], [1, 0, 0],
[0.5, 0.5, 0.5], [0.5, 0.5, 0.5],
[0.5, -0.5, 0.5], [0.5, -0.5, 0.5],
[0, 0, 1], [0, 0, 1],
]).astype(np.float32) ]).astype(np.float32)
G_T = np.transpose(G) G_T[4] = np.transpose(G[4])
# f(4, 3)
A_T[6] = np.array([
[1, 1, 1, 1, 1, 0],
[0, 1, -1, 2, -2, 0],
[0, 1, 1, 4, 4, 0],
[0, 1, -1, 8, -8, 1],
]).astype(np.float32)
A[6] = np.transpose(A_T[6])
B_T[6] = np.array([
[4, 0, -5, 0, 1, 0],
[0, -4, -4, 1, 1, 0],
[0, 4, -4, -1, 1, 0],
[0, -2, -1, 2, 1, 0],
[0, 2, -1, -2, 1, 0],
[0, 4, 0, -5, 0, 1],
]).astype(np.float32)
B[6] = np.transpose(B_T[6])
G[6] = np.array([
[1/4.0 , 0 , 0 ],
[-1/6.0, -1/6.0 , -1/6.0],
[-1/6.0, 1/6.0 , -1/6.0],
[1/24.0, 1/12.0 , 1/6.0 ],
[1/24.0, -1/12.0, 1/6.0 ],
[ 0 , 0 , 1 ],
]).astype(np.float32)
G_T[6] = np.transpose(G[6])
# f(6, 3)
A_T[8] = np.array([
[1, 1, 1 , 1 , 1 , 1 , 1 , 0],
[0, 1, -1, 2 , -2 , 1/2. , -1/2. , 0],
[0, 1, 1 , 4 , 4 , 1/4. , 1/4. , 0],
[0, 1, -1, 8 , -8 , 1/8. , -1/8. , 0],
[0, 1, 1 , 16, 16 , 1/16., 1/16. , 0],
[0, 1, -1, 32, -32, 1/32., -1/32., 1],
]).astype(np.float32)
A[8] = np.transpose(A_T[8])
B_T[8] = np.array([
[1, 0 , -21/4., 0 , 21/4., 0 , -1, 0],
[0, 1 , 1 , -17/4., -17/4., 1 , 1 , 0],
[0, -1 , 1 , 17/4. , -17/4., -1 , 1 , 0],
[0, 1/2. , 1/4. , -5/2. , -5/4., 2 , 1 , 0],
[0, -1/2., 1/4. , 5/2. , -5/4., -2 , 1 , 0],
[0, 2 , 4 , -5/2. , -5 , 1/2. , 1 , 0],
[0, -2 , 4 , 5/2. , -5 , -1/2. , 1 , 0],
[0, -1 , 0 , 21/4. , 0 , -21/4., 0 , 1],
]).astype(np.float32)
B[8] = np.transpose(B_T[8])
G[8] = np.array([
[ 1 , 0 , 0 ],
[-2/9. , -2/9. , -2/9.],
[-2/9. , 2/9. , -2/9.],
[1/90. , 1/45. , 2/45.],
[1/90. , -1/45. , 2/45.],
[32/45., 16/45. , 8/45.],
[32/45., -16/45., 8/45.],
[ 0 , 0 , 1 ],
]).astype(np.float32)
G_T[8] = np.transpose(G[8])
def output_shape(input_shape, filter_shape): def output_shape(input_shape, filter_shape):
...@@ -29,55 +96,54 @@ def output_shape(input_shape, filter_shape): ...@@ -29,55 +96,54 @@ def output_shape(input_shape, filter_shape):
return out_shape return out_shape
def winog_conv(input, filter): def winog_conv(m, r, input, filter):
m = 2
r = 3
alpha = m + r - 1 alpha = m + r - 1
print 'Winograd(m = %d, r = %d, tile size=%d' % (m, r, alpha)
alpha_square = alpha * alpha
input_shape = input.shape input_shape = input.shape
filter_shape = filter.shape filter_shape = filter.shape
out_shape = output_shape(input_shape, filter_shape) out_shape = output_shape(input_shape, filter_shape)
K = filter_shape[0] K = filter_shape[0]
C = input_shape[1] C = input_shape[1]
U = np.zeros((K * 16, C)) U = np.zeros((K * alpha_square, C))
for k in range(K): for k in range(K):
for c in range(C): for c in range(C):
u = np.dot(np.dot(G, filter[k, c, :, :]), G_T) u = np.dot(np.dot(G[alpha], filter[k, c, :, :]), G_T[alpha])
for i in range(4): for i in range(alpha):
for j in range(4) : for j in range(alpha) :
U[(i * 4 + j) * K + k, c] = u[i, j] U[(i * alpha + j) * K + k, c] = u[i, j]
print 'filter out: ', U.shape print 'filter out: ', U.shape
print U[0, 0]
U.astype(np.float32).tofile("filter_out")
rounded_h = int(math.ceil(out_shape[2] / 2.0)) rounded_h = int(math.ceil(out_shape[2] / (m * 1.0)))
rounded_w = int(math.ceil(out_shape[3] / 2.0)) rounded_w = int(math.ceil(out_shape[3] / (m * 1.0)))
P = input_shape[0] * rounded_h * rounded_w P = input_shape[0] * rounded_h * rounded_w
V = np.zeros((C * 16, P)) V = np.zeros((C * alpha_square, P))
for p in range(P): for p in range(P):
for c in range(C): for c in range(C):
n = p / (rounded_w * rounded_h) n = p / (rounded_w * rounded_h)
t = p % (rounded_h * rounded_w) t = p % (rounded_h * rounded_w)
h_idx = t / rounded_w h_idx = t / rounded_w
w_idx = t % rounded_w w_idx = t % rounded_w
h_start = h_idx * 2 h_start = h_idx * m
w_start = w_idx * 2 w_start = w_idx * m
h_end = min(h_start+4, input_shape[2]) h_end = min(h_start+alpha, input_shape[2])
w_end = min(w_start+4, input_shape[3]) w_end = min(w_start+alpha, input_shape[3])
d = np.zeros((4, 4)) d = np.zeros((alpha, alpha))
d[0:h_end-h_start, 0:w_end-w_start] = input[n, c, h_start:h_end, w_start:w_end] d[0:h_end-h_start, 0:w_end-w_start] = \
v = np.dot(np.dot(B_T, d), B) input[n, c, h_start:h_end, w_start:w_end]
for i in range(4): v = np.dot(np.dot(B_T[alpha], d), B[alpha])
for j in range(4): for i in range(alpha):
V[(i*4+j)*C + c, p] = v[i, j] for j in range(alpha):
V[(i*alpha+j)*C + c, p] = v[i, j]
tmp = V.reshape(16, C, P, 1)
tmp = V.reshape(alpha_square, C, P, 1)
print 'input out: ', tmp.shape print 'input out: ', tmp.shape
tmp.astype(np.float32).tofile("C") tmp.astype(np.float32).tofile("C")
M = np.zeros((16 * K, P)) M = np.zeros((alpha_square * K, P))
for i in range(alpha * alpha): for i in range(alpha_square):
u = U[i * K : (i+1) * K, :] u = U[i * K : (i+1) * K, :]
v = V[i * C : (i+1) * C, :] v = V[i * C : (i+1) * C, :]
M[i * K : (i+1) * K, :] = np.dot(u, v) M[i * K : (i+1) * K, :] = np.dot(u, v)
...@@ -87,17 +153,17 @@ def winog_conv(input, filter): ...@@ -87,17 +153,17 @@ def winog_conv(input, filter):
res = np.zeros((out_shape[0], out_shape[2], out_shape[3], out_shape[1])) res = np.zeros((out_shape[0], out_shape[2], out_shape[3], out_shape[1]))
for k in range(K): for k in range(K):
for b in range(P): for b in range(P):
m = np.zeros((4, 4)) tm = np.zeros((alpha, alpha))
for i in range(4): for i in range(alpha):
for j in range(4): for j in range(alpha):
m[i][j] = M[(i*4+j) * K + k, b] tm[i][j] = M[(i*alpha+j) * K + k, b]
y = np.dot(np.dot(A_T, m), A) y = np.dot(np.dot(A_T[alpha], tm), A[alpha])
for i in range(2): for i in range(m):
for j in range(2): for j in range(m):
n = b / (rounded_h * rounded_w) n = b / (rounded_h * rounded_w)
t = b % (rounded_h * rounded_w) t = b % (rounded_h * rounded_w)
p = (t / rounded_w) * 2 + i p = (t / rounded_w) * m + i
q = (t % rounded_w) * 2 + j q = (t % rounded_w) * m + j
if p >= out_shape[2] or q >= out_shape[3]: if p >= out_shape[2] or q >= out_shape[3]:
continue continue
res[n, p, q, k] = y[i, j] res[n, p, q, k] = y[i, j]
...@@ -115,23 +181,25 @@ def tf_conv(input, filter): ...@@ -115,23 +181,25 @@ def tf_conv(input, filter):
def main(): def main():
input = np.random.random([7, 61, 71, 31]).astype(np.float32) input = np.random.random([5, 23, 29, 15]).astype(np.float32)
# input = np.fromfile(file="A", dtype=np.float32) # input = np.fromfile(file="A", dtype=np.float32)
# input = input.reshape(1, 3, 3, 5) # input = input.reshape(1, 3, 3, 5)
print 'input shape: ', input.shape print 'input shape: ', input.shape
input.tofile("A") # input.tofile("A")
filter = np.random.random([3, 3, 31, 31]).astype(np.float32) filter = np.random.random([3, 3, 15, 13]).astype(np.float32)
tf_out = tf_conv(input, filter) tf_out = tf_conv(input, filter)
input = input.transpose((0, 3, 1, 2)) input = input.transpose((0, 3, 1, 2))
filter = filter.transpose((3, 2, 0, 1)) filter = filter.transpose((3, 2, 0, 1))
print 'filter shape: ', filter.shape print 'filter shape: ', filter.shape
filter.tofile("filter_in") # filter.tofile("filter_in")
winog_out = winog_conv(input, filter) for i in [2, 4, 6]:
print "==========f(%d,3)==========" % i
winog_out = winog_conv(i, 3, input, filter)
res = np.allclose(tf_out, winog_out) res = np.allclose(tf_out, winog_out)
if res: if res:
print "=========Pass=========" print "=========Pass========="
else: else:
print "=========Failed=========" print "=========Failed======="
print "TF: ", tf_out print "TF: ", tf_out
print "Winograd: ", winog_out print "Winograd: ", winog_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册