未验证 提交 7546121b 编写于 作者: D dyning 提交者: GitHub

Merge pull request #46 from WuHaobo/master

Add comments for train
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.(md|yml)$
- id: trailing-whitespace
files: \.(md|yml)$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.(md|yml)$
- id: remove-crlf
files: \.(md|yml)$
- id: forbid-tabs
files: \.(md|yml)$
- id: remove-tabs
files: \.(md|yml)$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: flake8
args: ['--ignore=E265']
- id: check-yaml
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.(md|yml)$
- id: trailing-whitespace
files: \.(md|yml)$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.(md|yml)$
- id: remove-crlf
files: \.(md|yml)$
- id: forbid-tabs
files: \.(md|yml)$
- id: remove-tabs
files: \.(md|yml)$
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import environment
from . import model_zoo
from . import misc
from . import logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as pff
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
def place():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
return fluid.CUDAPlace(gpu_id)
def places():
"""
Returns available running places, the numbers are usually
indicated by 'export CUDA_VISIBLE_DEVICES= '
Args:
"""
if trainers_num <= 1:
return pff.cuda_places()
else:
return place()
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import argparse
import paddle
import paddle.fluid as fluid
import program
from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
from ppcls.utils.save_load import init_model
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker
......@@ -40,7 +37,7 @@ def parse_args():
'-c',
'--config',
type=str,
default='configs/eval.yaml',
default='./configs/eval.yaml',
help='config file path')
parser.add_argument(
'-o',
......@@ -58,7 +55,8 @@ def main(args):
fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True)
place = env.place()
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
startup_prog = fluid.Program()
valid_prog = fluid.Program()
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......@@ -18,18 +18,14 @@ from __future__ import print_function
import argparse
import os
import sys
import paddle
import paddle.fluid as fluid
import program
from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker
......@@ -41,7 +37,7 @@ def parse_args():
'-c',
'--config',
type=str,
default='configs/ResNet/ResNet18_vd.yaml',
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'-o',
......@@ -58,8 +54,12 @@ def main(args):
fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True)
place = env.place()
# assign the place
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog = fluid.Program()
train_prog = fluid.Program()
......@@ -70,11 +70,15 @@ def main(args):
valid_prog = fluid.Program()
valid_dataloader, valid_fetchs = program.build(
config, valid_prog, startup_prog, is_train=False)
# clone to prune some content which is irrelevant in valid_prog
valid_prog = valid_prog.clone(for_test=True)
exe = fluid.Executor(place)
# create the "Executor" with the statement of which place
exe = fluid.Executor(place=place)
# only run startup_prog once to init
exe.run(startup_prog)
# load model from checkpoint or pretrained model
init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')()
......@@ -87,13 +91,15 @@ def main(args):
compiled_train_prog = fleet.main_program
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train')
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
program.run(valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid')
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册