提交 c8d2dff9 编写于 作者: Q qingqing01

Change fluid.dygraph.Layer to paddle.nn.Layer

上级 080748fc
...@@ -15,6 +15,7 @@ from __future__ import print_function ...@@ -15,6 +15,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.layers import BeamSearchDecoder from paddle.fluid.layers import BeamSearchDecoder
...@@ -22,7 +23,7 @@ from paddle.fluid.layers import BeamSearchDecoder ...@@ -22,7 +23,7 @@ from paddle.fluid.layers import BeamSearchDecoder
from paddle.text import RNNCell, RNN, DynamicDecode from paddle.text import RNNCell, RNN, DynamicDecode
class ConvBNPool(fluid.dygraph.Layer): class ConvBNPool(paddle.nn.Layer):
def __init__(self, def __init__(self,
in_ch, in_ch,
out_ch, out_ch,
...@@ -81,7 +82,7 @@ class ConvBNPool(fluid.dygraph.Layer): ...@@ -81,7 +82,7 @@ class ConvBNPool(fluid.dygraph.Layer):
return out return out
class CNN(fluid.dygraph.Layer): class CNN(paddle.nn.Layer):
def __init__(self, in_ch=1, is_test=False): def __init__(self, in_ch=1, is_test=False):
super(CNN, self).__init__() super(CNN, self).__init__()
self.conv_bn1 = ConvBNPool(in_ch, 16) self.conv_bn1 = ConvBNPool(in_ch, 16)
...@@ -134,7 +135,7 @@ class GRUCell(RNNCell): ...@@ -134,7 +135,7 @@ class GRUCell(RNNCell):
return [self.hidden_size] return [self.hidden_size]
class Encoder(fluid.dygraph.Layer): class Encoder(paddle.nn.Layer):
def __init__( def __init__(
self, self,
in_channel=1, in_channel=1,
...@@ -185,7 +186,7 @@ class Encoder(fluid.dygraph.Layer): ...@@ -185,7 +186,7 @@ class Encoder(fluid.dygraph.Layer):
return gru_bwd, encoded_vector, encoded_proj return gru_bwd, encoded_vector, encoded_proj
class Attention(fluid.dygraph.Layer): class Attention(paddle.nn.Layer):
""" """
Neural Machine Translation by Jointly Learning to Align and Translate. Neural Machine Translation by Jointly Learning to Align and Translate.
https://arxiv.org/abs/1409.0473 https://arxiv.org/abs/1409.0473
...@@ -230,7 +231,7 @@ class DecoderCell(RNNCell): ...@@ -230,7 +231,7 @@ class DecoderCell(RNNCell):
return hidden, hidden return hidden, hidden
class Decoder(fluid.dygraph.Layer): class Decoder(paddle.nn.Layer):
def __init__(self, num_classes, emb_dim, encoder_size, decoder_size): def __init__(self, num_classes, emb_dim, encoder_size, decoder_size):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size)) self.decoder_attention = RNN(DecoderCell(encoder_size, decoder_size))
...@@ -247,7 +248,7 @@ class Decoder(fluid.dygraph.Layer): ...@@ -247,7 +248,7 @@ class Decoder(fluid.dygraph.Layer):
return pred return pred
class Seq2SeqAttModel(fluid.dygraph.Layer): class Seq2SeqAttModel(paddle.nn.Layer):
def __init__( def __init__(
self, self,
in_channle=1, in_channle=1,
...@@ -320,7 +321,7 @@ class Seq2SeqAttInferModel(Seq2SeqAttModel): ...@@ -320,7 +321,7 @@ class Seq2SeqAttInferModel(Seq2SeqAttModel):
return rs return rs
class WeightCrossEntropy(fluid.dygraph.Layer): class WeightCrossEntropy(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(WeightCrossEntropy, self).__init__() super(WeightCrossEntropy, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册