diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 90e65d42aae47d0fd7d3c15955c3d2f7bd0693a1..e921ef6d091ce50b21419474c523e24937c47bfb 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -40,21 +40,19 @@ class SANAS(object): init_temperature=100, reduce_rate=0.85, search_steps=300, - key="sa_nas", save_checkpoint='nas_checkpoint', load_checkpoint=None, is_server=False): """ Search a group of ratios used to prune program. Args: - configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). + configs(list): A list of search space configuration with format [(key, {input_size, output_size, block_num, block_mask})]. `key` is the name of search space with data type str. `input_size` and `output_size` are - input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. + input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block. server_addr(tuple): A tuple of server ip and server port for controller server. init_temperature(float): The init temperature used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. search_steps(int): The steps of searching. - key(str): Identity used in communication between controller server and clients. save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'. load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None. is_server(bool): Whether current host is controller server. Default: True.