提交 06fd61f8 编写于 作者: J jinyuKing

update text.py

上级 f3e8f301
...@@ -19,7 +19,6 @@ from __future__ import print_function ...@@ -19,7 +19,6 @@ from __future__ import print_function
import os import os
import six import six
import sys import sys
if six.PY2: if six.PY2:
reload(sys) reload(sys)
sys.setdefaultencoding('utf8') sys.setdefaultencoding('utf8')
...@@ -50,8 +49,8 @@ __all__ = [ ...@@ -50,8 +49,8 @@ __all__ = [
'BeamSearchDecoder', 'MultiHeadAttention', 'FFN', 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN',
'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer', 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer',
'TransformerDecoder', 'TransformerBeamSearchDecoder', 'Linear_chain_crf', 'TransformerDecoder', 'TransformerBeamSearchDecoder', 'Linear_chain_crf',
'Crf_decoding', 'SequenceTagging', 'GRUEncoderLayer', 'CNNEncoder', 'Crf_decoding', 'SequenceTagging', 'GRUEncoderLayer', 'SimCNNEncoder',
'BOWEncoder', 'SimpleConvPoolLayer', 'GRUEncoder', 'DynamicGRU', 'LSTMEncoder' 'SimBOWEncoder', 'SimpleConvPoolLayer', 'SimGRUEncoder', 'DynamicGRU', 'SimLSTMEncoder'
] ]
...@@ -89,12 +88,12 @@ class RNNCell(Layer): ...@@ -89,12 +88,12 @@ class RNNCell(Layer):
batch_ref = flatten(batch_ref)[0] batch_ref = flatten(batch_ref)[0]
def _is_shape_sequence(seq): def _is_shape_sequence(seq):
if sys.version_info < (3,): if sys.version_info < (3, ):
integer_types = ( integer_types = (
int, int,
long,) long, )
else: else:
integer_types = (int,) integer_types = (int, )
"""For shape, list/tuple of integer is the finest-grained objection""" """For shape, list/tuple of integer is the finest-grained objection"""
if (isinstance(seq, list) or isinstance(seq, tuple)): if (isinstance(seq, list) or isinstance(seq, tuple)):
if reduce( if reduce(
...@@ -249,8 +248,8 @@ class BasicLSTMCell(RNNCell): ...@@ -249,8 +248,8 @@ class BasicLSTMCell(RNNCell):
self.use_customized_weight = False self.use_customized_weight = False
for _weights in [ for _weights in [
forget_gate_weights, input_gate_weights, output_gate_weights, forget_gate_weights, input_gate_weights, output_gate_weights,
cell_weights cell_weights
]: ]:
for _key in _weights: for _key in _weights:
if _weights[_key] is not None: if _weights[_key] is not None:
...@@ -275,7 +274,7 @@ class BasicLSTMCell(RNNCell): ...@@ -275,7 +274,7 @@ class BasicLSTMCell(RNNCell):
is_bias=True) is_bias=True)
else: else:
if "w" in forget_gate_weights and forget_gate_weights[ if "w" in forget_gate_weights and forget_gate_weights[
"w"] is not None: "w"] is not None:
self.fg_w = forget_gate_weights["w"] self.fg_w = forget_gate_weights["w"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -289,7 +288,7 @@ class BasicLSTMCell(RNNCell): ...@@ -289,7 +288,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "h" in forget_gate_weights and forget_gate_weights[ if "h" in forget_gate_weights and forget_gate_weights[
"h"] is not None: "h"] is not None:
self.fg_h = forget_gate_weights["h"] self.fg_h = forget_gate_weights["h"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -303,7 +302,7 @@ class BasicLSTMCell(RNNCell): ...@@ -303,7 +302,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "b" in forget_gate_weights and forget_gate_weights[ if "b" in forget_gate_weights and forget_gate_weights[
"b"] is not None: "b"] is not None:
self.fg_b = forget_gate_weights["b"] self.fg_b = forget_gate_weights["b"]
else: else:
if self._bias_attr is not None and self._bias_attr.name is not None: if self._bias_attr is not None and self._bias_attr.name is not None:
...@@ -318,7 +317,7 @@ class BasicLSTMCell(RNNCell): ...@@ -318,7 +317,7 @@ class BasicLSTMCell(RNNCell):
is_bias=True) is_bias=True)
if "w" in input_gate_weights and input_gate_weights[ if "w" in input_gate_weights and input_gate_weights[
"w"] is not None: "w"] is not None:
self.ig_w = input_gate_weights["w"] self.ig_w = input_gate_weights["w"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -333,7 +332,7 @@ class BasicLSTMCell(RNNCell): ...@@ -333,7 +332,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "h" in input_gate_weights and input_gate_weights[ if "h" in input_gate_weights and input_gate_weights[
"h"] is not None: "h"] is not None:
self.ig_h = input_gate_weights["h"] self.ig_h = input_gate_weights["h"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -348,7 +347,7 @@ class BasicLSTMCell(RNNCell): ...@@ -348,7 +347,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "b" in input_gate_weights and input_gate_weights[ if "b" in input_gate_weights and input_gate_weights[
"b"] is not None: "b"] is not None:
self.ig_b = input_gate_weights["b"] self.ig_b = input_gate_weights["b"]
else: else:
if self._bias_attr is not None and self._bias_attr.name is not None: if self._bias_attr is not None and self._bias_attr.name is not None:
...@@ -363,7 +362,7 @@ class BasicLSTMCell(RNNCell): ...@@ -363,7 +362,7 @@ class BasicLSTMCell(RNNCell):
is_bias=True) is_bias=True)
if "w" in output_gate_weights and output_gate_weights[ if "w" in output_gate_weights and output_gate_weights[
"w"] is not None: "w"] is not None:
self.og_w = output_gate_weights["w"] self.og_w = output_gate_weights["w"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -377,7 +376,7 @@ class BasicLSTMCell(RNNCell): ...@@ -377,7 +376,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "h" in output_gate_weights and output_gate_weights[ if "h" in output_gate_weights and output_gate_weights[
"h"] is not None: "h"] is not None:
self.og_h = output_gate_weights["h"] self.og_h = output_gate_weights["h"]
else: else:
if self._param_attr is not None and self._param_attr.name is not None: if self._param_attr is not None and self._param_attr.name is not None:
...@@ -392,7 +391,7 @@ class BasicLSTMCell(RNNCell): ...@@ -392,7 +391,7 @@ class BasicLSTMCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "b" in output_gate_weights and output_gate_weights[ if "b" in output_gate_weights and output_gate_weights[
"b"] is not None: "b"] is not None:
self.og_b = output_gate_weights["b"] self.og_b = output_gate_weights["b"]
else: else:
if self._bias_attr is not None and self._bias_attr.name is not None: if self._bias_attr is not None and self._bias_attr.name is not None:
...@@ -547,7 +546,7 @@ class BasicGRUCell(RNNCell): ...@@ -547,7 +546,7 @@ class BasicGRUCell(RNNCell):
self.use_customized_weight = False self.use_customized_weight = False
for _weights in [ for _weights in [
update_gate_weights, reset_gate_weights, cell_weights update_gate_weights, reset_gate_weights, cell_weights
]: ]:
for _key in _weights: for _key in _weights:
if _weights[_key] is not None: if _weights[_key] is not None:
...@@ -603,7 +602,7 @@ class BasicGRUCell(RNNCell): ...@@ -603,7 +602,7 @@ class BasicGRUCell(RNNCell):
# create the parameters of gates in gru # create the parameters of gates in gru
if "w" in update_gate_weights and update_gate_weights[ if "w" in update_gate_weights and update_gate_weights[
"w"] is not None: "w"] is not None:
self.ug_w = update_gate_weights["w"] self.ug_w = update_gate_weights["w"]
else: else:
if gate_param_attr is not None and gate_param_attr.name is not None: if gate_param_attr is not None and gate_param_attr.name is not None:
...@@ -617,7 +616,7 @@ class BasicGRUCell(RNNCell): ...@@ -617,7 +616,7 @@ class BasicGRUCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "h" in update_gate_weights and update_gate_weights[ if "h" in update_gate_weights and update_gate_weights[
"h"] is not None: "h"] is not None:
self.ug_h = update_gate_weights["h"] self.ug_h = update_gate_weights["h"]
else: else:
if gate_param_attr is not None and gate_param_attr.name is not None: if gate_param_attr is not None and gate_param_attr.name is not None:
...@@ -631,7 +630,7 @@ class BasicGRUCell(RNNCell): ...@@ -631,7 +630,7 @@ class BasicGRUCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "b" in update_gate_weights and update_gate_weights[ if "b" in update_gate_weights and update_gate_weights[
"b"] is not None: "b"] is not None:
self.ug_b = update_gate_weights["b"] self.ug_b = update_gate_weights["b"]
else: else:
if gate_bias_attr is not None and gate_bias_attr.name is not None: if gate_bias_attr is not None and gate_bias_attr.name is not None:
...@@ -647,7 +646,7 @@ class BasicGRUCell(RNNCell): ...@@ -647,7 +646,7 @@ class BasicGRUCell(RNNCell):
# reset gate parameters # reset gate parameters
if "w" in reset_gate_weights and reset_gate_weights[ if "w" in reset_gate_weights and reset_gate_weights[
"w"] is not None: "w"] is not None:
self.rg_w = reset_gate_weights["w"] self.rg_w = reset_gate_weights["w"]
else: else:
if gate_param_attr is not None and gate_param_attr.name is not None: if gate_param_attr is not None and gate_param_attr.name is not None:
...@@ -661,7 +660,7 @@ class BasicGRUCell(RNNCell): ...@@ -661,7 +660,7 @@ class BasicGRUCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "h" in reset_gate_weights and reset_gate_weights[ if "h" in reset_gate_weights and reset_gate_weights[
"h"] is not None: "h"] is not None:
self.rg_h = reset_gate_weights["h"] self.rg_h = reset_gate_weights["h"]
else: else:
if gate_param_attr is not None and gate_param_attr.name is not None: if gate_param_attr is not None and gate_param_attr.name is not None:
...@@ -675,7 +674,7 @@ class BasicGRUCell(RNNCell): ...@@ -675,7 +674,7 @@ class BasicGRUCell(RNNCell):
dtype=self._dtype) dtype=self._dtype)
if "b" in reset_gate_weights and reset_gate_weights[ if "b" in reset_gate_weights and reset_gate_weights[
"b"] is not None: "b"] is not None:
self.rg_b = reused_params["b"] self.rg_b = reused_params["b"]
else: else:
if gate_bias_attr is not None and gate_bias_attr.name is not None: if gate_bias_attr is not None and gate_bias_attr.name is not None:
...@@ -803,7 +802,7 @@ class RNN(fluid.dygraph.Layer): ...@@ -803,7 +802,7 @@ class RNN(fluid.dygraph.Layer):
new_state = fluid.layers.elementwise_mul( new_state = fluid.layers.elementwise_mul(
new_state, step_mask, new_state, step_mask,
axis=0) - fluid.layers.elementwise_mul( axis=0) - fluid.layers.elementwise_mul(
state, (step_mask - 1), axis=0) state, (step_mask - 1), axis=0)
return new_state return new_state
flat_inputs = flatten(inputs) flat_inputs = flatten(inputs)
...@@ -849,8 +848,8 @@ class RNN(fluid.dygraph.Layer): ...@@ -849,8 +848,8 @@ class RNN(fluid.dygraph.Layer):
outputs = map_structure( outputs = map_structure(
lambda x: ArrayWrapper(x), lambda x: ArrayWrapper(x),
step_outputs) if i == 0 else map_structure( step_outputs) if i == 0 else map_structure(
lambda x, x_array: x_array.append(x), step_outputs, lambda x, x_array: x_array.append(x), step_outputs,
outputs) outputs)
final_outputs = map_structure( final_outputs = map_structure(
lambda x: fluid.layers.stack(x.array, lambda x: fluid.layers.stack(x.array,
...@@ -919,7 +918,7 @@ class DynamicDecode(Layer): ...@@ -919,7 +918,7 @@ class DynamicDecode(Layer):
step_mask.stop_gradient = True step_mask.stop_gradient = True
new_state = layers.elementwise_mul( new_state = layers.elementwise_mul(
state, step_mask, axis=0) - layers.elementwise_mul( state, step_mask, axis=0) - layers.elementwise_mul(
new_state, (step_mask - 1), axis=0) new_state, (step_mask - 1), axis=0)
if convert_dtype(state_dtype) in ["bool"]: if convert_dtype(state_dtype) in ["bool"]:
new_state = layers.cast(new_state, dtype=state_dtype) new_state = layers.cast(new_state, dtype=state_dtype)
return new_state return new_state
...@@ -961,8 +960,8 @@ class DynamicDecode(Layer): ...@@ -961,8 +960,8 @@ class DynamicDecode(Layer):
outputs = map_structure( outputs = map_structure(
lambda x: ArrayWrapper(x), lambda x: ArrayWrapper(x),
step_outputs) if step_idx == 0 else map_structure( step_outputs) if step_idx == 0 else map_structure(
lambda x, x_array: x_array.append(x), step_outputs, lambda x, x_array: x_array.append(x), step_outputs,
outputs) outputs)
inputs, states, finished, sequence_lengths = ( inputs, states, finished, sequence_lengths = (
next_inputs, next_states, next_finished, next_inputs, next_states, next_finished,
next_sequence_lengths) next_sequence_lengths)
...@@ -991,7 +990,7 @@ class DynamicDecode(Layer): ...@@ -991,7 +990,7 @@ class DynamicDecode(Layer):
return (final_outputs, final_states, return (final_outputs, final_states,
sequence_lengths) if self.return_length else ( sequence_lengths) if self.return_length else (
final_outputs, final_states) final_outputs, final_states)
else: else:
return fluid.layers.dynamic_decode( return fluid.layers.dynamic_decode(
self.decoder, self.decoder,
...@@ -1042,7 +1041,7 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -1042,7 +1041,7 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
x = layers.reshape( x = layers.reshape(
x, [0] * (len(x.shape) - var_dim_in_state x, [0] * (len(x.shape) - var_dim_in_state
) + [self.batch_size * self.beam_size] + ) + [self.batch_size * self.beam_size] +
[int(size) for size in x.shape[-var_dim_in_state + 2:]]) [int(size) for size in x.shape[-var_dim_in_state + 2:]])
x = layers.transpose( x = layers.transpose(
x, x,
list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) + list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) +
...@@ -1053,9 +1052,9 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -1053,9 +1052,9 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
var_dim_size = layers.shape(x)[self.var_dim_in_state] var_dim_size = layers.shape(x)[self.var_dim_in_state]
x = layers.reshape( x = layers.reshape(
x, [-1, self.beam_size] + x, [-1, self.beam_size] +
[int(size) [int(size)
for size in x.shape[1:self.var_dim_in_state]] + [var_dim_size] + for size in x.shape[1:self.var_dim_in_state]] + [var_dim_size] +
[int(size) for size in x.shape[self.var_dim_in_state + 1:]]) [int(size) for size in x.shape[self.var_dim_in_state + 1:]])
return x return x
def step(self, time, inputs, states, **kwargs): def step(self, time, inputs, states, **kwargs):
...@@ -1118,7 +1117,7 @@ class PrePostProcessLayer(Layer): ...@@ -1118,7 +1117,7 @@ class PrePostProcessLayer(Layer):
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
self.functors.append(lambda x: layers.dropout( self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False) x, dropout_prob=dropout_rate, is_test=False)
if dropout_rate else x) if dropout_rate else x)
def forward(self, x, residual=None): def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd): for i, cmd in enumerate(self.process_cmd):
...@@ -1219,7 +1218,7 @@ class MultiHeadAttention(Layer): ...@@ -1219,7 +1218,7 @@ class MultiHeadAttention(Layer):
# scale dot product attention # scale dot product attention
product = layers.matmul( product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model ** -0.5) x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias: if attn_bias:
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
...@@ -1309,6 +1308,7 @@ class TransformerEncoderLayer(Layer): ...@@ -1309,6 +1308,7 @@ class TransformerEncoderLayer(Layer):
reused_ffn_weights={"reused_fc1": None, reused_ffn_weights={"reused_fc1": None,
"reused_fc2": None}, "reused_fc2": None},
reused_post_ffn_layernorm=None): reused_post_ffn_layernorm=None):
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
...@@ -1556,7 +1556,7 @@ class TransformerDecoder(Layer): ...@@ -1556,7 +1556,7 @@ class TransformerDecoder(Layer):
] ]
# TODO: we should merge GRUCell with BasicGRUCell #TODO: we should merge GRUCell with BasicGRUCell
class GRUCell(RNNCell): class GRUCell(RNNCell):
def __init__(self, def __init__(self,
input_size, input_size,
...@@ -1590,7 +1590,7 @@ class GRUCell(RNNCell): ...@@ -1590,7 +1590,7 @@ class GRUCell(RNNCell):
return [self.hidden_size] return [self.hidden_size]
# TODO: we should merge GRUCell with BasicGRUCell #TODO: we should merge GRUCell with BasicGRUCell
class GRUEncoderCell(RNNCell): class GRUEncoderCell(RNNCell):
def __init__(self, def __init__(self,
num_layers, num_layers,
...@@ -1606,7 +1606,7 @@ class GRUEncoderCell(RNNCell): ...@@ -1606,7 +1606,7 @@ class GRUEncoderCell(RNNCell):
self.gru_cells.append( self.gru_cells.append(
self.add_sublayer( self.add_sublayer(
"gru_%d" % i, "gru_%d" % i,
# BasicGRUCell( #BasicGRUCell(
GRUCell( GRUCell(
input_size=input_size if i == 0 else hidden_size, input_size=input_size if i == 0 else hidden_size,
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -1673,6 +1673,7 @@ class Linear_chain_crf(fluid.dygraph.Layer): ...@@ -1673,6 +1673,7 @@ class Linear_chain_crf(fluid.dygraph.Layer):
self._transition = value self._transition = value
def forward(self, input, label, length=None): def forward(self, input, label, length=None):
alpha = self._helper.create_variable_for_type_inference( alpha = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
emission_exps = self._helper.create_variable_for_type_inference( emission_exps = self._helper.create_variable_for_type_inference(
...@@ -1723,6 +1724,7 @@ class Crf_decoding(fluid.dygraph.Layer): ...@@ -1723,6 +1724,7 @@ class Crf_decoding(fluid.dygraph.Layer):
self._transition = value self._transition = value
def forward(self, input, label=None, length=None): def forward(self, input, label=None, length=None):
viterbi_path = self._helper.create_variable_for_type_inference( viterbi_path = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
this_inputs = { this_inputs = {
...@@ -1919,7 +1921,7 @@ class SimpleConvPoolLayer(Layer): ...@@ -1919,7 +1921,7 @@ class SimpleConvPoolLayer(Layer):
return x return x
class CNNEncoder(Layer): class SimCNNEncoder(Layer):
""" """
simple CNNEncoder for simnet simple CNNEncoder for simnet
""" """
...@@ -1933,7 +1935,7 @@ class CNNEncoder(Layer): ...@@ -1933,7 +1935,7 @@ class CNNEncoder(Layer):
padding_idx, padding_idx,
act act
): ):
super(CNNEncoder, self).__init__() super(SimCNNEncoder, self).__init__()
self.dict_size = dict_size self.dict_size = dict_size
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.filter_size = filter_size self.filter_size = filter_size
...@@ -1962,7 +1964,7 @@ class CNNEncoder(Layer): ...@@ -1962,7 +1964,7 @@ class CNNEncoder(Layer):
emb_out=self.cnn_layer(emb_reshape) emb_out=self.cnn_layer(emb_reshape)
return emb_out return emb_out
class BOWEncoder(Layer): class SimBOWEncoder(Layer):
""" """
simple BOWEncoder for simnet simple BOWEncoder for simnet
""" """
...@@ -1973,7 +1975,7 @@ class BOWEncoder(Layer): ...@@ -1973,7 +1975,7 @@ class BOWEncoder(Layer):
seq_len, seq_len,
padding_idx padding_idx
): ):
super(BOWEncoder, self).__init__() super(SimBOWEncoder, self).__init__()
self.dict_size = dict_size self.dict_size = dict_size
self.bow_dim = bow_dim self.bow_dim = bow_dim
self.seq_len = seq_len self.seq_len = seq_len
...@@ -2034,7 +2036,7 @@ class DynamicGRU(fluid.dygraph.Layer): ...@@ -2034,7 +2036,7 @@ class DynamicGRU(fluid.dygraph.Layer):
res = fluid.layers.concat(res, axis=1) res = fluid.layers.concat(res, axis=1)
return res return res
class GRUEncoder(Layer): class SimGRUEncoder(Layer):
""" """
simple GRUEncoder for simnet simple GRUEncoder for simnet
""" """
...@@ -2046,7 +2048,7 @@ class GRUEncoder(Layer): ...@@ -2046,7 +2048,7 @@ class GRUEncoder(Layer):
padding_idx, padding_idx,
seq_len seq_len
): ):
super(GRUEncoder, self).__init__() super(SimGRUEncoder, self).__init__()
self.dict_size = dict_size self.dict_size = dict_size
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.gru_dim = gru_dim self.gru_dim = gru_dim
...@@ -2071,7 +2073,7 @@ class GRUEncoder(Layer): ...@@ -2071,7 +2073,7 @@ class GRUEncoder(Layer):
gru = fluid.layers.tanh(gru) gru = fluid.layers.tanh(gru)
return gru return gru
class LSTMEncoder(Layer): class SimLSTMEncoder(Layer):
""" """
simple LSTMEncoder for simnet simple LSTMEncoder for simnet
""" """
...@@ -2087,7 +2089,7 @@ class LSTMEncoder(Layer): ...@@ -2087,7 +2089,7 @@ class LSTMEncoder(Layer):
""" """
initialize initialize
""" """
super(LSTMEncoder, self).__init__() super(SimLSTMEncoder, self).__init__()
self.dict_size = dict_size self.dict_size = dict_size
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.lstm_dim = lstm_dim self.lstm_dim = lstm_dim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册