diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index 4a5b39b5b8d11cce276e64e47ca209d6e2459c89..9596a0552a5f268846d5def1c7d9c2b09491d7a6 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -48,7 +48,11 @@ class SAController(EvolutionaryController): reduce_rate(float): The decay rate of temperature. init_temperature(float): Init temperature. max_try_times(int): max try times before get legal tokens. - init_tokens(list): The initial tokens. + init_tokens(list): The initial tokens. Default: None. + reward(float): The reward of current tokens. Default: -1. + max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. + iters(int): The iteration of sa controller. Default: 0. + best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. """ diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index b151195420bd812b3cf36e4bf2e28de652045d8c..90e65d42aae47d0fd7d3c15955c3d2f7bd0693a1 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -41,7 +41,7 @@ class SANAS(object): reduce_rate=0.85, search_steps=300, key="sa_nas", - save_checkpoint=None, + save_checkpoint='nas_checkpoint', load_checkpoint=None, is_server=False): """ @@ -55,6 +55,8 @@ class SANAS(object): reduce_rate(float): The decay rate used in simulated annealing search strategy. search_steps(int): The steps of searching. key(str): Identity used in communication between controller server and clients. + save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'. + load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None. is_server(bool): Whether current host is controller server. Default: True. """ if not is_server: