Seq2Seq 案例里的 Basemodel class BaseModel(object): 里的_build_data里面的label是啥意思啊
Created by: peterzsj6
class BaseModel(object): def init(self, hidden_size, src_vocab_size, tar_vocab_size, batch_size, num_layers=1, init_scale=0.1, dropout=None, beam_start_token=1, beam_end_token=2, beam_max_step_num=100):
self.hidden_size = hidden_size
self.src_vocab_size = VOCAB_SIZE
self.tar_vocab_size = VOCAB_SIZE
self.batch_size = batch_size
self.num_layers = num_layers
self.init_scale = init_scale
self.dropout = dropout
self.beam_start_token = beam_start_token
self.beam_end_token = beam_end_token
self.beam_max_step_num = beam_max_step_num
self.src_embeder = lambda x: fluid.embedding(
input=x,
size=[self.src_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='source_embedding',
initializer=uniform_initializer(init_scale)))
self.tar_embeder = lambda x: fluid.embedding(
input=x,
size=[self.tar_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='target_embedding',
initializer=uniform_initializer(init_scale)))
def _build_data(self):
self.src = fluid.data(name="src", shape=[None, None], dtype='int64')
self.src_sequence_length = fluid.data(
name="src_sequence_length", shape=[None], dtype='int32')
self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
self.tar_sequence_length = fluid.data(
name="tar_sequence_length", shape=[None], dtype='int32')
self.label = fluid.data(
name="label", shape=[None, None, 1], dtype='int64')