From a6a73ef624e94d7ddd649cef501b665cb8e754f5 Mon Sep 17 00:00:00 2001 From: zhengya01 Date: Wed, 10 Jul 2019 05:31:50 +0000 Subject: [PATCH] add ce for BERT --- BERT/run_classifier.py | 9 +-------- BERT/utils/cards.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 BERT/utils/cards.py diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index b3969b8..5e724d8 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -32,6 +32,7 @@ from model.classifier import create_model from optimization import optimization from utils.args import ArgumentGroup, print_arguments, check_cuda from utils.init import init_pretraining_params, init_checkpoint +from utils.cards import get_cards import dist_utils num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) @@ -435,14 +436,6 @@ def main(args): [loss.name, accuracy.name, num_seqs.name], "test") -def get_cards(): - num = 0 - cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') - if cards != '': - num = len(cards.split(",")) - return num - - if __name__ == '__main__': print_arguments(args) check_cuda(args.use_cuda) diff --git a/BERT/utils/cards.py b/BERT/utils/cards.py new file mode 100644 index 0000000..9ba9aa6 --- /dev/null +++ b/BERT/utils/cards.py @@ -0,0 +1,28 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +def get_cards(): + """ + get gpu cards number + """ + num = 0 + cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if cards != '': + num = len(cards.split(",")) + return num + + -- GitLab