未验证 提交 7546121b 编写于 作者: D dyning 提交者: GitHub

Merge pull request #46 from WuHaobo/master

Add comments for train
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks: hooks:
- id: yapf - id: yapf
files: \.py$ files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464 - repo: https://github.com/pre-commit/pre-commit-hooks
hooks: sha: a11d9314b22d8f8c7556443875b731ef05965464
- id: check-merge-conflict hooks:
- id: check-symlinks - id: flake8
- id: detect-private-key args: ['--ignore=E265']
files: (?!.*paddle)^.*$ - id: check-yaml
- id: end-of-file-fixer - id: check-merge-conflict
files: \.(md|yml)$ - id: check-symlinks
- id: trailing-whitespace - id: detect-private-key
files: \.(md|yml)$ files: (?!.*paddle)^.*$
- repo: https://github.com/Lucas-C/pre-commit-hooks - id: end-of-file-fixer
sha: v1.0.1 files: \.(md|yml)$
hooks: - id: trailing-whitespace
- id: forbid-crlf files: \.(md|yml)$
files: \.(md|yml)$
- id: remove-crlf - repo: https://github.com/Lucas-C/pre-commit-hooks
files: \.(md|yml)$ sha: v1.0.1
- id: forbid-tabs hooks:
files: \.(md|yml)$ - id: forbid-crlf
- id: remove-tabs files: \.(md|yml)$
files: \.(md|yml)$ - id: remove-crlf
files: \.(md|yml)$
- id: forbid-tabs
files: \.(md|yml)$
- id: remove-tabs
files: \.(md|yml)$
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import environment
from . import model_zoo from . import model_zoo
from . import misc from . import misc
from . import logger from . import logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as pff
trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
def place():
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
return fluid.CUDAPlace(gpu_id)
def places():
"""
Returns available running places, the numbers are usually
indicated by 'export CUDA_VISIBLE_DEVICES= '
Args:
"""
if trainers_num <= 1:
return pff.cuda_places()
else:
return place()
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys import os
import argparse import argparse
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import program import program
from ppcls.data import Reader from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model from ppcls.utils.save_load import init_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
...@@ -40,7 +37,7 @@ def parse_args(): ...@@ -40,7 +37,7 @@ def parse_args():
'-c', '-c',
'--config', '--config',
type=str, type=str,
default='configs/eval.yaml', default='./configs/eval.yaml',
help='config file path') help='config file path')
parser.add_argument( parser.add_argument(
'-o', '-o',
...@@ -58,7 +55,8 @@ def main(args): ...@@ -58,7 +55,8 @@ def main(args):
fleet.init(role) fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True) config = get_config(args.config, overrides=args.override, show=True)
place = env.place() gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
startup_prog = fluid.Program() startup_prog = fluid.Program()
valid_prog = fluid.Program() valid_prog = fluid.Program()
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -18,18 +18,14 @@ from __future__ import print_function ...@@ -18,18 +18,14 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import program import program
from ppcls.data import Reader from ppcls.data import Reader
import ppcls.utils.environment as env
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
...@@ -41,7 +37,7 @@ def parse_args(): ...@@ -41,7 +37,7 @@ def parse_args():
'-c', '-c',
'--config', '--config',
type=str, type=str,
default='configs/ResNet/ResNet18_vd.yaml', default='configs/ResNet/ResNet50.yaml',
help='config file path') help='config file path')
parser.add_argument( parser.add_argument(
'-o', '-o',
...@@ -58,8 +54,12 @@ def main(args): ...@@ -58,8 +54,12 @@ def main(args):
fleet.init(role) fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True) config = get_config(args.config, overrides=args.override, show=True)
place = env.place() # assign the place
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id)
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
...@@ -70,11 +70,15 @@ def main(args): ...@@ -70,11 +70,15 @@ def main(args):
valid_prog = fluid.Program() valid_prog = fluid.Program()
valid_dataloader, valid_fetchs = program.build( valid_dataloader, valid_fetchs = program.build(
config, valid_prog, startup_prog, is_train=False) config, valid_prog, startup_prog, is_train=False)
# clone to prune some content which is irrelevant in valid_prog
valid_prog = valid_prog.clone(for_test=True) valid_prog = valid_prog.clone(for_test=True)
exe = fluid.Executor(place) # create the "Executor" with the statement of which place
exe = fluid.Executor(place=place)
# only run startup_prog once to init
exe.run(startup_prog) exe.run(startup_prog)
# load model from checkpoint or pretrained model
init_model(config, train_prog, exe) init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')() train_reader = Reader(config, 'train')()
...@@ -87,13 +91,15 @@ def main(args): ...@@ -87,13 +91,15 @@ def main(args):
compiled_train_prog = fleet.main_program compiled_train_prog = fleet.main_program
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train') epoch_id, 'train')
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
program.run(valid_dataloader, exe, compiled_valid_prog, program.run(valid_dataloader, exe, compiled_valid_prog,
valid_fetchs, epoch_id, 'valid') valid_fetchs, epoch_id, 'valid')
# 3. save the persistable model
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"]) config.ARCHITECTURE["name"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册