From dad2dab0154cf1e4573995c20d5beca9d6814995 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Wed, 27 Jan 2021 17:13:26 +0800 Subject: [PATCH] add device num for xpu (#581) * add device num for xpu * fix vgg --- ppcls/data/reader.py | 25 +++++++++++++++++-------- ppcls/modeling/architectures/vgg.py | 23 ++++++++++++++--------- tools/program.py | 2 +- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index b8ed405e..6bd4a58a 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -25,7 +25,7 @@ from . import imaug from .imaug import transform from ppcls.utils import logger -trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 0)) +trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0)) @@ -257,6 +257,9 @@ class Reader: raise ModeException(mode=mode) self.use_gpu = config.get("use_gpu", True) + self.use_xpu = config.get("use_xpu", False) + self.is_distributed = config.get("is_distributed", True) + use_mix = config.get('use_mix') self.params['mode'] = mode if seed is not None: @@ -265,14 +268,20 @@ class Reader: if use_mix and mode == "train": self.batch_ops = create_operators(self.params['mix']) + def get_device_num(self): + if self.is_distributed: + device_num = trainers_num + elif self.use_gpu: + device_num = fluid.core.get_cuda_device_count() + elif self.use_xpu: + xpus = os.environ.get("FLAGS_selected_xpus", '1') + device_num = len(xpus.split(',')) + else: + device_num = int(os.environ.get('CPU_NUM', 1)) + return device_num + def __call__(self): - device_num = trainers_num - # non-distributed launch - if trainers_num <= 0: - if self.use_gpu: - device_num = fluid.core.get_cuda_device_count() - else: - device_num = int(os.environ.get('CPU_NUM', 1)) + device_num = self.get_device_num() batch_size = int(self.params['batch_size']) // device_num def wrapper(): diff --git a/ppcls/modeling/architectures/vgg.py b/ppcls/modeling/architectures/vgg.py index 2c5b77ea..c7f2795e 100644 --- a/ppcls/modeling/architectures/vgg.py +++ b/ppcls/modeling/architectures/vgg.py @@ -23,8 +23,9 @@ __all__ = ["VGGNet", "VGG11", "VGG13", "VGG16", "VGG19"] class VGGNet(): - def __init__(self, layers=16): + def __init__(self, layers=16, stop_grad_layers=0, **args): self.layers = layers + self.stop_grad_layers = stop_grad_layers def net(self, input, class_dim=1000): layers = self.layers @@ -44,6 +45,10 @@ class VGGNet(): conv4 = self.conv_block(conv3, 512, nums[3], name="conv4_") conv5 = self.conv_block(conv4, 512, nums[4], name="conv5_") + for idx, conv in enumerate([conv1, conv2, conv3, conv4, conv5]): + if self.stop_grad_layers >= idx + 1: + conv.stop_gradient = True + fc_dim = 4096 fc_name = ["fc6", "fc7", "fc8"] fc1 = fluid.layers.fc( @@ -88,21 +93,21 @@ class VGGNet(): input=conv, pool_size=2, pool_type='max', pool_stride=2) -def VGG11(): - model = VGGNet(layers=11) +def VGG11(stop_grad_layers=0, **args): + model = VGGNet(layers=11, stop_grad_layers=stop_grad_layers, **args) return model -def VGG13(): - model = VGGNet(layers=13) +def VGG13(stop_grad_layers=0, **args): + model = VGGNet(layers=13, stop_grad_layers=stop_grad_layers, **args) return model -def VGG16(): - model = VGGNet(layers=16) +def VGG16(stop_grad_layers=0, **args): + model = VGGNet(layers=16, stop_grad_layers=stop_grad_layers, **args) return model -def VGG19(): - model = VGGNet(layers=19) +def VGG19(stop_grad_layers=0, **args): + model = VGGNet(layers=19, stop_grad_layers=stop_grad_layers, **args) return model diff --git a/tools/program.py b/tools/program.py index 90770ed0..ea2ce536 100644 --- a/tools/program.py +++ b/tools/program.py @@ -282,6 +282,7 @@ def create_optimizer(config): # create optimizer instance opt_config = config['OPTIMIZER'] opt = OptimizerBuilder(**opt_config) + return opt(lr) @@ -305,7 +306,6 @@ def dist_optimizer(config, optimizer): dist_strategy.fuse_all_reduce_ops = True dist_strategy.exec_strategy = exec_strategy optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) - return optimizer -- GitLab