提交 886d57f2 编写于 作者: M moran

Add ST & Optimize template

上级 05816bdf
......@@ -11,7 +11,7 @@ These are examples of training AlexNet with dataset in MindSpore.
- Download the dataset, the directory structure is as follows:
{% if dataset=='Cifar10' %}
CIFAR-10
Cifar10
```
└─Data
......
......@@ -28,7 +28,7 @@ cfg = edict({
'lr': 0.002,
"momentum": 0.9,
{% elif optimizer=='SGD' %}
'lr': 0.1,
'lr': 0.01,
{% else %}
'lr': 0.001,
{% endif %}
......
......@@ -27,7 +27,8 @@ from src.alexnet import AlexNet
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
......
......@@ -21,9 +21,9 @@ cfg = edict({
'num_classes': 10,
{% if optimizer=='Momentum' %}
'lr': 0.01,
"momentum": 0.9,
'momentum': 0.9,
{% elif optimizer=='SGD' %}
'lr': 0.1,
'lr': 0.01,
{% else %}
'lr': 0.001,
{% endif %}
......
......@@ -21,10 +21,11 @@ import os
import argparse
import mindspore.nn as nn
from mindspore import context, ParallelMode
from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.metrics import Accuracy
......
......@@ -11,7 +11,7 @@ These are examples of training ResNet50 with dataset in MindSpore.
- Download the dataset, the directory structure is as follows:
{% if dataset=='Cifar10' %}
CIFAR-10
Cifar10
```
└─Data
......@@ -50,7 +50,6 @@ ImageNet
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
├── src
├── config.py # parameter configuration
├── crossentropy.py # loss definition for ImageNet2012 dataset
├── dataset.py # data preprocessing
├── lr_generator.py # generate learning rate for each step
└── resnet50.py # resNet50 network definition
......
......@@ -37,7 +37,7 @@ np.random.seed(1)
de.config.set_seed(1)
from src.resnet50 import resnet50 as resnet
from src.resnet50 import resnet50
from src.config import cfg
......@@ -57,7 +57,7 @@ if __name__ == '__main__':
step_size = dataset.get_dataset_size()
# define net
net = resnet(class_num=cfg.num_classes)
net = resnet50(class_num=cfg.num_classes)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
......
......@@ -23,36 +23,35 @@ cfg = ed({
{% elif dataset=='ImageNet' %}
'num_classes': 1001,
{% endif %}
"batch_size": 32,
"loss_scale": 1024,
'batch_size': 32,
'loss_scale': 1024,
{% if optimizer=='Momentum' %}
"lr": 0.01,
"momentum": 0.9,
"lr": 0.01,
'lr': 0.01,
'momentum': 0.9,
{% elif optimizer=='SGD' %}
'lr': 0.1,
'lr': 0.01,
{% else %}
'lr': 0.001,
{% endif %}
"image_height": 224,
"image_width": 224,
"weight_decay": 1e-4,
"epoch_size": 1,
"pretrain_epoch_size": 1,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
'image_height': 224,
'image_width': 224,
'weight_decay': 1e-4,
'epoch_size': 1,
'pretrain_epoch_size': 1,
'save_checkpoint': True,
'save_checkpoint_epochs': 5,
'keep_checkpoint_max': 10,
'save_checkpoint_path': './',
{% if dataset=='ImageNet' %}
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
'warmup_epochs': 0,
'lr_decay_mode': 'cosine',
{% elif dataset=='Cifar10' %}
"warmup_epochs": 5,
"lr_decay_mode": "poly",
'warmup_epochs': 5,
'lr_decay_mode': 'poly',
{% endif %}
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.01,
"lr_end": 0.00001,
"lr_max": 0.1
'use_label_smooth': True,
'label_smooth_factor': 0.1,
'lr_init': 0.01,
'lr_end': 0.00001,
'lr_max': 0.1
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
......@@ -21,7 +21,8 @@ from mindspore import context
from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......@@ -46,7 +47,7 @@ np.random.seed(1)
de.config.set_seed(1)
from src.resnet50 import resnet50 as resnet
from src.resnet50 import resnet50
from src.config import cfg
from src.dataset import create_dataset
......@@ -80,7 +81,7 @@ if __name__ == '__main__':
step_size = dataset.get_dataset_size()
# define net
net = resnet(class_num=cfg.num_classes)
net = resnet50(class_num=cfg.num_classes)
# init weight
if args_opt.pre_trained:
......@@ -156,13 +157,10 @@ if __name__ == '__main__':
{% elif loss=='SoftmaxCrossEntropyExpand' %}
loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
{% endif %}
{% if optimizer=='Momentum' %}
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr, momentum=cfg.momentum)
{% else %}
opt = nn.{{optimizer}}(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate=lr)
{% endif %}
{% endif %}
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test the various combinations based on AlexNet.
"""
import os
import pytest
from mindinsight.wizard.base.utility import load_network_maker
NETWORK_NAME = 'alexnet'
class TestAlexNet:
"""Test AlexNet Module"""
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.parametrize('params', [{
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Momentum',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Adam',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Momentum',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Adam',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'SGD',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Momentum',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Adam',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Momentum',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Adam',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'SGD',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}])
def test_combinations(self, params):
"""Do testing"""
network_maker_name = NETWORK_NAME
config = params['config']
dataset_loader_name = params['dataset_loader_name']
network_maker = load_network_maker(network_maker_name)
network_maker.configure(config)
self.source_files = network_maker.generate(**config)
self.check_scripts()
self.check_src(dataset_loader_name, config)
self.check_train_eval_readme(config['dataset'], config['loss'], config['optimizer'])
def check_src(self, dataset_name, config):
"""Check src file"""
dataset_is_right = False
config_dataset_is_right = False
config_optimizer_is_right = False
network_is_right = False
generator_lr_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'src/dataset.py':
if dataset_name in source_file.content:
dataset_is_right = True
if source_file.file_relative_path == os.path.join('src', NETWORK_NAME.lower()+'.py'):
network_is_right = True
if source_file.file_relative_path == 'src/generator_lr.py':
generator_lr_is_right = True
if source_file.file_relative_path == 'src/config.py':
content = source_file.content
if config['dataset'] == 'Cifar10':
if "'num_classes': 10" in content:
config_dataset_is_right = True
elif config['dataset'] == 'ImageNet':
if "'num_classes': 1001" in content:
config_dataset_is_right = True
if config['optimizer'] == 'Momentum':
if "'lr': 0.002" in content:
config_optimizer_is_right = True
elif config['optimizer'] == 'SGD':
if "'lr': 0.01" in content:
config_optimizer_is_right = True
else:
if "'lr': 0.001" in content:
config_optimizer_is_right = True
assert dataset_is_right
assert config_dataset_is_right
assert config_optimizer_is_right
assert network_is_right
assert generator_lr_is_right
def check_train_eval_readme(self, dataset_name, loss_name, optimizer_name):
"""Check train and eval"""
train_is_right = False
eval_is_right = False
readme_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'train.py':
content = source_file.content
if 'alexnet' in content and loss_name in content and optimizer_name in content:
train_is_right = True
if source_file.file_relative_path == 'eval.py':
content = source_file.content
if 'alexnet' in content and loss_name in content:
eval_is_right = True
if source_file.file_relative_path == 'README.md':
content = source_file.content
if 'AlexNet' in content and dataset_name in content:
readme_is_right = True
assert train_is_right
assert eval_is_right
assert readme_is_right
def check_scripts(self):
"""Check scripts"""
exist_run_distribute_train = False
exist_run_distribute_train_gpu = False
exist_run_eval = False
exist_run_eval_gpu = False
exist_run_standalone_train = False
exist_run_standalone_train_gpu = False
for source_file in self.source_files:
if source_file.file_relative_path == 'scripts/run_distribute_train.sh':
exist_run_distribute_train = True
if source_file.file_relative_path == 'scripts/run_distribute_train_gpu.sh':
exist_run_distribute_train_gpu = True
if source_file.file_relative_path == 'scripts/run_eval.sh':
exist_run_eval = True
if source_file.file_relative_path == 'scripts/run_eval_gpu.sh':
exist_run_eval_gpu = True
if source_file.file_relative_path == 'scripts/run_standalone_train.sh':
exist_run_standalone_train = True
if source_file.file_relative_path == 'scripts/run_standalone_train_gpu.sh':
exist_run_standalone_train_gpu = True
assert exist_run_distribute_train
assert exist_run_distribute_train_gpu
assert exist_run_eval
assert exist_run_eval_gpu
assert exist_run_standalone_train
assert exist_run_standalone_train_gpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test the various combinations based on LeNet.
"""
import os
import pytest
from mindinsight.wizard.base.utility import load_network_maker
NETWORK_NAME = 'lenet'
class TestLeNet:
"""Test LeNet Module"""
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.parametrize('params', [{
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Momentum',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Adam',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Momentum',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Adam',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'MNIST'},
'dataset_loader_name': 'MnistDataset'
}])
def test_combinations(self, params):
"""Do testing"""
network_maker_name = NETWORK_NAME
config = params['config']
dataset_loader_name = params['dataset_loader_name']
network_maker = load_network_maker(network_maker_name)
network_maker.configure(config)
self.source_files = network_maker.generate(**config)
self.check_scripts()
self.check_src(dataset_loader_name, config)
self.check_train_eval_readme(config['loss'], config['optimizer'])
def check_src(self, dataset_name, config):
"""Check src file"""
dataset_is_right = False
config_optimizer_is_right = False
network_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'src/dataset.py':
if dataset_name in source_file.content:
dataset_is_right = True
if source_file.file_relative_path == os.path.join('src', NETWORK_NAME.lower() + '.py'):
network_is_right = True
if source_file.file_relative_path == 'src/config.py':
content = source_file.content
if config['optimizer'] == 'Momentum':
if "'lr': 0.01" in content and \
"'momentum': 0.9" in content:
config_optimizer_is_right = True
elif config['optimizer'] == 'SGD':
if "'lr': 0.01" in content:
config_optimizer_is_right = True
else:
if "'lr': 0.001" in content:
config_optimizer_is_right = True
assert dataset_is_right
assert config_optimizer_is_right
assert network_is_right
def check_train_eval_readme(self, loss_name, optimizer_name):
"""Check train and eval"""
train_is_right = False
eval_is_right = False
readme_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'train.py':
content = source_file.content
if 'LeNet5' in content and loss_name in content and optimizer_name in content:
train_is_right = True
if source_file.file_relative_path == 'eval.py':
content = source_file.content
if 'LeNet5' in content and loss_name in content:
eval_is_right = True
if source_file.file_relative_path == 'README.md':
content = source_file.content
if 'LeNet' in content:
readme_is_right = True
assert train_is_right
assert eval_is_right
assert readme_is_right
def check_scripts(self):
"""Check scripts"""
exist_run_distribute_train = False
exist_run_distribute_train_gpu = False
exist_run_eval = False
exist_run_eval_gpu = False
exist_run_standalone_train = False
exist_run_standalone_train_gpu = False
for source_file in self.source_files:
if source_file.file_relative_path == 'scripts/run_distribute_train.sh':
exist_run_distribute_train = True
if source_file.file_relative_path == 'scripts/run_distribute_train_gpu.sh':
exist_run_distribute_train_gpu = True
if source_file.file_relative_path == 'scripts/run_eval.sh':
exist_run_eval = True
if source_file.file_relative_path == 'scripts/run_eval_gpu.sh':
exist_run_eval_gpu = True
if source_file.file_relative_path == 'scripts/run_standalone_train.sh':
exist_run_standalone_train = True
if source_file.file_relative_path == 'scripts/run_standalone_train_gpu.sh':
exist_run_standalone_train_gpu = True
assert exist_run_distribute_train
assert exist_run_distribute_train_gpu
assert exist_run_eval
assert exist_run_eval_gpu
assert exist_run_standalone_train
assert exist_run_standalone_train_gpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test the various combinations based on ResNet50.
"""
import os
import pytest
from mindinsight.wizard.base.utility import load_network_maker
NETWORK_NAME = 'resnet50'
class TestResNet50:
"""Test ResNet50 Module"""
@pytest.mark.level0
@pytest.mark.env_single
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.parametrize('params', [{
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Momentum',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Adam',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Momentum',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Adam',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'SGD',
'dataset': 'Cifar10'},
'dataset_loader_name': 'Cifar10Dataset'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Momentum',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'Adam',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyWithLogits',
'optimizer': 'SGD',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Momentum',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'Adam',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}, {
'config': {'loss': 'SoftmaxCrossEntropyExpand',
'optimizer': 'SGD',
'dataset': 'ImageNet'},
'dataset_loader_name': 'ImageFolderDatasetV2'
}])
def test_combinations(self, params):
"""Do testing"""
network_maker_name = NETWORK_NAME
config = params['config']
dataset_loader_name = params['dataset_loader_name']
network_maker = load_network_maker(network_maker_name)
network_maker.configure(config)
self.source_files = network_maker.generate(**config)
self.check_scripts()
self.check_src(dataset_loader_name, config)
self.check_train_eval_readme(config['dataset'], config['loss'], config['optimizer'])
def check_src(self, dataset_name, config):
"""Check src file"""
dataset_is_right = False
config_dataset_is_right = False
config_optimizer_is_right = False
network_is_right = False
generator_lr_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'src/dataset.py':
if dataset_name in source_file.content:
dataset_is_right = True
if source_file.file_relative_path == os.path.join('src', NETWORK_NAME.lower() + '.py'):
network_is_right = True
if source_file.file_relative_path == 'src/lr_generator.py':
generator_lr_is_right = True
if source_file.file_relative_path == 'src/config.py':
content = source_file.content
if config['dataset'] == 'Cifar10':
if "'num_classes': 10" in content \
and "'warmup_epochs': 5" in content \
and "'lr_decay_mode': 'poly'" in content:
config_dataset_is_right = True
elif config['dataset'] == 'ImageNet':
if "'num_classes': 1001" in content \
and "'warmup_epochs': 0" in content \
and "'lr_decay_mode': 'cosine'":
config_dataset_is_right = True
if config['optimizer'] == 'Momentum':
if "'lr': 0.01" in content and \
"'momentum': 0.9" in content:
config_optimizer_is_right = True
elif config['optimizer'] == 'SGD':
if "'lr': 0.01" in content:
config_optimizer_is_right = True
else:
if "'lr': 0.001" in content:
config_optimizer_is_right = True
assert dataset_is_right
assert config_dataset_is_right
assert config_optimizer_is_right
assert network_is_right
assert generator_lr_is_right
def check_train_eval_readme(self, dataset_name, loss_name, optimizer_name):
"""Check train and eval"""
train_is_right = False
eval_is_right = False
readme_is_right = False
for source_file in self.source_files:
if source_file.file_relative_path == 'train.py':
content = source_file.content
if 'resnet50' in content and loss_name in content and optimizer_name in content:
train_is_right = True
if source_file.file_relative_path == 'eval.py':
content = source_file.content
if 'resnet50' in content and loss_name in content:
eval_is_right = True
if source_file.file_relative_path == 'README.md':
content = source_file.content
if 'ResNet50' in content and dataset_name in content:
readme_is_right = True
assert train_is_right
assert eval_is_right
assert readme_is_right
def check_scripts(self):
"""Check scripts"""
exist_run_distribute_train = False
exist_run_distribute_train_gpu = False
exist_run_eval = False
exist_run_eval_gpu = False
exist_run_standalone_train = False
exist_run_standalone_train_gpu = False
for source_file in self.source_files:
if source_file.file_relative_path == 'scripts/run_distribute_train.sh':
exist_run_distribute_train = True
if source_file.file_relative_path == 'scripts/run_distribute_train_gpu.sh':
exist_run_distribute_train_gpu = True
if source_file.file_relative_path == 'scripts/run_eval.sh':
exist_run_eval = True
if source_file.file_relative_path == 'scripts/run_eval_gpu.sh':
exist_run_eval_gpu = True
if source_file.file_relative_path == 'scripts/run_standalone_train.sh':
exist_run_standalone_train = True
if source_file.file_relative_path == 'scripts/run_standalone_train_gpu.sh':
exist_run_standalone_train_gpu = True
assert exist_run_distribute_train
assert exist_run_distribute_train_gpu
assert exist_run_eval
assert exist_run_eval_gpu
assert exist_run_standalone_train
assert exist_run_standalone_train_gpu
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册