未验证 提交 38af05fe 编写于 作者: J JesseyXujin 提交者: GitHub

adapte senta for paddle 1.7 release (#4194)

* adapte senta

* upgrade senta

* modify gru network

* fix bugs in gru and bigru net

* fix channels in cnn net
上级 ac4aa52d
...@@ -133,18 +133,18 @@ def train(): ...@@ -133,18 +133,18 @@ def train():
epoch=args.epoch, epoch=args.epoch,
shuffle=False) shuffle=False)
if args.model_type == 'cnn_net': if args.model_type == 'cnn_net':
model = nets.CNN("cnn_net", args.vocab_size, args.batch_size, model = nets.CNN( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'bow_net': elif args.model_type == 'bow_net':
model = nets.BOW("bow_net", args.vocab_size, args.batch_size, model = nets.BOW( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'gru_net': elif args.model_type == 'gru_net':
model = nets.GRU("gru_net", args.vocab_size, args.batch_size, model = nets.GRU( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'bigru_net': elif args.model_type == 'bigru_net':
model = nets.BiGRU("bigru_net", args.vocab_size, args.batch_size, model = nets.BiGRU( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr,parameter_list=model.parameters())
steps = 0 steps = 0
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
gru_hidden_data = np.zeros((args.batch_size, 128), dtype='float32') gru_hidden_data = np.zeros((args.batch_size, 128), dtype='float32')
...@@ -162,7 +162,7 @@ def train(): ...@@ -162,7 +162,7 @@ def train():
'constant', 'constant',
constant_values=(args.vocab_size)) constant_values=(args.vocab_size))
for x in data for x in data
]).astype('int64').reshape(-1, 1)) ]).astype('int64').reshape(-1))
label = to_variable( label = to_variable(
np.array([x[1] for x in data]).astype('int64').reshape( np.array([x[1] for x in data]).astype('int64').reshape(
args.batch_size, 1)) args.batch_size, 1))
...@@ -203,11 +203,11 @@ def train(): ...@@ -203,11 +203,11 @@ def train():
'constant', 'constant',
constant_values=(args.vocab_size)) constant_values=(args.vocab_size))
for x in eval_data for x in eval_data
]).astype('int64').reshape(1, -1) ]).astype('int64').reshape(-1)
eval_label = to_variable( eval_label = to_variable(
np.array([x[1] for x in eval_data]).astype( np.array([x[1] for x in eval_data]).astype(
'int64').reshape(args.batch_size, 1)) 'int64').reshape(args.batch_size, 1))
eval_doc = to_variable(eval_np_doc.reshape(-1, 1)) eval_doc = to_variable(eval_np_doc)
eval_avg_cost, eval_prediction, eval_acc = model( eval_avg_cost, eval_prediction, eval_acc = model(
eval_doc, eval_label) eval_doc, eval_label)
eval_np_mask = ( eval_np_mask = (
...@@ -262,16 +262,16 @@ def infer(): ...@@ -262,16 +262,16 @@ def infer():
epoch=args.epoch, epoch=args.epoch,
shuffle=False) shuffle=False)
if args.model_type == 'cnn_net': if args.model_type == 'cnn_net':
model_infer = nets.CNN("cnn_net", args.vocab_size, args.batch_size, model_infer = nets.CNN( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'bow_net': elif args.model_type == 'bow_net':
model_infer = nets.BOW("bow_net", args.vocab_size, args.batch_size, model_infer = nets.BOW( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'gru_net': elif args.model_type == 'gru_net':
model_infer = nets.GRU("gru_net", args.vocab_size, args.batch_size, model_infer = nets.GRU( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
elif args.model_type == 'bigru_net': elif args.model_type == 'bigru_net':
model_infer = nets.BiGRU("bigru_net", args.vocab_size, args.batch_size, model_infer = nets.BiGRU( args.vocab_size, args.batch_size,
args.padding_size) args.padding_size)
print('Do inferring ...... ') print('Do inferring ...... ')
restore, _ = fluid.load_dygraph(args.checkpoints) restore, _ = fluid.load_dygraph(args.checkpoints)
...@@ -288,8 +288,8 @@ def infer(): ...@@ -288,8 +288,8 @@ def infer():
'constant', 'constant',
constant_values=(args.vocab_size)) constant_values=(args.vocab_size))
for x in data for x in data
]).astype('int64').reshape(-1, 1) ]).astype('int64').reshape(-1)
doc = to_variable(np_doc.reshape(-1, 1)) doc = to_variable(np_doc)
label = to_variable( label = to_variable(
np.array([x[1] for x in data]).astype('int64').reshape( np.array([x[1] for x in data]).astype('int64').reshape(
args.batch_size, 1)) args.batch_size, 1))
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC, Embedding from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Embedding
from paddle.fluid.dygraph import GRUUnit from paddle.fluid.dygraph import GRUUnit
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
import numpy as np import numpy as np
...@@ -20,7 +20,6 @@ import numpy as np ...@@ -20,7 +20,6 @@ import numpy as np
class DynamicGRU(fluid.dygraph.Layer): class DynamicGRU(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
scope_name,
size, size,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
...@@ -30,9 +29,8 @@ class DynamicGRU(fluid.dygraph.Layer): ...@@ -30,9 +29,8 @@ class DynamicGRU(fluid.dygraph.Layer):
h_0=None, h_0=None,
origin_mode=False, origin_mode=False,
init_size = None): init_size = None):
super(DynamicGRU, self).__init__(scope_name) super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit( self.gru_unit = GRUUnit(
self.full_name(),
size * 3, size * 3,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
...@@ -60,15 +58,14 @@ class DynamicGRU(fluid.dygraph.Layer): ...@@ -60,15 +58,14 @@ class DynamicGRU(fluid.dygraph.Layer):
class SimpleConvPool(fluid.dygraph.Layer): class SimpleConvPool(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
name_scope, num_channels,
num_filters, num_filters,
filter_size, filter_size,
use_cudnn=False, use_cudnn=False,
batch_size=None): batch_size=None):
super(SimpleConvPool, self).__init__(name_scope) super(SimpleConvPool, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self._conv2d = Conv2D( self._conv2d = Conv2D(num_channels = num_channels,
self.full_name(),
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
padding=[1, 1], padding=[1, 1],
...@@ -83,38 +80,38 @@ class SimpleConvPool(fluid.dygraph.Layer): ...@@ -83,38 +80,38 @@ class SimpleConvPool(fluid.dygraph.Layer):
class CNN(fluid.dygraph.Layer): class CNN(fluid.dygraph.Layer):
def __init__(self, name_scope, dict_dim, batch_size, seq_len): def __init__(self, dict_dim, batch_size, seq_len):
super(CNN, self).__init__(name_scope) super(CNN, self).__init__()
self.dict_dim = dict_dim self.dict_dim = dict_dim
self.emb_dim = 128 self.emb_dim = 128
self.hid_dim = 128 self.hid_dim = 128
self.fc_hid_dim = 96 self.fc_hid_dim = 96
self.class_dim = 2 self.class_dim = 2
self.channels = 1
self.win_size = [3, self.hid_dim] self.win_size = [3, self.hid_dim]
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len = seq_len self.seq_len = seq_len
self.embedding = Embedding( self.embedding = Embedding(
self.full_name(),
size=[self.dict_dim + 1, self.emb_dim], size=[self.dict_dim + 1, self.emb_dim],
dtype='float32', dtype='float32',
is_sparse=False) is_sparse=False)
self._simple_conv_pool_1 = SimpleConvPool( self._simple_conv_pool_1 = SimpleConvPool(
self.full_name(), self.channels,
self.hid_dim, self.hid_dim,
self.win_size, self.win_size,
batch_size=self.batch_size) batch_size=self.batch_size)
self._fc1 = FC(self.full_name(), size=self.fc_hid_dim, act="softmax") self._fc1 = Linear(input_dim = self.hid_dim*self.seq_len, output_dim=self.fc_hid_dim, act="softmax")
self._fc_prediction = FC(self.full_name(), self._fc_prediction = Linear(input_dim = self.fc_hid_dim,
size=self.class_dim, output_dim = self.class_dim,
act="softmax") act="softmax")
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
emb = self.embedding(inputs) emb = self.embedding(inputs)
o_np_mask = (inputs.numpy() != self.dict_dim).astype('float32') o_np_mask = (inputs.numpy().reshape(-1,1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand( mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim]) to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb emb = emb * mask_emb
emb = fluid.layers.reshape( emb = fluid.layers.reshape(
emb, shape=[-1, 1, self.seq_len, self.hid_dim]) emb, shape=[-1, self.channels , self.seq_len, self.hid_dim])
conv_3 = self._simple_conv_pool_1(emb) conv_3 = self._simple_conv_pool_1(emb)
fc_1 = self._fc1(conv_3) fc_1 = self._fc1(conv_3)
prediction = self._fc_prediction(fc_1) prediction = self._fc_prediction(fc_1)
...@@ -128,8 +125,8 @@ class CNN(fluid.dygraph.Layer): ...@@ -128,8 +125,8 @@ class CNN(fluid.dygraph.Layer):
class BOW(fluid.dygraph.Layer): class BOW(fluid.dygraph.Layer):
def __init__(self, name_scope, dict_dim, batch_size, seq_len): def __init__(self, dict_dim, batch_size, seq_len):
super(BOW, self).__init__(name_scope) super(BOW, self).__init__()
self.dict_dim = dict_dim self.dict_dim = dict_dim
self.emb_dim = 128 self.emb_dim = 128
self.hid_dim = 128 self.hid_dim = 128
...@@ -138,18 +135,17 @@ class BOW(fluid.dygraph.Layer): ...@@ -138,18 +135,17 @@ class BOW(fluid.dygraph.Layer):
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len = seq_len self.seq_len = seq_len
self.embedding = Embedding( self.embedding = Embedding(
self.full_name(),
size=[self.dict_dim + 1, self.emb_dim], size=[self.dict_dim + 1, self.emb_dim],
dtype='float32', dtype='float32',
is_sparse=False) is_sparse=False)
self._fc1 = FC(self.full_name(), size=self.hid_dim, act="tanh") self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim, act="tanh")
self._fc2 = FC(self.full_name(), size=self.fc_hid_dim, act="tanh") self._fc2 = Linear(input_dim = self.hid_dim, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = FC(self.full_name(), self._fc_prediction = Linear(input_dim = self.fc_hid_dim,
size=self.class_dim, output_dim = self.class_dim,
act="softmax") act="softmax")
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
emb = self.embedding(inputs) emb = self.embedding(inputs)
o_np_mask = (inputs.numpy() != self.dict_dim).astype('float32') o_np_mask = (inputs.numpy().reshape(-1,1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand( mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim]) to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb emb = emb * mask_emb
...@@ -170,8 +166,8 @@ class BOW(fluid.dygraph.Layer): ...@@ -170,8 +166,8 @@ class BOW(fluid.dygraph.Layer):
class GRU(fluid.dygraph.Layer): class GRU(fluid.dygraph.Layer):
def __init__(self, name_scope, dict_dim, batch_size, seq_len): def __init__(self, dict_dim, batch_size, seq_len):
super(GRU, self).__init__(name_scope) super(GRU, self).__init__()
self.dict_dim = dict_dim self.dict_dim = dict_dim
self.emb_dim = 128 self.emb_dim = 128
self.hid_dim = 128 self.hid_dim = 128
...@@ -180,22 +176,21 @@ class GRU(fluid.dygraph.Layer): ...@@ -180,22 +176,21 @@ class GRU(fluid.dygraph.Layer):
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len = seq_len self.seq_len = seq_len
self.embedding = Embedding( self.embedding = Embedding(
self.full_name(),
size=[self.dict_dim + 1, self.emb_dim], size=[self.dict_dim + 1, self.emb_dim],
dtype='float32', dtype='float32',
param_attr=fluid.ParamAttr(learning_rate=30), param_attr=fluid.ParamAttr(learning_rate=30),
is_sparse=False) is_sparse=False)
h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32") h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
h_0 = to_variable(h_0) h_0 = to_variable(h_0)
self._fc1 = FC(self.full_name(), size=self.hid_dim*3, num_flatten_dims=2) self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim*3)
self._fc2 = FC(self.full_name(), size=self.fc_hid_dim, act="tanh") self._fc2 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = FC(self.full_name(), self._fc_prediction = Linear(input_dim=self.fc_hid_dim,
size=self.class_dim, output_dim=self.class_dim,
act="softmax") act="softmax")
self._gru = DynamicGRU(self.full_name(), size= self.hid_dim, h_0=h_0) self._gru = DynamicGRU( size= self.hid_dim, h_0=h_0)
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
emb = self.embedding(inputs) emb = self.embedding(inputs)
o_np_mask =to_variable(inputs.numpy() != self.dict_dim).astype('float32') o_np_mask =to_variable(inputs.numpy().reshape(-1,1) != self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand( mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim]) to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb emb = emb * mask_emb
...@@ -216,8 +211,8 @@ class GRU(fluid.dygraph.Layer): ...@@ -216,8 +211,8 @@ class GRU(fluid.dygraph.Layer):
class BiGRU(fluid.dygraph.Layer): class BiGRU(fluid.dygraph.Layer):
def __init__(self, name_scope, dict_dim, batch_size, seq_len): def __init__(self, dict_dim, batch_size, seq_len):
super(BiGRU, self).__init__(name_scope) super(BiGRU, self).__init__()
self.dict_dim = dict_dim self.dict_dim = dict_dim
self.emb_dim = 128 self.emb_dim = 128
self.hid_dim = 128 self.hid_dim = 128
...@@ -226,24 +221,23 @@ class BiGRU(fluid.dygraph.Layer): ...@@ -226,24 +221,23 @@ class BiGRU(fluid.dygraph.Layer):
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len = seq_len self.seq_len = seq_len
self.embedding = Embedding( self.embedding = Embedding(
self.full_name(),
size=[self.dict_dim + 1, self.emb_dim], size=[self.dict_dim + 1, self.emb_dim],
dtype='float32', dtype='float32',
param_attr=fluid.ParamAttr(learning_rate=30), param_attr=fluid.ParamAttr(learning_rate=30),
is_sparse=False) is_sparse=False)
h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32") h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
h_0 = to_variable(h_0) h_0 = to_variable(h_0)
self._fc1 = FC(self.full_name(), size=self.hid_dim*3, num_flatten_dims=2) self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim*3)
self._fc2 = FC(self.full_name(), size=self.fc_hid_dim, act="tanh") self._fc2 = Linear(input_dim = self.hid_dim*2, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = FC(self.full_name(), self._fc_prediction = Linear(input_dim=self.fc_hid_dim,
size=self.class_dim, output_dim=self.class_dim,
act="softmax") act="softmax")
self._gru_forward = DynamicGRU(self.full_name(), size= self.hid_dim, h_0=h_0,is_reverse=False) self._gru_forward = DynamicGRU( size= self.hid_dim, h_0=h_0,is_reverse=False)
self._gru_backward = DynamicGRU(self.full_name(), size= self.hid_dim, h_0=h_0,is_reverse=True) self._gru_backward = DynamicGRU( size= self.hid_dim, h_0=h_0,is_reverse=True)
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
emb = self.embedding(inputs) emb = self.embedding(inputs)
o_np_mask =to_variable(inputs.numpy() != self.dict_dim).astype('float32') o_np_mask =to_variable(inputs.numpy() .reshape(-1,1)!= self.dict_dim).astype('float32')
mask_emb = fluid.layers.expand( mask_emb = fluid.layers.expand(
to_variable(o_np_mask), [1, self.hid_dim]) to_variable(o_np_mask), [1, self.hid_dim])
emb = emb * mask_emb emb = emb * mask_emb
...@@ -255,6 +249,7 @@ class BiGRU(fluid.dygraph.Layer): ...@@ -255,6 +249,7 @@ class BiGRU(fluid.dygraph.Layer):
gru_backward_tanh = fluid.layers.tanh(gru_backward) gru_backward_tanh = fluid.layers.tanh(gru_backward)
encoded_vector = fluid.layers.concat( encoded_vector = fluid.layers.concat(
input=[gru_forward_tanh, gru_backward_tanh], axis=2) input=[gru_forward_tanh, gru_backward_tanh], axis=2)
encoded_vector = fluid.layers.reduce_max(encoded_vector, dim=1)
fc_2 = self._fc2(encoded_vector) fc_2 = self._fc2(encoded_vector)
prediction = self._fc_prediction(fc_2) prediction = self._fc_prediction(fc_2)
if label: if label:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册