未验证 提交 d8b971d6 编写于 作者: Y yukavio 提交者: GitHub

Fix prune ci (#512)

* remove sensitive_pruner module

* migrate prune module

* fix some bug for migrating the prune module to 2.0-rc

* update some 1.8 api

* change num_worker from 16 to 1

* migrate common and core module

* fix a api in rl controller

* fix dataloader
上级 74511818
...@@ -5,6 +5,7 @@ import functools ...@@ -5,6 +5,7 @@ import functools
import numpy as np import numpy as np
import paddle import paddle
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from paddle.io import Dataset
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -194,3 +195,54 @@ def val(data_dir=DATA_DIR): ...@@ -194,3 +195,54 @@ def val(data_dir=DATA_DIR):
def test(data_dir=DATA_DIR): def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'test_list.txt') file_list = os.path.join(data_dir, 'test_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
class ImageNetDataset(Dataset):
def __init__(self, data_dir=DATA_DIR, mode='train'):
super(ImageNetDataset, self).__init__()
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
if os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
self.data = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.data = [line.split() for line in lines]
def __getitem__(self, index):
sample = self.data[index]
data_path = os.path.join(DATA_DIR, sample[0])
if self.mode == 'train':
data, label = process_image(
[data_path, sample[1]],
mode='train',
color_jitter=False,
rotate=False)
if self.mode == 'val':
data, label = process_image(
[data_path, sample[1]],
mode='val',
color_jitter=False,
rotate=False)
return data, np.array([label]).astype('int64')
def __len__(self):
return len(self.data)
...@@ -34,13 +34,12 @@ def eval(args): ...@@ -34,13 +34,12 @@ def eval(args):
train_reader = None train_reader = None
test_reader = None test_reader = None
if args.data == "mnist": if args.data == "mnist":
val_reader = paddle.dataset.mnist.test() val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_reader = reader.train() val_dataset = reader.ImageNetDataset(mode='val')
val_reader = reader.val()
class_dim = 1000 class_dim = 1000
image_shape = "3,224,224" image_shape = "3,224,224"
else: else:
...@@ -61,14 +60,13 @@ def eval(args): ...@@ -61,14 +60,13 @@ def eval(args):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
val_reader = paddle.batch(val_reader, batch_size=args.batch_size) valid_loader = paddle.io.DataLoader(
val_dataset,
valid_loader = paddle.io.DataLoader.from_generator( places=place,
feed_list=[image, label], feed_list=[image, label],
capacity=64, drop_last=False,
use_double_buffer=True, batch_size=args.batch_size,
iterable=True) shuffle=False)
valid_loader.set_sample_list_generator(val_reader, place)
load_model(exe, val_program, args.model_path) load_model(exe, val_program, args.model_path)
......
...@@ -79,8 +79,8 @@ def piecewise_decay(args): ...@@ -79,8 +79,8 @@ def piecewise_decay(args):
def cosine_decay(args): def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size)) step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = paddle.optimizer.lr.cosine_decay( learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) learning_rate=args.lr, T_max=args.num_epochs)
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=args.momentum_rate, momentum=args.momentum_rate,
...@@ -99,14 +99,14 @@ def compress(args): ...@@ -99,14 +99,14 @@ def compress(args):
train_reader = None train_reader = None
test_reader = None test_reader = None
if args.data == "mnist": if args.data == "mnist":
train_reader = paddle.dataset.mnist.train() train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_reader = paddle.dataset.mnist.test() val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_reader = reader.train() train_dataset = reader.ImageNetDataset(mode='train')
val_reader = reader.val() val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000 class_dim = 1000
image_shape = "3,224,224" image_shape = "3,224,224"
else: else:
...@@ -143,22 +143,23 @@ def compress(args): ...@@ -143,22 +143,23 @@ def compress(args):
paddle.fluid.io.load_vars( paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist) exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size) train_loader = paddle.io.DataLoader(
train_reader = paddle.batch( train_dataset,
train_reader, batch_size=args.batch_size, drop_last=True) places=places,
train_loader = paddle.io.DataLoader.from_generator(
feed_list=[image, label], feed_list=[image, label],
capacity=64, drop_last=True,
use_double_buffer=True, batch_size=args.batch_size,
iterable=True) shuffle=True,
valid_loader = paddle.io.DataLoader.from_generator( use_shared_memory=False,
num_workers=16)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
feed_list=[image, label], feed_list=[image, label],
capacity=64, drop_last=False,
use_double_buffer=True, use_shared_memory=False,
iterable=True) batch_size=args.batch_size,
train_loader.set_sample_list_generator(train_reader, places) shuffle=False)
valid_loader.set_sample_list_generator(val_reader, place)
def test(epoch, program): def test(epoch, program):
acc_top1_ns = [] acc_top1_ns = []
...@@ -237,8 +238,8 @@ def compress(args): ...@@ -237,8 +238,8 @@ def compress(args):
if args.save_inference: if args.save_inference:
infer_model_path = os.path.join(args.model_path, "infer_models", infer_model_path = os.path.join(args.model_path, "infer_models",
str(i)) str(i))
paddle.fluid.io.save_inference_model(infer_model_path, ["image"], paddle.static.save_inference_model(infer_model_path, ["image"],
[out], exe, pruned_val_program) [out], exe, pruned_val_program)
_logger.info("Saved inference model into [{}]".format( _logger.info("Saved inference model into [{}]".format(
infer_model_path)) infer_model_path))
......
...@@ -33,13 +33,12 @@ model_list = [m for m in dir(models) if "__" not in m] ...@@ -33,13 +33,12 @@ model_list = [m for m in dir(models) if "__" not in m]
def compress(args): def compress(args):
test_reader = None test_reader = None
if args.data == "mnist": if args.data == "mnist":
import paddle.dataset.mnist as reader val_dataset = paddle.vision.datasets.MNIST(mode='test')
val_reader = reader.test()
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
val_reader = reader.val() val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000 class_dim = 1000
image_shape = "3,224,224" image_shape = "3,224,224"
else: else:
...@@ -70,14 +69,13 @@ def compress(args): ...@@ -70,14 +69,13 @@ def compress(args):
paddle.fluid.io.load_vars( paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist) exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size) valid_loader = paddle.io.DataLoader(
valid_loader = paddle.io.DataLoader.from_generator( val_dataset,
places=place,
feed_list=[image, label], feed_list=[image, label],
capacity=64, drop_last=False,
use_double_buffer=True, batch_size=args.batch_size,
iterable=True) shuffle=False)
valid_loader.set_sample_list_generator(val_reader, place)
def test(program): def test(program):
acc_top1_ns = [] acc_top1_ns = []
......
...@@ -70,7 +70,7 @@ class VarCollector(object): ...@@ -70,7 +70,7 @@ class VarCollector(object):
scope=None): scope=None):
self.program = program self.program = program
self.var_names = var_names self.var_names = var_names
self.scope = fluid.global_scope() if scope is None else scope self.scope = paddle.static.global_scope() if scope is None else scope
self.use_ema = use_ema self.use_ema = use_ema
self.set_up() self.set_up()
if self.use_ema: if self.use_ema:
...@@ -104,8 +104,8 @@ class VarCollector(object): ...@@ -104,8 +104,8 @@ class VarCollector(object):
def run(self, reader, exe, step=None, loss_name=None): def run(self, reader, exe, step=None, loss_name=None):
if not hasattr(self.program, '_program'): if not hasattr(self.program, '_program'):
# Compile the native program to speed up # Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel( program = paddle.static.CompiledProgram(
loss_name=loss_name) self.program).with_data_parallel(loss_name=loss_name)
for idx, data in enumerate(reader): for idx, data in enumerate(reader):
vars_np = exe.run(program=program, vars_np = exe.run(program=program,
...@@ -122,17 +122,17 @@ class VarCollector(object): ...@@ -122,17 +122,17 @@ class VarCollector(object):
def abs_max_run(self, reader, exe, step=None, loss_name=None): def abs_max_run(self, reader, exe, step=None, loss_name=None):
fetch_list = [] fetch_list = []
with fluid.program_guard(self.program): with paddle.static.program_guard(self.program):
for act_name in self.real_names: for act_name in self.real_names:
act = self.program.global_block().var(act_name) act = self.program.global_block().var(act_name)
act = fluid.layers.reduce_max( act = fluid.layers.reduce_max(
fluid.layers.abs(act), name=act_name + "_reduced") paddle.abs(act), name=act_name + "_reduced")
fetch_list.append(act_name + "_reduced.tmp_0") fetch_list.append(act_name + "_reduced.tmp_0")
if not hasattr(self.program, '_program'): if not hasattr(self.program, '_program'):
# Compile the native program to speed up # Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel( program = paddle.static.CompiledProgram(
loss_name=loss_name) self.program).with_data_parallel(loss_name=loss_name)
for idx, data in enumerate(reader): for idx, data in enumerate(reader):
vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list) vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list)
vars_np = [np.max(var) for var in vars_np] vars_np = [np.max(var) for var in vars_np]
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import copy import copy
import math import math
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
__all__ = ['EvolutionaryController', 'RLBaseController'] __all__ = ['EvolutionaryController', 'RLBaseController']
...@@ -72,11 +73,11 @@ class RLBaseController(object): ...@@ -72,11 +73,11 @@ class RLBaseController(object):
def get_params(self, program): def get_params(self, program):
var_dict = {} var_dict = {}
for var in program.global_block().all_parameters(): for var in program.global_block().all_parameters():
var_dict[var.name] = np.array(fluid.global_scope().find_var( var_dict[var.name] = np.array(paddle.static.global_scope().find_var(
var.name).get_tensor()) var.name).get_tensor())
return var_dict return var_dict
def set_params(self, program, params_dict, place): def set_params(self, program, params_dict, place):
for var in program.global_block().all_parameters(): for var in program.global_block().all_parameters():
fluid.global_scope().find_var(var.name).get_tensor().set( paddle.static.global_scope().find_var(var.name).get_tensor().set(
params_dict[var.name], place) params_dict[var.name], place)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import numpy as np import numpy as np
import parl import parl
from parl import layers from parl import layers
import paddle
from paddle import fluid from paddle import fluid
from ..utils import RLCONTROLLER, action_mapping from ..utils import RLCONTROLLER, action_mapping
from ...controller import RLBaseController from ...controller import RLBaseController
...@@ -37,15 +38,15 @@ class DDPGAgent(parl.Agent): ...@@ -37,15 +38,15 @@ class DDPGAgent(parl.Agent):
self.alg.sync_target(decay=0) self.alg.sync_target(decay=0)
def build_program(self): def build_program(self):
self.pred_program = fluid.Program() self.pred_program = paddle.static.Program()
self.learn_program = fluid.Program() self.learn_program = paddle.static.Program()
with fluid.program_guard(self.pred_program): with paddle.static.program_guard(self.pred_program):
obs = fluid.data( obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32') name='obs', shape=[None, self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs) self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program): with paddle.static.program_guard(self.learn_program):
obs = fluid.data( obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32') name='obs', shape=[None, self.obs_dim], dtype='float32')
act = fluid.data( act = fluid.data(
...@@ -88,8 +89,7 @@ class DDPG(RLBaseController): ...@@ -88,8 +89,7 @@ class DDPG(RLBaseController):
self.obs_dim = kwargs.get('obs_dim') self.obs_dim = kwargs.get('obs_dim')
self.model = kwargs.get( self.model = kwargs.get(
'model') if 'model' in kwargs else default_ddpg_model 'model') if 'model' in kwargs else default_ddpg_model
self.actor_lr = kwargs.get( self.actor_lr = kwargs.get('actor_lr') if 'actor_lr' in kwargs else 1e-4
'actor_lr') if 'actor_lr' in kwargs else 1e-4
self.critic_lr = kwargs.get( self.critic_lr = kwargs.get(
'critic_lr') if 'critic_lr' in kwargs else 1e-3 'critic_lr') if 'critic_lr' in kwargs else 1e-3
self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99 self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99
...@@ -103,7 +103,7 @@ class DDPG(RLBaseController): ...@@ -103,7 +103,7 @@ class DDPG(RLBaseController):
self.actions_noise = kwargs.get( self.actions_noise = kwargs.get(
'actions_noise') if 'actions_noise' in kwargs else default_noise 'actions_noise') if 'actions_noise' in kwargs else default_noise
self.action_dist = 0.0 self.action_dist = 0.0
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace()
model = self.model(self.act_dim) model = self.model(self.act_dim)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import math import math
import logging import logging
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn from paddle.fluid.layers import RNNCell, LSTMCell, rnn
...@@ -39,8 +40,7 @@ class lstm_cell(RNNCell): ...@@ -39,8 +40,7 @@ class lstm_cell(RNNCell):
bias_attr = ParamAttr(initializer=uniform_initializer( bias_attr = ParamAttr(initializer=uniform_initializer(
1.0 / math.sqrt(hidden_size))) 1.0 / math.sqrt(hidden_size)))
for i in range(num_layers): for i in range(num_layers):
self.lstm_cells.append( self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
LSTMCell(hidden_size, param_attr, bias_attr))
def call(self, inputs, states): def call(self, inputs, states):
new_states = [] new_states = []
...@@ -75,26 +75,27 @@ class LSTM(RLBaseController): ...@@ -75,26 +75,27 @@ class LSTM(RLBaseController):
self._create_parameter() self._create_parameter()
self._build_program() self._build_program()
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace()
self.exe = fluid.Executor(self.place) self.exe = paddle.static.Executor(self.place)
self.exe.run(fluid.default_startup_program()) self.exe.run(paddle.static.default_startup_program())
self.param_dict = self.get_params(self.learn_program) self.param_dict = self.get_params(self.learn_program)
def _lstm(self, inputs, hidden, cell, token_idx): def _lstm(self, inputs, hidden, cell, token_idx):
cells = lstm_cell(self.lstm_num_layers, self.hidden_size) cells = lstm_cell(self.lstm_num_layers, self.hidden_size)
output, new_states = cells.call(inputs, states=([[hidden, cell]])) output, new_states = cells.call(inputs, states=([[hidden, cell]]))
logits = fluid.layers.fc(new_states[0], self.range_tables[token_idx]) logits = paddle.static.nn.fc(new_states[0],
self.range_tables[token_idx])
if self.temperature is not None: if self.temperature is not None:
logits = logits / self.temperature logits = logits / self.temperature
if self.tanh_constant is not None: if self.tanh_constant is not None:
logits = self.tanh_constant * fluid.layers.tanh(logits) logits = self.tanh_constant * paddle.tanh(logits)
return logits, output, new_states return logits, output, new_states
def _create_parameter(self): def _create_parameter(self):
self.g_emb = fluid.layers.create_parameter( self.g_emb = paddle.static.create_parameter(
name='emb_g', name='emb_g',
shape=(self.controller_batch_size, self.hidden_size), shape=(self.controller_batch_size, self.hidden_size),
dtype='float32', dtype='float32',
...@@ -120,12 +121,12 @@ class LSTM(RLBaseController): ...@@ -120,12 +121,12 @@ class LSTM(RLBaseController):
logits, output, states = self._lstm( logits, output, states = self._lstm(
inputs, hidden, cell, token_idx=idx) inputs, hidden, cell, token_idx=idx)
hidden, cell = np.squeeze(states) hidden, cell = np.squeeze(states)
probs = fluid.layers.softmax(logits, axis=1) probs = paddle.nn.functional.softmax(logits, axis=1)
if is_inference: if is_inference:
action = fluid.layers.argmax(probs, axis=1) action = paddle.argmax(probs, axis=1)
else: else:
if init_actions: if init_actions:
action = fluid.layers.slice( action = paddle.slice(
init_actions, init_actions,
axes=[1], axes=[1],
starts=[idx], starts=[idx],
...@@ -135,51 +136,48 @@ class LSTM(RLBaseController): ...@@ -135,51 +136,48 @@ class LSTM(RLBaseController):
else: else:
action = fluid.layers.sampling_id(probs) action = fluid.layers.sampling_id(probs)
actions.append(action) actions.append(action)
log_prob = fluid.layers.softmax_with_cross_entropy( log_prob = paddle.nn.functional.softmax_with_cross_entropy(
logits, logits,
fluid.layers.reshape( paddle.reshape(
action, shape=[fluid.layers.shape(action), 1]), action, shape=[paddle.shape(action), 1]),
axis=1) axis=1)
sample_log_probs.append(log_prob) sample_log_probs.append(log_prob)
entropy = log_prob * fluid.layers.exp(-1 * log_prob) entropy = log_prob * paddle.exp(-1 * log_prob)
entropy.stop_gradient = True entropy.stop_gradient = True
entropies.append(entropy) entropies.append(entropy)
action_emb = fluid.layers.cast(action, dtype=np.int64) action_emb = paddle.cast(action, dtype=np.int64)
inputs = fluid.embedding( inputs = paddle.static.nn.embedding(
action_emb, action_emb,
size=(self.max_range_table, self.hidden_size), size=(self.max_range_table, self.hidden_size),
param_attr=fluid.ParamAttr( param_attr=paddle.ParamAttr(
name='emb_w', initializer=uniform_initializer(1.0))) name='emb_w', initializer=uniform_initializer(1.0)))
self.sample_log_probs = fluid.layers.concat( self.sample_log_probs = paddle.concat(sample_log_probs, axis=0)
sample_log_probs, axis=0)
entropies = fluid.layers.stack(entropies) entropies = paddle.stack(entropies)
self.sample_entropies = fluid.layers.reduce_sum(entropies) self.sample_entropies = fluid.layers.reduce_sum(entropies)
return actions return actions
def _build_program(self, is_inference=False): def _build_program(self, is_inference=False):
self.pred_program = fluid.Program() self.pred_program = paddle.static.Program()
self.learn_program = fluid.Program() self.learn_program = paddle.static.Program()
with fluid.program_guard(self.pred_program): with paddle.static.program_guard(self.pred_program):
self.g_emb = fluid.layers.create_parameter( self.g_emb = paddle.static.create_parameter(
name='emb_g', name='emb_g',
shape=(self.controller_batch_size, self.hidden_size), shape=(self.controller_batch_size, self.hidden_size),
dtype='float32', dtype='float32',
default_initializer=uniform_initializer(1.0)) default_initializer=uniform_initializer(1.0))
fluid.layers.assign( paddle.assign(
fluid.layers.uniform_random(shape=self.g_emb.shape), fluid.layers.uniform_random(shape=self.g_emb.shape), self.g_emb)
self.g_emb)
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size]) cell = fluid.data(name='cell', shape=[None, self.hidden_size])
self.tokens = self._network( self.tokens = self._network(hidden, cell, is_inference=is_inference)
hidden, cell, is_inference=is_inference)
with fluid.program_guard(self.learn_program): with paddle.static.program_guard(self.learn_program):
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size]) cell = fluid.data(name='cell', shape=[None, self.hidden_size])
init_actions = fluid.data( init_actions = fluid.data(
...@@ -197,18 +195,18 @@ class LSTM(RLBaseController): ...@@ -197,18 +195,18 @@ class LSTM(RLBaseController):
self.sample_log_probs = fluid.layers.reduce_sum( self.sample_log_probs = fluid.layers.reduce_sum(
self.sample_log_probs) self.sample_log_probs)
fluid.layers.assign(self.baseline - (1.0 - self.decay) * paddle.assign(self.baseline - (1.0 - self.decay) *
(self.baseline - self.rewards), self.baseline) (self.baseline - self.rewards), self.baseline)
self.loss = self.sample_log_probs * (self.rewards - self.baseline) self.loss = self.sample_log_probs * (self.rewards - self.baseline)
clip = fluid.clip.GradientClipByNorm(clip_norm=5.0) clip = fluid.clip.GradientClipByNorm(clip_norm=5.0)
if self.decay_steps is not None: if self.decay_steps is not None:
lr = fluid.layers.exponential_decay( lr = paddle.optimizer.lr.ExponentialDecay(
self.controller_lr, learning_rate=self.controller_lr,
decay_steps=self.decay_steps, gamma=self.decay_rate,
decay_rate=self.decay_rate) verbose=False)
else: else:
lr = self.controller_lr lr = self.controller_lr
optimizer = fluid.optimizer.Adam(learning_rate=lr, grad_clip=clip) optimizer = paddle.optimizer.Adam(learning_rate=lr, grad_clip=clip)
optimizer.minimize(self.loss) optimizer.minimize(self.loss)
def _create_input(self, is_test=True, actual_rewards=None): def _create_input(self, is_test=True, actual_rewards=None):
......
...@@ -184,10 +184,10 @@ class AutoPruner(object): ...@@ -184,10 +184,10 @@ class AutoPruner(object):
Prune program with latest tokens generated by controller. Prune program with latest tokens generated by controller.
Args: Args:
program(fluid.Program): The program to be pruned. program(paddle.static.Program): The program to be pruned.
Returns: Returns:
paddle.fluid.Program: The pruned program. paddle.static.Program: The pruned program.
""" """
self._current_ratios = self._next_ratios() self._current_ratios = self._next_ratios()
pruned_program, self._param_backup, _ = self._pruner.prune( pruned_program, self._param_backup, _ = self._pruner.prune(
......
...@@ -42,7 +42,7 @@ def collect_convs(params, graph, visited={}): ...@@ -42,7 +42,7 @@ def collect_convs(params, graph, visited={}):
Args: Args:
params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters. params(list): A list of convolution layer's parameter names. It will collect all the groups that contains anyone of these parameters.
graph(paddle.fluid.Program | GraphWrapper): The graph used to search the groups. graph(paddle.static.Program | GraphWrapper): The graph used to search the groups.
Returns: Returns:
list<list<tuple>>: The groups. list<list<tuple>>: The groups.
......
...@@ -17,7 +17,7 @@ def save_model(exe, graph, dirname): ...@@ -17,7 +17,7 @@ def save_model(exe, graph, dirname):
Save weights of model and information of shapes into filesystem. Save weights of model and information of shapes into filesystem.
Args: Args:
exe(paddle.fluid.Executor): The executor used to save model. exe(paddle.static.Executor): The executor used to save model.
graph(Program|Graph): The graph to be saved. graph(Program|Graph): The graph to be saved.
dirname(str): The directory that the model saved into. dirname(str): The directory that the model saved into.
""" """
......
...@@ -64,7 +64,7 @@ class Pruner(): ...@@ -64,7 +64,7 @@ class Pruner():
Args: Args:
program(fluid.Program): The program to be pruned. program(paddle.static.Program): The program to be pruned.
scope(fluid.Scope): The scope storing paramaters to be pruned. scope(fluid.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned. params(list<str>): A list of parameter names to be pruned.
ratios(list<float>): A list of ratios to be used to pruning parameters. ratios(list<float>): A list of ratios to be used to pruning parameters.
......
...@@ -57,10 +57,10 @@ def sensitivity(program, ...@@ -57,10 +57,10 @@ def sensitivity(program,
Args: Args:
program(paddle.fluid.Program): The program to be analysised. program(paddle.static.Program): The program to be analysised.
place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters. place(paddle.CPUPlace | paddle.CUDAPlace): The device place of filter parameters.
param_names(list): The parameter names of convolutions to be analysised. param_names(list): The parameter names of convolutions to be analysised.
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset. eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library. sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``. pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册