diff --git a/examples/ocr/seq2seq_attn.py b/examples/ocr/seq2seq_attn.py index 2315e982c9474884f4c6dc5f7932be5cd90eadbe..e0f19e2e4f8372ff700be0997fb0c19654437152 100644 --- a/examples/ocr/seq2seq_attn.py +++ b/examples/ocr/seq2seq_attn.py @@ -15,6 +15,7 @@ from __future__ import print_function import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers from paddle.fluid.layers import BeamSearchDecoder @@ -22,7 +23,7 @@ from paddle.fluid.layers import BeamSearchDecoder from paddle.text import RNNCell, RNN, DynamicDecode -class ConvBNPool(fluid.dygraph.Layer): +class ConvBNPool(paddle.nn.Layer): def __init__(self, in_ch, out_ch, @@ -81,7 +82,7 @@ class ConvBNPool(fluid.dygraph.Layer): return out -class CNN(fluid.dygraph.Layer): +class CNN(paddle.nn.Layer): def __init__(self, in_ch=1, is_test=False): super(CNN, self).__init__() self.conv_bn1 = ConvBNPool(in_ch, 16) @@ -134,7 +135,7 @@ class GRUCell(RNNCell): return [self.hidden_size] -class Encoder(fluid.dygraph.Layer): +class Encoder(paddle.nn.Layer): def __init__( self, in_channel=1, @@ -185,7 +186,7 @@ class Encoder(fluid.dygraph.Layer): return gru_bwd, encoded_vector, encoded_proj -class Attention(fluid.dygraph.Layer): +class Attention(paddle.nn.Layer): """ Neural Machine Translation by Jointly Learning to Align and Translate. https://arxiv.org/abs/1409.0473 @@ -230,7 +231,7 @@ class DecoderCell(RNNCell): return hidden, hidden -class Decoder(fluid.dygraph.Layer): +class Decoder(paddle.nn.Layer): def __init__(self, num_classes, emb_dim, encoder_size, decoder_size): super(Decoder, self).__init__() self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size)) @@ -247,7 +248,7 @@ class Decoder(fluid.dygraph.Layer): return pred -class Seq2SeqAttModel(fluid.dygraph.Layer): +class Seq2SeqAttModel(paddle.nn.Layer): def __init__( self, in_channle=1, @@ -320,7 +321,7 @@ class Seq2SeqAttInferModel(Seq2SeqAttModel): return rs -class WeightCrossEntropy(fluid.dygraph.Layer): +class WeightCrossEntropy(paddle.nn.Layer): def __init__(self): super(WeightCrossEntropy, self).__init__()