From c809077c0b2121388480f91ee904e44d0fec95e6 Mon Sep 17 00:00:00 2001 From: wangxiao1021 Date: Mon, 3 Feb 2020 16:32:01 +0800 Subject: [PATCH] fix bugs --- download_models.py | 34 ------------------- examples/predict/run.py | 2 +- paddlepalm/head/cls.py | 10 +++--- .../lr_sched/slanted_triangular_schedualer.py | 2 +- 4 files changed, 8 insertions(+), 40 deletions(-) delete mode 100644 download_models.py diff --git a/download_models.py b/download_models.py deleted file mode 100644 index e3cfd80..0000000 --- a/download_models.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddlepalm as palm -import sys -import argparse - -# create parser -parser = argparse.ArgumentParser(prog='download_models.py', usage='python %(prog)s -l | -d [-h]\n\nFor example,\n\tpython %(prog)s -d bert-en-uncased-large ',description = 'Download pretrain models for initializing params of backbones. ') -parser1= parser.add_argument_group("required arguments") -parser1.add_argument('-l','--list', action = 'store_true', help = 'show the list of available pretrain models', default = False) -parser1.add_argument('-d','--download', action = 'store', help = 'download pretrain models. The available pretrain models can be listed by run "python download_models.py -l"') -args = parser.parse_args() - -if(args.list): - palm.downloader.ls('pretrain') -elif(args.download): - print('download~~~') - print(args.download) - palm.downloader.download('pretrain', args.download) -else: - print (parser.parse_args(['-h'])) diff --git a/examples/predict/run.py b/examples/predict/run.py index 5f05f96..6e00c26 100644 --- a/examples/predict/run.py +++ b/examples/predict/run.py @@ -43,7 +43,7 @@ if __name__ == '__main__': trainer.build_predict_forward(pred_ernie, cls_pred_head) # step 6: load pretrained model - pred_model = trainer.load_ckpt(pre_params) + pred_model = trainer.load_pretrain(pre_params) # step 7: fit prepared reader and data trainer.fit_reader(predict_cls_reader, phase='predict') diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index da63227..e6ca016 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -38,6 +38,7 @@ class Classify(Head): self._param_initializer = fluid.initializer.TruncatedNormal( scale=param_initializer_range) self._preds = [] + self._probs = [] @property def inputs_attrs(self): @@ -52,7 +53,9 @@ class Classify(Head): if self._is_training: return {'loss': [[1], 'float32']} else: - return {'logits': [[-1, self.num_classes], 'float32']} + return {'logits': [[-1, self.num_classes], 'float32'], + 'probs': [[-1, self.num_classes], 'float32']} + def build(self, inputs, scope_name=''): sent_emb = inputs['backbone']['sentence_embedding'] @@ -71,11 +74,10 @@ class Classify(Head): initializer=self._param_initializer), bias_attr=fluid.ParamAttr( name=scope_name+"cls_out_b", initializer=fluid.initializer.Constant(0.))) - + probs = fluid.layers.softmax(logits) if self._is_training: - inputs = fluid.layers.softmax(logits) loss = fluid.layers.cross_entropy( - input=inputs, label=label_ids) + input=probs, label=label_ids) loss = layers.mean(loss) return {"loss": loss} else: diff --git a/paddlepalm/lr_sched/slanted_triangular_schedualer.py b/paddlepalm/lr_sched/slanted_triangular_schedualer.py index 26d3fc4..e94c51e 100644 --- a/paddlepalm/lr_sched/slanted_triangular_schedualer.py +++ b/paddlepalm/lr_sched/slanted_triangular_schedualer.py @@ -13,7 +13,7 @@ class TriangularSchedualer(Schedualer): num_train_steps: the number of train steps. """ - BaseSchedualer.__init__(self) + Schedualer.__init__(self) assert num_train_steps > warmup_steps > 0 self.warmup_steps = warmup_steps self.num_train_steps = num_train_steps -- GitLab