未验证 提交 40e0684a 编写于 作者: C ceci3 提交者: GitHub

Refine ofa (#527)

* support 2.0
上级 f3b898c8
...@@ -14,4 +14,10 @@ ...@@ -14,4 +14,10 @@
from .ofa import OFA, RunConfig, DistillConfig from .ofa import OFA, RunConfig, DistillConfig
from .convert_super import supernet from .convert_super import supernet
from .layers import *
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import *
else:
from .layers_new import *
此差异已折叠。
此差异已折叠。
...@@ -16,10 +16,15 @@ import logging ...@@ -16,10 +16,15 @@ import logging
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
import paddle import paddle
#import paddle.nn as nn
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D from .utils.utils import get_paddle_version
from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import BaseBlock, SuperConv2D
Layer = paddle.fluid.dygraph.Layer
else:
from .layers_new import BaseBlock, SuperConv2D
Layer = paddle.nn.Layer
from .utils.utils import search_idx from .utils.utils import search_idx
from ...common import get_logger from ...common import get_logger
...@@ -40,7 +45,7 @@ DistillConfig = namedtuple('DistillConfig', [ ...@@ -40,7 +45,7 @@ DistillConfig = namedtuple('DistillConfig', [
DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields) DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
class OFABase(fluid.dygraph.Layer): class OFABase(Layer):
def __init__(self, model): def __init__(self, model):
super(OFABase, self).__init__() super(OFABase, self).__init__()
self.model = model self.model = model
...@@ -169,8 +174,7 @@ class OFA(OFABase): ...@@ -169,8 +174,7 @@ class OFA(OFABase):
) )
### instance model by user can input super-param easily. ### instance model by user can input super-param easily.
assert isinstance(self.distill_config.teacher_model, assert isinstance(self.distill_config.teacher_model, Layer)
paddle.fluid.dygraph.Layer)
# load teacher parameter # load teacher parameter
if self.distill_config.teacher_model_path != None: if self.distill_config.teacher_model_path != None:
...@@ -190,9 +194,10 @@ class OFA(OFABase): ...@@ -190,9 +194,10 @@ class OFA(OFABase):
for name, sublayer in self.model.named_sublayers(): for name, sublayer in self.model.named_sublayers():
if name in mapping_layers: if name in mapping_layers:
netA = SuperConv2D( netA = SuperConv2D(
sublayer._num_filters, getattr(sublayer, '_num_filters',
sublayer._num_filters, sublayer._out_channels),
filter_size=1) getattr(sublayer, '_num_filters',
sublayer._out_channels), 1)
self.netAs_param.extend(netA.parameters()) self.netAs_param.extend(netA.parameters())
self.netAs.append(netA) self.netAs.append(netA)
...@@ -288,7 +293,8 @@ class OFA(OFABase): ...@@ -288,7 +293,8 @@ class OFA(OFABase):
n = self.distill_config.mapping_layers[i] n = self.distill_config.mapping_layers[i]
Tact = self.Tacts[n] Tact = self.Tacts[n]
Sact = self.Sacts[n] Sact = self.Sacts[n]
Sact = netA(Sact, channel=netA._num_filters) Sact = netA(
Sact, channel=getattr(netA, '_num_filters', netA._out_channels))
if self.distill_config.distill_fn == None: if self.distill_config.distill_fn == None:
loss = fluid.layers.mse_loss(Sact, Tact) loss = fluid.layers.mse_loss(Sact, Tact)
else: else:
......
...@@ -44,3 +44,13 @@ def search_idx(num, sorted_nestlist): ...@@ -44,3 +44,13 @@ def search_idx(num, sorted_nestlist):
return idx, phase_idx return idx, phase_idx
assert num > max_num assert num > max_num
return len(sorted_nestlist) - 1, max_idx return len(sorted_nestlist) - 1, max_idx
def get_paddle_version():
import paddle
pd_ver = 185
if hasattr(paddle, 'nn'):
if hasattr(paddle.nn, 'Conv1D'): ### judge 2.0 alpha
pd_ver = 200
return pd_ver
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
class TestConvertSuper(unittest.TestCase):
def setUp(self):
self.model = mobilenet_v1()
def test_convert(self):
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(self.model)
assert len(sp_model.sublayers()) == 151
if __name__ == '__main__':
unittest.main()
...@@ -17,16 +17,15 @@ sys.path.append("../") ...@@ -17,16 +17,15 @@ sys.path.append("../")
import numpy as np import numpy as np
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.nn as nn
import paddle.fluid.dygraph.nn as nn
from paddle.nn import ReLU from paddle.nn import ReLU
from paddleslim.nas import ofa from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D from paddleslim.nas.ofa.layers_new import Block, SuperSeparableConv2D
class ModelConv(fluid.dygraph.Layer): class ModelConv(nn.Layer):
def __init__(self): def __init__(self):
super(ModelConv, self).__init__() super(ModelConv, self).__init__()
with supernet( with supernet(
...@@ -35,16 +34,13 @@ class ModelConv(fluid.dygraph.Layer): ...@@ -35,16 +34,13 @@ class ModelConv(fluid.dygraph.Layer):
(8, 12, 16))) as ofa_super: (8, 12, 16))) as ofa_super:
models = [] models = []
models += [nn.Conv2D(3, 4, 3, padding=1)] models += [nn.Conv2D(3, 4, 3, padding=1)]
models += [nn.InstanceNorm(4)] models += [nn.InstanceNorm2D(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2D(4, 4, 3, groups=4)] models += [nn.Conv2D(4, 4, 3, groups=4)]
models += [nn.InstanceNorm(4)] models += [nn.InstanceNorm2D(4)]
models += [ReLU()] models += [ReLU()]
models += [ models += [nn.Conv2DTranspose(4, 4, 3, groups=4, padding=1)]
nn.Conv2DTranspose( models += [nn.BatchNorm2D(4)]
4, 4, 3, groups=4, padding=1, use_cudnn=True)
]
models += [nn.BatchNorm(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2D(4, 3, 3)] models += [nn.Conv2D(4, 3, 3)]
models += [ReLU()] models += [ReLU()]
...@@ -60,21 +56,23 @@ class ModelConv(fluid.dygraph.Layer): ...@@ -60,21 +56,23 @@ class ModelConv(fluid.dygraph.Layer):
kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super: kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
models1 = [] models1 = []
models1 += [nn.Conv2D(6, 4, 3)] models1 += [nn.Conv2D(6, 4, 3)]
models1 += [nn.BatchNorm(4)] models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()] models1 += [ReLU()]
models1 += [nn.Conv2D(4, 4, 3, groups=2)] models1 += [nn.Conv2D(4, 4, 3, groups=2)]
models1 += [nn.InstanceNorm(4)] models1 += [nn.InstanceNorm2D(4)]
models1 += [ReLU()] models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)] models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
models1 += [nn.BatchNorm(4)] models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()] models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3)] models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm(4)] models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 1)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()] models1 += [ReLU()]
models1 = ofa_super.convert(models1) models1 = ofa_super.convert(models1)
models += models1 models += models1
self.models = paddle.nn.Sequential(*models) self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None): def forward(self, inputs, depth=None):
...@@ -89,16 +87,61 @@ class ModelConv(fluid.dygraph.Layer): ...@@ -89,16 +87,61 @@ class ModelConv(fluid.dygraph.Layer):
return inputs return inputs
class ModelLinear(fluid.dygraph.Layer): class ModelConv2(nn.Layer):
def __init__(self):
super(ModelConv2, self).__init__()
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = []
models += [nn.Conv2DTranspose(4, 4, 3)]
models += [nn.BatchNorm2D(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 4, 3)]
models += [nn.BatchNorm2D(4)]
models += [ReLU()]
models = ofa_super.convert(models)
with supernet(channel=((4, 6, 8), (4, 6, 8))) as ofa_super:
models1 = []
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 = ofa_super.convert(models1)
models += models1
with supernet(kernel_size=(3, 5, 7)) as ofa_super:
models2 = []
models2 += [nn.Conv2D(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 += [nn.Conv2DTranspose(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 += [nn.Conv2D(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 = ofa_super.convert(models2)
models += models2
self.models = paddle.nn.Sequential(*models)
class ModelLinear(nn.Layer):
def __init__(self): def __init__(self):
super(ModelLinear, self).__init__() super(ModelLinear, self).__init__()
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = [] models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
with supernet(expand_ratio=(1, 2, 4)) as ofa_super: with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models1 = [] models1 = []
models1 += [nn.Embedding(size=(64, 64))] models1 += [nn.Linear(256, 256)]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1) models1 = ofa_super.convert(models1)
models += models1 models += models1
...@@ -116,17 +159,21 @@ class ModelLinear(fluid.dygraph.Layer): ...@@ -116,17 +159,21 @@ class ModelLinear(fluid.dygraph.Layer):
return inputs return inputs
class ModelLinear1(fluid.dygraph.Layer): class ModelLinear1(nn.Layer):
def __init__(self): def __init__(self):
super(ModelLinear1, self).__init__() super(ModelLinear1, self).__init__()
models = []
with supernet(channel=((64, 128, 256), (64, 128, 256), with supernet(channel=((64, 128, 256), (64, 128, 256),
(64, 128, 256))) as ofa_super: (64, 128, 256))) as ofa_super:
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
with supernet(channel=((64, 128, 256), )) as ofa_super:
models1 = [] models1 = []
models1 += [nn.Embedding(size=(64, 64))] models1 += [nn.Linear(256, 256)]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1) models1 = ofa_super.convert(models1)
models += models1 models += models1
...@@ -145,20 +192,16 @@ class ModelLinear1(fluid.dygraph.Layer): ...@@ -145,20 +192,16 @@ class ModelLinear1(fluid.dygraph.Layer):
return inputs return inputs
class ModelLinear2(fluid.dygraph.Layer): class ModelLinear2(nn.Layer):
def __init__(self): def __init__(self):
super(ModelLinear2, self).__init__() super(ModelLinear2, self).__init__()
models = []
with supernet(expand_ratio=None) as ofa_super: with supernet(expand_ratio=None) as ofa_super:
models1 = [] models = []
models1 += [nn.Embedding(size=(64, 64))] models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models1 += [nn.Linear(64, 128)] models += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)] models += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)] models += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1) models = ofa_super.convert(models)
models += models1
self.models = paddle.nn.Sequential(*models) self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None): def forward(self, inputs, depth=None):
...@@ -175,7 +218,6 @@ class ModelLinear2(fluid.dygraph.Layer): ...@@ -175,7 +218,6 @@ class ModelLinear2(fluid.dygraph.Layer):
class TestOFA(unittest.TestCase): class TestOFA(unittest.TestCase):
def setUp(self): def setUp(self):
fluid.enable_dygraph()
self.init_model_and_data() self.init_model_and_data()
self.init_config() self.init_config()
...@@ -185,7 +227,7 @@ class TestOFA(unittest.TestCase): ...@@ -185,7 +227,7 @@ class TestOFA(unittest.TestCase):
data_np = np.random.random((1, 3, 10, 10)).astype(np.float32) data_np = np.random.random((1, 3, 10, 10)).astype(np.float32)
label_np = np.random.random((1)).astype(np.float32) label_np = np.random.random((1)).astype(np.float32)
self.data = fluid.dygraph.to_variable(data_np) self.data = paddle.to_tensor(data_np)
def init_config(self): def init_config(self):
default_run_config = { default_run_config = {
...@@ -217,10 +259,9 @@ class TestOFA(unittest.TestCase): ...@@ -217,10 +259,9 @@ class TestOFA(unittest.TestCase):
cur_idx = self.run_config.n_epochs[idx] cur_idx = self.run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)): for ph_idx in range(len(cur_idx)):
cur_lr = self.run_config.init_learning_rate[idx][ph_idx] cur_lr = self.run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam( adam = paddle.optimizer.Adam(
learning_rate=cur_lr, learning_rate=cur_lr,
parameter_list=( parameters=(ofa_model.parameters() + ofa_model.netAs_param))
ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch, for epoch_id in range(start_epoch,
self.run_config.n_epochs[idx][ph_idx]): self.run_config.n_epochs[idx][ph_idx]):
if epoch_id == 0: if epoch_id == 0:
...@@ -228,7 +269,7 @@ class TestOFA(unittest.TestCase): ...@@ -228,7 +269,7 @@ class TestOFA(unittest.TestCase):
for model_no in range(self.run_config.dynamic_batch_size[ for model_no in range(self.run_config.dynamic_batch_size[
idx]): idx]):
output, _ = ofa_model(self.data) output, _ = ofa_model(self.data)
loss = fluid.layers.reduce_mean(output) loss = paddle.mean(output)
if self.distill_config.mapping_layers != None: if self.distill_config.mapping_layers != None:
dis_loss = ofa_model.calc_distill_loss() dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss loss += dis_loss
...@@ -249,7 +290,7 @@ class TestOFACase1(TestOFA): ...@@ -249,7 +290,7 @@ class TestOFACase1(TestOFA):
self.teacher_model = ModelLinear() self.teacher_model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.int64) data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np) self.data = paddle.to_tensor(data_np)
def init_config(self): def init_config(self):
default_run_config = { default_run_config = {
...@@ -275,7 +316,7 @@ class TestOFACase2(TestOFACase1): ...@@ -275,7 +316,7 @@ class TestOFACase2(TestOFACase1):
self.teacher_model = ModelLinear1() self.teacher_model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64) data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np) self.data = paddle.to_tensor(data_np)
class TestOFACase3(unittest.TestCase): class TestOFACase3(unittest.TestCase):
...@@ -285,5 +326,10 @@ class TestOFACase3(unittest.TestCase): ...@@ -285,5 +326,10 @@ class TestOFACase3(unittest.TestCase):
ofa_model.set_net_config({'expand_ratio': None}) ofa_model.set_net_config({'expand_ratio': None})
class TestOFACase3(unittest.TestCase):
def test_ofa(self):
self.model = ModelConv2()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册