提交 8a80a7da 编写于 作者: littletomatodonkey's avatar littletomatodonkey

remove python path config and support cpu train/val/infer

上级 c7a8c89f
......@@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from ppcls import model_zoo
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def parse_args():
......
......@@ -13,19 +13,22 @@
# limitations under the License.
from __future__ import absolute_import
import program
from ppcls.utils import logger
from ppcls.utils.save_load import init_model
from ppcls.utils.config import get_config
from ppcls.data import Reader
import paddle.fluid as fluid
import paddle
import argparse
from __future__ import division
from __future__ import print_function
import os
import argparse
from ppcls.data import Reader
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def parse_args():
......@@ -47,21 +50,26 @@ def parse_args():
def main(args):
# assign the place
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
config = get_config(args.config, overrides=args.override, show=True)
# assign place
use_gpu = config.get("use_gpu", True)
if use_gpu:
gpu_id = fluid.dygraph.ParallelEnv().dev_id
place = fluid.CUDAPlace(gpu_id)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0]
strategy = fluid.dygraph.parallel.prepare_context()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = fluid.dygraph.parallel.DataParallel(net, strategy)
net.set_dict(pre_weights_dict)
init_model(config, net, optimizer=None)
valid_dataloader = program.create_dataloader()
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
if __name__ == '__main__':
args = parse_args()
main(args)
export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0" \
tools/eval.py \
-c ./configs/eval.yaml
-c ./configs/eval.yaml \
-o load_static_weights=True \
-o use_gpu=False
#!/usr/bin/env bash
export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
tools/train.py \
......
#!/usr/bin/env bash
export PYTHONPATH=$PWD:$PYTHONPATH
python tools/download.py -a ResNet34 -p ./pretrained/ -d 1
......@@ -13,19 +13,21 @@
# limitations under the License.
from __future__ import absolute_import
import program
from ppcls.utils import logger
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils.config import get_config
from ppcls.data import Reader
import paddle.fluid as fluid
from __future__ import division
from __future__ import print_function
import argparse
import os
import paddle.fluid as fluid
from ppcls.data import Reader
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger
import program
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def parse_args():
......@@ -49,8 +51,12 @@ def parse_args():
def main(args):
config = get_config(args.config, overrides=args.override, show=True)
# assign the place
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
use_gpu = config.get("use_gpu", True)
if use_gpu:
gpu_id = fluid.dygraph.ParallelEnv().dev_id
place = fluid.CUDAPlace(gpu_id)
else:
place = fluid.CPUPlace()
use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1
config["use_data_parallel"] = use_data_parallel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册