提交 eaf049c4 编写于 作者: J Jiabin Yang 提交者: Hongyu Liu

test=develop, refine ocr attention model (#17763)

* test=develop, refine ocr attention model

* test=develop, refine code, remove cpu only part

test=develop, refine code, remove cpu only part
上级 5e4f99dd
...@@ -90,7 +90,7 @@ class ConvBNPool(fluid.dygraph.Layer): ...@@ -90,7 +90,7 @@ class ConvBNPool(fluid.dygraph.Layer):
3, 3,
padding=1, padding=1,
param_attr=conv_param_0, param_attr=conv_param_0,
bias_attr=None, bias_attr=False,
act=None, act=None,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm( self.bn_0_layer = BatchNorm(
...@@ -102,7 +102,7 @@ class ConvBNPool(fluid.dygraph.Layer): ...@@ -102,7 +102,7 @@ class ConvBNPool(fluid.dygraph.Layer):
filter_size=3, filter_size=3,
padding=1, padding=1,
param_attr=conv_param_1, param_attr=conv_param_1,
bias_attr=None, bias_attr=False,
act=None, act=None,
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm( self.bn_1_layer = BatchNorm(
...@@ -301,7 +301,11 @@ class SimpleAttention(fluid.dygraph.Layer): ...@@ -301,7 +301,11 @@ class SimpleAttention(fluid.dygraph.Layer):
decoder_size, decoder_size,
act=None, act=None,
bias_attr=False) bias_attr=False)
self.fc_2 = FC(self.full_name(), 1, act=None, bias_attr=False) self.fc_2 = FC(self.full_name(),
1,
num_flatten_dims=2,
act=None,
bias_attr=False)
def _build_once(self, encoder_vec, encoder_proj, decoder_state): def _build_once(self, encoder_vec, encoder_proj, decoder_state):
pass pass
...@@ -317,19 +321,17 @@ class SimpleAttention(fluid.dygraph.Layer): ...@@ -317,19 +321,17 @@ class SimpleAttention(fluid.dygraph.Layer):
decoder_state_expand) decoder_state_expand)
concated = fluid.layers.tanh(x=concated) concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated) attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape( weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[-1], inplace=False) x=attention_weight,
shape=[attention_weight.shape[0], attention_weight.shape[1]],
inplace=False)
weights_reshape = fluid.layers.softmax(weights_reshape)
scaled = fluid.layers.elementwise_mul( scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0) x=encoder_vec, y=weights_reshape, axis=0)
scaled = fluid.layers.transpose(scaled, [0, 2, 1]) context = fluid.layers.reduce_sum(scaled, dim=1)
scaled = fluid.layers.reshape(
scaled, [-1, scaled.shape[1], scaled.shape[2], 1], inplace=False)
context = fluid.layers.pool2d(
input=scaled,
pool_size=[scaled.shape[2], scaled.shape[3]],
pool_type='avg')
context = fluid.layers.reshape(
context, [-1, context.shape[1]], inplace=False)
return context return context
...@@ -381,7 +383,7 @@ class GRUDecoderWithAttention(fluid.dygraph.Layer): ...@@ -381,7 +383,7 @@ class GRUDecoderWithAttention(fluid.dygraph.Layer):
out = self.out_layer(h) out = self.out_layer(h)
res.append(out) res.append(out)
res1 = fluid.layers.concat(res, axis=0) res1 = fluid.layers.concat(res, axis=1)
return res1 return res1
...@@ -427,7 +429,11 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -427,7 +429,11 @@ class TestDygraphOCRAttention(unittest.TestCase):
def test_while_op(self): def test_while_op(self):
seed = 90 seed = 90
epoch_num = 2 epoch_num = 2
if core.is_compiled_with_cuda():
batch_num = 20 batch_num = 20
else:
print("in CPU")
batch_num = 2
np.random.seed = seed np.random.seed = seed
image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0], image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0],
Config.DATA_SHAPE[1], Config.DATA_SHAPE[1],
...@@ -441,7 +447,6 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -441,7 +447,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
i * Config.max_length, i * Config.max_length,
dtype='int64').reshape([1, Config.max_length]))) dtype='int64').reshape([1, Config.max_length])))
print(label_in_np.shape)
label_out_np = np.arange( label_out_np = np.arange(
0, Config.max_length, 0, Config.max_length,
dtype='int64').reshape([1, Config.max_length]) dtype='int64').reshape([1, Config.max_length])
...@@ -450,7 +455,6 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -450,7 +455,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
(i - 1) * Config.max_length, (i - 1) * Config.max_length,
i * Config.max_length, i * Config.max_length,
dtype='int64').reshape([1, Config.max_length]))) dtype='int64').reshape([1, Config.max_length])))
print(label_out_np.shape)
#if Config.use_gpu: #if Config.use_gpu:
# place = fluid.CUDAPlace(0) # place = fluid.CUDAPlace(0)
#else: #else:
...@@ -484,6 +488,8 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -484,6 +488,8 @@ class TestDygraphOCRAttention(unittest.TestCase):
dy_prediction = ocr_attention(img, label_in) dy_prediction = ocr_attention(img, label_in)
label_out = fluid.layers.reshape( label_out = fluid.layers.reshape(
label_out, [-1, 1], inplace=False) label_out, [-1, 1], inplace=False)
dy_prediction = fluid.layers.reshape(
dy_prediction, [label_out.shape[0], -1], inplace=False)
loss = fluid.layers.cross_entropy( loss = fluid.layers.cross_entropy(
input=dy_prediction, label=label_out) input=dy_prediction, label=label_out)
avg_loss = fluid.layers.reduce_sum(loss) avg_loss = fluid.layers.reduce_sum(loss)
...@@ -536,6 +542,9 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -536,6 +542,9 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_prediction = ocr_attention(images, static_label_in) static_prediction = ocr_attention(images, static_label_in)
static_prediction = fluid.layers.reshape(
static_prediction, shape=[-1, Config.num_classes + 2])
cost = fluid.layers.cross_entropy( cost = fluid.layers.cross_entropy(
input=static_prediction, label=static_label_out) input=static_prediction, label=static_label_out)
static_avg_loss = fluid.layers.reduce_sum(cost) static_avg_loss = fluid.layers.reduce_sum(cost)
...@@ -558,8 +567,6 @@ class TestDygraphOCRAttention(unittest.TestCase): ...@@ -558,8 +567,6 @@ class TestDygraphOCRAttention(unittest.TestCase):
static_param_init_value[static_param_name_list[i]] = out[i] static_param_init_value[static_param_name_list[i]] = out[i]
fetch_list = [static_avg_loss.name] fetch_list = [static_avg_loss.name]
# print(static_test.name)
# fetch_list = [static_avg_loss.name, static_test.name]
fetch_list.extend(static_param_name_list) fetch_list.extend(static_param_name_list)
fetch_list.extend(static_grad_name_list) fetch_list.extend(static_grad_name_list)
for epoch in range(epoch_num): for epoch in range(epoch_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册