diff --git a/BERT/run_classifier.py b/BERT/run_classifier.py index b3969b8d11de10a09d8dff087a666a3614706b9c..5e724d8a377ed454c0570ab5d43e3180e4d144b0 100644 --- a/BERT/run_classifier.py +++ b/BERT/run_classifier.py @@ -32,6 +32,7 @@ from model.classifier import create_model from optimization import optimization from utils.args import ArgumentGroup, print_arguments, check_cuda from utils.init import init_pretraining_params, init_checkpoint +from utils.cards import get_cards import dist_utils num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) @@ -435,14 +436,6 @@ def main(args): [loss.name, accuracy.name, num_seqs.name], "test") -def get_cards(): - num = 0 - cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') - if cards != '': - num = len(cards.split(",")) - return num - - if __name__ == '__main__': print_arguments(args) check_cuda(args.use_cuda) diff --git a/BERT/utils/cards.py b/BERT/utils/cards.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba9aa6d2ee81eebfc8c02bdef5d50dff7d96f6e --- /dev/null +++ b/BERT/utils/cards.py @@ -0,0 +1,28 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +def get_cards(): + """ + get gpu cards number + """ + num = 0 + cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if cards != '': + num = len(cards.split(",")) + return num + +