diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py index f63c82856bbcc0a0741e563c251f547361432daa..f978ae58bac912eab6ca6a6524f9cc8ef6cb2108 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_ocr_attention_model.py @@ -90,7 +90,7 @@ class ConvBNPool(fluid.dygraph.Layer): 3, padding=1, param_attr=conv_param_0, - bias_attr=None, + bias_attr=False, act=None, use_cudnn=use_cudnn) self.bn_0_layer = BatchNorm( @@ -102,7 +102,7 @@ class ConvBNPool(fluid.dygraph.Layer): filter_size=3, padding=1, param_attr=conv_param_1, - bias_attr=None, + bias_attr=False, act=None, use_cudnn=use_cudnn) self.bn_1_layer = BatchNorm( @@ -301,7 +301,11 @@ class SimpleAttention(fluid.dygraph.Layer): decoder_size, act=None, bias_attr=False) - self.fc_2 = FC(self.full_name(), 1, act=None, bias_attr=False) + self.fc_2 = FC(self.full_name(), + 1, + num_flatten_dims=2, + act=None, + bias_attr=False) def _build_once(self, encoder_vec, encoder_proj, decoder_state): pass @@ -317,19 +321,17 @@ class SimpleAttention(fluid.dygraph.Layer): decoder_state_expand) concated = fluid.layers.tanh(x=concated) attention_weight = self.fc_2(concated) + weights_reshape = fluid.layers.reshape( - x=attention_weight, shape=[-1], inplace=False) + x=attention_weight, + shape=[attention_weight.shape[0], attention_weight.shape[1]], + inplace=False) + + weights_reshape = fluid.layers.softmax(weights_reshape) scaled = fluid.layers.elementwise_mul( x=encoder_vec, y=weights_reshape, axis=0) - scaled = fluid.layers.transpose(scaled, [0, 2, 1]) - scaled = fluid.layers.reshape( - scaled, [-1, scaled.shape[1], scaled.shape[2], 1], inplace=False) - context = fluid.layers.pool2d( - input=scaled, - pool_size=[scaled.shape[2], scaled.shape[3]], - pool_type='avg') - context = fluid.layers.reshape( - context, [-1, context.shape[1]], inplace=False) + context = fluid.layers.reduce_sum(scaled, dim=1) + return context @@ -381,7 +383,7 @@ class GRUDecoderWithAttention(fluid.dygraph.Layer): out = self.out_layer(h) res.append(out) - res1 = fluid.layers.concat(res, axis=0) + res1 = fluid.layers.concat(res, axis=1) return res1 @@ -427,7 +429,11 @@ class TestDygraphOCRAttention(unittest.TestCase): def test_while_op(self): seed = 90 epoch_num = 2 - batch_num = 20 + if core.is_compiled_with_cuda(): + batch_num = 20 + else: + print("in CPU") + batch_num = 2 np.random.seed = seed image_np = np.random.randn(Config.batch_size, Config.DATA_SHAPE[0], Config.DATA_SHAPE[1], @@ -441,7 +447,6 @@ class TestDygraphOCRAttention(unittest.TestCase): i * Config.max_length, dtype='int64').reshape([1, Config.max_length]))) - print(label_in_np.shape) label_out_np = np.arange( 0, Config.max_length, dtype='int64').reshape([1, Config.max_length]) @@ -450,7 +455,6 @@ class TestDygraphOCRAttention(unittest.TestCase): (i - 1) * Config.max_length, i * Config.max_length, dtype='int64').reshape([1, Config.max_length]))) - print(label_out_np.shape) #if Config.use_gpu: # place = fluid.CUDAPlace(0) #else: @@ -484,6 +488,8 @@ class TestDygraphOCRAttention(unittest.TestCase): dy_prediction = ocr_attention(img, label_in) label_out = fluid.layers.reshape( label_out, [-1, 1], inplace=False) + dy_prediction = fluid.layers.reshape( + dy_prediction, [label_out.shape[0], -1], inplace=False) loss = fluid.layers.cross_entropy( input=dy_prediction, label=label_out) avg_loss = fluid.layers.reduce_sum(loss) @@ -536,6 +542,9 @@ class TestDygraphOCRAttention(unittest.TestCase): static_prediction = ocr_attention(images, static_label_in) + static_prediction = fluid.layers.reshape( + static_prediction, shape=[-1, Config.num_classes + 2]) + cost = fluid.layers.cross_entropy( input=static_prediction, label=static_label_out) static_avg_loss = fluid.layers.reduce_sum(cost) @@ -558,8 +567,6 @@ class TestDygraphOCRAttention(unittest.TestCase): static_param_init_value[static_param_name_list[i]] = out[i] fetch_list = [static_avg_loss.name] - # print(static_test.name) - # fetch_list = [static_avg_loss.name, static_test.name] fetch_list.extend(static_param_name_list) fetch_list.extend(static_grad_name_list) for epoch in range(epoch_num):