未验证 提交 d8b971d6 编写于 作者: Y yukavio 提交者: GitHub

Fix prune ci (#512)

* remove sensitive_pruner module

* migrate prune module

* fix some bug for migrating the prune module to 2.0-rc

* update some 1.8 api

* change num_worker from 16 to 1

* migrate common and core module

* fix a api in rl controller

* fix dataloader
上级 74511818
......@@ -5,6 +5,7 @@ import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
from paddle.io import Dataset
random.seed(0)
np.random.seed(0)
......@@ -194,3 +195,54 @@ def val(data_dir=DATA_DIR):
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'test_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
class ImageNetDataset(Dataset):
def __init__(self, data_dir=DATA_DIR, mode='train'):
super(ImageNetDataset, self).__init__()
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
if os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
self.data = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.data = [line.split() for line in lines]
def __getitem__(self, index):
sample = self.data[index]
data_path = os.path.join(DATA_DIR, sample[0])
if self.mode == 'train':
data, label = process_image(
[data_path, sample[1]],
mode='train',
color_jitter=False,
rotate=False)
if self.mode == 'val':
data, label = process_image(
[data_path, sample[1]],
mode='val',
color_jitter=False,
rotate=False)
return data, np.array([label]).astype('int64')
def __len__(self):
return len(self.data)
......@@ -34,13 +34,12 @@ def eval(args):
train_reader = None
test_reader = None
if args.data == "mnist":
val_reader = paddle.dataset.mnist.test()
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
......@@ -61,14 +60,13 @@ def eval(args):
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
valid_loader = paddle.io.DataLoader.from_generator(
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
valid_loader.set_sample_list_generator(val_reader, place)
drop_last=False,
batch_size=args.batch_size,
shuffle=False)
load_model(exe, val_program, args.model_path)
......
......@@ -79,8 +79,8 @@ def piecewise_decay(args):
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = paddle.optimizer.lr.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
......@@ -99,14 +99,14 @@ def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
train_reader = paddle.dataset.mnist.train()
val_reader = paddle.dataset.mnist.test()
train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
......@@ -143,22 +143,23 @@ def compress(args):
paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
train_loader = paddle.io.DataLoader(
train_dataset,
places=places,
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
valid_loader = paddle.io.DataLoader.from_generator(
drop_last=True,
batch_size=args.batch_size,
shuffle=True,
use_shared_memory=False,
num_workers=16)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
drop_last=False,
use_shared_memory=False,
batch_size=args.batch_size,
shuffle=False)
def test(epoch, program):
acc_top1_ns = []
......@@ -237,8 +238,8 @@ def compress(args):
if args.save_inference:
infer_model_path = os.path.join(args.model_path, "infer_models",
str(i))
paddle.fluid.io.save_inference_model(infer_model_path, ["image"],
[out], exe, pruned_val_program)
paddle.static.save_inference_model(infer_model_path, ["image"],
[out], exe, pruned_val_program)
_logger.info("Saved inference model into [{}]".format(
infer_model_path))
......
......@@ -33,13 +33,12 @@ model_list = [m for m in dir(models) if "__" not in m]
def compress(args):
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
val_reader = reader.test()
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
val_reader = reader.val()
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
......@@ -70,14 +69,13 @@ def compress(args):
paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
valid_loader = paddle.io.DataLoader.from_generator(
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
valid_loader.set_sample_list_generator(val_reader, place)
drop_last=False,
batch_size=args.batch_size,
shuffle=False)
def test(program):
acc_top1_ns = []
......
......@@ -70,7 +70,7 @@ class VarCollector(object):
scope=None):
self.program = program
self.var_names = var_names
self.scope = fluid.global_scope() if scope is None else scope
self.scope = paddle.static.global_scope() if scope is None else scope
self.use_ema = use_ema
self.set_up()
if self.use_ema:
......@@ -104,8 +104,8 @@ class VarCollector(object):
def run(self, reader, exe, step=None, loss_name=None):
if not hasattr(self.program, '_program'):
# Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel(
loss_name=loss_name)
program = paddle.static.CompiledProgram(
self.program).with_data_parallel(loss_name=loss_name)
for idx, data in enumerate(reader):
vars_np = exe.run(program=program,
......@@ -122,17 +122,17 @@ class VarCollector(object):
def abs_max_run(self, reader, exe, step=None, loss_name=None):
fetch_list = []
with fluid.program_guard(self.program):
with paddle.static.program_guard(self.program):
for act_name in self.real_names:
act = self.program.global_block().var(act_name)
act = fluid.layers.reduce_max(
fluid.layers.abs(act), name=act_name + "_reduced")
paddle.abs(act), name=act_name + "_reduced")
fetch_list.append(act_name + "_reduced.tmp_0")
if not hasattr(self.program, '_program'):
# Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel(
loss_name=loss_name)
program = paddle.static.CompiledProgram(
self.program).with_data_parallel(loss_name=loss_name)
for idx, data in enumerate(reader):
vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list)
vars_np = [np.max(var) for var in vars_np]
......
......@@ -16,6 +16,7 @@
import copy
import math
import numpy as np
import paddle
import paddle.fluid as fluid
__all__ = ['EvolutionaryController', 'RLBaseController']
......@@ -72,11 +73,11 @@ class RLBaseController(object):
def get_params(self, program):
var_dict = {}
for var in program.global_block().all_parameters():
var_dict[var.name] = np.array(fluid.global_scope().find_var(
var_dict[var.name] = np.array(paddle.static.global_scope().find_var(
var.name).get_tensor())
return var_dict
def set_params(self, program, params_dict, place):
for var in program.global_block().all_parameters():
fluid.global_scope().find_var(var.name).get_tensor().set(
paddle.static.global_scope().find_var(var.name).get_tensor().set(
params_dict[var.name], place)
......@@ -15,6 +15,7 @@
import numpy as np
import parl
from parl import layers
import paddle
from paddle import fluid
from ..utils import RLCONTROLLER, action_mapping
from ...controller import RLBaseController
......@@ -37,15 +38,15 @@ class DDPGAgent(parl.Agent):
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
self.pred_program = paddle.static.Program()
self.learn_program = paddle.static.Program()
with fluid.program_guard(self.pred_program):
with paddle.static.program_guard(self.pred_program):
obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
with paddle.static.program_guard(self.learn_program):
obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32')
act = fluid.data(
......@@ -88,8 +89,7 @@ class DDPG(RLBaseController):
self.obs_dim = kwargs.get('obs_dim')
self.model = kwargs.get(
'model') if 'model' in kwargs else default_ddpg_model
self.actor_lr = kwargs.get(
'actor_lr') if 'actor_lr' in kwargs else 1e-4
self.actor_lr = kwargs.get('actor_lr') if 'actor_lr' in kwargs else 1e-4
self.critic_lr = kwargs.get(
'critic_lr') if 'critic_lr' in kwargs else 1e-3
self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99
......@@ -103,7 +103,7 @@ class DDPG(RLBaseController):
self.actions_noise = kwargs.get(
'actions_noise') if 'actions_noise' in kwargs else default_noise
self.action_dist = 0.0
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace()
self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace()
model = self.model(self.act_dim)
......
......@@ -15,6 +15,7 @@
import math
import logging
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn
......@@ -39,8 +40,7 @@ class lstm_cell(RNNCell):
bias_attr = ParamAttr(initializer=uniform_initializer(
1.0 / math.sqrt(hidden_size)))
for i in range(num_layers):
self.lstm_cells.append(
LSTMCell(hidden_size, param_attr, bias_attr))
self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
def call(self, inputs, states):
new_states = []
......@@ -75,26 +75,27 @@ class LSTM(RLBaseController):
self._create_parameter()
self._build_program()
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
self.exe.run(fluid.default_startup_program())
self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace()
self.exe = paddle.static.Executor(self.place)
self.exe.run(paddle.static.default_startup_program())
self.param_dict = self.get_params(self.learn_program)
def _lstm(self, inputs, hidden, cell, token_idx):
cells = lstm_cell(self.lstm_num_layers, self.hidden_size)
output, new_states = cells.call(inputs, states=([[hidden, cell]]))
logits = fluid.layers.fc(new_states[0], self.range_tables[token_idx])
logits = paddle.static.nn.fc(new_states[0],
self.range_tables[token_idx])
if self.temperature is not None:
logits = logits / self.temperature
if self.tanh_constant is not None:
logits = self.tanh_constant * fluid.layers.tanh(logits)
logits = self.tanh_constant * paddle.tanh(logits)
return logits, output, new_states
def _create_parameter(self):
self.g_emb = fluid.layers.create_parameter(
self.g_emb = paddle.static.create_parameter(
name='emb_g',
shape=(self.controller_batch_size, self.hidden_size),
dtype='float32',
......@@ -120,12 +121,12 @@ class LSTM(RLBaseController):
logits, output, states = self._lstm(
inputs, hidden, cell, token_idx=idx)
hidden, cell = np.squeeze(states)
probs = fluid.layers.softmax(logits, axis=1)
probs = paddle.nn.functional.softmax(logits, axis=1)
if is_inference:
action = fluid.layers.argmax(probs, axis=1)
action = paddle.argmax(probs, axis=1)
else:
if init_actions:
action = fluid.layers.slice(
action = paddle.slice(
init_actions,
axes=[1],
starts=[idx],
......@@ -135,51 +136,48 @@ class LSTM(RLBaseController):
else:
action = fluid.layers.sampling_id(probs)
actions.append(action)
log_prob = fluid.layers.softmax_with_cross_entropy(
log_prob = paddle.nn.functional.softmax_with_cross_entropy(
logits,
fluid.layers.reshape(
action, shape=[fluid.layers.shape(action), 1]),
paddle.reshape(
action, shape=[paddle.shape(action), 1]),
axis=1)
sample_log_probs.append(log_prob)
entropy = log_prob * fluid.layers.exp(-1 * log_prob)
entropy = log_prob * paddle.exp(-1 * log_prob)
entropy.stop_gradient = True
entropies.append(entropy)
action_emb = fluid.layers.cast(action, dtype=np.int64)
inputs = fluid.embedding(
action_emb = paddle.cast(action, dtype=np.int64)
inputs = paddle.static.nn.embedding(
action_emb,
size=(self.max_range_table, self.hidden_size),
param_attr=fluid.ParamAttr(
param_attr=paddle.ParamAttr(
name='emb_w', initializer=uniform_initializer(1.0)))
self.sample_log_probs = fluid.layers.concat(
sample_log_probs, axis=0)
self.sample_log_probs = paddle.concat(sample_log_probs, axis=0)
entropies = fluid.layers.stack(entropies)
entropies = paddle.stack(entropies)
self.sample_entropies = fluid.layers.reduce_sum(entropies)
return actions
def _build_program(self, is_inference=False):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
self.g_emb = fluid.layers.create_parameter(
self.pred_program = paddle.static.Program()
self.learn_program = paddle.static.Program()
with paddle.static.program_guard(self.pred_program):
self.g_emb = paddle.static.create_parameter(
name='emb_g',
shape=(self.controller_batch_size, self.hidden_size),
dtype='float32',
default_initializer=uniform_initializer(1.0))
fluid.layers.assign(
fluid.layers.uniform_random(shape=self.g_emb.shape),
self.g_emb)
paddle.assign(
fluid.layers.uniform_random(shape=self.g_emb.shape), self.g_emb)
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size])
self.tokens = self._network(
hidden, cell, is_inference=is_inference)
self.tokens = self._network(hidden, cell, is_inference=is_inference)
with fluid.program_guard(self.learn_program):
with paddle.static.program_guard(self.learn_program):
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size])
init_actions = fluid.data(
......@@ -197,18 +195,18 @@ class LSTM(RLBaseController):
self.sample_log_probs = fluid.layers.reduce_sum(
self.sample_log_probs)
fluid.layers.assign(self.baseline - (1.0 - self.decay) *
(self.baseline - self.rewards), self.baseline)
paddle.assign(self.baseline - (1.0 - self.decay) *
(self.baseline - self.rewards), self.baseline)
self.loss = self.sample_log_probs * (self.rewards - self.baseline)
clip = fluid.clip.GradientClipByNorm(clip_norm=5.0)
if self.decay_steps is not None:
lr = fluid.layers.exponential_decay(
self.controller_lr,
decay_steps=self.decay_steps,
decay_rate=self.decay_rate)
lr = paddle.optimizer.lr.ExponentialDecay(
learning_rate=self.controller_lr,
gamma=self.decay_rate,
verbose=False)
else:
lr = self.controller_lr
optimizer = fluid.optimizer.Adam(learning_rate=lr, grad_clip=clip)
optimizer = paddle.optimizer.Adam(learning_rate=lr, grad_clip=clip)
optimizer.minimize(self.loss)
def _create_input(self, is_test=True, actual_rewards=None):
......
......@@ -184,10 +184,10 @@ class AutoPruner(object):
Prune program with latest tokens generated by controller.
Args:
program(fluid.Program): The program to be pruned.
program(paddle.static.Program): The program to be pruned.
Returns:
paddle.fluid.Program: The pruned program.
paddle.static.Program: The pruned program.
"""
self._current_ratios = self._next_ratios()
pruned_program, self._param_backup, _ = self._pruner.prune(
......
......@@ -42,7 +42,7 @@ def collect_convs(params, graph, visited={}):
Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups.
graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
Returns:
list<list<tuple>>: The groups.
......
......@@ -17,7 +17,7 @@ def save_model(exe, graph, dirname):
Save weights of model and information of shapes into filesystem.
Args:
exe(paddle.fluid.Executor): The executor used to save model.
exe(paddle.static.Executor): The executor used to save model.
graph(Program|Graph): The graph to be saved.
dirname(str): The directory that the model saved into.
"""
......
......@@ -64,7 +64,7 @@ class Pruner():
Args:
program(fluid.Program): The program to be pruned.
program(paddle.static.Program): The program to be pruned.
scope(fluid.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned.
ratios(list<float>): A list of ratios to be used to pruning parameters.
......
......@@ -57,10 +57,10 @@ def sensitivity(program,
Args:
program(paddle.fluid.Program): The program to be analysised.
place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters.
program(paddle.static.Program): The program to be analysised.
place(paddle.CPUPlace | paddle.CUDAPlace): The device place of filter parameters.
param_names(list): The parameter names of convolutions to be analysised.
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset.
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册