提交 b4d29614 编写于 作者: W wanghaoshuang

Fix teacher model of BERT.

上级 1ab71d6a
...@@ -45,18 +45,24 @@ class ClsModelLayer(Layer): ...@@ -45,18 +45,24 @@ class ClsModelLayer(Layer):
self.is_training = is_training self.is_training = is_training
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.loss_scaling = loss_scaling self.loss_scaling = loss_scaling
self.n_layers = config['num_hidden_layers']
self.bert_layer = BertModelLayer( self.bert_layer = BertModelLayer(
config=self.config, return_pooled_out=True, use_fp16=self.use_fp16) config=self.config, return_pooled_out=True, use_fp16=self.use_fp16)
self.cls_fc = Linear( self.cls_fc = list()
input_dim=self.config["hidden_size"], for i in range(self.n_layers):
output_dim=num_labels, fc = Linear(
param_attr=fluid.ParamAttr( input_dim=self.config["hidden_size"],
name="cls_out_w", output_dim=num_labels,
initializer=fluid.initializer.TruncatedNormal(scale=0.02)), param_attr=fluid.ParamAttr(
bias_attr=fluid.ParamAttr( name="cls_out_%d_w" % i,
name="cls_out_b", initializer=fluid.initializer.Constant(0.))) initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_%d_b" % i,
initializer=fluid.initializer.Constant(0.)))
fc = self.add_sublayer("cls_fc_%d" % i, fc)
self.cls_fc.append(fc)
def forward(self, data_ids): def forward(self, data_ids):
""" """
...@@ -73,13 +79,13 @@ class ClsModelLayer(Layer): ...@@ -73,13 +79,13 @@ class ClsModelLayer(Layer):
logits = [] logits = []
losses = [] losses = []
accuracys = [] accuracys = []
for next_sent_feat in next_sent_feats: for next_sent_feat, fc in zip(next_sent_feats, self.cls_fc):
cls_feat = fluid.layers.dropout( cls_feat = fluid.layers.dropout(
x=next_sent_feat, x=next_sent_feat,
dropout_prob=0.1, dropout_prob=0.1,
dropout_implementation="upscale_in_train") dropout_implementation="upscale_in_train")
logit = self.cls_fc(cls_feat) logit = fc(cls_feat)
logits.append(logit) logits.append(logit)
ce_loss, probs = fluid.layers.softmax_with_cross_entropy( ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册