提交 bb622b24 编写于 作者: S shippingwang

fix bugs

上级 04e9720b
...@@ -33,6 +33,3 @@ ...@@ -33,6 +33,3 @@
- id: trailing-whitespace - id: trailing-whitespace
files: \.(md|yml)$ files: \.(md|yml)$
- id: check-case-conflict - id: check-case-conflict
- id: flake8
args: ['--ignore=E265']
...@@ -46,10 +46,10 @@ TRAIN: ...@@ -46,10 +46,10 @@ TRAIN:
channel_first: False channel_first: False
- RandCropImage: - RandCropImage:
size: 224 size: 224
interpolation: 2 #interpolation: 2
- RandFlipImage: - RandFlipImage:
flip_code: 1 flip_code: 1
- AutoArgument: - AutoAugment:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
......
...@@ -25,7 +25,8 @@ import random ...@@ -25,7 +25,8 @@ import random
import cv2 import cv2
import numpy as np import numpy as np
from autoargument import ImageNetPolicy from .autoaugment import ImageNetPolicy
class OperatorParamError(ValueError): class OperatorParamError(ValueError):
""" OperatorParamError """ OperatorParamError
...@@ -172,12 +173,12 @@ class RandFlipImage(object): ...@@ -172,12 +173,12 @@ class RandFlipImage(object):
else: else:
return img return img
class AutoArgument(object):
class AutoAugment(object):
def __init__(self): def __init__(self):
self.policy = ImageNetPolicy() self.policy = ImageNetPolicy()
def __call__(self,img): def __call__(self, img):
from PIL import Image from PIL import Image
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
img = Image.fromarray(img) img = Image.fromarray(img)
......
...@@ -383,7 +383,9 @@ class EfficientNet(): ...@@ -383,7 +383,9 @@ class EfficientNet():
use_bias=True, use_bias=True,
padding_type=self.padding_type, padding_type=self.padding_type,
name=name + '_se_expand') name=name + '_se_expand')
se_out = inputs * fluid.layers.sigmoid(x_squeezed) #se_out = inputs * fluid.layers.sigmoid(x_squeezed)
se_out = fluid.layers.elementwise_mul(
inputs, fluid.layers.sigmoid(x_squeezed), axis=-1)
return se_out return se_out
def extract_features(self, inputs, is_test): def extract_features(self, inputs, is_test):
...@@ -467,8 +469,8 @@ class BlockDecoder(object): ...@@ -467,8 +469,8 @@ class BlockDecoder(object):
# Check stride # Check stride
cond_1 = ('s' in options and len(options['s']) == 1) cond_1 = ('s' in options and len(options['s']) == 1)
cond_2 = ((len(options['s']) == 2) cond_2 = ((len(options['s']) == 2) and
and (options['s'][0] == options['s'][1])) (options['s'][0] == options['s'][1]))
assert (cond_1 or cond_2) assert (cond_1 or cond_2)
return BlockArgs( return BlockArgs(
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -130,7 +130,7 @@ class CosineWarmup(object): ...@@ -130,7 +130,7 @@ class CosineWarmup(object):
with fluid.layers.control_flow.Switch() as switch: with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < self.warmup_epoch): with switch.case(epoch < self.warmup_epoch):
decayed_lr = self.lr * \ decayed_lr = self.lr * \
(global_step / (self.step_each_epoch * self.warmup_epoch)) (global_step / (self.step_each_epoch * self.warmup_epoch))
fluid.layers.tensor.assign( fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate) input=decayed_lr, output=learning_rate)
with switch.default(): with switch.default():
...@@ -146,7 +146,6 @@ class CosineWarmup(object): ...@@ -146,7 +146,6 @@ class CosineWarmup(object):
class ExponentialWarmup(object): class ExponentialWarmup(object):
""" """
Exponential learning rate decay with warmup Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup [0, warmup_epoch): linear warmup
...@@ -160,8 +159,14 @@ class ExponentialWarmup(object): ...@@ -160,8 +159,14 @@ class ExponentialWarmup(object):
warmup_epoch(int): epoch num of warmup warmup_epoch(int): epoch num of warmup
""" """
def __init__(self, lr, step_each_epoch, decay_epochs=2.4, decay_rate=0.97, warmup_epoch=5, **kwargs): def __init__(self,
super(CosineWarmup, self).__init__() lr,
step_each_epoch,
decay_epochs=2.4,
decay_rate=0.97,
warmup_epoch=5,
**kwargs):
super(ExponentialWarmup, self).__init__()
self.lr = lr self.lr = lr
self.step_each_epoch = step_each_epoch self.step_each_epoch = step_each_epoch
self.decay_epochs = decay_epochs * self.step_each_epoch self.decay_epochs = decay_epochs * self.step_each_epoch
...@@ -185,19 +190,20 @@ class ExponentialWarmup(object): ...@@ -185,19 +190,20 @@ class ExponentialWarmup(object):
with fluid.layers.control_flow.Switch() as switch: with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < self.warmup_epoch): with switch.case(epoch < self.warmup_epoch):
decayed_lr = self.lr * \ decayed_lr = self.lr * \
(global_step / (self.step_each_epoch * self.warmup_epoch)) (global_step / (self.step_each_epoch * self.warmup_epoch))
fluid.layers.tensor.assign( fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate) input=decayed_lr, output=learning_rate)
with switch.default(): with switch.default():
rest_step = global_step - self.warmup_epoch * self.step_each_epoch rest_step = global_step - self.warmup_epoch * self.step_each_epoch
div_res = ops.floor(rest_step / self.decay_epochs) div_res = ops.floor(rest_step / self.decay_epochs)
decayed_lr = self.lr*(self.decay_rate**div_res) decayed_lr = self.lr * (self.decay_rate**div_res)
fluid.layers.tensor.assign( fluid.layers.tensor.assign(
input=decayed_lr, output=learning_rate) input=decayed_lr, output=learning_rate)
return learning_rate return learning_rate
class LearningRateBuilder(): class LearningRateBuilder():
""" """
Build learning rate variable Build learning rate variable
......
...@@ -36,7 +36,7 @@ from ppcls.utils import logger ...@@ -36,7 +36,7 @@ from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy from paddle.fluid.incubate.fleet.collective import DistributedStrategy
import ema from ema import ExponentialMovingAverage
def create_feeds(image_shape, use_mix=None): def create_feeds(image_shape, use_mix=None):
...@@ -359,10 +359,12 @@ def build(config, main_prog, startup_prog, is_train=True): ...@@ -359,10 +359,12 @@ def build(config, main_prog, startup_prog, is_train=True):
optimizer.minimize(fetchs['loss'][0]) optimizer.minimize(fetchs['loss'][0])
if config.get('use_ema'): if config.get('use_ema'):
global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter() global_steps = fluid.layers.learning_rate_scheduler._decay_step_counter(
ema = ExponentialMovingAverage(config.get('ema_decay'), thres_steps=global_steps) )
ema = ExponentialMovingAverage(
config.get('ema_decay'), thres_steps=global_steps)
ema.update() ema.update()
fetchs['ema'] = ema return dataloader, fetchs, ema
return dataloader, fetchs return dataloader, fetchs
...@@ -396,7 +398,13 @@ def compile(config, program, loss_name=None): ...@@ -396,7 +398,13 @@ def compile(config, program, loss_name=None):
total_step = 0 total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None): def run(dataloader,
exe,
program,
fetchs,
epoch=0,
mode='train',
vdl_writer=None):
""" """
Feed data to the model and fetch the measures and loss Feed data to the model and fetch the measures and loss
...@@ -410,6 +418,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None ...@@ -410,6 +418,7 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None
Returns: Returns:
""" """
print(fetchs)
fetch_list = [f[0] for f in fetchs.values()] fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()]
for m in metric_list: for m in metric_list:
......
...@@ -70,8 +70,12 @@ def main(args): ...@@ -70,8 +70,12 @@ def main(args):
best_top1_acc = 0.0 # best top1 acc record best_top1_acc = 0.0 # best top1 acc record
train_dataloader, train_fetchs = program.build( if not config.get('use_ema'):
config, train_prog, startup_prog, is_train=True) train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True)
else:
train_dataloader, train_fetchs, ema = program.build(
config, train_prog, startup_prog, is_train=True)
if config.validate: if config.validate:
valid_prog = fluid.Program() valid_prog = fluid.Program()
...@@ -81,11 +85,11 @@ def main(args): ...@@ -81,11 +85,11 @@ def main(args):
valid_prog = valid_prog.clone(for_test=True) valid_prog = valid_prog.clone(for_test=True)
# create the "Executor" with the statement of which place # create the "Executor" with the statement of which place
exe = fluid.Executor(place=place) exe = fluid.Executor(place)
# only run startup_prog once to init # Parameter initialization
exe.run(startup_prog) exe.run(startup_prog)
# load model from checkpoint or pretrained model # load model from 1. checkpoint to resume training, 2. pretrained model to finetune
init_model(config, train_prog, exe) init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')() train_reader = Reader(config, 'train')()
...@@ -110,8 +114,8 @@ def main(args): ...@@ -110,8 +114,8 @@ def main(args):
logger.info(logger.coloring("EMA validate start...")) logger.info(logger.coloring("EMA validate start..."))
with train_fetchs('ema').apply(exe): with train_fetchs('ema').apply(exe):
top1_acc = program.run(valid_dataloader, exe, top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs, compiled_valid_prog,
epoch_id, 'valid') valid_fetchs, epoch_id, 'valid')
logger.info(logger.coloring("EMA validate over!")) logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe, top1_acc = program.run(valid_dataloader, exe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册