未验证 提交 423b057c 编写于 作者: W whs 提交者: GitHub

Remove and update fluid API (#1508)

1. Remove fluid
2. Remove unused codes
2.1 remove adabert
2.2 remove ddpg controller of NAS
2.3 remove paddleslim.models
2.4 Remove PC-DARTS
上级 7d90a673
......@@ -22,7 +22,6 @@ import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core
......@@ -63,19 +62,19 @@ def parse_args():
def transform_and_save_int8_model(original_path, save_path):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.Executor.global_scope()
model_filename = 'model.pdmodel'
params_filename = 'model.pdiparams'
with fluid.scope_guard(inference_scope):
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(original_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path, exe)
[inference_program, feed_target_names, fetch_targets
] = paddle.static.load_inference_model(original_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
fetch_targets] = paddle.static.load_inference_model(
original_path, exe, model_filename, params_filename)
ops_to_quantize = set()
......@@ -98,15 +97,13 @@ def transform_and_save_int8_model(original_path, save_path):
_debug=test_args.debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
with paddle.static.scope_guard(inference_scope):
paddle.static.save_inference_model(
save_path,
feed_target_names,
fetch_targets,
exe,
inference_program,
model_filename=model_filename,
params_filename=params_filename)
program=inference_program)
print(
"Success! INT8 model obtained from the Quant model can be found at {}\n"
.format(save_path))
......
CUDA_VISIBLE_DEVICES=0 python2 -u train_cell_base.py
import numpy as np
from itertools import izip
import paddle.fluid as fluid
from paddleslim.teachers.bert.reader.cls import *
from paddleslim.nas.darts.search_space import AdaBERTClassifier
from paddle.fluid.dygraph.base import to_variable
from tqdm import tqdm
import os
import pickle
import logging
from paddleslim.common import AvgrageMeter, get_logger
logger = get_logger(__name__, level=logging.INFO)
def valid_one_epoch(model, valid_loader, epoch, log_freq):
accs = AvgrageMeter()
ce_losses = AvgrageMeter()
model.student.eval()
step_id = 0
for valid_data in valid_loader():
try:
loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch)
except:
loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch)
batch_size = valid_data[0].shape[0]
ce_losses.update(ce_loss.numpy(), batch_size)
accs.update(acc.numpy(), batch_size)
step_id += 1
return ce_losses.avg[0], accs.avg[0]
def train_one_epoch(model, train_loader, optimizer, epoch, use_data_parallel,
log_freq):
total_losses = AvgrageMeter()
accs = AvgrageMeter()
ce_losses = AvgrageMeter()
kd_losses = AvgrageMeter()
model.student.train()
step_id = 0
for train_data in train_loader():
batch_size = train_data[0].shape[0]
if use_data_parallel:
total_loss, acc, ce_loss, kd_loss, _ = model._layers.loss(
train_data, epoch)
else:
total_loss, acc, ce_loss, kd_loss, _ = model.loss(train_data,
epoch)
if use_data_parallel:
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
optimizer.minimize(total_loss)
model.clear_gradients()
total_losses.update(total_loss.numpy(), batch_size)
accs.update(acc.numpy(), batch_size)
ce_losses.update(ce_loss.numpy(), batch_size)
kd_losses.update(kd_loss.numpy(), batch_size)
if step_id % log_freq == 0:
logger.info(
"Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f};".
format(epoch, step_id,
optimizer.current_step_lr(), total_losses.avg[0],
ce_losses.avg[0], kd_losses.avg[0], accs.avg[0]))
step_id += 1
def main():
# whether use multi-gpus
use_data_parallel = False
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
).dev_id) if use_data_parallel else fluid.CUDAPlace(0)
BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12"
vocab_path = BERT_BASE_PATH + "/vocab.txt"
do_lower_case = True
# augmented dataset nums
# num_samples = 8016987
max_seq_len = 128
batch_size = 192
hidden_size = 768
emb_size = 768
epoch = 80
log_freq = 10
task_name = 'mnli'
if task_name == 'mrpc':
data_dir = "./data/glue_data/MRPC/"
teacher_model_dir = "./data/teacher_model/mrpc"
num_samples = 3668
max_layer = 4
num_labels = 2
processor_func = MrpcProcessor
elif task_name == 'mnli':
data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/steps_23000"
num_samples = 392702
max_layer = 8
num_labels = 3
processor_func = MnliProcessor
device_num = fluid.dygraph.parallel.Env().nranks
use_fixed_gumbel = True
train_phase = "train"
val_phase = "dev"
step_per_epoch = int(num_samples / (batch_size * device_num))
with fluid.dygraph.guard(place):
if use_fixed_gumbel:
# make sure gumbel arch is constant
np.random.seed(1)
fluid.default_main_program().random_seed = 1
model = AdaBERTClassifier(
num_labels,
n_layer=max_layer,
hidden_size=hidden_size,
task_name=task_name,
emb_size=emb_size,
teacher_model=teacher_model_dir,
data_dir=data_dir,
use_fixed_gumbel=use_fixed_gumbel)
learning_rate = fluid.dygraph.CosineDecay(2e-2, step_per_epoch, epoch)
model_parameters = []
for p in model.parameters():
if (p.name not in [a.name for a in model.arch_parameters()] and
p.name not in
[a.name for a in model.teacher.parameters()]):
model_parameters.append(p)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
0.9,
regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters)
processor = processor_func(
data_dir=data_dir,
vocab_path=vocab_path,
max_seq_len=max_seq_len,
do_lower_case=do_lower_case,
in_tokens=False)
train_reader = processor.data_generator(
batch_size=batch_size,
phase=train_phase,
epoch=1,
dev_count=1,
shuffle=True)
dev_reader = processor.data_generator(
batch_size=batch_size,
phase=val_phase,
epoch=1,
dev_count=1,
shuffle=False)
if use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
train_loader = fluid.io.DataLoader.from_generator(
capacity=128,
use_double_buffer=True,
iterable=True,
return_list=True)
dev_loader = fluid.io.DataLoader.from_generator(
capacity=128,
use_double_buffer=True,
iterable=True,
return_list=True)
train_loader.set_batch_generator(train_reader, places=place)
dev_loader.set_batch_generator(dev_reader, places=place)
if use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
best_valid_acc = 0
for epoch_id in range(epoch):
train_one_epoch(model, train_loader, optimizer, epoch_id,
use_data_parallel, log_freq)
loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq)
if acc > best_valid_acc:
best_valid_acc = acc
logger.info(
"dev set, ce_loss {:.6f}; acc {:.6f}, best_acc {:.6f};".format(
loss, acc, best_valid_acc))
if __name__ == '__main__':
main()
import numpy as np
from itertools import izip
import paddle.fluid as fluid
from paddleslim.teachers.bert.reader.cls import *
from paddleslim.nas.darts.search_space import AdaBERTClassifier
from paddle.fluid.dygraph.base import to_variable
from tqdm import tqdm
import os
import pickle
import logging
from paddleslim.common import AvgrageMeter, get_logger
logger = get_logger(__name__, level=logging.INFO)
def valid_one_epoch(model, valid_loader, epoch, log_freq):
accs = AvgrageMeter()
ce_losses = AvgrageMeter()
model.student.eval()
step_id = 0
for valid_data in valid_loader():
try:
loss, acc, ce_loss, _, _ = model._layers.loss(valid_data, epoch)
except:
loss, acc, ce_loss, _, _ = model.loss(valid_data, epoch)
batch_size = valid_data[0].shape[0]
ce_losses.update(ce_loss.numpy(), batch_size)
accs.update(acc.numpy(), batch_size)
step_id += 1
return ce_losses.avg[0], accs.avg[0]
def train_one_epoch(model, train_loader, valid_loader, optimizer,
arch_optimizer, epoch, use_data_parallel, log_freq):
total_losses = AvgrageMeter()
accs = AvgrageMeter()
ce_losses = AvgrageMeter()
kd_losses = AvgrageMeter()
val_accs = AvgrageMeter()
model.student.train()
step_id = 0
for train_data, valid_data in izip(train_loader(), valid_loader()):
batch_size = train_data[0].shape[0]
# make sure arch on every gpu is same, otherwise an error will occurs
np.random.seed(step_id * 2 * (epoch + 1))
if use_data_parallel:
total_loss, acc, ce_loss, kd_loss, _ = model._layers.loss(
train_data, epoch)
else:
total_loss, acc, ce_loss, kd_loss, _ = model.loss(train_data,
epoch)
if use_data_parallel:
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
optimizer.minimize(total_loss)
model.clear_gradients()
total_losses.update(total_loss.numpy(), batch_size)
accs.update(acc.numpy(), batch_size)
ce_losses.update(ce_loss.numpy(), batch_size)
kd_losses.update(kd_loss.numpy(), batch_size)
# make sure arch on every gpu is same, otherwise an error will occurs
np.random.seed(step_id * 2 * (epoch + 1) + 1)
if use_data_parallel:
arch_loss, _, _, _, arch_logits = model._layers.loss(valid_data,
epoch)
else:
arch_loss, _, _, _, arch_logits = model.loss(valid_data, epoch)
if use_data_parallel:
arch_loss = model.scale_loss(arch_loss)
arch_loss.backward()
model.apply_collective_grads()
else:
arch_loss.backward()
arch_optimizer.minimize(arch_loss)
model.clear_gradients()
probs = fluid.layers.softmax(arch_logits[-1])
val_acc = fluid.layers.accuracy(input=probs, label=valid_data[4])
val_accs.update(val_acc.numpy(), batch_size)
if step_id % log_freq == 0:
logger.info(
"Train Epoch {}, Step {}, Lr {:.6f} total_loss {:.6f}; ce_loss {:.6f}, kd_loss {:.6f}, train_acc {:.6f}, search_valid_acc {:.6f};".
format(epoch, step_id,
optimizer.current_step_lr(), total_losses.avg[
0], ce_losses.avg[0], kd_losses.avg[0], accs.avg[0],
val_accs.avg[0]))
step_id += 1
def main():
# whether use multi-gpus
use_data_parallel = False
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env(
).dev_id) if use_data_parallel else fluid.CUDAPlace(0)
BERT_BASE_PATH = "./data/pretrained_models/uncased_L-12_H-768_A-12"
vocab_path = BERT_BASE_PATH + "/vocab.txt"
data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/steps_23000"
do_lower_case = True
num_samples = 392702
# augmented dataset nums
# num_samples = 8016987
max_seq_len = 128
batch_size = 128
hidden_size = 768
emb_size = 768
max_layer = 8
epoch = 80
log_freq = 10
device_num = fluid.dygraph.parallel.Env().nranks
use_fixed_gumbel = False
train_phase = "search_train"
val_phase = "search_valid"
step_per_epoch = int(num_samples * 0.5 / ((batch_size) * device_num))
with fluid.dygraph.guard(place):
model = AdaBERTClassifier(
3,
n_layer=max_layer,
hidden_size=hidden_size,
emb_size=emb_size,
teacher_model=teacher_model_dir,
data_dir=data_dir,
use_fixed_gumbel=use_fixed_gumbel)
learning_rate = fluid.dygraph.CosineDecay(2e-2, step_per_epoch, epoch)
model_parameters = []
for p in model.parameters():
if (p.name not in [a.name for a in model.arch_parameters()] and
p.name not in
[a.name for a in model.teacher.parameters()]):
model_parameters.append(p)
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate,
0.9,
regularization=fluid.regularizer.L2DecayRegularizer(3e-4),
parameter_list=model_parameters)
arch_optimizer = fluid.optimizer.Adam(
3e-4,
0.5,
0.999,
regularization=fluid.regularizer.L2Decay(1e-3),
parameter_list=model.arch_parameters())
processor = MnliProcessor(
data_dir=data_dir,
vocab_path=vocab_path,
max_seq_len=max_seq_len,
do_lower_case=do_lower_case,
in_tokens=False)
train_reader = processor.data_generator(
batch_size=batch_size,
phase=train_phase,
epoch=1,
dev_count=1,
shuffle=True)
valid_reader = processor.data_generator(
batch_size=batch_size,
phase=val_phase,
epoch=1,
dev_count=1,
shuffle=True)
dev_reader = processor.data_generator(
batch_size=batch_size,
phase="dev",
epoch=1,
dev_count=1,
shuffle=False)
if use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
valid_reader = fluid.contrib.reader.distributed_batch_reader(
valid_reader)
train_loader = fluid.io.DataLoader.from_generator(
capacity=128,
use_double_buffer=True,
iterable=True,
return_list=True)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=128,
use_double_buffer=True,
iterable=True,
return_list=True)
dev_loader = fluid.io.DataLoader.from_generator(
capacity=128,
use_double_buffer=True,
iterable=True,
return_list=True)
train_loader.set_batch_generator(train_reader, places=place)
valid_loader.set_batch_generator(valid_reader, places=place)
dev_loader.set_batch_generator(dev_reader, places=place)
if use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
for epoch_id in range(epoch):
train_one_epoch(model, train_loader, valid_loader, optimizer,
arch_optimizer, epoch_id, use_data_parallel,
log_freq)
loss, acc = valid_one_epoch(model, dev_loader, epoch_id, log_freq)
logger.info("dev set, ce_loss {:.6f}; acc: {:.6f};".format(loss,
acc))
if use_data_parallel:
print(model._layers.student._encoder.alphas.numpy())
else:
print(model.student._encoder.alphas.numpy())
print("=" * 100)
if __name__ == '__main__':
main()
import paddle.fluid as fluid
from paddleslim.teachers.bert import BERTClassifier
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place):
bert = BERTClassifier(3)
bert.fit("./data/glue_data/MNLI/",
5,
batch_size=32,
use_data_parallel=True,
learning_rate=0.00005,
save_steps=1000)
......@@ -18,8 +18,7 @@ from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer
from paddle.nn.initializer import Constant, KaimingUniform
from paddle.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
......@@ -28,7 +27,7 @@ from genotypes import Genotype
from operations import *
class ConvBN(fluid.dygraph.Layer):
class ConvBN(paddle.nn.Layer):
def __init__(self, c_curr, c_out, kernel_size, padding, stride, name=None):
super(ConvBN, self).__init__()
self.conv = Conv2D(
......@@ -37,18 +36,18 @@ class ConvBN(fluid.dygraph.Layer):
filter_size=kernel_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(
param_attr=paddle.ParamAttr(
name=name + "_conv" if name is not None else None,
initializer=MSRAInitializer()),
initializer=KaimingUniform()),
bias_attr=False)
self.bn = BatchNorm(
num_channels=c_out,
param_attr=fluid.ParamAttr(
param_attr=paddle.ParamAttr(
name=name + "_bn_scale" if name is not None else None,
initializer=ConstantInitializer(value=1)),
bias_attr=fluid.ParamAttr(
initializer=Constant(value=1)),
bias_attr=paddle.ParamAttr(
name=name + "_bn_offset" if name is not None else None,
initializer=ConstantInitializer(value=0)),
initializer=Constant(value=0)),
moving_mean_name=name + "_bn_mean" if name is not None else None,
moving_variance_name=name + "_bn_variance"
if name is not None else None)
......@@ -59,19 +58,19 @@ class ConvBN(fluid.dygraph.Layer):
return bn
class Classifier(fluid.dygraph.Layer):
class Classifier(paddle.nn.Layer):
def __init__(self, input_dim, num_classes, name=None):
super(Classifier, self).__init__()
self.pool2d = Pool2D(pool_type='avg', global_pooling=True)
self.fc = Linear(
input_dim=input_dim,
output_dim=num_classes,
param_attr=fluid.ParamAttr(
param_attr=paddle.ParamAttr(
name=name + "_fc_weights" if name is not None else None,
initializer=MSRAInitializer()),
bias_attr=fluid.ParamAttr(
initializer=KaimingUniform()),
bias_attr=paddle.ParamAttr(
name=name + "_fc_bias" if name is not None else None,
initializer=MSRAInitializer()))
initializer=KaimingUniform()))
def forward(self, x):
x = self.pool2d(x)
......@@ -90,7 +89,7 @@ def drop_path(x, drop_prob):
return x
class Cell(fluid.dygraph.Layer):
class Cell(paddle.nn.Layer):
def __init__(self, genotype, c_prev_prev, c_prev, c_curr, reduction,
reduction_prev):
super(Cell, self).__init__()
......@@ -144,11 +143,11 @@ class Cell(fluid.dygraph.Layer):
if not isinstance(op2, Identity):
h2 = drop_path(h2, drop_prob)
states += [h1 + h2]
out = fluid.layers.concat(input=states[-self._multiplier:], axis=1)
out = paddle.concat(states[-self._multiplier:], axis=1)
return out
class AuxiliaryHeadCIFAR(fluid.dygraph.Layer):
class AuxiliaryHeadCIFAR(paddle.nn.Layer):
def __init__(self, C, num_classes):
super(AuxiliaryHeadCIFAR, self).__init__()
self.avgpool = Pool2D(
......@@ -170,17 +169,17 @@ class AuxiliaryHeadCIFAR(fluid.dygraph.Layer):
self.classifier = Classifier(768, num_classes, 'aux')
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.avgpool(x)
conv1 = self.conv_bn1(x)
conv1 = fluid.layers.relu(conv1)
conv1 = paddle.nn.functional.relu(conv1)
conv2 = self.conv_bn2(conv1)
conv2 = fluid.layers.relu(conv2)
conv2 = paddle.nn.functional.relu(conv2)
out = self.classifier(conv2)
return out
class NetworkCIFAR(fluid.dygraph.Layer):
class NetworkCIFAR(paddle.nn.Layer):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkCIFAR, self).__init__()
self._layers = layers
......@@ -226,7 +225,7 @@ class NetworkCIFAR(fluid.dygraph.Layer):
return logits, logits_aux
class AuxiliaryHeadImageNet(fluid.dygraph.Layer):
class AuxiliaryHeadImageNet(paddle.nn.Layer):
def __init__(self, C, num_classes):
super(AuxiliaryHeadImageNet, self).__init__()
self.avgpool = Pool2D(
......@@ -248,17 +247,17 @@ class AuxiliaryHeadImageNet(fluid.dygraph.Layer):
self.classifier = Classifier(768, num_classes, 'aux')
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.avgpool(x)
conv1 = self.conv_bn1(x)
conv1 = fluid.layers.relu(conv1)
conv1 = paddle.nn.functional.relu(conv1)
conv2 = self.conv_bn2(conv1)
conv2 = fluid.layers.relu(conv2)
conv2 = paddle.nn.functional.relu(conv2)
out = self.classifier(conv2)
return out
class NetworkImageNet(fluid.dygraph.Layer):
class NetworkImageNet(paddle.nn.Layer):
def __init__(self, C, num_classes, layers, auxiliary, genotype):
super(NetworkImageNet, self).__init__()
self._layers = layers
......@@ -299,9 +298,9 @@ class NetworkImageNet(fluid.dygraph.Layer):
def forward(self, input, training):
logits_aux = None
s0 = self.stem_a0(input)
s0 = fluid.layers.relu(s0)
s0 = paddle.nn.functional.relu(s0)
s0 = self.stem_a1(s0)
s1 = fluid.layers.relu(s0)
s1 = paddle.nn.functional.relu(s0)
s1 = self.stem_b(s1)
for i, cell in enumerate(self.cells):
......
......@@ -17,8 +17,7 @@ from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import NormalInitializer, MSRAInitializer, ConstantInitializer
from paddle.nn.initializer import Normal, KaimingUniform, Constant
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from genotypes import PRIMITIVES
......@@ -30,20 +29,20 @@ def channel_shuffle(x, groups):
channels_per_group = num_channels // groups
# reshape
x = fluid.layers.reshape(
x, [batchsize, groups, channels_per_group, height, width])
x = paddle.reshape(x,
[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.transpose(x, [0, 2, 1, 3, 4])
# flatten
x = fluid.layers.reshape(x, [batchsize, num_channels, height, width])
x = paddle.reshape(x, [batchsize, num_channels, height, width])
return x
class MixedOp(fluid.dygraph.Layer):
class MixedOp(paddle.nn.Layer):
def __init__(self, c_cur, stride, method):
super(MixedOp, self).__init__()
self._method = method
self._k = 4 if self._method == "PC-DARTS" else 1
self._k = 1
self.mp = Pool2D(
pool_size=2,
pool_stride=2,
......@@ -52,11 +51,11 @@ class MixedOp(fluid.dygraph.Layer):
for primitive in PRIMITIVES:
op = OPS[primitive](c_cur // self._k, stride, False)
if 'pool' in primitive:
gama = ParamAttr(
initializer=fluid.initializer.Constant(value=1),
gama = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1),
trainable=False)
beta = ParamAttr(
initializer=fluid.initializer.Constant(value=0),
beta = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0),
trainable=False)
BN = BatchNorm(
c_cur // self._k, param_attr=gama, bias_attr=beta)
......@@ -65,28 +64,13 @@ class MixedOp(fluid.dygraph.Layer):
self._ops = fluid.dygraph.LayerList(ops)
def forward(self, x, weights):
if self._method == "PC-DARTS":
dim_2 = x.shape[1]
xtemp = x[:, :dim_2 // self._k, :, :]
xtemp2 = x[:, dim_2 // self._k:, :, :]
return fluid.layers.sums(
[weights[i] * op(x) for i, op in enumerate(self._ops)])
temp1 = fluid.layers.sums(
[weights[i] * op(xtemp) for i, op in enumerate(self._ops)])
if temp1.shape[2] == x.shape[2]:
out = fluid.layers.concat([temp1, xtemp2], axis=1)
else:
out = fluid.layers.concat([temp1, self.mp(xtemp2)], axis=1)
out = channel_shuffle(out, self._k)
else:
out = fluid.layers.sums(
[weights[i] * op(x) for i, op in enumerate(self._ops)])
return out
class Cell(fluid.dygraph.Layer):
def __init__(self, steps, multiplier, c_prev_prev, c_prev, c_cur,
reduction, reduction_prev, method):
class Cell(paddle.nn.Layer):
def __init__(self, steps, multiplier, c_prev_prev, c_prev, c_cur, reduction,
reduction_prev, method):
super(Cell, self).__init__()
self.reduction = reduction
......@@ -114,24 +98,17 @@ class Cell(fluid.dygraph.Layer):
states = [s0, s1]
offset = 0
for i in range(self._steps):
if self._method == "PC-DARTS":
s = fluid.layers.sums([
weights2[offset + j] *
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
else:
s = fluid.layers.sums([
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
s = fluid.layers.sums([
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
])
offset += len(states)
states.append(s)
out = fluid.layers.concat(input=states[-self._multiplier:], axis=1)
out = paddle.concat(states[-self._multiplier:], axis=1)
return out
class Network(fluid.dygraph.Layer):
class Network(paddle.nn.Layer):
def __init__(self,
c_in,
num_classes,
......@@ -156,14 +133,12 @@ class Network(fluid.dygraph.Layer):
num_filters=c_cur,
filter_size=3,
padding=1,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False),
BatchNorm(
num_channels=c_cur,
param_attr=fluid.ParamAttr(
initializer=ConstantInitializer(value=1)),
bias_attr=fluid.ParamAttr(
initializer=ConstantInitializer(value=0))))
param_attr=paddle.ParamAttr(initializer=Constant(value=1)),
bias_attr=paddle.ParamAttr(initializer=Constant(value=0))))
c_prev_prev, c_prev, c_cur = c_cur, c_cur, c_in
cells = []
......@@ -184,8 +159,8 @@ class Network(fluid.dygraph.Layer):
self.classifier = Linear(
input_dim=c_prev,
output_dim=num_classes,
param_attr=ParamAttr(initializer=MSRAInitializer()),
bias_attr=ParamAttr(initializer=MSRAInitializer()))
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=paddle.ParamAttr(initializer=KaimingUniform()))
self._initialize_alphas()
......@@ -194,31 +169,9 @@ class Network(fluid.dygraph.Layer):
weights2 = None
for i, cell in enumerate(self.cells):
if cell.reduction:
weights = fluid.layers.softmax(self.alphas_reduce)
if self._method == "PC-DARTS":
n = 3
start = 2
weights2 = fluid.layers.softmax(self.betas_reduce[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_reduce[start:
end])
start = end
n += 1
weights2 = fluid.layers.concat([weights2, tw2])
weights = paddle.nn.functional.softmax(self.alphas_reduce)
else:
weights = fluid.layers.softmax(self.alphas_normal)
if self._method == "PC-DARTS":
n = 3
start = 2
weights2 = fluid.layers.softmax(self.betas_normal[0:2])
for i in range(self._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(self.betas_normal[start:
end])
start = end
n += 1
weights2 = fluid.layers.concat([weights2, tw2])
weights = paddle.nn.functional.softmax(self.alphas_normal)
s0, s1 = s1, cell(s0, s1, weights, weights2)
out = self.global_pooling(s1)
out = fluid.layers.squeeze(out, axes=[2, 3])
......@@ -228,7 +181,7 @@ class Network(fluid.dygraph.Layer):
def _loss(self, input, target):
logits = self(input)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, target))
paddle.nn.functional.softmax_with_cross_entropy(logits, target))
return loss
def new(self):
......@@ -239,32 +192,20 @@ class Network(fluid.dygraph.Layer):
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2 + i))
num_ops = len(self._primitives)
self.alphas_normal = fluid.layers.create_parameter(
self.alphas_normal = paddle.static.create_parameter(
shape=[k, num_ops],
dtype="float32",
default_initializer=NormalInitializer(
default_initializer=Normal(
loc=0.0, scale=1e-3))
self.alphas_reduce = fluid.layers.create_parameter(
self.alphas_reduce = paddle.static.create_parameter(
shape=[k, num_ops],
dtype="float32",
default_initializer=NormalInitializer(
default_initializer=Normal(
loc=0.0, scale=1e-3))
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
if self._method == "PC-DARTS":
self.betas_normal = fluid.layers.create_parameter(
shape=[k],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self.betas_reduce = fluid.layers.create_parameter(
shape=[k],
dtype="float32",
default_initializer=NormalInitializer(
loc=0.0, scale=1e-3))
self._arch_parameters += [self.betas_normal, self.betas_reduce]
def arch_parameters(self):
return self._arch_parameters
......@@ -15,8 +15,7 @@
import paddle.fluid as fluid
from paddle.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D, BatchNorm
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import ConstantInitializer, MSRAInitializer
from paddle.nn.initializer import Constant, KaimingUniform
OPS = {
......@@ -59,12 +58,12 @@ OPS = {
def bn_param_config(affine=False):
gama = ParamAttr(initializer=ConstantInitializer(value=1), trainable=affine)
beta = ParamAttr(initializer=ConstantInitializer(value=0), trainable=affine)
gama = paddle.ParamAttr(initializer=Constant(value=1), trainable=affine)
beta = paddle.ParamAttr(initializer=Constant(value=0), trainable=affine)
return gama, beta
class Zero(fluid.dygraph.Layer):
class Zero(paddle.nn.Layer):
def __init__(self, stride):
super(Zero, self).__init__()
self.stride = stride
......@@ -77,7 +76,7 @@ class Zero(fluid.dygraph.Layer):
return x
class Identity(fluid.dygraph.Layer):
class Identity(paddle.nn.Layer):
def __init__(self):
super(Identity, self).__init__()
......@@ -85,7 +84,7 @@ class Identity(fluid.dygraph.Layer):
return x
class FactorizedReduce(fluid.dygraph.Layer):
class FactorizedReduce(paddle.nn.Layer):
def __init__(self, c_in, c_out, affine=True):
super(FactorizedReduce, self).__init__()
assert c_out % 2 == 0
......@@ -95,7 +94,7 @@ class FactorizedReduce(fluid.dygraph.Layer):
filter_size=1,
stride=2,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
......@@ -103,20 +102,20 @@ class FactorizedReduce(fluid.dygraph.Layer):
filter_size=1,
stride=2,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn = BatchNorm(num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
out = fluid.layers.concat(
x = paddle.nn.functional.relu(x)
out = paddle.concat(
input=[self.conv1(x), self.conv2(x[:, :, 1:, 1:])], axis=1)
out = self.bn(out)
return out
class SepConv(fluid.dygraph.Layer):
class SepConv(paddle.nn.Layer):
def __init__(self, c_in, c_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.conv1 = Conv2D(
......@@ -127,7 +126,7 @@ class SepConv(fluid.dygraph.Layer):
padding=padding,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
......@@ -135,7 +134,7 @@ class SepConv(fluid.dygraph.Layer):
filter_size=1,
stride=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(num_channels=c_in, param_attr=gama, bias_attr=beta)
......@@ -147,7 +146,7 @@ class SepConv(fluid.dygraph.Layer):
padding=padding,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
self.conv4 = Conv2D(
num_channels=c_in,
......@@ -155,25 +154,25 @@ class SepConv(fluid.dygraph.Layer):
filter_size=1,
stride=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn2 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.conv1(x)
x = self.conv2(x)
bn1 = self.bn1(x)
x = fluid.layers.relu(bn1)
x = paddle.nn.functional.relu(bn1)
x = self.conv3(x)
x = self.conv4(x)
bn2 = self.bn2(x)
return bn2
class DilConv(fluid.dygraph.Layer):
class DilConv(paddle.nn.Layer):
def __init__(self,
c_in,
c_out,
......@@ -192,28 +191,28 @@ class DilConv(fluid.dygraph.Layer):
dilation=dilation,
groups=c_in,
use_cudnn=False,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=1,
padding=0,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.conv1(x)
x = self.conv2(x)
out = self.bn1(x)
return out
class Conv_7x1_1x7(fluid.dygraph.Layer):
class Conv_7x1_1x7(paddle.nn.Layer):
def __init__(self, c_in, c_out, stride, affine=True):
super(Conv_7x1_1x7, self).__init__()
self.conv1 = Conv2D(
......@@ -221,28 +220,28 @@ class Conv_7x1_1x7(fluid.dygraph.Layer):
num_filters=c_out,
filter_size=(1, 7),
padding=(0, 3),
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
self.conv2 = Conv2D(
num_channels=c_in,
num_filters=c_out,
filter_size=(7, 1),
padding=(3, 0),
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn1 = BatchNorm(
num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.conv1(x)
x = self.conv2(x)
out = self.bn1(x)
return out
class ReLUConvBN(fluid.dygraph.Layer):
class ReLUConvBN(paddle.nn.Layer):
def __init__(self, c_in, c_out, kernel_size, stride, padding, affine=True):
super(ReLUConvBN, self).__init__()
self.conv = Conv2D(
......@@ -251,13 +250,13 @@ class ReLUConvBN(fluid.dygraph.Layer):
filter_size=kernel_size,
stride=stride,
padding=padding,
param_attr=fluid.ParamAttr(initializer=MSRAInitializer()),
param_attr=paddle.ParamAttr(initializer=KaimingUniform()),
bias_attr=False)
gama, beta = bn_param_config(affine)
self.bn = BatchNorm(num_channels=c_out, param_attr=gama, bias_attr=beta)
def forward(self, x):
x = fluid.layers.relu(x)
x = paddle.nn.functional.relu(x)
x = self.conv(x)
out = self.bn(x)
return out
......@@ -59,11 +59,11 @@ add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whet
def main(args):
if not args.use_gpu:
place = fluid.CPUPlace()
place = paddle.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
place = paddle.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
place = paddle.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
train_reader, valid_reader = reader.train_search(
batch_size=args.batch_size,
......
......@@ -77,13 +77,14 @@ def train(model, train_reader, optimizer, epoch, drop_path_prob, args):
label.stop_gradient = True
logits, logits_aux = model(image, drop_path_prob, True)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
prec1 = paddle.static.accuracy(input=logits, label=label, k=1)
prec5 = paddle.static.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, label))
paddle.nn.functional.softmax_with_cross_entropy(logits, label))
if args.auxiliary:
loss_aux = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits_aux, label))
paddle.nn.functional.softmax_with_cross_entropy(logits_aux,
label))
loss = loss + args.auxiliary_weight * loss_aux
if args.use_data_parallel:
......@@ -119,10 +120,10 @@ def valid(model, valid_reader, epoch, args):
image = to_variable(image_np)
label = to_variable(label_np)
logits, _ = model(image, 0, False)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
prec1 = paddle.static.accuracy(input=logits, label=label, k=1)
prec5 = paddle.static.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, label))
paddle.nn.functional.softmax_with_cross_entropy(logits, label))
n = image.shape[0]
objs.update(loss.numpy(), n)
......@@ -136,8 +137,8 @@ def valid(model, valid_reader, epoch, args):
def main(args):
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
place = paddle.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else paddle.CUDAPlace(0)
with fluid.dygraph.guard(place):
genotype = eval("genotypes.%s" % args.arch)
......@@ -156,7 +157,7 @@ def main(args):
learning_rate = fluid.dygraph.CosineDecay(args.learning_rate,
step_per_epoch, args.epochs)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
optimizer = fluid.optimizer.MomentumOptimizer(
optimizer = paddle.optimizer.Momentum(
learning_rate,
momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay),
......@@ -212,8 +213,8 @@ def main(args):
if valid_top1 > best_acc:
best_acc = valid_top1
if save_parameters:
fluid.save_dygraph(model.state_dict(),
args.model_save_dir + "/best_model")
paddle.save(model.state_dict(),
args.model_save_dir + "/best_model")
logger.info("Epoch {}, valid_acc {:.6f}, best_valid_acc {:.6f}".
format(epoch, valid_top1, best_acc))
......
......@@ -67,7 +67,7 @@ add_arg('use_data_parallel', ast.literal_eval, False, "The flag indicating whet
def cross_entropy_label_smooth(preds, targets, epsilon):
preds = fluid.layers.softmax(preds)
preds = paddle.nn.functional.softmax(preds)
targets_one_hot = fluid.one_hot(input=targets, depth=args.class_num)
targets_smooth = fluid.layers.label_smooth(
targets_one_hot, epsilon=epsilon, dtype="float32")
......@@ -89,8 +89,8 @@ def train(model, train_reader, optimizer, epoch, args):
label.stop_gradient = True
logits, logits_aux = model(image, True)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
prec1 = paddle.static.accuracy(input=logits, label=label, k=1)
prec5 = paddle.static.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
cross_entropy_label_smooth(logits, label, args.label_smooth))
......@@ -133,8 +133,8 @@ def valid(model, valid_reader, epoch, args):
image = to_variable(image_np)
label = to_variable(label_np)
logits, _ = model(image, False)
prec1 = fluid.layers.accuracy(input=logits, label=label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=label, k=5)
prec1 = paddle.static.accuracy(input=logits, label=label, k=1)
prec5 = paddle.static.accuracy(input=logits, label=label, k=5)
loss = fluid.layers.reduce_mean(
cross_entropy_label_smooth(logits, label, args.label_smooth))
......@@ -150,8 +150,8 @@ def valid(model, valid_reader, epoch, args):
def main(args):
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
place = paddle.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else paddle.CUDAPlace(0)
with fluid.dygraph.guard(place):
genotype = eval("genotypes.%s" % args.arch)
......@@ -166,16 +166,12 @@ def main(args):
count_parameters_in_MB(model.parameters())))
device_num = fluid.dygraph.parallel.Env().nranks
step_per_epoch = int(args.trainset_num /
(args.batch_size * device_num))
step_per_epoch = int(args.trainset_num / (args.batch_size * device_num))
learning_rate = fluid.dygraph.ExponentialDecay(
args.learning_rate,
step_per_epoch,
args.decay_rate,
staircase=True)
args.learning_rate, step_per_epoch, args.decay_rate, staircase=True)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.grad_clip)
optimizer = fluid.optimizer.MomentumOptimizer(
optimizer = paddle.optimizer.Momentum(
learning_rate,
momentum=args.momentum,
regularization=fluid.regularizer.L2Decay(args.weight_decay),
......@@ -216,18 +212,17 @@ def main(args):
fluid.dygraph.parallel.Env().local_rank == 0)
best_top1 = 0
for epoch in range(args.epochs):
logger.info('Epoch {}, lr {:.6f}'.format(
epoch, optimizer.current_step_lr()))
logger.info('Epoch {}, lr {:.6f}'.format(epoch, optimizer.get_lr()))
train_top1, train_top5 = train(model, train_loader, optimizer,
epoch, args)
logger.info("Epoch {}, train_top1 {:.6f}, train_top5 {:.6f}".
format(epoch, train_top1, train_top5))
logger.info("Epoch {}, train_top1 {:.6f}, train_top5 {:.6f}".format(
epoch, train_top1, train_top5))
valid_top1, valid_top5 = valid(model, valid_loader, epoch, args)
if valid_top1 > best_top1:
best_top1 = valid_top1
if save_parameters:
fluid.save_dygraph(model.state_dict(),
args.model_save_dir + "/best_model")
paddle.save(model.state_dict(),
args.model_save_dir + "/best_model")
logger.info(
"Epoch {}, valid_top1 {:.6f}, valid_top5 {:.6f}, best_valid_top1 {:6f}".
format(epoch, valid_top1, valid_top5, best_top1))
......
......@@ -24,8 +24,7 @@ import logging
import paddle
from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.dist import DML
from paddleslim.models.dygraph import MobileNetV1
from paddleslim.models.dygraph import ResNet
from paddle.vision.models import MobileNetV1, resnet34
import cifar100_reader as reader
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments
......@@ -152,13 +151,13 @@ def main(args):
# 2. Define neural network
if args.models == "mobilenet-mobilenet":
models = [
MobileNetV1(class_dim=args.class_num),
MobileNetV1(class_dim=args.class_num)
MobileNetV1(num_classes=args.class_num),
MobileNetV1(num_classes=args.class_num)
]
elif args.models == "mobilenet-resnet50":
models = [
MobileNetV1(class_dim=args.class_num),
ResNet(class_dim=args.class_num)
MobileNetV1(num_classes=args.class_num),
resnet34(num_classes=args.class_num)
]
else:
logger.info("You can define the model as you wish")
......
......@@ -212,9 +212,10 @@ def compress(args):
format(epoch_id, step_id, val_loss[0], val_acc1[0],
val_acc5[0]))
if args.save_inference:
paddle.fluid.io.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), ["image"],
[out], exe, student_program)
paddle.static.save_inference_model(
os.path.join("./saved_models", str(epoch_id)), [image], [out],
exe,
program=student_program)
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))
......
......@@ -148,16 +148,16 @@ class SampleTester(unittest.TestCase):
batch_size=1,
batch_num=1,
skip_batch_num=0):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.Executor.global_scope()
with paddle.static.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
[inference_program, feed_target_names, fetch_targets
] = paddle.static.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
fetch_targets] = paddle.static.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
......
......@@ -18,7 +18,6 @@ from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
......@@ -148,14 +147,14 @@ class MobileNetV3(nn.Layer):
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name="last_1x1_conv_weights"),
weight_attr=paddle.ParamAttr(name="last_1x1_conv_weights"),
bias_attr=False)
self.out = Linear(
self.cls_ch_expand,
class_dim,
weight_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
weight_attr=paddle.ParamAttr("fc_weights"),
bias_attr=paddle.ParamAttr(name="fc_offset"))
def forward(self, inputs, label=None):
x = self.conv1(inputs)
......@@ -197,14 +196,14 @@ class ConvBNLayer(nn.Layer):
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(name=name + "_weights"),
weight_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=False)
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(
param_attr=paddle.ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(
bias_attr=paddle.ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(0.0)),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
......@@ -291,16 +290,16 @@ class SEModule(nn.Layer):
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
weight_attr=paddle.ParamAttr(name=name + "_1_weights"),
bias_attr=paddle.ParamAttr(name=name + "_1_offset"))
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
weight_attr=paddle.ParamAttr(name + "_2_weights"),
bias_attr=paddle.ParamAttr(name=name + "_2_offset"))
if skip_se_quant:
self.conv1.skip_quant = True
self.conv2.skip_quant = True
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.nn.initializer import KaimingUniform
__all__ = ['MobileNet']
......@@ -128,12 +128,12 @@ class MobileNet():
pool_type='avg',
global_pooling=True)
with fluid.name_scope('last_fc'):
output = fluid.layers.fc(input=input,
size=class_dim,
param_attr=ParamAttr(
initializer=MSRA(),
name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
output = paddle.static.nn.fc(
input,
class_dim,
weight_attr=paddle.ParamAttr(
initializer=KaimingUniform(), name="fc7_weights"),
bias_attr=paddle.ParamAttr(name="fc7_offset"))
return output
......@@ -148,7 +148,7 @@ class MobileNet():
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -157,15 +157,15 @@ class MobileNet():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
param_attr=paddle.ParamAttr(
initializer=KaimingUniform(), name=name + "_weights"),
bias_attr=False)
bn_name = name + "_bn"
return fluid.layers.batch_norm(
return paddle.static.nn.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
param_attr=paddle.ParamAttr(name=bn_name + "_scale"),
bias_attr=paddle.ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......
......@@ -15,9 +15,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.nn.initializer import KaimingUniform
__all__ = [
'MobileNetV2', 'MobileNetV2_x0_25, '
......@@ -108,10 +108,11 @@ class MobileNetV2():
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
output = paddle.static.nn.fc(
input,
class_dim,
weight_attr=paddle.ParamAttr(name='fc10_weights'),
bias_attr=paddle.ParamAttr(name='fc10_offset'))
return output
def conv_bn_layer(self,
......@@ -125,7 +126,7 @@ class MobileNetV2():
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -134,17 +135,17 @@ class MobileNetV2():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
param_attr=paddle.ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
bn = paddle.static.nn.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
param_attr=paddle.ParamAttr(name=bn_name + "_scale"),
bias_attr=paddle.ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.relu6(bn)
return paddle.nn.functional.relu6(bn)
else:
return bn
......
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.nn.initializer import KaimingUniform
import math
__all__ = [
......@@ -105,20 +105,21 @@ class MobileNetV3():
name='conv_last')
conv = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True, use_cudnn=False)
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=conv,
num_filters=cls_ch_expand,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name='last_1x1_conv_weights'),
param_attr=paddle.ParamAttr(name='last_1x1_conv_weights'),
bias_attr=False)
conv = fluid.layers.hard_swish(conv)
out = fluid.layers.fc(input=conv,
size=class_dim,
param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
out = paddle.static.nn.fc(
conv,
class_dim,
weight_attr=paddle.ParamAttr(name='fc_weights'),
bias_attr=paddle.ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
......@@ -132,7 +133,7 @@ class MobileNetV3():
act=None,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -141,16 +142,16 @@ class MobileNetV3():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
param_attr=paddle.ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
bn = paddle.static.nn.batch_norm(
input=conv,
param_attr=ParamAttr(
param_attr=paddle.ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
bias_attr=paddle.ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
......@@ -158,32 +159,32 @@ class MobileNetV3():
moving_variance_name=bn_name + '_variance')
if if_act:
if act == 'relu':
bn = fluid.layers.relu(bn)
bn = paddle.nn.functional.relu(bn)
elif act == 'hard_swish':
bn = fluid.layers.hard_swish(bn)
return bn
def hard_swish(self, x):
return x * fluid.layers.relu6(x + 3) / 6.
return x * paddle.nn.functional.relu6(x + 3) / 6.
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
conv1 = paddle.static.nn.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act='relu',
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv2 = fluid.layers.conv2d(
param_attr=paddle.ParamAttr(name=name + '_1_weights'),
bias_attr=paddle.ParamAttr(name=name + '_1_offset'))
conv2 = paddle.static.nn.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
param_attr=paddle.ParamAttr(name=name + '_2_weights'),
bias_attr=paddle.ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
......
......@@ -3,9 +3,7 @@ from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.nn.initializer import KaimingUniform
import os, sys, time, math
import numpy as np
from collections import namedtuple
......@@ -84,11 +82,12 @@ class PVANet():
conv5 = self._bn(conv5, 'relu', 'conv5_4_last_bn')
end_points['conv5'] = conv5
output = fluid.layers.fc(input=input,
size=class_dim,
param_attr=ParamAttr(
initializer=MSRA(), name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
output = paddle.static.nn.fc(
input,
class_dim,
weight_attr=paddle.ParamAttr(
initializer=KaimingUniform(), name="fc_weights"),
bias_attr=paddle.ParamAttr(name="fc_offset"))
return output
......@@ -193,7 +192,7 @@ class PVANet():
path_net = self._conv_bn_relu(path_net, pool_path_outputs, 1,
name + '_poolproj')
paths.append(path_net)
block_net = fluid.layers.concat(paths, axis=1)
block_net = paddle.concat(paths, axis=1)
block_net = self._conv(block_net, inception_outputs, 1,
name + '_out_conv')
......@@ -210,23 +209,23 @@ class PVANet():
prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'gamma')
scale_param = fluid.layers.create_parameter(
param_attr = paddle.ParamAttr(name=prefix + 'gamma')
scale_param = paddle.static.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=param_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0))
default_initializer=paddle.nn.initializer.Constant(value=1.0))
offset_attr = fluid.ParamAttr(name=prefix + 'beta')
offset_param = fluid.layers.create_parameter(
offset_attr = paddle.ParamAttr(name=prefix + 'beta')
offset_param = paddle.static.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=offset_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0))
default_initializer=paddle.nn.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul(
input, scale_param, axis=axis, name=prefix + 'mul')
......@@ -242,7 +241,7 @@ class PVANet():
stride=1,
groups=1,
act=None):
net = fluid.layers.conv2d(
net = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -251,20 +250,20 @@ class PVANet():
groups=groups,
act=act,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
param_attr=paddle.ParamAttr(name=name + '_weights'),
bias_attr=paddle.ParamAttr(name=name + '_bias'),
name=name)
return net
def _bn(self, input, act, name):
net = fluid.layers.batch_norm(
net = paddle.static.nn.batch_norm(
input=input,
act=act,
name=name,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_variance',
param_attr=ParamAttr(name=name + '_scale'),
bias_attr=ParamAttr(name=name + '_offset'))
param_attr=paddle.ParamAttr(name=name + '_scale'),
bias_attr=paddle.ParamAttr(name=name + '_offset'))
return net
def _bn_relu_conv(self,
......@@ -295,9 +294,9 @@ class PVANet():
def _bn_crelu(self, input, name):
net = self._bn(input, None, name + '_bn_1')
neg_net = fluid.layers.scale(net, scale=-1.0, name=name + '_neg')
net = fluid.layers.concat([net, neg_net], axis=1)
net = paddle.concat([net, neg_net], axis=1)
net = self._scale(net, name + '_scale')
net = fluid.layers.relu(net, name=name + '_relu')
net = paddle.nn.functional.relu(net, name=name + '_relu')
return net
def _conv_bn_crelu(self,
......@@ -335,15 +334,15 @@ class PVANet():
act='relu',
name=None):
"""Deconv bn layer."""
deconv = fluid.layers.conv2d_transpose(
deconv = paddle.static.nn.conv2d_transpose(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=ParamAttr(name=name + '_bias'),
param_attr=paddle.ParamAttr(name=name + '_weights'),
bias_attr=paddle.ParamAttr(name=name + '_bias'),
name=name + 'deconv')
return self._bn(deconv, act, name + '_bn')
......@@ -388,31 +387,31 @@ def Detector_Header(f_common, net, class_num):
f_geo = net.conv_bn_layer(f_common, 64, 1, name='geo_1')
f_geo = net.conv_bn_layer(f_geo, 64, 3, name='geo_2')
f_geo = net.conv_bn_layer(f_geo, 64, 1, name='geo_3')
f_geo = fluid.layers.conv2d(
f_geo = paddle.static.nn.conv2d(
f_geo,
8,
1,
use_cudnn=True,
param_attr=ParamAttr(name='geo_4_conv_weights'),
bias_attr=ParamAttr(name='geo_4_conv_bias'),
param_attr=paddle.ParamAttr(name='geo_4_conv_weights'),
bias_attr=paddle.ParamAttr(name='geo_4_conv_bias'),
name='geo_4_conv')
name = 'score_class_num' + str(class_num + 1)
f_score = net.conv_bn_layer(f_common, 64, 1, 'score_1')
f_score = net.conv_bn_layer(f_score, 64, 3, 'score_2')
f_score = net.conv_bn_layer(f_score, 64, 1, 'score_3')
f_score = fluid.layers.conv2d(
f_score = paddle.static.nn.conv2d(
f_score,
class_num + 1,
1,
use_cudnn=True,
param_attr=ParamAttr(name=name + '_conv_weights'),
bias_attr=ParamAttr(name=name + '_conv_bias'),
param_attr=paddle.ParamAttr(name=name + '_conv_weights'),
bias_attr=paddle.ParamAttr(name=name + '_conv_bias'),
name=name + '_conv')
f_score = fluid.layers.transpose(f_score, perm=[0, 2, 3, 1])
f_score = fluid.layers.reshape(f_score, shape=[-1, class_num + 1])
f_score = fluid.layers.softmax(input=f_score)
f_score = paddle.reshape(f_score, shape=[-1, class_num + 1])
f_score = paddle.nn.functional.softmax(input=f_score)
return f_score, f_geo
......@@ -440,7 +439,7 @@ def east(input, class_num=31):
j // 8,
name='fusion_' + str(len(blocks)) + '_2')
blocks.append(conv)
conv = fluid.layers.concat(blocks, axis=1)
conv = paddle.concat(blocks, axis=1)
f_score, f_geo = Detector_Header(conv, net, class_num)
return f_score, f_geo
......@@ -488,7 +487,7 @@ def loss(f_score, f_geo, l_score, l_geo, l_mask, class_num=1):
l_score.stop_gradient = True
l_score = fluid.layers.transpose(l_score, perm=[0, 2, 3, 1])
l_score.stop_gradient = True
l_score = fluid.layers.reshape(l_score, shape=[-1, 1])
l_score = paddle.reshape(l_score, shape=[-1, 1])
l_score.stop_gradient = True
l_score = fluid.layers.cast(x=l_score, dtype="int64")
l_score.stop_gradient = True
......
......@@ -4,7 +4,6 @@ from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
......@@ -79,13 +78,13 @@ class ResNet():
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
out = paddle.static.nn.fc(
pool,
class_dim,
activation='softmax',
name=fc_name,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))
else:
for block in range(len(depth)):
for i in range(depth[block]):
......@@ -102,12 +101,12 @@ class ResNet():
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(
input=pool,
size=class_dim,
out = paddle.static.nn.fc(
pool,
class_dim,
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))
return out
......@@ -119,7 +118,7 @@ class ResNet():
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -127,7 +126,7 @@ class ResNet():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if self.prefix_name == '':
......@@ -141,12 +140,12 @@ class ResNet():
else:
bn_name = name.split("_", 1)[0] + "_bn" + name.split("_",
1)[1][3:]
return fluid.layers.batch_norm(
return paddle.static.nn.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
param_attr=paddle.ParamAttr(name=bn_name + '_scale'),
bias_attr=paddle.ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
......
......@@ -20,7 +20,6 @@ import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
"ResNet", "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd",
......@@ -119,11 +118,11 @@ class ResNet():
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = paddle.static.nn.fc(
pool,
class_dim,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))
return out
......@@ -135,7 +134,7 @@ class ResNet():
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -143,17 +142,17 @@ class ResNet():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
return paddle.static.nn.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
param_attr=paddle.ParamAttr(name=bn_name + '_scale'),
bias_attr=paddle.ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......@@ -173,7 +172,7 @@ class ResNet():
pool_type='avg',
ceil_mode=True)
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
......@@ -181,17 +180,17 @@ class ResNet():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=paddle.ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
return paddle.static.nn.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
param_attr=paddle.ParamAttr(name=bn_name + '_scale'),
bias_attr=paddle.ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......
......@@ -18,8 +18,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.nn.initializer import KaimingUniform
class SlimFaceNet():
......@@ -143,7 +142,7 @@ class SlimFaceNet():
num_groups=out_c,
if_act=False,
name='global_dw_conv7x7')
x = fluid.layers.conv2d(
x = paddle.static.nn.conv2d(
x,
num_filters=128,
filter_size=1,
......@@ -152,30 +151,30 @@ class SlimFaceNet():
groups=1,
act=None,
use_cudnn=True,
param_attr=ParamAttr(
param_attr=paddle.ParamAttr(
name='linear_conv1x1_weights',
initializer=MSRA(),
initializer=KaimingUniform(),
regularizer=fluid.regularizer.L2Decay(4e-4)),
bias_attr=False)
bn_name = 'linear_conv1x1_bn'
x = fluid.layers.batch_norm(
x = paddle.static.nn.batch_norm(
x,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
param_attr=paddle.ParamAttr(name=bn_name + "_scale"),
bias_attr=paddle.ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
if self.extract_feature:
return x
out = self.arc_margin_product(
x, label, self.class_dim, s=32.0, m=0.50, mode=2)
softmax = fluid.layers.softmax(input=out)
softmax = paddle.nn.functional.softmax(input=out)
cost = fluid.layers.cross_entropy(input=softmax, label=label)
loss = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=out, label=label, k=1)
acc = paddle.static.accuracy(input=out, label=label, k=1)
return loss, acc
def residual_unit(self,
......@@ -235,26 +234,26 @@ class SlimFaceNet():
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
conv1 = paddle.static.nn.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act=None,
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
param_attr=paddle.ParamAttr(name=name + '_1_weights'),
bias_attr=paddle.ParamAttr(name=name + '_1_offset'))
conv1 = fluid.layers.prelu(
conv1,
mode='channel',
param_attr=ParamAttr(
param_attr=paddle.ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
conv2 = fluid.layers.conv2d(
conv2 = paddle.static.nn.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
param_attr=paddle.ParamAttr(name=name + '_2_weights'),
bias_attr=paddle.ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
......@@ -268,7 +267,7 @@ class SlimFaceNet():
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
conv = paddle.static.nn.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
......@@ -277,48 +276,47 @@ class SlimFaceNet():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
name=name + '_weights', initializer=MSRA()),
param_attr=paddle.ParamAttr(
name=name + '_weights', initializer=KaimingUniform()),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
bn = paddle.static.nn.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
param_attr=paddle.ParamAttr(name=bn_name + "_scale"),
bias_attr=paddle.ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.prelu(
bn,
mode='channel',
param_attr=ParamAttr(
param_attr=paddle.ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
else:
return bn
def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50,
mode=2):
def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50, mode=2):
input_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(input), dim=1))
paddle.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
weight = fluid.layers.create_parameter(
weight = paddle.static.create_parameter(
shape=[out_dim, input.shape[1]],
dtype='float32',
name='weight_norm',
attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(),
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(4e-4)))
weight_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(weight), dim=1))
paddle.square(weight), dim=1))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0)
weight = fluid.layers.transpose(weight, perm=[1, 0])
cosine = fluid.layers.mul(input, weight)
sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine))
sine = fluid.layers.sqrt(1.0 - paddle.square(cosine))
cos_m = math.cos(m)
sin_m = math.sin(m)
......@@ -368,7 +366,8 @@ def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
if __name__ == "__main__":
paddle.enable_static()
x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32')
x = paddle.static.data(name='x', shape=[-1, 3, 112, 112], dtype='float32')
print(x.shape)
model = SlimFaceNet(10000, arch=[1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3])
model = SlimFaceNet(
10000, arch=[1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3])
y = model.net(x)
......@@ -10,7 +10,6 @@ import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as T
import paddle.static as static
from paddle import ParamAttr
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
from paddleslim.common import get_logger
......@@ -38,14 +37,14 @@ def conv_bn_layer(input,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
param_attr=paddle.ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
return static.nn.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
param_attr=paddle.ParamAttr(name=bn_name + '_scale'),
bias_attr=paddle.ParamAttr(name=bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......@@ -130,8 +129,8 @@ def search_mobilenetv2_block(config, args, image_size):
output = static.nn.fc(
x=data,
size=args.class_dim,
weight_attr=ParamAttr(name='mobilenetv2_fc_weights'),
bias_attr=ParamAttr(name='mobilenetv2_fc_offset'))
weight_attr=paddle.ParamAttr(name='mobilenetv2_fc_weights'),
bias_attr=paddle.ParamAttr(name='mobilenetv2_fc_offset'))
softmax_out = F.softmax(output)
cost = F.cross_entropy(softmax_out, label=label)
......
import sys
sys.path.append('..')
import numpy as np
import argparse
import ast
import time
import argparse
import ast
import logging
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.nas import RLNAS
from paddleslim.common import get_logger
from optimizer import create_optimizer
import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO)
def build_program(main_program,
startup_program,
image_shape,
dataset,
archs,
args,
places,
is_test=False):
with static.program_guard(main_program, startup_program):
with paddle.utils.unique_name.guard():
data_shape = [None] + image_shape
data = static.data(name='data', shape=data_shape, dtype='float32')
label = static.data(name='label', shape=[None, 1], dtype='int64')
if args.data == 'cifar10':
paddle.assign(paddle.reshape(label, [-1, 1]), label)
if is_test:
data_loader = paddle.io.DataLoader(
dataset,
places=places,
feed_list=[data, label],
drop_last=False,
batch_size=args.batch_size,
return_list=False,
shuffle=False)
else:
data_loader = paddle.io.DataLoader(
dataset,
places=places,
feed_list=[data, label],
drop_last=True,
batch_size=args.batch_size,
return_list=False,
shuffle=True,
use_shared_memory=True,
num_workers=4)
output = archs(data)
output = static.nn.fc(output, size=args.class_dim)
softmax_out = F.softmax(output)
cost = F.cross_entropy(softmax_out, label=label)
avg_cost = paddle.mean(cost)
acc_top1 = paddle.metric.accuracy(
input=softmax_out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(
input=softmax_out, label=label, k=5)
if is_test == False:
optimizer = create_optimizer(args)
optimizer.minimize(avg_cost)
return data_loader, avg_cost, acc_top1, acc_top5
def search_mobilenetv2(config, args, image_size, is_server=True):
places = static.cuda_places() if args.use_gpu else static.cpu_places()
place = places[0]
if is_server:
### start a server and a client
rl_nas = RLNAS(
key='ddpg',
configs=config,
is_sync=False,
obs_dim=26, ### step + length_of_token
server_addr=(args.server_address, args.port))
else:
### start a client
rl_nas = RLNAS(
key='ddpg',
configs=config,
is_sync=False,
obs_dim=26,
server_addr=(args.server_address, args.port),
is_server=False)
image_shape = [3, image_size, image_size]
if args.data == 'cifar10':
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
val_dataset = imagenet_reader.ImageNetDataset(mode='val')
for step in range(args.search_steps):
if step == 0:
action_prev = [1. for _ in rl_nas.range_tables]
else:
action_prev = rl_nas.tokens[0]
obs = [step]
obs.extend(action_prev)
archs = rl_nas.next_archs(obs=obs)[0][0]
train_program = static.Program()
test_program = static.Program()
startup_program = static.Program()
train_loader, avg_cost, acc_top1, acc_top5 = build_program(
train_program, startup_program, image_shape, train_dataset, archs,
args, places)
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
test_program,
startup_program,
image_shape,
val_dataset,
archs,
args,
place,
is_test=True)
test_program = test_program.clone(for_test=True)
exe = static.Executor(place)
exe.run(startup_program)
build_strategy = static.BuildStrategy()
train_compiled_program = static.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time
if batch_id % 10 == 0:
_logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time))
reward = []
for batch_id, data in enumerate(test_loader()):
test_fetches = [
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
]
batch_reward = exe.run(test_program,
feed=data,
fetch_list=test_fetches)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
_logger.info(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(step, batch_id, batch_reward[0], batch_reward[1],
batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
finally_reward[0], finally_reward[1], finally_reward[2]))
obs = np.expand_dims(obs, axis=0).astype('float32')
actions = rl_nas.tokens
obs_next = [step + 1]
obs_next.extend(actions[0])
obs_next = np.expand_dims(obs_next, axis=0).astype('float32')
if step == args.search_steps - 1:
terminal = np.expand_dims([True], axis=0).astype(np.bool)
else:
terminal = np.expand_dims([False], axis=0).astype(np.bool)
rl_nas.reward(
np.expand_dims(
np.float32(finally_reward[1]), axis=0),
obs=obs,
actions=actions.astype('float32'),
obs_next=obs_next,
terminal=terminal)
if step == 2:
sys.exit(0)
if __name__ == '__main__':
paddle.enable_static()
parser = argparse.ArgumentParser(
description='RL NAS MobileNetV2 cifar10 argparase')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='Whether to use GPU in train/test model.')
parser.add_argument(
'--batch_size', type=int, default=256, help='batch size.')
parser.add_argument(
'--class_dim', type=int, default=10, help='classify number.')
parser.add_argument(
'--data',
type=str,
default='cifar10',
choices=['cifar10', 'imagenet'],
help='server address.')
parser.add_argument(
'--is_server',
type=ast.literal_eval,
default=True,
help='Whether to start a server.')
parser.add_argument(
'--search_steps',
type=int,
default=100,
help='controller server number.')
parser.add_argument(
'--server_address', type=str, default="", help='server ip.')
parser.add_argument('--port', type=int, default=8881, help='server port')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='epoch for each token.')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
args = parser.parse_args()
print(args)
if args.data == 'cifar10':
image_size = 32
block_num = 3
elif args.data == 'imagenet':
image_size = 224
block_num = 6
else:
raise NotImplementedError(
'data must in [cifar10, imagenet], but received: {}'.format(
args.data))
config = [('MobileNetV2Space')]
search_mobilenetv2(config, args, image_size, is_server=args.is_server)
......@@ -12,7 +12,6 @@ import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddle import ParamAttr
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
from paddleslim.common import get_logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import ast
import numpy as np
from PIL import Image
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable
from paddleslim.nas.one_shot import SuperMnasnet
from paddleslim.nas.one_shot import OneShotSearch
def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
return args
class SimpleImgConv(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConv, self).__init__()
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
act=act,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu")
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu")
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self._simple_img_conv_pool_1(inputs)
x = self.arch(x, tokens=tokens) # addddddd
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
def test_mnist(model, tokens=None):
acc_set = []
avg_loss_set = []
batch_size = 64
test_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.test(), batch_size=batch_size, drop_last=True)
for batch_id, data in enumerate(test_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
prediction, acc = model.forward(img, label, tokens=tokens)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
if batch_id % 100 == 0:
print("Test - batch_id: {}".format(batch_id))
# get test acc and loss
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
return acc_val_mean
def train_mnist(args, model, tokens=None):
epoch_num = args.epoch
BATCH_SIZE = 64
adam = AdamOptimizer(learning_rate=0.001, parameter_list=model.parameters())
train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
cost, acc = model.forward(img, label, tokens=tokens)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
if args.use_data_parallel:
avg_loss = model.scale_loss(avg_loss)
avg_loss.backward()
model.apply_collective_grads()
else:
avg_loss.backward()
adam.minimize(avg_loss)
# save checkpoint
model.clear_gradients()
if batch_id % 1 == 0:
print("Loss at epoch {} step {}: {:}".format(epoch, batch_id,
avg_loss.numpy()))
model.eval()
test_acc = test_mnist(model, tokens=tokens)
model.train()
print("Loss at epoch {} , acc is: {}".format(epoch, test_acc))
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.save_dygraph(model.state_dict(), "save_temp")
print("checkpoint saved")
if __name__ == '__main__':
args = parse_args()
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = MNIST()
# step 1: training super net
#train_mnist(args, model)
# step 2: search
best_tokens = OneShotSearch(model, test_mnist)
# step 3: final training
# train_mnist(args, model, best_tokens)
......@@ -438,14 +438,10 @@ def compress(args):
_logger.info("final acc:{}".format(final_acc1))
# 4. Save inference model
paddle.fluid.io.save_inference_model(
dirname=model_path,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=float_program,
model_filename=model_path + '/model.pdmodel',
params_filename=model_path + '/model.pdiparams')
paddle.static.save_inference_model(
os.path.join(model_path, 'model'), [image], [out],
exe,
program=float_program)
def main():
......
......@@ -309,14 +309,12 @@ def compress(args):
if not os.path.isdir(model_path):
os.makedirs(model_path)
paddle.fluid.io.save_inference_model(
dirname=float_path,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=float_program,
model_filename=float_path + '/model',
params_filename=float_path + '/params')
paddle.static.save_inference_model(
os.path.join(float_path, "model"),
[image],
[out],
exe,
program=float_program, )
def main():
......
......@@ -40,7 +40,7 @@ def eval(args):
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.fluid.io.load_inference_model(
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
args.model_path,
exe,
model_filename=args.model_name,
......
......@@ -78,7 +78,7 @@ class CASIA_Face(object):
if __name__ == '__main__':
data_dir = 'PATH to CASIA dataset'
place = fluid.CPUPlace()
place = paddle.CPUPlace()
with fluid.dygraph.guard(place):
dataset = CASIA_Face(root=data_dir)
print(len(dataset))
......
......@@ -88,8 +88,8 @@ def evaluation_10_fold(root='result.mat'):
flags = np.squeeze(flags)
mu = np.mean(
np.concatenate(
(featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]),
0), 0)
mu = np.expand_dims(mu, 0)
featureLs = featureLs - mu
featureRs = featureRs - mu
......@@ -145,8 +145,7 @@ if __name__ == "__main__":
'--use_gpu', default=0, type=int, help='Use GPU or not, 0 is not used')
parser.add_argument(
'--test_data_dir', default='./lfw', type=str, help='lfw_data_dir')
parser.add_argument(
'--resume', default='output/0', type=str, help='resume')
parser.add_argument('--resume', default='output/0', type=str, help='resume')
parser.add_argument(
'--feature_save_dir',
default='result.mat',
......@@ -154,7 +153,7 @@ if __name__ == "__main__":
help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
place = fluid.CPUPlace() if args.use_gpu == 0 else fluid.CUDAPlace(0)
place = paddle.CPUPlace() if args.use_gpu == 0 else paddle.CUDAPlace(0)
with fluid.dygraph.guard(place):
train_dataset = CASIA_Face(root=args.train_data_dir)
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
......
......@@ -31,6 +31,7 @@ from paddleslim.quant import quant_post_static
paddle.enable_static()
def now():
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
......@@ -142,12 +143,10 @@ def train(exe, train_program, train_out, test_program, test_out, args):
best_ave = temp_ave
print('Best AVE: {}'.format(best_ave))
out_feature, test_reader, flods, flags = test_out
fluid.io.save_inference_model(
executor=exe,
dirname='./out_inference',
feeded_var_names=['image_test'],
target_vars=[out_feature],
main_program=test_program)
paddle.static.save_inference_model(
'./out_inference', ['image_test'], [out_feature],
exe,
program=test_program)
def build_program(program, startup, args, is_train=True):
......@@ -155,21 +154,23 @@ def build_program(program, startup, args, is_train=True):
num_trainers = fluid.core.get_cuda_device_count()
else:
num_trainers = int(os.environ.get('CPU_NUM', 1))
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else paddle.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir)
trainset_scale = len(train_dataset)
with fluid.program_guard(main_program=program, startup_program=startup):
with paddle.static.program_guard(
main_program=program, startup_program=startup):
with fluid.unique_name.guard():
# Model construction
model = models.__dict__[args.model](
class_dim=train_dataset.class_nums)
if is_train:
image = fluid.data(
image = paddle.static.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64')
train_reader = fluid.io.batch(
train_dataset.reader,
batch_size=args.train_batchsize // num_trainers,
......@@ -195,24 +196,16 @@ def build_program(program, startup, args, is_train=True):
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
image_test = fluid.data(
image_test = paddle.static.data(
name='image_test', shape=[-1, 3, 112, 96], dtype='float32')
image_test1 = fluid.data(
name='image_test1',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test2 = fluid.data(
name='image_test2',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test3 = fluid.data(
name='image_test3',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test4 = fluid.data(
name='image_test4',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test1 = paddle.static.data(
name='image_test1', shape=[-1, 3, 112, 96], dtype='float32')
image_test2 = paddle.static.data(
name='image_test2', shape=[-1, 3, 112, 96], dtype='float32')
image_test3 = paddle.static.data(
name='image_test3', shape=[-1, 3, 112, 96], dtype='float32')
image_test4 = paddle.static.data(
name='image_test4', shape=[-1, 3, 112, 96], dtype='float32')
reader = fluid.io.DataLoader.from_generator(
feed_list=[
image_test1, image_test2, image_test3, image_test4
......@@ -223,7 +216,7 @@ def build_program(program, startup, args, is_train=True):
reader.set_sample_list_generator(
test_reader,
places=fluid.cuda_places()
if args.use_gpu else fluid.CPUPlace())
if args.use_gpu else paddle.CPUPlace())
model.extract_feature = True
feature = model.net(image_test)
......@@ -317,16 +310,16 @@ def main():
f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
if args.action == 'train':
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
train_program = paddle.static.Program()
test_program = paddle.static.Program()
startup_program = paddle.static.Program()
if args.action == 'train':
train_out = build_program(train_program, startup_program, args, True)
test_out = build_program(test_program, startup_program, args, False)
test_program = test_program.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
if args.action == 'train':
......@@ -345,7 +338,7 @@ def main():
batch_nums=np.random.randint(4, 10))
elif args.action == 'test':
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
fetch_targets] = paddle.static.load_inference_model(
dirname='./quant_model/',
model_filename=None,
params_filename=None,
......@@ -356,15 +349,15 @@ def main():
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
image_test = fluid.data(
image_test = paddle.static.data(
name='image_test', shape=[-1, 3, 112, 96], dtype='float32')
image_test1 = fluid.data(
image_test1 = paddle.static.data(
name='image_test1', shape=[-1, 3, 112, 96], dtype='float32')
image_test2 = fluid.data(
image_test2 = paddle.static.data(
name='image_test2', shape=[-1, 3, 112, 96], dtype='float32')
image_test3 = fluid.data(
image_test3 = paddle.static.data(
name='image_test3', shape=[-1, 3, 112, 96], dtype='float32')
image_test4 = fluid.data(
image_test4 = paddle.static.data(
name='image_test4', shape=[-1, 3, 112, 96], dtype='float32')
reader = fluid.io.DataLoader.from_generator(
feed_list=[image_test1, image_test2, image_test3, image_test4],
......@@ -373,7 +366,7 @@ def main():
return_list=False)
reader.set_sample_list_generator(
test_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.CPUPlace())
places=fluid.cuda_places() if args.use_gpu else paddle.CPUPlace())
test_out = (fetch_targets, reader, flods, flags)
print('fetch_targets[0]: ', fetch_targets[0])
print('feed_target_names: ', feed_target_names)
......
......@@ -27,9 +27,7 @@ import hashlib
import tarfile
import zipfile
import logging
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.framework import Program
from paddle.static import Program
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
......@@ -79,90 +77,6 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
**kwargs)
def save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = node.name()
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list)
def load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = node.name()
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
def _load_var(name, scope):
return np.array(scope.find_var(name).get_tensor())
def _store_var(name, array, scope, place):
tensor = scope.find_var(name).get_tensor()
tensor.set(array, place)
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
else:
_logger.info("Cannot find the var %s!!!" % (node.name()))
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
......
......@@ -15,11 +15,11 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
place = fluid.CPUPlace()
place = paddle.CPUPlace()
exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names,
fetch_targets] = paddle.fluid.io.load_inference_model(
fetch_targets] = paddle.static.load_inference_model(
dirname=args.model_dir,
executor=exe,
model_filename=args.model_filename,
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import absolute_import
from paddleslim import models
from paddleslim import prune
from paddleslim import nas
from paddleslim import analysis
......@@ -22,8 +21,13 @@ from paddleslim import quant
from paddleslim import dygraph
from paddleslim import auto_compression
__all__ = [
'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'dygraph',
'auto_compression'
'prune',
'nas',
'analysis',
'dist',
'quant',
'dygraph',
'auto_compression',
]
from paddleslim.dygraph import *
......
......@@ -119,7 +119,7 @@ def save_cls_model(model, input_shape, save_dir, data_type):
if data_type == 'int8':
paddle.enable_static()
exe = paddle.fluid.Executor(paddle.fluid.CPUPlace())
exe = paddle.static.Executor(paddle.CPUPlace())
save_dir = os.path.dirname(model_file)
quantize_model_path = os.path.join(save_dir, 'int8model')
......
......@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import Program
from ..core import GraphWrapper, OpWrapper
import paddle
__all__ = ["LatencyEvaluator", "TableLatencyEvaluator"]
......@@ -288,7 +288,7 @@ class TableLatencyEvaluator(LatencyEvaluator):
latency(float): The latency of given graph on current evaluator.
"""
total_latency = 0
if isinstance(graph, Program):
if isinstance(graph, paddle.static.Program):
graph = GraphWrapper(graph)
assert isinstance(graph, GraphWrapper)
for op in self._get_ops_from_graph(graph, only_conv):
......
......@@ -198,8 +198,7 @@ class TableLatencyPredictor(LatencyPredictor):
paddle.enable_static()
with open(pbmodel_file, "rb") as f:
fluid_program = paddle.fluid.framework.Program.parse_from_string(
f.read())
fluid_program = paddle.static.Program.parse_from_string(f.read())
graph = GraphWrapper(fluid_program)
......
......@@ -2,7 +2,6 @@ import os
import time
import numpy as np
import paddle
import paddle.static as static
from ...prune import Pruner
from ...core import GraphWrapper
from ...common.load_model import load_inference_model
......@@ -34,8 +33,8 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
else:
param_name = param_file.split('/')[-1]
main_prog = static.Program()
startup_prog = static.Program()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
executor.run(startup_prog)
inference_program, feed_target_names, fetch_targets = load_inference_model(
......@@ -90,7 +89,7 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
paddle.static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
......@@ -124,9 +123,9 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
else:
param_name = param_file.split('/')[-1]
main_prog = static.Program()
startup_prog = static.Program()
scope = static.global_scope()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
scope = paddle.static.global_scope()
executor.run(startup_prog)
inference_program, feed_target_names, fetch_targets = load_inference_model(
......@@ -166,7 +165,7 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
paddle.static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
......
......@@ -15,7 +15,6 @@
import os
import types
import paddle
import paddle.fluid as fluid
import numpy as np
from collections import defaultdict
import matplotlib
......
......@@ -17,7 +17,6 @@ import copy
import math
import numpy as np
import paddle
import paddle.fluid as fluid
__all__ = ['EvolutionaryController', 'RLBaseController']
......@@ -64,12 +63,6 @@ class RLBaseController(object):
def update(self, *args, **kwargs):
raise NotImplementedError('Abstract method.')
def save_controller(self, program, output_dir):
fluid.save(program, output_dir)
def load_controller(self, program, load_dir):
fluid.load(program, load_dir)
def get_params(self, program):
var_dict = {}
for var in program.global_block().all_parameters():
......
......@@ -13,9 +13,7 @@
# limitations under the License.
import six
from paddle.fluid.framework import Parameter
from paddle.fluid import unique_name
from paddle.fluid import core
import paddle
from ..core import GraphWrapper
__all__ = ['recover_inference_program']
......@@ -43,10 +41,11 @@ def _recover_reserve_space_with_bn(program):
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(
["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
type=paddle.fluid.core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
......@@ -59,7 +58,7 @@ def _recover_param_attr(program):
all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
for w in all_weights:
new_w = Parameter(
new_w = paddle.fluid.framework.Parameter(
block=program.block(0),
shape=w.shape,
dtype=w.dtype,
......
......@@ -15,10 +15,6 @@
import logging
from ..log_helper import get_logger
_logger = get_logger(__name__, level=logging.INFO)
try:
from .ddpg import *
except ImportError as e:
pass
from .lstm import *
from .utils import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
try:
from .ddpg_controller import *
except Exception as e:
logging.warning(e)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl
from parl import layers
import paddle
from paddle import fluid
from ..utils import RLCONTROLLER, action_mapping
from ...controller import RLBaseController
from .ddpg_model import DefaultDDPGModel as default_ddpg_model
from .noise import AdaptiveNoiseSpec as default_noise
from parl.utils import ReplayMemory
__all__ = ['DDPG']
class DDPGAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(DDPGAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = paddle.static.Program()
self.learn_program = paddle.static.Program()
with paddle.static.program_guard(self.pred_program):
obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs)
with paddle.static.program_guard(self.learn_program):
obs = fluid.data(
name='obs', shape=[None, self.obs_dim], dtype='float32')
act = fluid.data(
name='act', shape=[None, self.act_dim], dtype='float32')
reward = fluid.data(name='reward', shape=[None], dtype='float32')
next_obs = fluid.data(
name='next_obs', shape=[None, self.obs_dim], dtype='float32')
terminal = fluid.data(
name='terminal', shape=[None, 1], dtype='bool')
_, self.critic_cost = self.alg.learn(obs, act, reward, next_obs,
terminal)
def predict(self, obs):
act = self.fluid_executor.run(self.pred_program,
feed={'obs': obs},
fetch_list=[self.pred_act])[0]
return act
def learn(self, obs, act, reward, next_obs, terminal):
feed = {
'obs': obs,
'act': act,
'reward': reward,
'next_obs': next_obs,
'terminal': terminal
}
critic_cost = self.fluid_executor.run(self.learn_program,
feed=feed,
fetch_list=[self.critic_cost])[0]
self.alg.sync_target()
return critic_cost
@RLCONTROLLER.register
class DDPG(RLBaseController):
def __init__(self, range_tables, use_gpu=False, **kwargs):
self.use_gpu = use_gpu
self.range_tables = range_tables - np.asarray(1)
self.act_dim = len(self.range_tables)
self.obs_dim = kwargs.get('obs_dim')
self.model = kwargs.get(
'model') if 'model' in kwargs else default_ddpg_model
self.actor_lr = kwargs.get('actor_lr') if 'actor_lr' in kwargs else 1e-4
self.critic_lr = kwargs.get(
'critic_lr') if 'critic_lr' in kwargs else 1e-3
self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99
self.tau = kwargs.get('tau') if 'tau' in kwargs else 0.001
self.memory_size = kwargs.get(
'memory_size') if 'memory_size' in kwargs else 10
self.reward_scale = kwargs.get(
'reward_scale') if 'reward_scale' in kwargs else 0.1
self.batch_size = kwargs.get(
'controller_batch_size') if 'controller_batch_size' in kwargs else 1
self.actions_noise = kwargs.get(
'actions_noise') if 'actions_noise' in kwargs else default_noise
self.action_dist = 0.0
self.place = paddle.CUDAPlace(0) if self.use_gpu else paddle.CPUPlace()
model = self.model(self.act_dim)
if self.actions_noise:
self.actions_noise = self.actions_noise()
algorithm = parl.algorithms.DDPG(
model,
gamma=self.gamma,
tau=self.tau,
actor_lr=self.actor_lr,
critic_lr=self.critic_lr)
self.agent = DDPGAgent(algorithm, self.obs_dim, self.act_dim)
self.rpm = ReplayMemory(self.memory_size, self.obs_dim, self.act_dim)
self.pred_program = self.agent.pred_program
self.learn_program = self.agent.learn_program
self.param_dict = self.get_params(self.learn_program)
def next_tokens(self, obs, params_dict, is_inference=False):
batch_obs = np.expand_dims(obs, axis=0)
self.set_params(self.pred_program, params_dict, self.place)
actions = self.agent.predict(batch_obs.astype('float32'))
### add noise to action
if self.actions_noise and is_inference == False:
actions_noise = np.clip(
np.random.normal(
actions, scale=self.actions_noise.stdev_curr),
-1.0,
1.0)
self.action_dist = np.mean(np.abs(actions_noise - actions))
else:
actions_noise = actions
actions_noise = action_mapping(actions_noise, self.range_tables)
return actions_noise
def _update_noise(self, actions_dist):
self.actions_noise.update(actions_dist)
def update(self, rewards, params_dict, obs, actions, obs_next, terminal):
self.set_params(self.learn_program, params_dict, self.place)
self.rpm.append(obs, actions, self.reward_scale * rewards, obs_next,
terminal)
if self.actions_noise:
self._update_noise(self.action_dist)
if self.rpm.size() > self.memory_size:
obs, actions, rewards, obs_next, terminal = rpm.sample_batch(
self.batch_size)
self.agent.learn(obs, actions, rewards, obs_next, terminal)
params_dict = self.get_params(self.learn_program)
return params_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class DefaultDDPGModel(parl.Model):
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
def policy(self, obs):
return self.actor_model.policy(obs)
def value(self, obs, act):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameters()
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act='tanh')
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.fc3(hid2)
means = means
return means
class CriticModel(parl.Model):
def __init__(self):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=1, act=None)
def value(self, obs, act):
hid1 = self.fc1(obs)
concat = layers.concat([hid1, act], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['AdaptiveNoiseSpec']
class AdaptiveNoiseSpec(object):
def __init__(self):
self.stdev_curr = 1.0
def reset(self):
self.stdev_curr = 1.0
def update(self, action_dist):
if action_dist > 1e-2:
self.stdev_curr /= 1.03
else:
self.stdev_curr *= 1.03
......@@ -287,4 +287,4 @@ class LSTM(RLBaseController):
_logger.info("Controller: current reward is {}, loss is {}".format(
rewards, loss))
params_dict = self.get_params(self.learn_program)
return params_dict
return params_dict
\ No newline at end of file
......@@ -39,7 +39,7 @@ class Counter:
return counter_wrapper
class FuncWrapper(nn.Layer):
class FuncWrapper(paddle.nn.Layer):
"""
"""
......@@ -121,10 +121,11 @@ def functional2layer():
not_convert = ['linear', 'conv1d', 'conv1d_transpose', \
'conv2d', 'conv2d_transpose', 'conv3d', \
'conv3d_transpose', 'one_hot', 'embedding']
for f in dir(F):
for f in dir(paddle.nn.functional):
if not f.startswith('__') and f not in not_convert and not f.startswith(
'origin_'):
setattr(F, 'origin_{}'.format(f), eval('F.{}'.format(f)))
setattr(paddle.nn.functional, 'origin_{}'.format(f),
eval('F.{}'.format(f)))
if inspect.isfunction(eval('F.{}'.format(f))):
new_fn = convert_fn(eval('F.{}'.format(f)))
setattr(F, '{}'.format(f), new_fn)
setattr(paddle.nn.functional, '{}'.format(f), new_fn)
......@@ -3,11 +3,6 @@ import paddle
import collections
import logging
import numpy as np
from paddle.fluid import core
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard, program_guard, in_dygraph_mode
from paddle.fluid.dygraph.base import program_desc_tracing_guard, _switch_declarative_mode_guard_
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from ..common import get_logger
__all__ = ["dygraph2program"]
......@@ -69,32 +64,6 @@ def _create_tensors(shapes, dtypes=None, is_static=False):
return tensors
def extract_vars(inputs):
"""
Extract a list of variables from inputs.
Args:
inputs(Variable | list<Object> | dict):
"""
vars = []
if isinstance(inputs, Variable):
vars = [inputs]
elif isinstance(inputs, dict):
for _key, _value in inputs.items():
if isinstance(_value, Variable):
vars.append(_value)
else:
_logger.warn(
f"Variable is excepted, but get an element with type({type(_value)}) from inputs whose type is dict. And the key of element is {_key}."
)
elif isinstance(inputs, (tuple, list)):
for _value in inputs:
vars.extend(extract_vars(_value))
if len(vars) == 0:
_logger.warn(f"Extract none variables from inputs.")
return vars
def _to_var(x):
"""
Convert Variable or np.array into Placeholder.
......@@ -105,99 +74,49 @@ def _to_var(x):
return paddle.static.data(shape=shape, dtype=dtype, name=name)
def to_variables(inputs, is_static=False):
def to_variables(inputs):
"""
Find and rename variables. Find np.ndarray and convert it to variable.
"""
if isinstance(inputs, (Variable, paddle.Tensor)) or isinstance(inputs,
np.ndarray):
if is_static:
return _to_var(inputs)
else:
return paddle.fluid.dygraph.to_variable(inputs)
if isinstance(inputs,
(paddle.static.Variable, paddle.Tensor)) or isinstance(
inputs, np.ndarray):
return _to_var(inputs)
elif isinstance(inputs, dict):
ret = {}
for _key in inputs:
ret[_key] = to_variables(inputs[_key], is_static)
ret[_key] = to_variables(inputs[_key])
return ret
elif isinstance(inputs, list):
ret = []
for _value in inputs:
ret.append(to_variables(_value, is_static))
ret.append(to_variables(_value))
return ret
@dygraph_only
def dygraph2program(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_',
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
print(type(layer))
assert isinstance(layer, Layer)
extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
if in_dygraph_mode():
return _dy2prog(layer, inputs, feed_prefix, fetch_prefix, tmp_prefix,
extract_inputs_fn, extract_outputs_fn, dtypes)
tracer = _dygraph_tracer()._get_program_desc_tracer()
with program_desc_tracing_guard(True):
if _is_shape(inputs):
shapes = [inputs]
inputs = _create_tensors(shapes, dtypes=dtypes)
input_var_list = inputs
elif _is_shapes(inputs):
inputs = _create_tensors(inputs, dtypes=dtypes)
input_var_list = inputs
else:
inputs = to_variables(inputs)
input_var_list = extract_inputs_fn(inputs)
original_outputs = layer(*inputs)
# 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'.
out_var_list = extract_outputs_fn(original_outputs)
program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
input_var_list, feed_prefix, out_var_list, fetch_prefix, tmp_prefix)
tracer.reset()
with _dygraph_guard(None):
program = Program()
program.desc = program_desc
program.blocks = [Block(program, 0)]
program._sync_with_cpp()
return program
@paddle.fluid.framework.dygraph_only
def dygraph2program(layer, inputs, dtypes=None):
assert isinstance(layer, paddle.nn.Layer)
return _dy2prog(layer, inputs, dtypes)
def _dy2prog(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_',
extract_inputs_fn=None,
extract_outputs_fn=None,
dtypes=None):
def _dy2prog(layer, inputs, dtypes=None):
"""
Tracing program in Eager Mode.
"""
paddle.enable_static()
program = Program()
program = paddle.static.Program()
# convert ParamBase into Parameter automatically by _switch_declarative_mode_guard_
with program_guard(program), _switch_declarative_mode_guard_(True):
with paddle.static.program_guard(
program), paddle.fluid.dygraph.base._switch_declarative_mode_guard_(
True):
if _is_shape(inputs):
shapes = [inputs]
inputs = _create_tensors(shapes, dtypes=dtypes, is_static=True)
inputs = _create_tensors(shapes, dtypes=dtypes)
elif _is_shapes(inputs):
inputs = _create_tensors(inputs, dtypes=dtypes, is_static=True)
inputs = _create_tensors(inputs, dtypes=dtypes)
else:
inputs = to_variables(inputs, is_static=True)
inputs = to_variables(inputs)
if isinstance(inputs, list):
outputs = layer(*inputs)
else:
......
......@@ -18,7 +18,7 @@ import pickle
import numpy as np
from collections import OrderedDict
from collections.abc import Iterable
from paddle.fluid.framework import Program, program_guard, Parameter, Variable
import paddle
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper']
......@@ -37,7 +37,7 @@ OPTIMIZER_OPS = [
class VarWrapper(object):
def __init__(self, var, graph):
assert isinstance(var, (Variable, Parameter))
# assert isinstance(var, paddle.static.Variable), f"The type is {type(var)}"
assert isinstance(graph, GraphWrapper)
self._var = var
self._graph = graph
......@@ -104,9 +104,6 @@ class VarWrapper(object):
ops.append(op)
return ops
def is_parameter(self):
return isinstance(self._var, Parameter)
class OpWrapper(object):
def __init__(self, op, graph):
......@@ -240,7 +237,7 @@ class GraphWrapper(object):
"""
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.program = paddle.static.Program() if program is None else program
self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars():
......@@ -266,15 +263,6 @@ class GraphWrapper(object):
params.append(VarWrapper(param, self))
return params
def is_parameter(self, var):
"""
Whether the given variable is parameter.
Args:
var(VarWrapper): The given varibale.
"""
return isinstance(var._var, Parameter)
def is_persistable(self, var):
"""
Whether the given variable is persistable.
......@@ -363,18 +351,6 @@ class GraphWrapper(object):
ops.append(p)
return sorted(ops)
def get_param_by_op(self, op):
"""
Get the parameters used by target operator.
"""
assert isinstance(op, OpWrapper)
params = []
for var in op.all_inputs():
if isinstance(var._var, Parameter):
params.append(var)
assert len(params) > 0
return params
def numel_params(self):
"""
Get the number of elements in all parameters.
......
......@@ -18,12 +18,9 @@ from __future__ import print_function
import copy
import paddle
import paddle.nn as nn
from paddle.nn import LogSoftmax
class DML(nn.Layer):
class DML(paddle.nn.Layer):
def __init__(self, model, use_parallel=False):
super(DML, self).__init__()
self.model = model
......@@ -66,11 +63,11 @@ class DML(nn.Layer):
cur_kl_loss = 0
for j in range(self.model_num):
if i != j:
log_softmax = LogSoftmax(axis=1)
log_softmax = paddle.nn.LogSoftmax(axis=1)
x = log_softmax(logits[i])
y = nn.functional.softmax(logits[j], axis=1)
cur_kl_loss += nn.functional.kl_div(
y = paddle.nn.functional.softmax(logits[j], axis=1)
cur_kl_loss += paddle.nn.functional.kl_div(
x, y, reduction='batchmean')
kl_losses.append(cur_kl_loss / (self.model_num - 1))
return kl_losses
......
......@@ -203,8 +203,8 @@ def soft_label(teacher_var_name,
teacher_var = paddle.nn.functional.softmax(teacher_var /
teacher_temperature)
soft_label_loss = paddle.mean(
paddle.fluid.layers.cross_entropy(
student_var, teacher_var, soft_label=True))
paddle.nn.functional.cross_entropy(
input=student_var, label=teacher_var, soft_label=True))
return soft_label_loss
......
......@@ -16,7 +16,6 @@ import copy
import collections
import numpy as np
import paddle
import paddle.nn as nn
from ...common.wrapper_function import init_index, functional2layer
from . import losses
from .losses.basic_loss import BASIC_LOSS
......@@ -91,7 +90,7 @@ def _remove_hooks(hooks):
hook.remove()
class Distill(nn.Layer):
class Distill(paddle.nn.Layer):
"""
Distill API.
configs(list(dict) | string): the list of distill config or the path of yaml file which contain the distill config.
......@@ -111,9 +110,9 @@ class Distill(nn.Layer):
super(Distill, self).__init__()
if convert_fn:
functional2layer()
if isinstance(students, nn.Layer):
if isinstance(students, paddle.nn.Layer):
students = [students]
if isinstance(teachers, nn.Layer):
if isinstance(teachers, paddle.nn.Layer):
teachers = [teachers]
if isinstance(configs, list):
......@@ -125,8 +124,8 @@ class Distill(nn.Layer):
raise NotImplementedError("distill config file type error!")
else:
raise NotImplementedError("distill config error!")
self._student_models = nn.LayerList(students)
self._teacher_models = nn.LayerList(teachers)
self._student_models = paddle.nn.LayerList(students)
self._teacher_models = paddle.nn.LayerList(teachers)
self._return_model_outputs = return_model_outputs
self._loss_config_list = []
......
......@@ -14,7 +14,6 @@
import copy
import paddle
import paddle.nn as nn
from . import basic_loss
from . import distillation_loss
......@@ -22,7 +21,7 @@ from . import distillation_loss
from .distillation_loss import DistillationLoss
class CombinedLoss(nn.Layer):
class CombinedLoss(paddle.nn.Layer):
"""
CombinedLoss: a combination of loss function.
Args:
......@@ -40,7 +39,7 @@ class CombinedLoss(nn.Layer):
def __init__(self, loss_config_list=None):
super(CombinedLoss, self).__init__()
loss_config_list = copy.deepcopy(loss_config_list)
self.loss_func = nn.LayerList()
self.loss_func = paddle.nn.LayerList()
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
......
......@@ -14,14 +14,13 @@
import numpy as np
import paddle
import paddle.nn as nn
from .basic_loss import BASIC_LOSS
__all__ = ["DistillationLoss", "ShapeAlign"]
class DistillationLoss(nn.Layer):
class DistillationLoss(paddle.nn.Layer):
"""
DistillationLoss
Args:
......@@ -78,7 +77,7 @@ class DistillationLoss(nn.Layer):
return loss_dict
class ShapeAlign(nn.Layer):
class ShapeAlign(paddle.nn.Layer):
"""
Align the feature map between student and teacher.
Args:
......@@ -171,5 +170,7 @@ class ShapeAlign(nn.Layer):
bias_attr=bias_attr)
def forward(self, feat):
if isinstance(feat, tuple):
feat = feat[0]
out = self.align_op(feat)
return out
......@@ -3,7 +3,6 @@ import collections
import numpy as np
import logging
from paddleslim.common import get_logger
from paddle.fluid import core
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['PruningPlan', 'PruningMask']
......@@ -115,7 +114,7 @@ class PruningPlan():
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p = paddle.fluid.core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
......@@ -154,7 +153,7 @@ class PruningPlan():
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p = paddle.fluid.core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
......@@ -196,7 +195,7 @@ class PruningPlan():
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p = paddle.fluid.core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
......@@ -247,7 +246,7 @@ class PruningPlan():
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p = paddle.fluid.core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
t_value.set(pruned_value, place)
......@@ -275,7 +274,7 @@ class PruningPlan():
elif p.is_cuda_pinned_place():
place = paddle.CUDAPinnedPlace()
else:
p = core.Place()
p = paddle.fluid.core.Place()
p.set_place(t_value._place())
place = paddle.CUDAPlace(p.gpu_device_id())
......
......@@ -90,6 +90,7 @@ class UnstructuredPruner():
tmp_array = np.ones(param.shape, dtype=np.float32)
mask_name = "_".join([param.name.replace(".", "_"), "mask"])
if mask_name not in sub_layer._buffers:
print(f"target type: {type(paddle.to_tensor(tmp_array))}")
sub_layer.register_buffer(mask_name,
paddle.to_tensor(tmp_array))
self.masks[param.name] = sub_layer._buffers[mask_name]
......
import numpy as np
import logging
import paddle
from paddle.fluid.dygraph import TracedLayer
from paddleslim.core import GraphWrapper, dygraph2program
from paddleslim.prune import PruningCollections
from paddleslim.common import get_logger
......
......@@ -16,7 +16,6 @@ import copy
import logging
import paddle
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
......@@ -204,7 +203,7 @@ class QAT(object):
# TODO: remove try-except when the version is stable
try:
self.imperative_qat = ImperativeQuantAware(
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -221,7 +220,7 @@ class QAT(object):
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = ImperativeQuantAware(
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -292,7 +291,7 @@ class QAT(object):
def _remove_preprocess(self, model):
state_dict = model.state_dict()
try:
self.imperative_qat = ImperativeQuantAware(
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -303,7 +302,7 @@ class QAT(object):
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = ImperativeQuantAware(
self.imperative_qat = paddle.fluid.contrib.slim.quantization.ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
......@@ -312,11 +311,12 @@ class QAT(object):
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'])
with paddle.utils.unique_name.guard():
if hasattr(model, "_layers"):
model = model._layers
model = self._model
self.imperative_qat.quantize(model)
model.set_state_dict(state_dict)
paddle.disable_static()
if hasattr(model, "_layers"):
model = model._layers
model = self._model
self.imperative_qat.quantize(model)
model.set_state_dict(state_dict)
paddle.enable_static()
return model
# SlimX系列小模型
PaddleSlim模型压缩工具在人脸识别,OCR,通用任务分类任务,检测任务等多个任务上都发布了SlimX系列小模型:
- `SlimMobileNet系列`
- `SlimFaceNet系列`
## SlimMobileNet系列指标
SlimMobileNet基于百度自研的[GP-NAS论文](https://openaccess.thecvf.com/content_CVPR_2020/papers/Li_GP-NAS_Gaussian_Process_Based_Neural_Architecture_Search_CVPR_2020_paper.pdf)(CVPR2020)AutoDL技术以及自研的蒸馏技术得到。
相比于MobileNetV3, SlimMobileNet_V1在精度提升1.7个点的情况下Flops可以压缩138%。
由于精度比MobileNetV3高出了1.7个点,SlimMobileNet_V1量化后精度仍然高于MobileNetV3。量化后SlimMobileNet_V1可以在精度高于MobileNetV3的情况下Flops压缩552%。SlimMobileNet_V4_x1_1为业界首次发布的Flops 300M以下,ImagenetNet精度超过80%的分类小模型。
|Method|Flops(M)|Top1 Acc|
|------|-----|-----|
|MobileNetV3_large_x1_0|225|75.2|
|MobileNetV3_large_x1_25|357|76.6|
|GhostNet_x1_3|220|75.7|
|SlimMobileNet_V1|163|76.9|
|SlimMobileNet_V4_x1_1|296|80.1|
|SlimMobileNet_V5|390|80.4|
## [SlimFaceNet](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/slimfacenet/README.md)系列指标
SlimFaceNet同样是基于百度自研的GP-NAS AutoDL技术以及百度自研的自监督超网络训练算法得到。相比于MobileNetV2,SlimFaceNet_A_x0_60 flops压缩216%,在RK3288上加速428%。基于PaddleSlim的离线量化功能还可以进一步压缩模型,相比于MobileNetV2,SlimFaceNet_A_x0_60_quant flops可以压缩865%,在RK3288硬件上可以加速643%。为了对齐论文,LFW指标为112x96输入下的结果;结合业务场景,Flops和speed为112x112输入下的结果,延时为RK3288上的延时。
|Method|LFW|Flops|speed|
|------|-----|-----|-----|
|MobileNetV2|98.58%|277M|270ms|
|MobileFaceNet|99.18%|224M|102ms|
|SlimFaceNet_A_x0_60|99.21%|128M|63ms|
|SlimFaceNet_B_x0_75|99.22%|151M|70ms|
|SlimFaceNet_A_x0_60_quant|99.17%|32M|42ms|
|SlimFaceNet_B_x0_75_quant|99.21%|38M|45ms|
## 业界领先的AutoDL技术
GP-NAS从贝叶斯角度来建模NAS,并为不同的搜索空间设计了定制化的高斯过程均值函数和核函数。 具体来说,基于GP-NAS的超参数,我们有能力高效率的预测搜索空间中任意模型结构的性能。 从而,模型结构自动搜索问题就被转
换为GP-NAS高斯过程的超参数估计问题。接下来,通过互信息最大化采样算法,我们可以有效地对模型结构进行采样。 因此,根据采样网络的性能,我们可以有效的逐步更新GP-NAS超参数的后验分布。基于估计出的GP-NAS超参数,
我们可以预测出满足特定延时约束的最优的模型结构,更详细的技术细节请参考GP-NAS论文。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .util import image_classification
from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75
from .slim_mobilenet import SlimMobileNet_v1, SlimMobileNet_v2, SlimMobileNet_v3, SlimMobileNet_v4, SlimMobileNet_v5
from ..models import mobilenet
from .mobilenet import *
from ..models import resnet
from .resnet import *
__all__ = ["image_classification"]
__all__ += mobilenet.__all__
__all__ += resnet.__all__
from __future__ import absolute_import
from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50
from .mobilenet_v2 import MobileNetV2
__all__ = ["model_list", "MobileNet", "ResNet34", "ResNet50", "MobileNetV2"]
model_list = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .mobilenet import MobileNetV1
from .resnet import ResNet
__all__ = ["MobileNetV1", "ResNet"]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#order: standard library, third party, local library
import os
import time
import sys
import math
import numpy as np
import argparse
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import framework
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=self.full_name() + "_bn" + "_scale"),
bias_attr=ParamAttr(name=self.full_name() + "_bn" + "_offset"),
moving_mean_name=self.full_name() + "_bn" + '_mean',
moving_variance_name=self.full_name() + "_bn" + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DepthwiseSeparable(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
super(DepthwiseSeparable, self).__init__()
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=3,
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
y = self._pointwise_conv(y)
return y
class MobileNetV1(fluid.dygraph.Layer):
def __init__(self, scale=1.0, class_dim=100):
super(MobileNetV1, self).__init__()
self.scale = scale
self.dwsl = []
self.conv1 = ConvBNLayer(
num_channels=3,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=1,
padding=1)
dws21 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale),
name="conv2_1")
self.dwsl.append(dws21)
dws22 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale),
name="conv2_2")
self.dwsl.append(dws22)
dws31 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale),
name="conv3_1")
self.dwsl.append(dws31)
dws32 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale),
name="conv3_2")
self.dwsl.append(dws32)
dws41 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale),
name="conv4_1")
self.dwsl.append(dws41)
dws42 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale),
name="conv4_2")
self.dwsl.append(dws42)
for i in range(5):
tmp = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale),
name="conv5_" + str(i + 1))
self.dwsl.append(tmp)
dws56 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale),
name="conv5_6")
self.dwsl.append(dws56)
dws6 = self.add_sublayer(
sublayer=DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale),
name="conv6")
self.dwsl.append(dws6)
self.pool2d_avg = Pool2D(pool_type='avg', global_pooling=True)
self.out = Linear(
int(1024 * scale),
class_dim,
param_attr=ParamAttr(
initializer=MSRA(), name=self.full_name() + "fc7_weights"),
bias_attr=ParamAttr(name=self.full_name() + "fc7_offset"))
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
y = dws(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 1024])
y = self.out(y)
return y
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D, BatchNorm, Linear
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=100):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=1, act='relu')
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.bottleneck_block_list = []
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1
import math
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = Linear(
self.pool2d_avg_output,
class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
y = self.conv(inputs)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['MobileNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [10, 16, 30],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class MobileNet():
def __init__(self):
self.params = train_parameters
def net(self, input, class_dim=1000, scale=1.0):
# conv1: 112x112
input = self.conv_bn_layer(
input,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1,
name="conv1")
# 56x56
input = self.depthwise_separable(
input,
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale,
name="conv2_1")
input = self.depthwise_separable(
input,
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale,
name="conv2_2")
# 28x28
input = self.depthwise_separable(
input,
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale,
name="conv3_1")
input = self.depthwise_separable(
input,
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale,
name="conv3_2")
# 14x14
input = self.depthwise_separable(
input,
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale,
name="conv4_1")
input = self.depthwise_separable(
input,
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale,
name="conv4_2")
# 14x14
for i in range(5):
input = self.depthwise_separable(
input,
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale,
name="conv5" + "_" + str(i + 1))
# 7x7
input = self.depthwise_separable(
input,
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale,
name="conv5_6")
input = self.depthwise_separable(
input,
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale,
name="conv6")
input = fluid.layers.pool2d(
input=input,
pool_size=0,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
bias_attr=False)
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def depthwise_separable(self,
input,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
depthwise_conv = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=int(num_filters1 * scale),
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False,
name=name + "_dw")
pointwise_conv = self.conv_bn_layer(
input=depthwise_conv,
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0,
name=name + "_sep")
return pointwise_conv
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = [
'MobileNetV2', 'MobileNetV2_x0_25, '
'MobileNetV2_x0_5', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5',
'MobileNetV2_x2_0', 'MobileNetV2_scale'
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class MobileNetV2():
def __init__(self, scale=1.0, change_depth=False):
self.params = train_parameters
self.scale = scale
self.change_depth = change_depth
def net(self, input, class_dim=1000):
scale = self.scale
change_depth = self.change_depth
#if change_depth is True, the new depth is 1.4 times as deep as before.
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
] if change_depth == False else [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 5, 2),
(6, 64, 7, 2),
(6, 96, 5, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#conv1
input = self.conv_bn_layer(
input,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1,
if_act=True,
name='conv1_1')
# bottleneck sequences
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
input = self.invresi_blocks(
input=input,
in_c=in_c,
t=t,
c=int(c * scale),
n=n,
s=s,
name='conv' + str(i))
in_c = int(c * scale)
#last_conv
input = self.conv_bn_layer(
input=input,
num_filters=int(1280 * scale) if scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding=0,
if_act=True,
name='conv9')
input = fluid.layers.pool2d(
input=input,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.relu6(bn)
else:
return bn
def shortcut(self, input, data_residual):
return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
ifshortcut,
stride,
filter_size,
padding,
expansion_factor,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = self.conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name=name + '_expand')
bottleneck_conv = self.conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
if_act=True,
name=name + '_dwise',
use_cudnn=False)
linear_out = self.conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=False,
name=name + '_linear')
if ifshortcut:
out = self.shortcut(input=input, data_residual=linear_out)
return out
else:
return linear_out
def invresi_blocks(self, input, in_c, t, c, n, s, name=None):
first_block = self.inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
ifshortcut=False,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_1')
last_residual_block = first_block
last_c = c
for i in range(1, n):
last_residual_block = self.inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
ifshortcut=True,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_' + str(i + 1))
return last_residual_block
def MobileNetV2_x0_25():
model = MobileNetV2(scale=0.25)
return model
def MobileNetV2_x0_5():
model = MobileNetV2(scale=0.5)
return model
def MobileNetV2_x1_0():
model = MobileNetV2(scale=1.0)
return model
def MobileNetV2_x1_5():
model = MobileNetV2(scale=1.5)
return model
def MobileNetV2_x2_0():
model = MobileNetV2(scale=2.0)
return model
def MobileNetV2_scale():
model = MobileNetV2(scale=1.2, change_depth=True)
return model
此差异已折叠。
此差异已折叠。
此差异已折叠。
from __future__ import absolute_import
import paddle.fluid as fluid
from ..models import classification_models
__all__ = ["image_classification"]
model_list = classification_models.model_list
def image_classification(model, image_shape, class_num, use_gpu=False):
assert model in model_list
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = classification_models.__dict__[model]()
out = model.net(input=image, class_dim=class_num)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = fluid.optimizer.Momentum(0.1, 0.9)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
return exe, train_program, val_program, (image, label), (
acc_top1.name, acc_top5.name, avg_cost.name, out.name)
......@@ -15,9 +15,6 @@
from __future__ import absolute_import
from ..darts import train_search
from .train_search import *
from ..darts import search_space
from .search_space import *
__all__ = []
__all__ += train_search.__all__
__all__ += search_space.__all__
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
class Architect(object):
def __init__(self, model, eta, arch_learning_rate, place, unrolled):
self.network_momentum = 0.9
self.network_weight_decay = 1e-3
self.eta = eta
self.model = model
self.optimizer = fluid.optimizer.Adam(
arch_learning_rate,
0.5,
0.999,
regularization=fluid.regularizer.L2Decay(1e-3),
parameter_list=self.model.arch_parameters())
self.place = place
self.unrolled = unrolled
if self.unrolled:
self.unrolled_model = self.model.new()
self.unrolled_model_params = [
p for p in self.unrolled_model.parameters()
if p.name not in [
a.name for a in self.unrolled_model.arch_parameters()
] and p.trainable
]
self.unrolled_optimizer = fluid.optimizer.MomentumOptimizer(
self.eta,
self.network_momentum,
regularization=fluid.regularizer.L2DecayRegularizer(
self.network_weight_decay),
parameter_list=self.unrolled_model_params)
def step(self, train_data, valid_data, epoch):
if self.unrolled:
params_grads = self._backward_step_unrolled(train_data, valid_data)
self.optimizer.apply_gradients(params_grads)
else:
loss = self._backward_step(valid_data, epoch)
self.optimizer.minimize(loss)
self.optimizer.clear_gradients()
def _backward_step(self, valid_data, epoch):
loss = self.model.loss(valid_data, epoch)
loss[0].backward()
return loss[0]
def _backward_step_unrolled(self, train_data, valid_data):
self._compute_unrolled_model(train_data)
unrolled_loss = self.unrolled_model.loss(valid_data)
unrolled_loss.backward()
vector = [
to_variable(param._grad_ivar().numpy())
for param in self.unrolled_model_params
]
arch_params_grads = [
(alpha, to_variable(ualpha._grad_ivar().numpy()))
for alpha, ualpha in zip(self.model.arch_parameters(),
self.unrolled_model.arch_parameters())
]
self.unrolled_model.clear_gradients()
implicit_grads = self._hessian_vector_product(vector, train_data)
for (p, g), ig in zip(arch_params_grads, implicit_grads):
new_g = g - (ig * self.unrolled_optimizer.current_step_lr())
g.value().get_tensor().set(new_g.numpy(), self.place)
return arch_params_grads
def _compute_unrolled_model(self, data):
for x, y in zip(self.unrolled_model.parameters(),
self.model.parameters()):
x.value().get_tensor().set(y.numpy(), self.place)
loss = self.unrolled_model._loss(data)
loss.backward()
self.unrolled_optimizer.minimize(loss)
self.unrolled_model.clear_gradients()
def _hessian_vector_product(self, vector, data, r=1e-2):
R = r * fluid.layers.rsqrt(
fluid.layers.sum([
fluid.layers.reduce_sum(fluid.layers.square(v)) for v in vector
]))
model_params = [
p for p in self.model.parameters()
if p.name not in [a.name for a in self.model.arch_parameters()] and
p.trainable
]
for param, grad in zip(model_params, vector):
param_p = param + grad * R
param.value().get_tensor().set(param_p.numpy(), self.place)
loss = self.model.loss(data)
loss.backward()
grads_p = [
to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters()
]
for param, grad in zip(model_params, vector):
param_n = param - grad * R * 2
param.value().get_tensor().set(param_n.numpy(), self.place)
self.model.clear_gradients()
loss = self.model.loss(data)
loss.backward()
grads_n = [
to_variable(param._grad_ivar().numpy())
for param in self.model.arch_parameters()
]
for param, grad in zip(model_params, vector):
param_o = param + grad * R
param.value().get_tensor().set(param_o.numpy(), self.place)
self.model.clear_gradients()
arch_grad = [(p - n) / (2 * R) for p, n in zip(grads_p, grads_n)]
return arch_grad
......@@ -15,8 +15,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
......@@ -48,26 +48,11 @@ def get_genotype(model):
weightsr2 = None
weightsn2 = None
if model._method == "PC-DARTS":
n = 3
start = 2
weightsr2 = fluid.layers.softmax(model.betas_reduce[0:2])
weightsn2 = fluid.layers.softmax(model.betas_normal[0:2])
for i in range(model._steps - 1):
end = start + n
tw2 = fluid.layers.softmax(model.betas_reduce[start:end])
tn2 = fluid.layers.softmax(model.betas_normal[start:end])
start = end
n += 1
weightsr2 = fluid.layers.concat([weightsr2, tw2])
weightsn2 = fluid.layers.concat([weightsn2, tn2])
weightsr2 = weightsr2.numpy()
weightsn2 = weightsn2.numpy()
gene_normal = _parse(
fluid.layers.softmax(model.alphas_normal).numpy(), weightsn2)
paddle.nn.functional.softmax(model.alphas_normal).numpy(), weightsn2)
gene_reduce = _parse(
fluid.layers.softmax(model.alphas_reduce).numpy(), weightsr2)
paddle.nn.functional.softmax(model.alphas_reduce).numpy(), weightsr2)
concat = range(2 + model._steps - model._multiplier, model._steps + 2)
genotype = Genotype(
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册