提交 eaf049c4 编写于 作者: J Jiabin Yang 提交者: Hongyu Liu

test=develop, refine ocr attention model (#17763)

* test=develop, refine ocr attention model

* test=develop, refine code, remove cpu only part

test=develop, refine code, remove cpu only part
上级 5e4f99dd
......@@ -90,7 +90,7 @@ class ConvBNPool(fluid.dygraph.Layer):
3,
padding=1,
param_attr=conv_param_0,
bias_attr=None,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
......@@ -102,7 +102,7 @@ class ConvBNPool(fluid.dygraph.Layer):
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=None,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
......@@ -301,7 +301,11 @@ class SimpleAttention(fluid.dygraph.Layer):
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = FC(self.full_name(), 1, act=None, bias_attr=False)
self.fc_2 = FC(self.full_name(),
1,
num_flatten_dims=2,
act=None,
bias_attr=False)
def _build_once(self, encoder_vec, encoder_proj, decoder_state):
pass
......@@ -317,19 +321,17 @@ class SimpleAttention(fluid.dygraph.Layer):
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[-1], inplace=False)
x=attention_weight,
shape=[attention_weight.shape[0], attention_weight.shape[1]],
inplace=False)
weights_reshape = fluid.layers.softmax(weights_reshape)
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
scaled = fluid.layers.transpose(scaled, [0, 2, 1])
scaled = fluid.layers.reshape(
scaled, [-1, scaled.shape[1], scaled.shape[2], 1], inplace=False)
context = fluid.layers.pool2d(
input=scaled,
pool_size=[scaled.shape[2], scaled.shape[3]],
pool_type='avg')
context = fluid.layers.reshape(
context, [-1, context.shape[1]], inplace=False)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
......@@ -381,7 +383,7 @@ class GRUDecoderWithAttention(fluid.dygraph.Layer):
out = self.out_layer(h)
res.append(out)
res1 = fluid.layers.concat(res, axis=0)
res1 = fluid.layers.concat(res, axis=1)
return res1
......@@ -427,7 +429,11 @@ class TestDygraphOCRAttention(unittest.TestCase):
def test_while_op(self):
seed = 90
epoch_num = 2
if core.is_compiled_with_cuda():
batch_num = 20
else:
print("in CPU")
batch_num = 2
np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1],
......@@ -441,7 +447,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
i * Config.max_length,
dtype='int64').reshape([1, Config.max_length])))
print(label_in_np.shape)
label_out_np = np.arange(
0, Config.max_length,
dtype='int64').reshape([1, Config.max_length])
......@@ -450,7 +455,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
(i - 1) * Config.max_length,
i * Config.max_length,
dtype='int64').reshape([1, Config.max_length])))
print(label_out_np.shape)
#if Config.use_gpu:
# place = fluid.CUDAPlace(0)
#else:
......@@ -484,6 +488,8 @@ class TestDygraphOCRAttention(unittest.TestCase):
dy_prediction = ocr_attention(img, label_in)
label_out = fluid.layers.reshape(
label_out, [-1, 1], inplace=False)
dy_prediction = fluid.layers.reshape(
dy_prediction, [label_out.shape[0], -1], inplace=False)
loss = fluid.layers.cross_entropy(
input=dy_prediction, label=label_out)
avg_loss = fluid.layers.reduce_sum(loss)
......@@ -536,6 +542,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_prediction = ocr_attention(images, static_label_in)
static_prediction = fluid.layers.reshape(
static_prediction, shape=[-1, Config.num_classes + 2])
cost = fluid.layers.cross_entropy(
input=static_prediction, label=static_label_out)
static_avg_loss = fluid.layers.reduce_sum(cost)
......@@ -558,8 +567,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_param_init_value[static_param_name_list[i]] = out[i]
fetch_list = [static_avg_loss.name]
# print(static_test.name)
# fetch_list = [static_avg_loss.name, static_test.name]
fetch_list.extend(static_param_name_list)
fetch_list.extend(static_grad_name_list)
for epoch in range(epoch_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册