提交 6f927abb 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: wangguanzhong

modify nonlocal (#56)

上级 acf2cb24
......@@ -97,8 +97,12 @@ def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride =
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(t, shape=list(theta_shape), actual_shape=theta_shape_op )
n = fluid.layers.slice(theta_shape_op, axes=[0], starts=[0], ends=[1])
h = fluid.layers.slice(theta_shape_op, axes=[0], starts=[2], ends=[3])
w = fluid.layers.slice(theta_shape_op, axes=[0], starts=[3], ends=[4])
ch = int(theta_shape[1])
t_re = fluid.layers.reshape(t, shape=[n, ch, h, w])
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册