From 7762b906a5c673a3081818a7d60f717b88ebf202 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 22 Feb 2022 19:58:53 +0800 Subject: [PATCH] fix blazeface nas (#5241) --- static/slim/nas/blazeface.yml | 17 +++++++++++++++++ .../slim/nas/search_space/blazefacespace_nas.py | 7 +++++-- static/slim/nas/train_nas.py | 2 ++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/static/slim/nas/blazeface.yml b/static/slim/nas/blazeface.yml index 2ba0799c5..ac8af9b71 100644 --- a/static/slim/nas/blazeface.yml +++ b/static/slim/nas/blazeface.yml @@ -93,3 +93,20 @@ EvalReader: mean: [104, 117, 123] std: [127.502231, 127.502231, 127.502231] batch_size: 1 + +TestReader: + inputs_def: + fields: ['image', 'im_id', 'im_shape'] + dataset: + !ImageFolder + use_default_label: true + sample_transforms: + - !DecodeImage + to_rgb: true + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [123, 117, 104] + std: [127.502231, 127.502231, 127.502231] + - !Permute {} + batch_size: 1 diff --git a/static/slim/nas/search_space/blazefacespace_nas.py b/static/slim/nas/search_space/blazefacespace_nas.py index 06b1edefc..fe73007d8 100644 --- a/static/slim/nas/search_space/blazefacespace_nas.py +++ b/static/slim/nas/search_space/blazefacespace_nas.py @@ -33,7 +33,9 @@ class BlazeFaceNasSpace(SearchSpaceBase): self.mid_filter_num = np.array([8, 12, 16, 20, 24, 32]) self.double_filter_num = np.array( [8, 12, 16, 24, 32, 40, 48, 64, 72, 80, 88, 96]) - self.use_5x5kernel = np.array([0, 1]) + self.use_5x5kernel = np.array( + [0] + ) ### if constraint is latency, use 3x3 kernel, otherwise self.use_5x5kernel = np.array([0, 1]) def init_tokens(self): return [2, 1, 3, 8, 2, 1, 2, 1, 1] @@ -74,7 +76,8 @@ class BlazeFaceNasSpace(SearchSpaceBase): self.double_filter_num[tokens[3]] ]] - is_5x5kernel = True if self.use_5x5kernel[tokens[8]] else False + ### if constraint is latency, use 3x3 kernel, otherwise is_5x5kernel = True if self.use_5x5kernel[tokens[8]] else False + is_5x5kernel = False ###True if self.use_5x5kernel[tokens[8]] else False return blaze_filters, double_blaze_filters, is_5x5kernel def token2arch(self, tokens=None): diff --git a/static/slim/nas/train_nas.py b/static/slim/nas/train_nas.py index d0836d4c3..addb92a1a 100644 --- a/static/slim/nas/train_nas.py +++ b/static/slim/nas/train_nas.py @@ -62,6 +62,8 @@ except ImportError as e: from paddleslim.analysis import flops, TableLatencyEvaluator from paddleslim.nas import SANAS +### register search space to paddleslim +import search_space @register -- GitLab