diff --git a/demo/imagenet_reader.py b/demo/imagenet_reader.py index 947fd023b9c6eea5f4d6d0a5d52337b1ba97cc3f..8f80345baaadf9356bab97d73235cc6244048e4f 100644 --- a/demo/imagenet_reader.py +++ b/demo/imagenet_reader.py @@ -5,6 +5,7 @@ import functools import numpy as np import paddle from PIL import Image, ImageEnhance +from paddle.io import Dataset random.seed(0) np.random.seed(0) @@ -194,3 +195,54 @@ def val(data_dir=DATA_DIR): def test(data_dir=DATA_DIR): file_list = os.path.join(data_dir, 'test_list.txt') return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) + + +class ImageNetDataset(Dataset): + def __init__(self, data_dir=DATA_DIR, mode='train'): + super(ImageNetDataset, self).__init__() + train_file_list = os.path.join(data_dir, 'train_list.txt') + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.mode = mode + if mode == 'train': + with open(train_file_list) as flist: + full_lines = [line.strip() for line in flist] + np.random.shuffle(full_lines) + if os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:( + trainer_id + 1) * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, + len(lines), len(full_lines))) + else: + lines = full_lines + self.data = [line.split() for line in lines] + else: + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + sample = self.data[index] + data_path = os.path.join(DATA_DIR, sample[0]) + if self.mode == 'train': + data, label = process_image( + [data_path, sample[1]], + mode='train', + color_jitter=False, + rotate=False) + if self.mode == 'val': + data, label = process_image( + [data_path, sample[1]], + mode='val', + color_jitter=False, + rotate=False) + return data, np.array([label]).astype('int64') + + def __len__(self): + return len(self.data) diff --git a/demo/prune/eval.py b/demo/prune/eval.py index 6cbc25409c53d42e6385cb900abbb60bbf3c16c5..5448fe90083d0fd221fea8e7f8ec51fd3b1233d3 100644 --- a/demo/prune/eval.py +++ b/demo/prune/eval.py @@ -34,13 +34,12 @@ def eval(args): train_reader = None test_reader = None if args.data == "mnist": - val_reader = paddle.dataset.mnist.test() + val_dataset = paddle.vision.datasets.MNIST(mode='test') class_dim = 10 image_shape = "1,28,28" elif args.data == "imagenet": import imagenet_reader as reader - train_reader = reader.train() - val_reader = reader.val() + val_dataset = reader.ImageNetDataset(mode='val') class_dim = 1000 image_shape = "3,224,224" else: @@ -61,14 +60,13 @@ def eval(args): exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) - val_reader = paddle.batch(val_reader, batch_size=args.batch_size) - - valid_loader = paddle.io.DataLoader.from_generator( + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, feed_list=[image, label], - capacity=64, - use_double_buffer=True, - iterable=True) - valid_loader.set_sample_list_generator(val_reader, place) + drop_last=False, + batch_size=args.batch_size, + shuffle=False) load_model(exe, val_program, args.model_path) diff --git a/demo/prune/train.py b/demo/prune/train.py index 9748ca4edcd1c98979444b966ae387f27af9e7bc..581cbfc3d0c75746d7f60afe2fda0be7036c9f8d 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -79,8 +79,8 @@ def piecewise_decay(args): def cosine_decay(args): step = int(math.ceil(float(args.total_images) / args.batch_size)) - learning_rate = paddle.optimizer.lr.cosine_decay( - learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=args.lr, T_max=args.num_epochs) optimizer = paddle.optimizer.Momentum( learning_rate=learning_rate, momentum=args.momentum_rate, @@ -99,14 +99,14 @@ def compress(args): train_reader = None test_reader = None if args.data == "mnist": - train_reader = paddle.dataset.mnist.train() - val_reader = paddle.dataset.mnist.test() + train_dataset = paddle.vision.datasets.MNIST(mode='train') + val_dataset = paddle.vision.datasets.MNIST(mode='test') class_dim = 10 image_shape = "1,28,28" elif args.data == "imagenet": import imagenet_reader as reader - train_reader = reader.train() - val_reader = reader.val() + train_dataset = reader.ImageNetDataset(mode='train') + val_dataset = reader.ImageNetDataset(mode='val') class_dim = 1000 image_shape = "3,224,224" else: @@ -143,22 +143,23 @@ def compress(args): paddle.fluid.io.load_vars( exe, args.pretrained_model, predicate=if_exist) - val_reader = paddle.batch(val_reader, batch_size=args.batch_size) - train_reader = paddle.batch( - train_reader, batch_size=args.batch_size, drop_last=True) - - train_loader = paddle.io.DataLoader.from_generator( + train_loader = paddle.io.DataLoader( + train_dataset, + places=places, feed_list=[image, label], - capacity=64, - use_double_buffer=True, - iterable=True) - valid_loader = paddle.io.DataLoader.from_generator( + drop_last=True, + batch_size=args.batch_size, + shuffle=True, + use_shared_memory=False, + num_workers=16) + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, feed_list=[image, label], - capacity=64, - use_double_buffer=True, - iterable=True) - train_loader.set_sample_list_generator(train_reader, places) - valid_loader.set_sample_list_generator(val_reader, place) + drop_last=False, + use_shared_memory=False, + batch_size=args.batch_size, + shuffle=False) def test(epoch, program): acc_top1_ns = [] @@ -237,8 +238,8 @@ def compress(args): if args.save_inference: infer_model_path = os.path.join(args.model_path, "infer_models", str(i)) - paddle.fluid.io.save_inference_model(infer_model_path, ["image"], - [out], exe, pruned_val_program) + paddle.static.save_inference_model(infer_model_path, ["image"], + [out], exe, pruned_val_program) _logger.info("Saved inference model into [{}]".format( infer_model_path)) diff --git a/demo/sensitive/train.py b/demo/sensitive/train.py index 78c4c1e54c428239188f8d1dab009af664b9d256..90d537253cb7e8449904fad932b52cdac4eb219d 100644 --- a/demo/sensitive/train.py +++ b/demo/sensitive/train.py @@ -33,13 +33,12 @@ model_list = [m for m in dir(models) if "__" not in m] def compress(args): test_reader = None if args.data == "mnist": - import paddle.dataset.mnist as reader - val_reader = reader.test() + val_dataset = paddle.vision.datasets.MNIST(mode='test') class_dim = 10 image_shape = "1,28,28" elif args.data == "imagenet": import imagenet_reader as reader - val_reader = reader.val() + val_dataset = reader.ImageNetDataset(mode='val') class_dim = 1000 image_shape = "3,224,224" else: @@ -70,14 +69,13 @@ def compress(args): paddle.fluid.io.load_vars( exe, args.pretrained_model, predicate=if_exist) - val_reader = paddle.batch(val_reader, batch_size=args.batch_size) - valid_loader = paddle.io.DataLoader.from_generator( + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, feed_list=[image, label], - capacity=64, - use_double_buffer=True, - iterable=True) - - valid_loader.set_sample_list_generator(val_reader, place) + drop_last=False, + batch_size=args.batch_size, + shuffle=False) def test(program): acc_top1_ns = [] diff --git a/paddleslim/common/analyze_helper.py b/paddleslim/common/analyze_helper.py index 09879a074407134a7f2e16dd90dff0ee819b4faf..958526ffaa559d242d5e66250160c64b2dec4678 100644 --- a/paddleslim/common/analyze_helper.py +++ b/paddleslim/common/analyze_helper.py @@ -70,7 +70,7 @@ class VarCollector(object): scope=None): self.program = program self.var_names = var_names - self.scope = fluid.global_scope() if scope is None else scope + self.scope = paddle.static.global_scope() if scope is None else scope self.use_ema = use_ema self.set_up() if self.use_ema: @@ -104,8 +104,8 @@ class VarCollector(object): def run(self, reader, exe, step=None, loss_name=None): if not hasattr(self.program, '_program'): # Compile the native program to speed up - program = fluid.CompiledProgram(self.program).with_data_parallel( - loss_name=loss_name) + program = paddle.static.CompiledProgram( + self.program).with_data_parallel(loss_name=loss_name) for idx, data in enumerate(reader): vars_np = exe.run(program=program, @@ -122,17 +122,17 @@ class VarCollector(object): def abs_max_run(self, reader, exe, step=None, loss_name=None): fetch_list = [] - with fluid.program_guard(self.program): + with paddle.static.program_guard(self.program): for act_name in self.real_names: act = self.program.global_block().var(act_name) act = fluid.layers.reduce_max( - fluid.layers.abs(act), name=act_name + "_reduced") + paddle.abs(act), name=act_name + "_reduced") fetch_list.append(act_name + "_reduced.tmp_0") if not hasattr(self.program, '_program'): # Compile the native program to speed up - program = fluid.CompiledProgram(self.program).with_data_parallel( - loss_name=loss_name) + program = paddle.static.CompiledProgram( + self.program).with_data_parallel(loss_name=loss_name) for idx, data in enumerate(reader): vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list) vars_np = [np.max(var) for var in vars_np] diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py index 87c887bdc85a1e535e15daf03e27358f2c6a529e..34def9b013b4eaffa3a148dc1da50ee0b3d6b7e6 100644 --- a/paddleslim/common/controller.py +++ b/paddleslim/common/controller.py @@ -16,6 +16,7 @@ import copy import math import numpy as np +import paddle import paddle.fluid as fluid __all__ = ['EvolutionaryController', 'RLBaseController'] @@ -72,11 +73,11 @@ class RLBaseController(object): def get_params(self, program): var_dict = {} for var in program.global_block().all_parameters(): - var_dict[var.name] = np.array(fluid.global_scope().find_var( + var_dict[var.name] = np.array(paddle.static.global_scope().find_var( var.name).get_tensor()) return var_dict def set_params(self, program, params_dict, place): for var in program.global_block().all_parameters(): - fluid.global_scope().find_var(var.name).get_tensor().set( + paddle.static.global_scope().find_var(var.name).get_tensor().set( params_dict[var.name], place) diff --git a/paddleslim/common/rl_controller/ddpg/ddpg_controller.py b/paddleslim/common/rl_controller/ddpg/ddpg_controller.py index 50216adbdaec7f152d64e3d2f16007d059510efb..03138e57755c18a66a9321bff381c91c96450358 100644 --- a/paddleslim/common/rl_controller/ddpg/ddpg_controller.py +++ b/paddleslim/common/rl_controller/ddpg/ddpg_controller.py @@ -15,6 +15,7 @@ import numpy as np import parl from parl import layers +import paddle from paddle import fluid from ..utils import RLCONTROLLER, action_mapping from ...controller import RLBaseController @@ -37,15 +38,15 @@ class DDPGAgent(parl.Agent): self.alg.sync_target(decay=0) def build_program(self): - self.pred_program = fluid.Program() - self.learn_program = fluid.Program() + self.pred_program = paddle.static.Program() + self.learn_program = paddle.static.Program() - with fluid.program_guard(self.pred_program): + with paddle.static.program_guard(self.pred_program): obs = fluid.data( name='obs', shape=[None, self.obs_dim], dtype='float32') self.pred_act = self.alg.predict(obs) - with fluid.program_guard(self.learn_program): + with paddle.static.program_guard(self.learn_program): obs = fluid.data( name='obs', shape=[None, self.obs_dim], dtype='float32') act = fluid.data( @@ -88,8 +89,7 @@ class DDPG(RLBaseController): self.obs_dim = kwargs.get('obs_dim') self.model = kwargs.get( 'model') if 'model' in kwargs else default_ddpg_model - self.actor_lr = kwargs.get( - 'actor_lr') if 'actor_lr' in kwargs else 1e-4 + self.actor_lr = kwargs.get('actor_lr') if 'actor_lr' in kwargs else 1e-4 self.critic_lr = kwargs.get( 'critic_lr') if 'critic_lr' in kwargs else 1e-3 self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99 @@ -103,7 +103,7 @@ class DDPG(RLBaseController): self.actions_noise = kwargs.get( 'actions_noise') if 'actions_noise' in kwargs else default_noise self.action_dist = 0.0 - self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() + self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace() model = self.model(self.act_dim) diff --git a/paddleslim/common/rl_controller/lstm/lstm_controller.py b/paddleslim/common/rl_controller/lstm/lstm_controller.py index 920b29eac64c4eec2d8dca28d60c40c694a09512..fcc43d40d54958b64fef738c546e4187f4343455 100644 --- a/paddleslim/common/rl_controller/lstm/lstm_controller.py +++ b/paddleslim/common/rl_controller/lstm/lstm_controller.py @@ -15,6 +15,7 @@ import math import logging import numpy as np +import paddle import paddle.fluid as fluid from paddle.fluid import ParamAttr from paddle.fluid.layers import RNNCell, LSTMCell, rnn @@ -39,8 +40,7 @@ class lstm_cell(RNNCell): bias_attr = ParamAttr(initializer=uniform_initializer( 1.0 / math.sqrt(hidden_size))) for i in range(num_layers): - self.lstm_cells.append( - LSTMCell(hidden_size, param_attr, bias_attr)) + self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr)) def call(self, inputs, states): new_states = [] @@ -75,26 +75,27 @@ class LSTM(RLBaseController): self._create_parameter() self._build_program() - self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() - self.exe = fluid.Executor(self.place) - self.exe.run(fluid.default_startup_program()) + self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace() + self.exe = paddle.static.Executor(self.place) + self.exe.run(paddle.static.default_startup_program()) self.param_dict = self.get_params(self.learn_program) def _lstm(self, inputs, hidden, cell, token_idx): cells = lstm_cell(self.lstm_num_layers, self.hidden_size) output, new_states = cells.call(inputs, states=([[hidden, cell]])) - logits = fluid.layers.fc(new_states[0], self.range_tables[token_idx]) + logits = paddle.static.nn.fc(new_states[0], + self.range_tables[token_idx]) if self.temperature is not None: logits = logits / self.temperature if self.tanh_constant is not None: - logits = self.tanh_constant * fluid.layers.tanh(logits) + logits = self.tanh_constant * paddle.tanh(logits) return logits, output, new_states def _create_parameter(self): - self.g_emb = fluid.layers.create_parameter( + self.g_emb = paddle.static.create_parameter( name='emb_g', shape=(self.controller_batch_size, self.hidden_size), dtype='float32', @@ -120,12 +121,12 @@ class LSTM(RLBaseController): logits, output, states = self._lstm( inputs, hidden, cell, token_idx=idx) hidden, cell = np.squeeze(states) - probs = fluid.layers.softmax(logits, axis=1) + probs = paddle.nn.functional.softmax(logits, axis=1) if is_inference: - action = fluid.layers.argmax(probs, axis=1) + action = paddle.argmax(probs, axis=1) else: if init_actions: - action = fluid.layers.slice( + action = paddle.slice( init_actions, axes=[1], starts=[idx], @@ -135,51 +136,48 @@ class LSTM(RLBaseController): else: action = fluid.layers.sampling_id(probs) actions.append(action) - log_prob = fluid.layers.softmax_with_cross_entropy( + log_prob = paddle.nn.functional.softmax_with_cross_entropy( logits, - fluid.layers.reshape( - action, shape=[fluid.layers.shape(action), 1]), + paddle.reshape( + action, shape=[paddle.shape(action), 1]), axis=1) sample_log_probs.append(log_prob) - entropy = log_prob * fluid.layers.exp(-1 * log_prob) + entropy = log_prob * paddle.exp(-1 * log_prob) entropy.stop_gradient = True entropies.append(entropy) - action_emb = fluid.layers.cast(action, dtype=np.int64) - inputs = fluid.embedding( + action_emb = paddle.cast(action, dtype=np.int64) + inputs = paddle.static.nn.embedding( action_emb, size=(self.max_range_table, self.hidden_size), - param_attr=fluid.ParamAttr( + param_attr=paddle.ParamAttr( name='emb_w', initializer=uniform_initializer(1.0))) - self.sample_log_probs = fluid.layers.concat( - sample_log_probs, axis=0) + self.sample_log_probs = paddle.concat(sample_log_probs, axis=0) - entropies = fluid.layers.stack(entropies) + entropies = paddle.stack(entropies) self.sample_entropies = fluid.layers.reduce_sum(entropies) return actions def _build_program(self, is_inference=False): - self.pred_program = fluid.Program() - self.learn_program = fluid.Program() - with fluid.program_guard(self.pred_program): - self.g_emb = fluid.layers.create_parameter( + self.pred_program = paddle.static.Program() + self.learn_program = paddle.static.Program() + with paddle.static.program_guard(self.pred_program): + self.g_emb = paddle.static.create_parameter( name='emb_g', shape=(self.controller_batch_size, self.hidden_size), dtype='float32', default_initializer=uniform_initializer(1.0)) - fluid.layers.assign( - fluid.layers.uniform_random(shape=self.g_emb.shape), - self.g_emb) + paddle.assign( + fluid.layers.uniform_random(shape=self.g_emb.shape), self.g_emb) hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) cell = fluid.data(name='cell', shape=[None, self.hidden_size]) - self.tokens = self._network( - hidden, cell, is_inference=is_inference) + self.tokens = self._network(hidden, cell, is_inference=is_inference) - with fluid.program_guard(self.learn_program): + with paddle.static.program_guard(self.learn_program): hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) cell = fluid.data(name='cell', shape=[None, self.hidden_size]) init_actions = fluid.data( @@ -197,18 +195,18 @@ class LSTM(RLBaseController): self.sample_log_probs = fluid.layers.reduce_sum( self.sample_log_probs) - fluid.layers.assign(self.baseline - (1.0 - self.decay) * - (self.baseline - self.rewards), self.baseline) + paddle.assign(self.baseline - (1.0 - self.decay) * + (self.baseline - self.rewards), self.baseline) self.loss = self.sample_log_probs * (self.rewards - self.baseline) clip = fluid.clip.GradientClipByNorm(clip_norm=5.0) if self.decay_steps is not None: - lr = fluid.layers.exponential_decay( - self.controller_lr, - decay_steps=self.decay_steps, - decay_rate=self.decay_rate) + lr = paddle.optimizer.lr.ExponentialDecay( + learning_rate=self.controller_lr, + gamma=self.decay_rate, + verbose=False) else: lr = self.controller_lr - optimizer = fluid.optimizer.Adam(learning_rate=lr, grad_clip=clip) + optimizer = paddle.optimizer.Adam(learning_rate=lr, grad_clip=clip) optimizer.minimize(self.loss) def _create_input(self, is_test=True, actual_rewards=None): diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 4e4e56c67ecb233232502447d4359071adb6ea98..bb9e8f67a5feaad18888089d1ebbcb87fc69650c 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -184,10 +184,10 @@ class AutoPruner(object): Prune program with latest tokens generated by controller. Args: - program(fluid.Program): The program to be pruned. + program(paddle.static.Program): The program to be pruned. Returns: - paddle.fluid.Program: The pruned program. + paddle.static.Program: The pruned program. """ self._current_ratios = self._next_ratios() pruned_program, self._param_backup, _ = self._pruner.prune( diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index 61077c2b5db88dd68e1dc0ca7b512c26f5cc6eeb..34533b250a859c5ecf27d03eb02e87bcfa65dff5 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -42,7 +42,7 @@ def collect_convs(params, graph, visited={}): Args: params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. - graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups. + graph(paddle.static.Program | GraphWrapper): The graph used to search the groups. Returns: list>: The groups. diff --git a/paddleslim/prune/prune_io.py b/paddleslim/prune/prune_io.py index 901b22a1cab206bd4f75409d52435b3857ea8e30..463b661a73899e76f4d3d9fd1f78cee1a78a25b9 100644 --- a/paddleslim/prune/prune_io.py +++ b/paddleslim/prune/prune_io.py @@ -17,7 +17,7 @@ def save_model(exe, graph, dirname): Save weights of model and information of shapes into filesystem. Args: - exe(paddle.fluid.Executor): The executor used to save model. + exe(paddle.static.Executor): The executor used to save model. graph(Program|Graph): The graph to be saved. dirname(str): The directory that the model saved into. """ diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index dc8c83e2507fba6de4376dfe940efeda871c93e5..9c15cb10730e9facec195b1dfdadd2b5bc59bef6 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -64,7 +64,7 @@ class Pruner(): Args: - program(fluid.Program): The program to be pruned. + program(paddle.static.Program): The program to be pruned. scope(fluid.Scope): The scope storing paramaters to be pruned. params(list): A list of parameter names to be pruned. ratios(list): A list of ratios to be used to pruning parameters. diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 20a19be2d29647b4f4cd01607213bfe43d493d54..ce220e9a47a053793da24df8391e3260087510df 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -57,10 +57,10 @@ def sensitivity(program, Args: - program(paddle.fluid.Program): The program to be analysised. - place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters. + program(paddle.static.Program): The program to be analysised. + place(paddle.CPUPlace | paddle.CUDAPlace): The device place of filter parameters. param_names(list): The parameter names of convolutions to be analysised. - eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset. + eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset. sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library. pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``.