diff --git a/paddleslim/nas/one_shot/super_mnasnet.py b/paddleslim/nas/one_shot/super_mnasnet.py index 852b40383af5223668b1635a5ded05f39966f5fd..169d1050ba18043fcf8221a3cb3c52773dec7f44 100644 --- a/paddleslim/nas/one_shot/super_mnasnet.py +++ b/paddleslim/nas/one_shot/super_mnasnet.py @@ -2,7 +2,7 @@ import paddle from paddle import fluid from paddle.fluid.layer_helper import LayerHelper import numpy as np -from one_shot_nas import OneShotSuperNet +from .one_shot_nas import OneShotSuperNet __all__ = ['SuperMnasnet'] @@ -209,14 +209,14 @@ class SuperMnasnet(OneShotSuperNet): def init_tokens(self): return [ - 3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6, - 6, 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6 + 3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6, 6, + 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6 ] def range_table(self): max_v = [ - 6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6, - 6, 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10 + 6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6, 6, + 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10 ] return (len(max_v) * [0], max_v)