From dac6d74e6675cf903ac59e417312462762dae4b2 Mon Sep 17 00:00:00 2001 From: tianxin Date: Thu, 18 Jul 2019 19:00:39 +0800 Subject: [PATCH] add utils/init.py test=develop fix #215 --- ELMo/train.py | 22 +++++++-------- ELMo/utils/init.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 11 deletions(-) create mode 100644 ELMo/utils/init.py diff --git a/ELMo/train.py b/ELMo/train.py index 1e455bf..facf91a 100755 --- a/ELMo/train.py +++ b/ELMo/train.py @@ -17,24 +17,27 @@ from __future__ import division from __future__ import print_function import six -import numpy as np +import os import random import time -import os import math +import pickle +import logging + +import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.core as core import paddle.fluid.framework as framework from paddle.fluid.executor import Executor + import data -from args import * -from utils.cards import get_cards import lm_model -import logging +from args import parse_args +from utils.cards import get_cards +from utils.init import init_pretraining_params logging.basicConfig() -import pickle def prepare_batch_input(batch, args): @@ -576,11 +579,8 @@ def train_loop(args, ce_time = ce_info[-2][1] except: print("ce info error") - print("kpis\ttrain_duration_card%s\t%s" % - (card_num, ce_time)) - print("kpis\ttrain_loss_card%s\t%f" % - (card_num, ce_loss)) - + print("kpis\ttrain_duration_card%s\t%s" % (card_num, ce_time)) + print("kpis\ttrain_loss_card%s\t%f" % (card_num, ce_loss)) end_time = time.time() total_time += end_time - start_time diff --git a/ELMo/utils/init.py b/ELMo/utils/init.py new file mode 100644 index 0000000..35d0e4f --- /dev/null +++ b/ELMo/utils/init.py @@ -0,0 +1,67 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import six +import ast +import copy + +import numpy as np +import paddle.fluid as fluid + + +def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False): + assert os.path.exists( + init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path + + def existed_persitables(var): + if not fluid.io.is_persistable(var): + return False + return os.path.exists(os.path.join(init_checkpoint_path, var.name)) + + fluid.io.load_vars( + exe, + init_checkpoint_path, + main_program=main_program, + predicate=existed_persitables) + print("Load model from {}".format(init_checkpoint_path)) + + if use_fp16: + cast_fp32_to_fp16(exe, main_program) + + +def init_pretraining_params(exe, + pretraining_params_path, + main_program, + use_fp16=False): + assert os.path.exists(pretraining_params_path + ), "[%s] cann't be found." % pretraining_params_path + + def existed_params(var): + if not isinstance(var, fluid.framework.Parameter): + return False + return os.path.exists(os.path.join(pretraining_params_path, var.name)) + + fluid.io.load_vars( + exe, + pretraining_params_path, + main_program=main_program, + predicate=existed_params) + print("Load pretraining parameters from {}.".format( + pretraining_params_path)) + + if use_fp16: + cast_fp32_to_fp16(exe, main_program) -- GitLab