diff --git a/static/slim/nas/blazeface.yml b/static/slim/nas/blazeface.yml index 2ba0799c55b8dfab88f010294445ca3c1292c8a7..ac8af9b718a0f211b7b518d61349d7583ee26cd8 100644 --- a/static/slim/nas/blazeface.yml +++ b/static/slim/nas/blazeface.yml @@ -93,3 +93,20 @@ EvalReader: mean: [104, 117, 123] std: [127.502231, 127.502231, 127.502231] batch_size: 1 + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'im_shape'] + dataset: + !ImageFolder + use_default_label: true + sample_transforms: + - !DecodeImage + to_rgb: true + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [123, 117, 104] + std: [127.502231, 127.502231, 127.502231] + - !Permute {} + batch_size: 1 diff --git a/static/slim/nas/search_space/blazefacespace_nas.py b/static/slim/nas/search_space/blazefacespace_nas.py index 06b1edefc35e4588a966b93f851e6bfdb2c33567..fe73007d8f437a865f5ad9eb0ac063f5cb585b32 100644 --- a/static/slim/nas/search_space/blazefacespace_nas.py +++ b/static/slim/nas/search_space/blazefacespace_nas.py @@ -33,7 +33,9 @@ class BlazeFaceNasSpace(SearchSpaceBase): self.mid_filter_num = np.array([8, 12, 16, 20, 24, 32]) self.double_filter_num = np.array( [8, 12, 16, 24, 32, 40, 48, 64, 72, 80, 88, 96]) - self.use_5x5kernel = np.array([0, 1]) + self.use_5x5kernel = np.array( + [0] + ) ### if constraint is latency, use 3x3 kernel, otherwise self.use_5x5kernel = np.array([0, 1]) def init_tokens(self): return [2, 1, 3, 8, 2, 1, 2, 1, 1] @@ -74,7 +76,8 @@ class BlazeFaceNasSpace(SearchSpaceBase): self.double_filter_num[tokens[3]] ]] - is_5x5kernel = True if self.use_5x5kernel[tokens[8]] else False + ### if constraint is latency, use 3x3 kernel, otherwise is_5x5kernel = True if self.use_5x5kernel[tokens[8]] else False + is_5x5kernel = False ###True if self.use_5x5kernel[tokens[8]] else False return blaze_filters, double_blaze_filters, is_5x5kernel def token2arch(self, tokens=None): diff --git a/static/slim/nas/train_nas.py b/static/slim/nas/train_nas.py index d0836d4c32c83670126e82daf6c9e857da4e30df..addb92a1a413ce1bfc09475720dbde9642597de5 100644 --- a/static/slim/nas/train_nas.py +++ b/static/slim/nas/train_nas.py @@ -62,6 +62,8 @@ except ImportError as e: from paddleslim.analysis import flops, TableLatencyEvaluator from paddleslim.nas import SANAS +### register search space to paddleslim +import search_space @register