未验证 提交 9f43bbcc 编写于 作者: C ceci3 提交者: GitHub

fix bug for OFA (#464)

* fix bugs for ernie
上级 c6fdcc3f
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.fluid.dygraph.nn as nn import paddle.nn.functional as F
from paddle.nn import ReLU from paddle.nn import ReLU
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa import supernet from paddleslim.nas.ofa import supernet
class Model(fluid.dygraph.Layer): class Model(nn.Layer):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
with supernet( with supernet(
...@@ -50,18 +50,20 @@ class Model(fluid.dygraph.Layer): ...@@ -50,18 +50,20 @@ class Model(fluid.dygraph.Layer):
for idx, layer in enumerate(models): for idx, layer in enumerate(models):
if idx == 6: if idx == 6:
inputs = fluid.layers.flatten(inputs, 1) inputs = paddle.flatten(inputs, 1)
inputs = layer(inputs) inputs = layer(inputs)
inputs = fluid.layers.softmax(inputs) inputs = F.softmax(inputs)
return inputs return inputs
def test_ofa(): def test_ofa():
model = Model()
teacher_model = Model()
default_run_config = { default_run_config = {
'train_batch_size': 256, 'train_batch_size': 256,
'eval_batch_size': 64,
'n_epochs': [[1], [2, 3], [4, 5]], 'n_epochs': [[1], [2, 3], [4, 5]],
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
'dynamic_batch_size': [1, 1, 1], 'dynamic_batch_size': [1, 1, 1],
...@@ -72,42 +74,46 @@ def test_ofa(): ...@@ -72,42 +74,46 @@ def test_ofa():
default_distill_config = { default_distill_config = {
'lambda_distill': 0.01, 'lambda_distill': 0.01,
'teacher_model': Model, 'teacher_model': teacher_model,
'mapping_layers': ['models.0.fn'] 'mapping_layers': ['models.0.fn']
} }
distill_config = DistillConfig(**default_distill_config) distill_config = DistillConfig(**default_distill_config)
fluid.enable_dygraph()
model = Model()
ofa_model = OFA(model, run_config, distill_config=distill_config) ofa_model = OFA(model, run_config, distill_config=distill_config)
train_reader = paddle.fluid.io.batch( train_dataset = paddle.vision.datasets.MNIST(
paddle.dataset.mnist.train(), batch_size=256, drop_last=True) mode='train', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
batch_size=64)
start_epoch = 0 start_epoch = 0
for idx in range(len(run_config.n_epochs)): for idx in range(len(run_config.n_epochs)):
cur_idx = run_config.n_epochs[idx] cur_idx = run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)): for ph_idx in range(len(cur_idx)):
cur_lr = run_config.init_learning_rate[idx][ph_idx] cur_lr = run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=cur_lr, learning_rate=cur_lr,
parameter_list=(ofa_model.parameters() + ofa_model.netAs_param)) parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch, for epoch_id in range(start_epoch,
run_config.n_epochs[idx][ph_idx]): run_config.n_epochs[idx][ph_idx]):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_loader()):
dy_x_data = np.array( dy_x_data = np.array(
[x[0].reshape(1, 28, 28) [x[0].reshape(1, 28, 28)
for x in data]).astype('float32') for x in data]).astype('float32')
y_data = np.array( y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1) [x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(dy_x_data) img = paddle.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data) label = paddle.dygraph.to_variable(y_data)
label.stop_gradient = True label.stop_gradient = True
for model_no in range(run_config.dynamic_batch_size[idx]): for model_no in range(run_config.dynamic_batch_size[idx]):
output, _ = ofa_model(img, label) output, _ = ofa_model(img, label)
loss = fluid.layers.reduce_mean(output) loss = F.mean(output)
dis_loss = ofa_model.calc_distill_loss() dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss loss += dis_loss
loss.backward() loss.backward()
......
...@@ -19,6 +19,7 @@ from .sa_nas import * ...@@ -19,6 +19,7 @@ from .sa_nas import *
from .rl_nas import * from .rl_nas import *
from ..nas import darts from ..nas import darts
from .darts import * from .darts import *
from .ofa import *
__all__ = [] __all__ = []
__all__ += sa_nas.__all__ __all__ += sa_nas.__all__
......
...@@ -16,9 +16,8 @@ import inspect ...@@ -16,9 +16,8 @@ import inspect
import decorator import decorator
import logging import logging
import paddle import paddle
import paddle.fluid as fluid import numbers
from paddle.fluid import framework from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm, LayerNorm, Embedding
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm
from .layers import * from .layers import *
from ...common import get_logger from ...common import get_logger
...@@ -26,7 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -26,7 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO)
__all__ = ['supernet'] __all__ = ['supernet']
WEIGHT_LAYER = ['conv', 'linear'] WEIGHT_LAYER = ['conv', 'linear', 'embedding']
### TODO: add decorator ### TODO: add decorator
...@@ -45,7 +44,7 @@ class Convert: ...@@ -45,7 +44,7 @@ class Convert:
cur_channel = None cur_channel = None
for idx, layer in enumerate(model): for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower() cls_name = layer.__class__.__name__.lower()
if 'conv' in cls_name or 'linear' in cls_name: if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
weight_layer_count += 1 weight_layer_count += 1
last_weight_layer_idx = idx last_weight_layer_idx = idx
if first_weight_layer_idx == -1: if first_weight_layer_idx == -1:
...@@ -63,7 +62,7 @@ class Convert: ...@@ -63,7 +62,7 @@ class Convert:
new_attr_name = [ new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr', '_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype' '_bias_attr', '_use_cudnn', '_act', '_dtype', '_padding'
] ]
new_attr_dict = dict() new_attr_dict = dict()
...@@ -179,6 +178,8 @@ class Convert: ...@@ -179,6 +178,8 @@ class Convert:
layer._parameters['weight'].shape[0]) layer._parameters['weight'].shape[0])
elif self.context.channel: elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel) new_attr_dict['num_channels'] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']
for attr in new_attr_name: for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr] new_attr_dict[attr[1:]] = attr_dict[attr]
...@@ -196,7 +197,8 @@ class Convert: ...@@ -196,7 +197,8 @@ class Convert:
new_attr_name = [ new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr', '_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size' '_padding', '_bias_attr', '_use_cudnn', '_act', '_dtype',
'_output_size'
] ]
assert attr_dict[ assert attr_dict[
'_filter_size'] != None, "Conv2DTranspose only support filter size != None now" '_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
...@@ -371,6 +373,8 @@ class Convert: ...@@ -371,6 +373,8 @@ class Convert:
layer._parameters['scale'].shape[0]) layer._parameters['scale'].shape[0])
elif self.context.channel: elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel) new_attr_dict['num_channels'] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']
for attr in new_attr_name: for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr] new_attr_dict[attr[1:]] = attr_dict[attr]
...@@ -380,6 +384,76 @@ class Convert: ...@@ -380,6 +384,76 @@ class Convert:
layer = SuperInstanceNorm(**new_attr_dict) layer = SuperInstanceNorm(**new_attr_dict)
model[idx] = layer model[idx] = layer
elif isinstance(layer, LayerNorm) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
### TODO(ceci3): fix when normalized_shape != last_dim_of_input
if idx > last_weight_layer_idx:
continue
attr_dict = layer.__dict__
new_attr_name = [
'_scale', '_shift', '_param_attr', '_bias_attr', '_act',
'_dtype', '_epsilon'
]
new_attr_dict = dict()
if self.context.expand:
new_attr_dict[
'normalized_shape'] = self.context.expand * int(
attr_dict['_normalized_shape'][0])
elif self.context.channel:
new_attr_dict['normalized_shape'] = max(cur_channel)
else:
new_attr_dict['normalized_shape'] = attr_dict[
'_normalized_shape']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer, attr_dict
layer = SuperLayerNorm(**new_attr_dict)
model[idx] = layer
elif isinstance(layer, Embedding) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = [
'_is_sparse', '_is_distributed', '_padding_idx',
'_param_attr', '_dtype'
]
new_attr_dict = dict()
new_attr_dict['candidate_config'] = dict()
bef_size = attr_dict['_size']
if self.context.expand:
new_attr_dict['size'] = [
bef_size[0], self.context.expand * bef_size[1]
]
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['size'] = bef_size
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer, attr_dict
layer = Block(SuperEmbedding(**new_attr_dict), key=key)
model[idx] = layer
return model return model
......
...@@ -28,7 +28,7 @@ __all__ = [ ...@@ -28,7 +28,7 @@ __all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block', 'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose' 'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
] ]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -70,9 +70,10 @@ class Block(BaseBlock): ...@@ -70,9 +70,10 @@ class Block(BaseBlock):
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None. key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
""" """
def __init__(self, fn, key=None): def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key) super(Block, self).__init__(key)
self.fn = fn self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -208,7 +209,6 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -208,7 +209,6 @@ class SuperConv2D(fluid.dygraph.Conv2D):
act=None, act=None,
dtype='float32'): dtype='float32'):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain ### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
### TODO: change padding to any padding
super(SuperConv2D, self).__init__( super(SuperConv2D, self).__init__(
num_channels, num_filters, filter_size, stride, padding, dilation, num_channels, num_filters, filter_size, stride, padding, dilation,
groups, param_attr, bias_attr, use_cudnn, act, dtype) groups, param_attr, bias_attr, use_cudnn, act, dtype)
...@@ -228,7 +228,7 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -228,7 +228,7 @@ class SuperConv2D(fluid.dygraph.Conv2D):
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[ self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None 'channel'] if 'channel' in candidate_config else None
self.base_channel = None self.base_channel = self._num_filters
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_channel = int(self._num_filters / max(self.expand_ratio)) self.base_channel = int(self._num_filters / max(self.expand_ratio))
...@@ -296,6 +296,11 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -296,6 +296,11 @@ class SuperConv2D(fluid.dygraph.Conv2D):
if not in_dygraph_mode(): if not in_dygraph_mode():
_logger.error("NOT support static graph") _logger.error("NOT support static graph")
self.cur_config = {
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'channel': channel
}
in_nc = int(input.shape[1]) in_nc = int(input.shape[1])
assert ( assert (
expand_ratio == None or channel == None expand_ratio == None or channel == None
...@@ -313,7 +318,11 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -313,7 +318,11 @@ class SuperConv2D(fluid.dygraph.Conv2D):
out_nc) out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
padding = convert_to_list(get_same_padding(ks), 2)
if kernel_size != None or 'kernel_size' in self.candidate_config.keys():
padding = convert_to_list(get_same_padding(ks), 2)
else:
padding = self._padding
if self._l_type == 'conv2d': if self._l_type == 'conv2d':
attrs = ('strides', self._stride, 'paddings', padding, 'dilations', attrs = ('strides', self._stride, 'paddings', padding, 'dilations',
...@@ -488,7 +497,6 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -488,7 +497,6 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
use_cudnn=True, use_cudnn=True,
act=None, act=None,
dtype='float32'): dtype='float32'):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
super(SuperConv2DTranspose, self).__init__( super(SuperConv2DTranspose, self).__init__(
num_channels, num_filters, filter_size, output_size, padding, num_channels, num_filters, filter_size, output_size, padding,
stride, dilation, groups, param_attr, bias_attr, use_cudnn, act, stride, dilation, groups, param_attr, bias_attr, use_cudnn, act,
...@@ -507,7 +515,7 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -507,7 +515,7 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[ self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None 'channel'] if 'channel' in candidate_config else None
self.base_channel = None self.base_channel = self._num_filters
if self.expand_ratio: if self.expand_ratio:
self.base_channel = int(self._num_filters / max(self.expand_ratio)) self.base_channel = int(self._num_filters / max(self.expand_ratio))
...@@ -572,6 +580,11 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -572,6 +580,11 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
if not in_dygraph_mode(): if not in_dygraph_mode():
_logger.error("NOT support static graph") _logger.error("NOT support static graph")
self.cur_config = {
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'channel': channel
}
in_nc = int(input.shape[1]) in_nc = int(input.shape[1])
assert ( assert (
expand_ratio == None or channel == None expand_ratio == None or channel == None
...@@ -590,7 +603,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -590,7 +603,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
out_nc) out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
padding = convert_to_list(get_same_padding(ks), 2) if kernel_size != None or 'kernel_size' in self.candidate_config.keys():
padding = convert_to_list(get_same_padding(ks), 2)
else:
padding = self._padding
op = getattr(core.ops, self._op_type) op = getattr(core.ops, self._op_type)
out = op(input, weight, 'output_size', self._output_size, 'strides', out = op(input, weight, 'output_size', self._output_size, 'strides',
...@@ -701,7 +717,7 @@ class SuperSeparableConv2D(fluid.dygraph.Layer): ...@@ -701,7 +717,7 @@ class SuperSeparableConv2D(fluid.dygraph.Layer):
self.conv.extend([norm_layer(num_channels * scale_factor)]) self.conv.extend([norm_layer(num_channels * scale_factor)])
self.conv.extend([ self.conv.extend([
Conv2D( fluid.dygraph.nn.Conv2D(
num_channels=num_channels * scale_factor, num_channels=num_channels * scale_factor,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
...@@ -713,14 +729,16 @@ class SuperSeparableConv2D(fluid.dygraph.Layer): ...@@ -713,14 +729,16 @@ class SuperSeparableConv2D(fluid.dygraph.Layer):
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = None self.base_output_dim = self.conv[0]._num_filters
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) self.base_output_dim = int(self.conv[0]._num_filters /
max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None): def forward(self, input, expand_ratio=None, channel=None):
if not in_dygraph_mode(): if not in_dygraph_mode():
_logger.error("NOT support static graph") _logger.error("NOT support static graph")
self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel}
in_nc = int(input.shape[1]) in_nc = int(input.shape[1])
assert ( assert (
expand_ratio == None or channel == None expand_ratio == None or channel == None
...@@ -809,7 +827,7 @@ class SuperLinear(fluid.dygraph.Linear): ...@@ -809,7 +827,7 @@ class SuperLinear(fluid.dygraph.Linear):
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = None self.base_output_dim = self.output_dim
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_output_dim = int(self.output_dim / max(self.expand_ratio)) self.base_output_dim = int(self.output_dim / max(self.expand_ratio))
...@@ -817,8 +835,9 @@ class SuperLinear(fluid.dygraph.Linear): ...@@ -817,8 +835,9 @@ class SuperLinear(fluid.dygraph.Linear):
if not in_dygraph_mode(): if not in_dygraph_mode():
_logger.error("NOT support static graph") _logger.error("NOT support static graph")
self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel}
### weight: (Cin, Cout) ### weight: (Cin, Cout)
in_nc = int(input.shape[1]) in_nc = int(input.shape[-1])
assert ( assert (
expand_ratio == None or channel == None expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time." ), "expand_ratio and channel CANNOT be NOT None at the same time."
...@@ -927,3 +946,77 @@ class SuperInstanceNorm(fluid.dygraph.InstanceNorm): ...@@ -927,3 +946,77 @@ class SuperInstanceNorm(fluid.dygraph.InstanceNorm):
out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon', out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon',
self._epsilon) self._epsilon)
return out return out
class SuperLayerNorm(fluid.dygraph.LayerNorm):
def __init__(self,
normalized_shape,
candidate_config={},
scale=True,
shift=True,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None,
dtype='float32'):
super(SuperLayerNorm,
self).__init__(normalized_shape, scale, shift, epsilon,
param_attr, bias_attr, act, dtype)
def forward(self, input):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
input_shape = list(input.shape)
input_ndim = len(input_shape)
normalized_ndim = len(self._normalized_shape)
self._begin_norm_axis = input_ndim - normalized_ndim
### TODO(ceci3): fix if normalized_shape is not a single number
feature_dim = int(input.shape[-1])
weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim]
pre_act, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon',
self._epsilon, 'begin_norm_axis',
self._begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=self._act)
class SuperEmbedding(fluid.dygraph.Embedding):
def __init__(self,
size,
candidate_config={},
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
super(SuperEmbedding, self).__init__(size, is_sparse, is_distributed,
padding_idx, param_attr, dtype)
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._size[-1]
if self.expand_ratio != None:
self.base_output_dim = int(self._size[-1] / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._size[-1]
weight = self.weight[:, :out_nc]
return core.ops.lookup_table_v2(
weight, input, 'is_sparse', self._is_sparse, 'is_distributed',
self._is_distributed, 'remote_prefetch', self._remote_prefetch,
'padding_idx', self._padding_idx)
...@@ -16,7 +16,7 @@ import logging ...@@ -16,7 +16,7 @@ import logging
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
import paddle import paddle
import paddle.nn as nn #import paddle.nn as nn
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D from paddle.fluid.dygraph import Conv2D
from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm
...@@ -28,9 +28,8 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -28,9 +28,8 @@ _logger = get_logger(__name__, level=logging.INFO)
__all__ = ['OFA', 'RunConfig', 'DistillConfig'] __all__ = ['OFA', 'RunConfig', 'DistillConfig']
RunConfig = namedtuple('RunConfig', [ RunConfig = namedtuple('RunConfig', [
'train_batch_size', 'eval_batch_size', 'n_epochs', 'save_frequency', 'train_batch_size', 'n_epochs', 'save_frequency', 'eval_frequency',
'eval_frequency', 'init_learning_rate', 'total_images', 'elastic_depth', 'init_learning_rate', 'total_images', 'elastic_depth', 'dynamic_batch_size'
'dynamic_batch_size'
]) ])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields) RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
...@@ -53,20 +52,26 @@ class OFABase(fluid.dygraph.Layer): ...@@ -53,20 +52,26 @@ class OFABase(fluid.dygraph.Layer):
for name, sublayer in self.model.named_sublayers(): for name, sublayer in self.model.named_sublayers():
if isinstance(sublayer, BaseBlock): if isinstance(sublayer, BaseBlock):
sublayer.set_supernet(self) sublayer.set_supernet(self)
layers[sublayer.key] = sublayer.candidate_config if not sublayer.fixed:
for k in sublayer.candidate_config.keys(): layers[sublayer.key] = sublayer.candidate_config
elastic_task.add(k) for k in sublayer.candidate_config.keys():
elastic_task.add(k)
return layers, elastic_task return layers, elastic_task
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
raise NotImplementedError raise NotImplementedError
# NOTE: config means set forward config for layers, used in distill.
def layers_forward(self, block, *inputs, **kwargs): def layers_forward(self, block, *inputs, **kwargs):
if getattr(self, 'current_config', None) != None: if getattr(self, 'current_config', None) != None:
assert block.key in self.current_config, 'DONNT have {} layer in config.'.format( ### if block is fixed, donnot join key into candidate
block.key) ### concrete config as parameter in kwargs
config = self.current_config[block.key] if block.fixed == False:
assert block.key in self.current_config, 'DONNT have {} layer in config.'.format(
block.key)
config = self.current_config[block.key]
else:
config = dict()
config.update(kwargs)
else: else:
config = dict() config = dict()
logging.debug(self.model, config) logging.debug(self.model, config)
...@@ -81,7 +86,7 @@ class OFABase(fluid.dygraph.Layer): ...@@ -81,7 +86,7 @@ class OFABase(fluid.dygraph.Layer):
class OFA(OFABase): class OFA(OFABase):
def __init__(self, def __init__(self,
model, model,
run_config, run_config=None,
net_config=None, net_config=None,
distill_config=None, distill_config=None,
elastic_order=None, elastic_order=None,
...@@ -92,7 +97,6 @@ class OFA(OFABase): ...@@ -92,7 +97,6 @@ class OFA(OFABase):
self.distill_config = distill_config self.distill_config = distill_config
self.elastic_order = elastic_order self.elastic_order = elastic_order
self.train_full = train_full self.train_full = train_full
self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size
self.iter = 0 self.iter = 0
self.dynamic_iter = 0 self.dynamic_iter = 0
self.manual_set_task = False self.manual_set_task = False
...@@ -100,18 +104,16 @@ class OFA(OFABase): ...@@ -100,18 +104,16 @@ class OFA(OFABase):
self._add_teacher = False self._add_teacher = False
self.netAs_param = [] self.netAs_param = []
for idx in range(len(run_config.n_epochs)):
assert isinstance(
run_config.init_learning_rate[idx],
list), "each candidate in init_learning_rate must be list"
assert isinstance(run_config.n_epochs[idx],
list), "each candidate in n_epochs must be list"
### if elastic_order is none, use default order ### if elastic_order is none, use default order
if self.elastic_order is not None: if self.elastic_order is not None:
assert isinstance(self.elastic_order, assert isinstance(self.elastic_order,
list), 'elastic_order must be a list' list), 'elastic_order must be a list'
if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self.layers['depth'] = depth_list
if self.elastic_order is None: if self.elastic_order is None:
self.elastic_order = [] self.elastic_order = []
# zero, elastic resulotion, write in demo # zero, elastic resulotion, write in demo
...@@ -133,16 +135,26 @@ class OFA(OFABase): ...@@ -133,16 +135,26 @@ class OFA(OFABase):
if 'channel' in self._elastic_task and 'width' not in self.elastic_order: if 'channel' in self._elastic_task and 'width' not in self.elastic_order:
self.elastic_order.append('width') self.elastic_order.append('width')
assert len(self.run_config.n_epochs) == len(self.elastic_order) if getattr(self.run_config, 'n_epochs', None) != None:
assert len(self.run_config.n_epochs) == len( assert len(self.run_config.n_epochs) == len(self.elastic_order)
self.run_config.dynamic_batch_size) for idx in range(len(run_config.n_epochs)):
assert len(self.run_config.n_epochs) == len( assert isinstance(
self.run_config.init_learning_rate) run_config.n_epochs[idx],
list), "each candidate in n_epochs must be list"
if self.run_config.dynamic_batch_size != None:
assert len(self.run_config.n_epochs) == len(
self.run_config.dynamic_batch_size)
if self.run_config.init_learning_rate != None:
assert len(self.run_config.n_epochs) == len(
self.run_config.init_learning_rate)
for idx in range(len(run_config.n_epochs)):
assert isinstance(
run_config.init_learning_rate[idx], list
), "each candidate in init_learning_rate must be list"
### ================= add distill prepare ====================== ### ================= add distill prepare ======================
if self.distill_config != None and ( if self.distill_config != None:
self.distill_config.lambda_distill != None and
self.distill_config.lambda_distill > 0):
self._add_teacher = True self._add_teacher = True
self._prepare_distill() self._prepare_distill()
...@@ -153,9 +165,10 @@ class OFA(OFABase): ...@@ -153,9 +165,10 @@ class OFA(OFABase):
if self.distill_config.teacher_model == None: if self.distill_config.teacher_model == None:
logging.error( logging.error(
'If you want to add distill, please input class of teacher model' 'If you want to add distill, please input instance of teacher model'
) )
### instance model by user can input super-param easily.
assert isinstance(self.distill_config.teacher_model, assert isinstance(self.distill_config.teacher_model,
paddle.fluid.dygraph.Layer) paddle.fluid.dygraph.Layer)
...@@ -171,7 +184,7 @@ class OFA(OFABase): ...@@ -171,7 +184,7 @@ class OFA(OFABase):
# add hook if mapping layers is not None # add hook if mapping layers is not None
# if mapping layer is None, return the output of the teacher model, # if mapping layer is None, return the output of the teacher model,
# if mapping layer is NOT None, add hook and compute distill loss about mapping layers. # if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
mapping_layers = self.distill_config.mapping_layers mapping_layers = getattr(self.distill_config, 'mapping_layers', None)
if mapping_layers != None: if mapping_layers != None:
self.netAs = [] self.netAs = []
for name, sublayer in self.model.named_sublayers(): for name, sublayer in self.model.named_sublayers():
...@@ -199,9 +212,16 @@ class OFA(OFABase): ...@@ -199,9 +212,16 @@ class OFA(OFABase):
def _compute_epochs(self): def _compute_epochs(self):
if getattr(self, 'epoch', None) == None: if getattr(self, 'epoch', None) == None:
assert self.run_config.total_images is not None, \
"if not use set_epoch() to set epoch, please set total_images in run_config."
assert self.run_config.train_batch_size is not None, \
"if not use set_epoch() to set epoch, please set train_batch_size in run_config."
assert self.run_config.n_epochs is not None, \
"if not use set_epoch() to set epoch, please set n_epochs in run_config."
self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size
epoch = self.iter // self.iter_per_epochs epoch = self.iter // self.iter_per_epochs
else: else:
epoch = self.epochs epoch = self.epoch
return epoch return epoch
def _sample_from_nestdict(self, cands, sample_type, task, phase): def _sample_from_nestdict(self, cands, sample_type, task, phase):
...@@ -284,6 +304,9 @@ class OFA(OFABase): ...@@ -284,6 +304,9 @@ class OFA(OFABase):
def export(self, config): def export(self, config):
pass pass
def set_net_config(self, net_config):
self.net_config = net_config
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
# ===================== teacher process ===================== # ===================== teacher process =====================
teacher_output = None teacher_output = None
...@@ -293,11 +316,12 @@ class OFA(OFABase): ...@@ -293,11 +316,12 @@ class OFA(OFABase):
# ============================================================ # ============================================================
# ==================== student process ===================== # ==================== student process =====================
self.dynamic_iter += 1 if getattr(self.run_config, 'dynamic_batch_size', None) != None:
if self.dynamic_iter == self.run_config.dynamic_batch_size[ self.dynamic_iter += 1
self.task_idx]: if self.dynamic_iter == self.run_config.dynamic_batch_size[
self.iter += 1 self.task_idx]:
self.dynamic_iter = 0 self.iter += 1
self.dynamic_iter = 0
if self.net_config == None: if self.net_config == None:
if self.train_full == True: if self.train_full == True:
...@@ -314,6 +338,6 @@ class OFA(OFABase): ...@@ -314,6 +338,6 @@ class OFA(OFABase):
_logger.debug("Current config is {}".format(self.current_config)) _logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config: if 'depth' in self.current_config:
kwargs['depth'] = int(self.current_config['depth']) kwargs['depth'] = self.current_config['depth']
return self.model.forward(*inputs, **kwargs), teacher_output return self.model.forward(*inputs, **kwargs), teacher_output
...@@ -17,7 +17,6 @@ sys.path.append("../") ...@@ -17,7 +17,6 @@ sys.path.append("../")
import numpy as np import numpy as np
import unittest import unittest
import paddle import paddle
from static_case import StaticCase
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn import paddle.fluid.dygraph.nn as nn
from paddle.nn import ReLU from paddle.nn import ReLU
...@@ -35,13 +34,16 @@ class ModelConv(fluid.dygraph.Layer): ...@@ -35,13 +34,16 @@ class ModelConv(fluid.dygraph.Layer):
channel=((4, 8, 12), (8, 12, 16), (8, 12, 16), channel=((4, 8, 12), (8, 12, 16), (8, 12, 16),
(8, 12, 16))) as ofa_super: (8, 12, 16))) as ofa_super:
models = [] models = []
models += [nn.Conv2D(3, 4, 3)] models += [nn.Conv2D(3, 4, 3, padding=1)]
models += [nn.InstanceNorm(4)] models += [nn.InstanceNorm(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2D(4, 4, 3, groups=4)] models += [nn.Conv2D(4, 4, 3, groups=4)]
models += [nn.InstanceNorm(4)] models += [nn.InstanceNorm(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2DTranspose(4, 4, 3, groups=4, use_cudnn=True)] models += [
nn.Conv2DTranspose(
4, 4, 3, groups=4, padding=1, use_cudnn=True)
]
models += [nn.BatchNorm(4)] models += [nn.BatchNorm(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2D(4, 3, 3)] models += [nn.Conv2D(4, 3, 3)]
...@@ -51,7 +53,8 @@ class ModelConv(fluid.dygraph.Layer): ...@@ -51,7 +53,8 @@ class ModelConv(fluid.dygraph.Layer):
models += [ models += [
Block( Block(
SuperSeparableConv2D( SuperSeparableConv2D(
3, 6, 1, candidate_config={'channel': (3, 6)})) 3, 6, 1, padding=1, candidate_config={'channel': (3, 6)}),
fixed=True)
] ]
with supernet( with supernet(
kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super: kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
...@@ -92,15 +95,37 @@ class ModelLinear(fluid.dygraph.Layer): ...@@ -92,15 +95,37 @@ class ModelLinear(fluid.dygraph.Layer):
models = [] models = []
with supernet(expand_ratio=(1, 2, 4)) as ofa_super: with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models1 = [] models1 = []
models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)] models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)] models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1) models1 = ofa_super.convert(models1)
models += models1 models += models1
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
if depth != None:
assert isinstance(depth, int)
assert depth < len(self.models)
else:
depth = len(self.models)
for idx in range(depth):
layer = self.models[idx]
inputs = layer(inputs)
return inputs
with supernet(channel=((64, 128, 256), (64, 128, 256))) as ofa_super:
class ModelLinear1(fluid.dygraph.Layer):
def __init__(self):
super(ModelLinear1, self).__init__()
models = []
with supernet(channel=((64, 128, 256), (64, 128, 256),
(64, 128, 256))) as ofa_super:
models1 = [] models1 = []
models1 += [nn.Linear(256, 128)] models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)] models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1) models1 = ofa_super.convert(models1)
...@@ -120,7 +145,35 @@ class ModelLinear(fluid.dygraph.Layer): ...@@ -120,7 +145,35 @@ class ModelLinear(fluid.dygraph.Layer):
return inputs return inputs
class TestOFA(StaticCase): class ModelLinear2(fluid.dygraph.Layer):
def __init__(self):
super(ModelLinear2, self).__init__()
models = []
with supernet(expand_ratio=None) as ofa_super:
models1 = []
models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1)
models += models1
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
if depth != None:
assert isinstance(depth, int)
assert depth < len(self.models)
else:
depth = len(self.models)
for idx in range(depth):
layer = self.models[idx]
inputs = layer(inputs)
return inputs
class TestOFA(unittest.TestCase):
def setUp(self): def setUp(self):
fluid.enable_dygraph() fluid.enable_dygraph()
self.init_model_and_data() self.init_model_and_data()
...@@ -137,7 +190,6 @@ class TestOFA(StaticCase): ...@@ -137,7 +190,6 @@ class TestOFA(StaticCase):
def init_config(self): def init_config(self):
default_run_config = { default_run_config = {
'train_batch_size': 1, 'train_batch_size': 1,
'eval_batch_size': 1,
'n_epochs': [[1], [2, 3], [4, 5]], 'n_epochs': [[1], [2, 3], [4, 5]],
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]], 'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
'dynamic_batch_size': [1, 1, 1], 'dynamic_batch_size': [1, 1, 1],
...@@ -152,11 +204,13 @@ class TestOFA(StaticCase): ...@@ -152,11 +204,13 @@ class TestOFA(StaticCase):
'mapping_layers': ['models.0.fn'] 'mapping_layers': ['models.0.fn']
} }
self.distill_config = DistillConfig(**default_distill_config) self.distill_config = DistillConfig(**default_distill_config)
self.elastic_order = ['kernel_size', 'width', 'depth']
def test_ofa(self): def test_ofa(self):
ofa_model = OFA(self.model, ofa_model = OFA(self.model,
self.run_config, self.run_config,
distill_config=self.distill_config) distill_config=self.distill_config,
elastic_order=self.elastic_order)
start_epoch = 0 start_epoch = 0
for idx in range(len(self.run_config.n_epochs)): for idx in range(len(self.run_config.n_epochs)):
...@@ -169,6 +223,8 @@ class TestOFA(StaticCase): ...@@ -169,6 +223,8 @@ class TestOFA(StaticCase):
ofa_model.parameters() + ofa_model.netAs_param)) ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch, for epoch_id in range(start_epoch,
self.run_config.n_epochs[idx][ph_idx]): self.run_config.n_epochs[idx][ph_idx]):
if epoch_id == 0:
ofa_model.set_epoch(epoch_id)
for model_no in range(self.run_config.dynamic_batch_size[ for model_no in range(self.run_config.dynamic_batch_size[
idx]): idx]):
output, _ = ofa_model(self.data) output, _ = ofa_model(self.data)
...@@ -191,14 +247,13 @@ class TestOFACase1(TestOFA): ...@@ -191,14 +247,13 @@ class TestOFACase1(TestOFA):
def init_model_and_data(self): def init_model_and_data(self):
self.model = ModelLinear() self.model = ModelLinear()
self.teacher_model = ModelLinear() self.teacher_model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.float32) data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np) self.data = fluid.dygraph.to_variable(data_np)
def init_config(self): def init_config(self):
default_run_config = { default_run_config = {
'train_batch_size': 1, 'train_batch_size': 1,
'eval_batch_size': 1,
'n_epochs': [[2, 5]], 'n_epochs': [[2, 5]],
'init_learning_rate': [[0.003, 0.001]], 'init_learning_rate': [[0.003, 0.001]],
'dynamic_batch_size': [1], 'dynamic_batch_size': [1],
...@@ -211,6 +266,23 @@ class TestOFACase1(TestOFA): ...@@ -211,6 +266,23 @@ class TestOFACase1(TestOFA):
'teacher_model': self.teacher_model, 'teacher_model': self.teacher_model,
} }
self.distill_config = DistillConfig(**default_distill_config) self.distill_config = DistillConfig(**default_distill_config)
self.elastic_order = None
class TestOFACase2(TestOFACase1):
def init_model_and_data(self):
self.model = ModelLinear1()
self.teacher_model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np)
class TestOFACase3(unittest.TestCase):
def test_ofa(self):
self.model = ModelLinear2()
ofa_model = OFA(self.model)
ofa_model.set_net_config({'expand_ratio': None})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册