提交 c809077c 编写于 作者: W wangxiao1021

fix bugs

上级 e3f25d2f
# -*- coding: UTF-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddlepalm as palm
import sys
import argparse
# create parser
parser = argparse.ArgumentParser(prog='download_models.py', usage='python %(prog)s -l | -d <model_name> [-h]\n\nFor example,\n\tpython %(prog)s -d bert-en-uncased-large ',description = 'Download pretrain models for initializing params of backbones. ')
parser1= parser.add_argument_group("required arguments")
parser1.add_argument('-l','--list', action = 'store_true', help = 'show the list of available pretrain models', default = False)
parser1.add_argument('-d','--download', action = 'store', help = 'download pretrain models. The available pretrain models can be listed by run "python download_models.py -l"')
args = parser.parse_args()
if(args.list):
palm.downloader.ls('pretrain')
elif(args.download):
print('download~~~')
print(args.download)
palm.downloader.download('pretrain', args.download)
else:
print (parser.parse_args(['-h']))
......@@ -43,7 +43,7 @@ if __name__ == '__main__':
trainer.build_predict_forward(pred_ernie, cls_pred_head)
# step 6: load pretrained model
pred_model = trainer.load_ckpt(pre_params)
pred_model = trainer.load_pretrain(pre_params)
# step 7: fit prepared reader and data
trainer.fit_reader(predict_cls_reader, phase='predict')
......
......@@ -38,6 +38,7 @@ class Classify(Head):
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=param_initializer_range)
self._preds = []
self._probs = []
@property
def inputs_attrs(self):
......@@ -52,7 +53,9 @@ class Classify(Head):
if self._is_training:
return {'loss': [[1], 'float32']}
else:
return {'logits': [[-1, self.num_classes], 'float32']}
return {'logits': [[-1, self.num_classes], 'float32'],
'probs': [[-1, self.num_classes], 'float32']}
def build(self, inputs, scope_name=''):
sent_emb = inputs['backbone']['sentence_embedding']
......@@ -71,11 +74,10 @@ class Classify(Head):
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.)))
probs = fluid.layers.softmax(logits)
if self._is_training:
inputs = fluid.layers.softmax(logits)
loss = fluid.layers.cross_entropy(
input=inputs, label=label_ids)
input=probs, label=label_ids)
loss = layers.mean(loss)
return {"loss": loss}
else:
......
......@@ -13,7 +13,7 @@ class TriangularSchedualer(Schedualer):
num_train_steps: the number of train steps.
"""
BaseSchedualer.__init__(self)
Schedualer.__init__(self)
assert num_train_steps > warmup_steps > 0
self.warmup_steps = warmup_steps
self.num_train_steps = num_train_steps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册