提交 eb20b652 编写于 作者: G guosheng

Add more unit tests for apis in text.py.

Rename some apis in text.py.
上级 60917f41
...@@ -16,7 +16,7 @@ from paddle.fluid.dygraph.nn import Linear, Embedding ...@@ -16,7 +16,7 @@ from paddle.fluid.dygraph.nn import Linear, Embedding
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
import numpy as np import numpy as np
from hapi.model import Model from hapi.model import Model
from hapi.text.text import GRUEncoderLayer as BiGRUEncoder from hapi.text.text import GRUEncoder as BiGRUEncoder
from hapi.text.test import BOWEncoder, CNNEncoder, GRUEncoder from hapi.text.test import BOWEncoder, CNNEncoder, GRUEncoder
...@@ -36,14 +36,18 @@ class CNN(Model): ...@@ -36,14 +36,18 @@ class CNN(Model):
dict_size=self.dict_dim + 1, dict_size=self.dict_dim + 1,
emb_dim=self.emb_dim, emb_dim=self.emb_dim,
seq_len=self.seq_len, seq_len=self.seq_len,
filter_size= self.win_size, filter_size=self.win_size,
num_filters= self.hid_dim, num_filters=self.hid_dim,
hidden_dim= self.hid_dim, hidden_dim=self.hid_dim,
padding_idx=None, padding_idx=None,
act='tanh') act='tanh')
self._fc1 = Linear(input_dim = self.hid_dim*self.seq_len, output_dim=self.fc_hid_dim, act="softmax") self._fc1 = Linear(
self._fc_prediction = Linear(input_dim = self.fc_hid_dim, input_dim=self.hid_dim * self.seq_len,
output_dim = self.class_dim, output_dim=self.fc_hid_dim,
act="softmax")
self._fc_prediction = Linear(
input_dim=self.fc_hid_dim,
output_dim=self.class_dim,
act="softmax") act="softmax")
def forward(self, inputs): def forward(self, inputs):
...@@ -69,10 +73,13 @@ class BOW(Model): ...@@ -69,10 +73,13 @@ class BOW(Model):
padding_idx=None, padding_idx=None,
bow_dim=self.hid_dim, bow_dim=self.hid_dim,
seq_len=self.seq_len) seq_len=self.seq_len)
self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim, act="tanh") self._fc1 = Linear(
self._fc2 = Linear(input_dim = self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") input_dim=self.hid_dim, output_dim=self.hid_dim, act="tanh")
self._fc_prediction = Linear(input_dim = self.fc_hid_dim, self._fc2 = Linear(
output_dim = self.class_dim, input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = Linear(
input_dim=self.fc_hid_dim,
output_dim=self.class_dim,
act="softmax") act="softmax")
def forward(self, inputs): def forward(self, inputs):
...@@ -94,8 +101,10 @@ class GRU(Model): ...@@ -94,8 +101,10 @@ class GRU(Model):
self.class_dim = 2 self.class_dim = 2
self.batch_size = batch_size self.batch_size = batch_size
self.seq_len = seq_len self.seq_len = seq_len
self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh") self._fc1 = Linear(
self._fc_prediction = Linear(input_dim=self.fc_hid_dim, input_dim=self.hid_dim, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = Linear(
input_dim=self.fc_hid_dim,
output_dim=self.class_dim, output_dim=self.class_dim,
act="softmax") act="softmax")
self._encoder = GRUEncoder( self._encoder = GRUEncoder(
...@@ -130,9 +139,11 @@ class BiGRU(Model): ...@@ -130,9 +139,11 @@ class BiGRU(Model):
is_sparse=False) is_sparse=False)
h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32") h_0 = np.zeros((self.batch_size, self.hid_dim), dtype="float32")
h_0 = to_variable(h_0) h_0 = to_variable(h_0)
self._fc1 = Linear(input_dim = self.hid_dim, output_dim=self.hid_dim*3) self._fc1 = Linear(input_dim=self.hid_dim, output_dim=self.hid_dim * 3)
self._fc2 = Linear(input_dim = self.hid_dim*2, output_dim=self.fc_hid_dim, act="tanh") self._fc2 = Linear(
self._fc_prediction = Linear(input_dim=self.fc_hid_dim, input_dim=self.hid_dim * 2, output_dim=self.fc_hid_dim, act="tanh")
self._fc_prediction = Linear(
input_dim=self.fc_hid_dim,
output_dim=self.class_dim, output_dim=self.class_dim,
act="softmax") act="softmax")
self._encoder = BiGRUEncoder( self._encoder = BiGRUEncoder(
...@@ -144,7 +155,8 @@ class BiGRU(Model): ...@@ -144,7 +155,8 @@ class BiGRU(Model):
def forward(self, inputs): def forward(self, inputs):
emb = self.embedding(inputs) emb = self.embedding(inputs)
emb = fluid.layers.reshape(emb, shape=[self.batch_size, -1, self.hid_dim]) emb = fluid.layers.reshape(
emb, shape=[self.batch_size, -1, self.hid_dim])
fc_1 = self._fc1(emb) fc_1 = self._fc1(emb)
encoded_vector = self._encoder(fc_1) encoded_vector = self._encoder(fc_1)
encoded_vector = fluid.layers.tanh(encoded_vector) encoded_vector = fluid.layers.tanh(encoded_vector)
......
...@@ -21,7 +21,7 @@ import paddle.fluid.layers as layers ...@@ -21,7 +21,7 @@ import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from hapi.model import Model, CrossEntropy, Loss from hapi.model import Model, CrossEntropy, Loss
from hapi.text import TransformerCell, TransformerBeamSearchDecoder, DynamicDecode from hapi.text import TransformerBeamSearchDecoder, DynamicDecode
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
...@@ -606,6 +606,27 @@ class Transformer(Model): ...@@ -606,6 +606,27 @@ class Transformer(Model):
return predict return predict
class TransformerCell(Layer):
"""
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
"""
def __init__(self, decoder):
super(TransformerCell, self).__init__()
self.decoder = decoder
def forward(self, inputs, states, trg_src_attn_bias, enc_output,
static_caches):
trg_word, trg_pos = inputs
for cache, static_cache in zip(states, static_caches):
cache.update(static_cache)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, states)
new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
return logits, new_states
class InferTransformer(Transformer): class InferTransformer(Transformer):
""" """
model for prediction model for prediction
......
...@@ -25,8 +25,8 @@ from paddle.fluid.dygraph import Embedding, Linear, Layer ...@@ -25,8 +25,8 @@ from paddle.fluid.dygraph import Embedding, Linear, Layer
from paddle.fluid.layers import BeamSearchDecoder from paddle.fluid.layers import BeamSearchDecoder
import hapi.text as text import hapi.text as text
from hapi.model import Model, Input, set_device from hapi.model import Model, Input, set_device
from hapi.text import BasicLSTMCell, BasicGRUCell, RNN, DynamicDecode, MultiHeadAttention, TransformerEncoder # from hapi.text.text import BasicLSTMCell, BasicGRUCell, RNN, DynamicDecode, MultiHeadAttention, TransformerEncoder, TransformerCell
from hapi.text import * from hapi.text.text import *
def sigmoid(x): def sigmoid(x):
...@@ -187,7 +187,7 @@ class TestBasicLSTM(ModuleApiTest): ...@@ -187,7 +187,7 @@ class TestBasicLSTM(ModuleApiTest):
Input( Input(
[None, None, self.inputs[-1].shape[-1]], [None, None, self.inputs[-1].shape[-1]],
"float32", "float32",
name="input") name="input"),
] ]
return inputs return inputs
...@@ -216,7 +216,7 @@ class TestBasicGRU(ModuleApiTest): ...@@ -216,7 +216,7 @@ class TestBasicGRU(ModuleApiTest):
Input( Input(
[None, None, self.inputs[-1].shape[-1]], [None, None, self.inputs[-1].shape[-1]],
"float32", "float32",
name="input") name="input"),
] ]
return inputs return inputs
...@@ -270,10 +270,9 @@ class TestBeamSearch(ModuleApiTest): ...@@ -270,10 +270,9 @@ class TestBeamSearch(ModuleApiTest):
Input( Input(
[None, self.inputs[0].shape[-1]], [None, self.inputs[0].shape[-1]],
"float32", "float32",
name="init_hidden"), Input( name="init_hidden"),
[None, self.inputs[1].shape[-1]], Input(
"float32", [None, self.inputs[1].shape[-1]], "float32", name="init_cell"),
name="init_cell")
] ]
return inputs return inputs
...@@ -328,10 +327,11 @@ class TestTransformerEncoder(ModuleApiTest): ...@@ -328,10 +327,11 @@ class TestTransformerEncoder(ModuleApiTest):
Input( Input(
[None, None, self.inputs[0].shape[-1]], [None, None, self.inputs[0].shape[-1]],
"float32", "float32",
name="enc_input"), Input( name="enc_input"),
Input(
[None, self.inputs[1].shape[1], None, None], [None, self.inputs[1].shape[1], None, None],
"float32", "float32",
name="attn_bias") name="attn_bias"),
] ]
return inputs return inputs
...@@ -395,16 +395,19 @@ class TestTransformerDecoder(TestTransformerEncoder): ...@@ -395,16 +395,19 @@ class TestTransformerDecoder(TestTransformerEncoder):
Input( Input(
[None, None, self.inputs[0].shape[-1]], [None, None, self.inputs[0].shape[-1]],
"float32", "float32",
name="dec_input"), Input( name="dec_input"),
Input(
[None, None, self.inputs[0].shape[-1]], [None, None, self.inputs[0].shape[-1]],
"float32", "float32",
name="enc_output"), Input( name="enc_output"),
Input(
[None, self.inputs[-1].shape[1], None, None], [None, self.inputs[-1].shape[1], None, None],
"float32", "float32",
name="self_attn_bias"), Input( name="self_attn_bias"),
Input(
[None, self.inputs[-1].shape[1], None, None], [None, self.inputs[-1].shape[1], None, None],
"float32", "float32",
name="cross_attn_bias") name="cross_attn_bias"),
] ]
return inputs return inputs
...@@ -414,16 +417,21 @@ class TestTransformerDecoder(TestTransformerEncoder): ...@@ -414,16 +417,21 @@ class TestTransformerDecoder(TestTransformerEncoder):
class TestTransformerBeamSearchDecoder(ModuleApiTest): class TestTransformerBeamSearchDecoder(ModuleApiTest):
def setUp(self): def setUp(self):
shape = (8, 32)
self.inputs = [ self.inputs = [
np.random.random(shape).astype("float32"), # encoder output: [batch_size, seq_len, hidden_size]
np.random.random(shape).astype("float32") np.random.random([2, 5, 128]).astype("float32"),
# cross attention bias: [batch_size, n_head, seq_len, seq_len]
np.random.randint(0, 1, [2, 2, 1, 5]).astype("float32") * -1e9
] ]
self.outputs = None self.outputs = None
self.attrs = { self.attrs = {
"vocab_size": 100, "vocab_size": 100,
"embed_dim": 32, "n_layer": 2,
"hidden_size": 32, "n_head": 2,
"d_key": 64,
"d_value": 64,
"d_model": 128,
"d_inner_hid": 128
} }
self.param_states = {} self.param_states = {}
...@@ -445,13 +453,24 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest): ...@@ -445,13 +453,24 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest):
eos_id=1, eos_id=1,
beam_size=4, beam_size=4,
max_step_num=20): max_step_num=20):
embedder = Embedding(size=[vocab_size, d_model]) self.beam_size = beam_size
def embeder_init(self, size):
Layer.__init__(self)
self.embedder = Embedding(size)
Embedder = type("Embedder", (Layer, ), {
"__init__": embeder_init,
"forward": lambda self, word, pos: self.embedder(word)
})
embedder = Embedder(size=[vocab_size, d_model])
output_layer = Linear(d_model, vocab_size) output_layer = Linear(d_model, vocab_size)
decoder = TransformerDecoder(n_layer, n_head, d_key, d_value, d_model, self.decoder = TransformerDecoder(
d_inner_hid, prepostprocess_dropout, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
attention_dropout, relu_dropout, prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd) preprocess_cmd, postprocess_cmd)
transformer_cell = TransformerCell(decoder) transformer_cell = TransformerCell(self.decoder, embedder,
output_layer)
self.beam_search_decoder = DynamicDecode( self.beam_search_decoder = DynamicDecode(
TransformerBeamSearchDecoder( TransformerBeamSearchDecoder(
transformer_cell, transformer_cell,
...@@ -464,23 +483,12 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest): ...@@ -464,23 +483,12 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest):
@staticmethod @staticmethod
def model_forward(self, enc_output, trg_src_attn_bias): def model_forward(self, enc_output, trg_src_attn_bias):
caches = [{ caches = self.decoder.prepare_incremental_cache(enc_output)
"k": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch( enc_output = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
enc_output, self.beam_size) enc_output, self.beam_size)
trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch( trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
trg_src_attn_bias, self.beam_size) trg_src_attn_bias, self.beam_size)
static_caches = self.decoder.decoder.prepare_static_cache(enc_output) static_caches = self.decoder.prepare_static_cache(enc_output)
rs, _ = self.beam_search_decoder( rs, _ = self.beam_search_decoder(
inits=caches, inits=caches,
enc_output=enc_output, enc_output=enc_output,
...@@ -491,12 +499,42 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest): ...@@ -491,12 +499,42 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest):
def make_inputs(self): def make_inputs(self):
inputs = [ inputs = [
Input( Input(
[None, self.inputs[0].shape[-1]], [None, None, self.inputs[0].shape[-1]],
"float32",
name="enc_output"),
Input(
[None, self.inputs[1].shape[1], None, None],
"float32", "float32",
name="init_hidden"), Input( name="trg_src_attn_bias"),
[None, self.inputs[1].shape[-1]], ]
return inputs
def test_check_output(self):
self.check_output()
class TestSequenceTagging(ModuleApiTest):
def setUp(self):
shape = (2, 4, 128)
self.inputs = [np.random.random(shape).astype("float32")]
self.outputs = None
self.attrs = {"input_size": 128, "hidden_size": 128}
self.param_states = {}
@staticmethod
def model_init(self, input_size, hidden_size):
self.module = SequenceTagging(input_size, hidden_size)
@staticmethod
def model_forward(self, inputs):
return self.gru(inputs)[0]
def make_inputs(self):
inputs = [
Input(
[None, None, self.inputs[-1].shape[-1]],
"float32", "float32",
name="init_cell") name="input"),
] ]
return inputs return inputs
......
...@@ -28,6 +28,6 @@ from hapi.text.text import TransformerBeamSearchDecoder as TransformerBeamSearch ...@@ -28,6 +28,6 @@ from hapi.text.text import TransformerBeamSearchDecoder as TransformerBeamSearch
from hapi.text.text import GRUCell as GRUCell from hapi.text.text import GRUCell as GRUCell
from hapi.text.text import GRUEncoderCell as GRUEncoderCell from hapi.text.text import GRUEncoderCell as GRUEncoderCell
from hapi.text.text import BiGRU as BiGRU from hapi.text.text import BiGRU as BiGRU
from hapi.text.text import Linear_chain_crf as Linear_chain_crf from hapi.text.text import LinearChainCRF as LinearChainCRF
from hapi.text.text import Crf_decoding as Crf_decoding from hapi.text.text import CRFDecoding as CRFDecoding
from hapi.text.text import SequenceTagging as SequenceTagging from hapi.text.text import SequenceTagging as SequenceTagging
...@@ -49,7 +49,7 @@ __all__ = [ ...@@ -49,7 +49,7 @@ __all__ = [
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN', 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer', 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerCell', 'TransformerBeamSearchDecoder', 'TransformerDecoder', 'TransformerCell', 'TransformerBeamSearchDecoder',
'Linear_chain_crf', 'Crf_decoding', 'SequenceTagging', 'GRUEncoderLayer' 'LinearChainCRF', 'CRFDecoding', 'SequenceTagging', 'GRUEncoder'
] ]
...@@ -1008,18 +1008,38 @@ class TransformerCell(Layer): ...@@ -1008,18 +1008,38 @@ class TransformerCell(Layer):
used as RNNCell used as RNNCell
""" """
def __init__(self, decoder): def __init__(self, decoder, embedding_fn=None, output_fn=None):
super(TransformerCell, self).__init__()
self.decoder = decoder self.decoder = decoder
self.embedding_fn = embedding_fn
self.output_fn = output_fn
def __call__(self, inputs, states, trg_src_attn_bias, enc_output, def forward(self, inputs, states, trg_src_attn_bias, enc_output,
static_caches): static_caches):
trg_word, trg_pos = inputs trg_word, trg_pos = inputs
for cache, static_cache in zip(states, static_caches): for cache, static_cache in zip(states, static_caches):
cache.update(static_cache) cache.update(static_cache)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, if self.embedding_fn is not None:
enc_output, states) dec_input = self.embedding_fn(trg_word, trg_pos)
outputs = self.decoder(dec_input, enc_output, None,
trg_src_attn_bias, states)
else:
outputs = self.decoder(trg_word, trg_pos, enc_output, None,
trg_src_attn_bias, states)
if self.output_fn is not None:
outputs = self.output_fn(outputs)
if len(outputs.shape) == 3:
# squeeze to adapt to BeamSearchDecoder which use 2D logits
outputs = layers.squeeze(outputs, [1])
new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states] new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states]
return logits, new_states return outputs, new_states
@property
def state_shape(self):
return [{
"k": [self.n_head, 0, self.d_key],
"v": [self.n_head, 0, self.d_value],
} for i in range(len(self.n_layer))]
class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
...@@ -1521,6 +1541,11 @@ class TransformerDecoder(Layer): ...@@ -1521,6 +1541,11 @@ class TransformerDecoder(Layer):
preprocess_cmd, postprocess_cmd): preprocess_cmd, postprocess_cmd):
super(TransformerDecoder, self).__init__() super(TransformerDecoder, self).__init__()
self.n_layer = n_layer
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.decoder_layers = list() self.decoder_layers = list()
for i in range(n_layer): for i in range(n_layer):
self.decoder_layers.append( self.decoder_layers.append(
...@@ -1555,6 +1580,20 @@ class TransformerDecoder(Layer): ...@@ -1555,6 +1580,20 @@ class TransformerDecoder(Layer):
for decoder_layer in self.decoder_layers for decoder_layer in self.decoder_layers
] ]
def prepare_incremental_cache(self, enc_output):
return [{
"k": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
#TODO: we should merge GRUCell with BasicGRUCell #TODO: we should merge GRUCell with BasicGRUCell
class GRUCell(RNNCell): class GRUCell(RNNCell):
...@@ -1651,9 +1690,9 @@ class BiGRU(fluid.dygraph.Layer): ...@@ -1651,9 +1690,9 @@ class BiGRU(fluid.dygraph.Layer):
return bi_merge return bi_merge
class Linear_chain_crf(fluid.dygraph.Layer): class LinearChainCRF(Layer):
def __init__(self, param_attr, size=None, is_test=False, dtype='float32'): def __init__(self, param_attr, size=None, is_test=False, dtype='float32'):
super(Linear_chain_crf, self).__init__() super(LinearChainCRF, self).__init__()
self._param_attr = param_attr self._param_attr = param_attr
self._dtype = dtype self._dtype = dtype
...@@ -1702,9 +1741,9 @@ class Linear_chain_crf(fluid.dygraph.Layer): ...@@ -1702,9 +1741,9 @@ class Linear_chain_crf(fluid.dygraph.Layer):
return log_likelihood return log_likelihood
class Crf_decoding(fluid.dygraph.Layer): class CRFDecoding(Layer):
def __init__(self, param_attr, size=None, is_test=False, dtype='float32'): def __init__(self, param_attr, size=None, is_test=False, dtype='float32'):
super(Crf_decoding, self).__init__() super(CRFDecoding, self).__init__()
self._dtype = dtype self._dtype = dtype
self._size = size self._size = size
...@@ -1742,7 +1781,7 @@ class Crf_decoding(fluid.dygraph.Layer): ...@@ -1742,7 +1781,7 @@ class Crf_decoding(fluid.dygraph.Layer):
return viterbi_path return viterbi_path
class GRUEncoderLayer(Layer): class GRUEncoder(Layer):
def __init__(self, def __init__(self,
input_dim, input_dim,
grnn_hidden_dim, grnn_hidden_dim,
...@@ -1750,7 +1789,7 @@ class GRUEncoderLayer(Layer): ...@@ -1750,7 +1789,7 @@ class GRUEncoderLayer(Layer):
num_layers=1, num_layers=1,
h_0=None, h_0=None,
is_bidirection=False): is_bidirection=False):
super(GRUEncoderLayer, self).__init__() super(GRUEncoder, self).__init__()
self.h_0 = h_0 self.h_0 = h_0
self.num_layers = num_layers self.num_layers = num_layers
self.is_bidirection = is_bidirection self.is_bidirection = is_bidirection
...@@ -1849,7 +1888,7 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1849,7 +1888,7 @@ class SequenceTagging(fluid.dygraph.Layer):
force_cpu=True, force_cpu=True,
name='h_0') name='h_0')
self.gru_encoder = GRUEncoderLayer( self.gru_encoder = GRUEncoder(
input_dim=self.grnn_hidden_dim, input_dim=self.grnn_hidden_dim,
grnn_hidden_dim=self.grnn_hidden_dim, grnn_hidden_dim=self.grnn_hidden_dim,
init_bound=self.init_bound, init_bound=self.init_bound,
...@@ -1866,12 +1905,12 @@ class SequenceTagging(fluid.dygraph.Layer): ...@@ -1866,12 +1905,12 @@ class SequenceTagging(fluid.dygraph.Layer):
regularizer=fluid.regularizer.L2DecayRegularizer( regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-4))) regularization_coeff=1e-4)))
self.linear_chain_crf = Linear_chain_crf( self.linear_chain_crf = LinearChainCRF(
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='linear_chain_crfw', learning_rate=self.crf_lr), name='linear_chain_crfw', learning_rate=self.crf_lr),
size=self.num_labels) size=self.num_labels)
self.crf_decoding = Crf_decoding( self.crf_decoding = CRFDecoding(
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', learning_rate=self.crf_lr), name='crfw', learning_rate=self.crf_lr),
size=self.num_labels) size=self.num_labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册