提交 a82eef9f 编写于 作者: W WuHaobo

add comment

上级 33c96900
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import environment
from . import model_zoo
from . import misc
from . import logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as pff
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
def place():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
return fluid.CUDAPlace(gpu_id)
def places():
"""
Returns available running places, the numbers are usually
indicated by 'export CUDA_VISIBLE_DEVICES= '
Args:
"""
if trainers_num <= 1:
return pff.cuda_places()
else:
return place()
......@@ -25,7 +25,6 @@ import paddle.fluid as fluid
import program
from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
......@@ -58,7 +57,8 @@ def main(args):
fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True)
place = env.place()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
startup_prog = fluid.Program()
valid_prog = fluid.Program()
......@@ -69,7 +69,7 @@ def main(args):
exe = fluid.Executor(place)
exe.run(startup_prog)
init_model(config, valid_prog, exe)
init_model(config, valid_prog, exe, 'ppcls')
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
......
......@@ -19,8 +19,6 @@ from __future__ import print_function
import argparse
import os
import sys
sys.path.append(os.getcwd())
import paddle
import paddle.fluid as fluid
......@@ -28,7 +26,6 @@ import paddle.fluid as fluid
import program
from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
......@@ -60,8 +57,12 @@ def main(args):
fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True)
place = env.place()
# assign the place
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog = fluid.Program()
train_prog = fluid.Program()
......@@ -72,11 +73,15 @@ def main(args):
valid_prog = fluid.Program()
valid_dataloader, valid_fetchs = program.build(
config, valid_prog, startup_prog, is_train=False)
# clone to prune some content which is irrelevant in valid_prog
valid_prog = valid_prog.clone(for_test=True)
exe = fluid.Executor(place)
# create the "Executor" with the statement of which place
exe = fluid.Executor(place=place)
# only run startup_prog once to init
exe.run(startup_prog)
# load model from checkpoint or pretrained model
init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')()
......@@ -89,13 +94,15 @@ def main(args):
compiled_train_prog = fleet.main_program
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train')
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
program.run(valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid')
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册