未验证 提交 44e359c4 编写于 作者: B Bai Yifan 提交者: GitHub

split two text input (#336)

上级 e4d61e5e
......@@ -54,8 +54,7 @@ def train_one_epoch(model, architect, train_loader, valid_loader, optimizer,
else:
loss.backward()
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0)
optimizer.minimize(loss, grad_clip=grad_clip)
optimizer.minimize(loss)
model.clear_gradients()
batch_size = train_data[0].shape[0]
......@@ -161,11 +160,13 @@ def main():
if p.name not in [a.name for a in model.arch_parameters()]
]
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
0.9,
regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters)
parameter_list=model_parameters,
grad_clip=clip)
train_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
......
......@@ -113,16 +113,46 @@ class BertModelLayer(Layer):
"""
forward
"""
src_emb = self._src_emb(src_ids)
pos_emb = self._pos_emb(position_ids)
sent_emb = self._sent_emb(sentence_ids)
emb_out = src_emb + pos_emb
emb_out = emb_out + sent_emb
emb_out = self._emb_fac(emb_out)
ids = np.squeeze(src_ids.numpy())
sids = np.squeeze(sentence_ids.numpy())
batchsize = ids.shape[0]
ids_0 = ids[((sids == 0) & (ids != 0))]
seqlen_0 = ((sids == 0) & (ids != 0)).astype(np.int64).sum(1)
y_0 = np.concatenate([np.arange(s) for s in seqlen_0])
x_0 = np.concatenate([
np.ones(
[s], dtype=np.int64) * i for i, s in enumerate(seqlen_0)
])
ids0 = np.zeros([batchsize, seqlen_0.max()], dtype=np.int64)
ids0[(x_0, y_0)] = ids_0
ids_1 = ids[(sids == 1) & (ids != 0)]
seqlen_1 = ((sids == 1) & (ids != 0)).astype(np.int64).sum(1)
y_1 = np.concatenate([np.arange(s) for s in seqlen_1])
x_1 = np.concatenate([
np.ones(
[s], dtype=np.int64) * i for i, s in enumerate(seqlen_1)
])
ids1 = np.zeros([batchsize, seqlen_1.max()], dtype=np.int64)
ids1[(x_1, y_1)] = ids_1
msl = max(seqlen_0.max(), seqlen_1.max())
ids0 = np.pad(ids0, [[0, 0], [0, msl - seqlen_0.max()]],
mode='constant')
ids1 = np.pad(ids1, [[0, 0], [0, msl - seqlen_1.max()]],
mode='constant')
ids0 = fluid.dygraph.to_variable(ids0)
ids1 = fluid.dygraph.to_variable(ids1)
src_emb_0 = self._src_emb(ids0)
src_emb_1 = self._src_emb(ids1)
emb_out_0 = self._emb_fac(src_emb_0)
emb_out_1 = self._emb_fac(src_emb_1)
# (bs, seq_len, 768)
enc_output = self._encoder(emb_out, flops=flops, model_size=model_size)
enc_output = self._encoder(
emb_out_0, emb_out_1, flops=flops, model_size=model_size)
return enc_output
......@@ -278,19 +278,22 @@ class EncoderLayer(Layer):
bias_attr=ParamAttr(initializer=MSRA()))
self.use_fixed_gumbel = use_fixed_gumbel
self.gumbel_alphas = gumbel_softmax(self.alphas)
self.gumbel_alphas = gumbel_softmax(self.alphas).detach()
def forward(self, enc_input, flops=[], model_size=[]):
tmp = fluid.layers.reshape(
enc_input, [-1, 1, enc_input.shape[1], enc_input.shape[2]])
def forward(self, enc_input_0, enc_input_1, flops=[], model_size=[]):
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
self.alphas)
s0 = fluid.layers.reshape(
enc_input_0, [-1, 1, enc_input_0.shape[1], enc_input_0.shape[2]])
s1 = fluid.layers.reshape(
enc_input_1, [-1, 1, enc_input_1.shape[1], enc_input_1.shape[2]])
# (bs, 1, seq_len, hidden_size)
tmp = self.stem(tmp)
s0 = self.stem(s0)
s1 = self.stem(s1)
# (bs, n_channel, seq_len, 1)
alphas = self.gumbel_alphas if self.use_fixed_gumbel else gumbel_softmax(
self.alphas)
s0 = s1 = tmp
for i in range(self._n_layer):
s0, s1 = s1, self._cells[i](s0, s1, alphas)
# (bs, n_channel, seq_len, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册