提交 cd02d9c8 编写于 作者: Z zhhsplendid

Splite PR, test=develop

上级 f4f5f3f2
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import logging
import time
import unittest
import numpy as np
import paddle
PRINT_STEP = 20
SEED = 2020
program_translator = paddle.fluid.dygraph.dygraph_to_static.ProgramTranslator()
class SimpleLSTMRNN(paddle.fluid.Layer):
def __init__(self,
hidden_size,
num_steps,
num_layers=2,
init_scale=0.1,
dropout=None):
super(SimpleLSTMRNN, self).__init__()
self._hidden_size = hidden_size
self._num_layers = num_layers
self._init_scale = init_scale
self._dropout = dropout
self._num_steps = num_steps
self.cell_array = []
self.hidden_array = []
self.weight_1_arr = []
self.weight_2_arr = []
self.bias_arr = []
self.mask_array = []
for i in range(self._num_layers):
weight_1 = self.create_parameter(
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
low=-self._init_scale, high=self._init_scale)),
shape=[self._hidden_size * 2, self._hidden_size * 4],
dtype="float32",
default_initializer=paddle.nn.initializer.Uniform(
low=-self._init_scale, high=self._init_scale))
self.weight_1_arr.append(self.add_parameter('w_%d' % i, weight_1))
bias_1 = self.create_parameter(
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
low=-self._init_scale, high=self._init_scale)),
shape=[self._hidden_size * 4],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0.0))
self.bias_arr.append(self.add_parameter('b_%d' % i, bias_1))
def forward(self, input_embedding, init_hidden=None, init_cell=None):
cell_array = []
hidden_array = []
for i in range(self._num_layers):
hidden_array.append(init_hidden[i])
cell_array.append(init_cell[i])
res = []
for index in range(self._num_steps):
step_input = input_embedding[:, index, :]
for k in range(self._num_layers):
pre_hidden = hidden_array[k]
pre_cell = cell_array[k]
weight_1 = self.weight_1_arr[k]
bias = self.bias_arr[k]
nn = paddle.concat(x=[step_input, pre_hidden], axis=1)
gate_input = paddle.matmul(x=nn, y=weight_1)
gate_input = paddle.add(x=gate_input, y=bias)
i, j, f, o = paddle.split(
x=gate_input, num_or_sections=4, axis=-1)
c = pre_cell * paddle.nn.functional.sigmoid(
f) + paddle.nn.functional.sigmoid(i) * paddle.tanh(j)
m = paddle.tanh(c) * paddle.nn.functional.sigmoid(o)
hidden_array[k] = m
cell_array[k] = c
step_input = m
if self._dropout is not None and self._dropout > 0.0:
step_input = paddle.fluid.layers.dropout(
step_input,
dropout_prob=self._dropout,
dropout_implementation='upscale_in_train')
res.append(step_input)
real_res = paddle.concat(x=res, axis=1)
real_res = paddle.fluid.layers.reshape(
real_res, [-1, self._num_steps, self._hidden_size])
last_hidden = paddle.concat(x=hidden_array, axis=1)
last_hidden = paddle.fluid.layers.reshape(
last_hidden, shape=[-1, self._num_layers, self._hidden_size])
last_hidden = paddle.transpose(x=last_hidden, perm=[1, 0, 2])
last_cell = paddle.concat(x=cell_array, axis=1)
last_cell = paddle.fluid.layers.reshape(
last_cell, shape=[-1, self._num_layers, self._hidden_size])
last_cell = paddle.transpose(x=last_cell, perm=[1, 0, 2])
return real_res, last_hidden, last_cell
class PtbModel(paddle.fluid.Layer):
def __init__(self,
hidden_size,
vocab_size,
num_layers=2,
num_steps=20,
init_scale=0.1,
dropout=None):
super(PtbModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.init_scale = init_scale
self.num_layers = num_layers
self.num_steps = num_steps
self.dropout = dropout
self.simple_lstm_rnn = SimpleLSTMRNN(
hidden_size,
num_steps,
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
self.embedding = paddle.fluid.dygraph.nn.Embedding(
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False,
param_attr=paddle.ParamAttr(
name='embedding_para',
initializer=paddle.nn.initializer.Uniform(
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter(
attr=paddle.ParamAttr(),
shape=[self.hidden_size, self.vocab_size],
dtype="float32",
default_initializer=paddle.nn.initializer.Uniform(
low=-self.init_scale, high=self.init_scale))
self.softmax_bias = self.create_parameter(
attr=paddle.ParamAttr(),
shape=[self.vocab_size],
dtype="float32",
default_initializer=paddle.nn.initializer.Uniform(
low=-self.init_scale, high=self.init_scale))
def build_once(self, input, label, init_hidden, init_cell):
pass
@paddle.fluid.dygraph.jit.declarative
def forward(self, input, label, init_hidden, init_cell):
init_h = paddle.fluid.layers.reshape(
init_hidden, shape=[self.num_layers, -1, self.hidden_size])
init_c = paddle.fluid.layers.reshape(
init_cell, shape=[self.num_layers, -1, self.hidden_size])
x_emb = self.embedding(input)
x_emb = paddle.fluid.layers.reshape(
x_emb, shape=[-1, self.num_steps, self.hidden_size])
if self.dropout is not None and self.dropout > 0.0:
x_emb = paddle.fluid.layers.dropout(
x_emb,
dropout_prob=self.dropout,
dropout_implementation='upscale_in_train')
rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h,
init_c)
projection = paddle.matmul(x=rnn_out, y=self.softmax_weight)
projection = paddle.add(x=projection, y=self.softmax_bias)
loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False)
loss = paddle.fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = paddle.reduce_mean(loss, dim=[0])
loss = paddle.reduce_sum(loss)
return loss, last_hidden, last_cell
def debug_emb(self):
np.save("emb_grad", self.x_emb.gradient())
def train(place):
num_layers = 1
batch_size = 4
hidden_size = 10
num_steps = 3
init_scale = 0.1
max_epoch = 1
dropout = 0.0
vocab_size = 1000
batch_num = 200
paddle.disable_static(place)
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale,
dropout=dropout)
sgd = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=ptb_model.parameters())
for epoch_id in range(max_epoch):
total_loss = 0.0
iters = 0.0
total_sample = 0
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_hidden = paddle.to_tensor(
data=init_hidden_data, dtype=None, place=None, stop_gradient=True)
init_cell = paddle.to_tensor(
data=init_cell_data, dtype=None, place=None, stop_gradient=True)
for step_id in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, num_steps, 1))
x = paddle.to_tensor(
data=x_data, dtype=None, place=None, stop_gradient=True)
y = paddle.to_tensor(
data=y_data, dtype=None, place=None, stop_gradient=True)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
out_loss = dy_loss.numpy()
dy_loss.backward()
sgd.minimize(dy_loss)
ptb_model.clear_gradients()
total_loss += out_loss
iters += num_steps
total_sample += 1
if step_id % PRINT_STEP == 0:
if step_id == 0:
logging.info("epoch %d | step %d, loss %0.3f" %
(epoch_id, step_id, total_loss / total_sample))
avg_batch_time = time.time()
else:
speed = PRINT_STEP / (time.time() - avg_batch_time)
logging.info(
"epoch %d | step %d, loss %0.3f, speed %.3f steps/s" %
(epoch_id, step_id, total_loss / total_sample, speed))
avg_batch_time = time.time()
ret = out_loss, last_hidden.numpy(), last_cell.numpy()
paddle.enable_static()
return ret
def train_dygraph(place):
program_translator.enable(False)
return train(place)
def train_static(place):
program_translator.enable(True)
return train(place)
class TestPtb(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_check_result(self):
loss_1, hidden_1, cell_1 = train_static(self.place)
loss_2, hidden_2, cell_2 = train_dygraph(self.place)
self.assertTrue(
np.allclose(loss_1, loss_2),
msg="static loss: {} \ndygraph loss: {}".format(loss_1, loss_2))
self.assertTrue(
np.allclose(hidden_1, hidden_2),
msg="static hidden: {} \ndygraph acc1: {}".format(hidden_1,
hidden_2))
self.assertTrue(
np.allclose(cell_1, cell_2),
msg="static cell: {} \ndygraph cell: {}".format(cell_1, cell_2))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
import time
import unittest
import numpy as np
import paddle
from predictor_utils import PredictorTools
SEED = 2020
IMAGENET1000 = 1281167
base_lr = 0.001
momentum_rate = 0.9
l2_decay = 1e-4
# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout.
batch_size = 2
epoch_num = 1
place = paddle.CUDAPlace(0) if paddle.fluid.is_compiled_with_cuda() \
else paddle.CPUPlace()
MODEL_SAVE_PATH = "./resnet_v2.inference.model"
DY_STATE_DICT_SAVE_PATH = "./resnet_v2.dygraph"
program_translator = paddle.jit.ProgramTranslator()
if paddle.fluid.is_compiled_with_cuda():
paddle.fluid.set_flags({'FLAGS_cudnn_deterministic': True})
def optimizer_setting(parameter_list=None):
optimizer = paddle.optimizer.Momentum(
learning_rate=base_lr,
momentum=momentum_rate,
weight_decay=paddle.fluid.regularizer.L2Decay(l2_decay),
parameters=parameter_list)
return optimizer
class ConvBNLayer(paddle.nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = paddle.nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self._batch_norm = paddle.nn.BatchNorm(num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(paddle.nn.Layer):
def __init__(self, num_channels, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
layer_helper = paddle.fluid.layer_helper.LayerHelper(
self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet(paddle.nn.Layer):
def __init__(self, layers=50, class_dim=102):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
self.pool2d_max = paddle.fluid.dygraph.nn.Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
self.bottleneck_block_list = []
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = paddle.fluid.dygraph.nn.Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = paddle.nn.Linear(
in_features=self.pool2d_avg_output,
out_features=class_dim,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv)))
@paddle.fluid.dygraph.declarative
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = paddle.fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
pred = self.out(y)
pred = paddle.nn.functional.softmax(pred)
return pred
def reader_decorator(reader):
def __reader__():
for item in reader():
img = np.array(item[0]).astype('float32').reshape(3, 224, 224)
label = np.array(item[1]).astype('int64').reshape(1)
yield img, label
return __reader__
def train(to_static):
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
paddle.disable_static(place)
np.random.seed(SEED)
paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size,
drop_last=True)
data_loader = paddle.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = ResNet()
optimizer = optimizer_setting(parameter_list=resnet.parameters())
for epoch in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(data_loader()):
start_time = time.time()
img, label = data
pred = resnet(img)
loss = paddle.fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = paddle.mean(x=loss)
acc_top1 = paddle.metric.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=pred, label=label, k=5)
avg_loss.backward()
optimizer.minimize(avg_loss)
resnet.clear_gradients()
total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1
end_time = time.time()
if batch_id % 2 == 0:
print( "epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f" % \
( epoch, batch_id, total_loss.numpy() / total_sample, \
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time))
if batch_id == 10:
if to_static:
paddle.fluid.dygraph.jit.save(resnet, MODEL_SAVE_PATH)
else:
paddle.fluid.dygraph.save_dygraph(resnet.state_dict(),
DY_STATE_DICT_SAVE_PATH)
# avoid dataloader throw abort signaal
data_loader._reset()
break
paddle.enable_static()
return total_loss.numpy()
def predict_dygraph(data):
program_translator.enable(False)
paddle.disable_static(place)
resnet = ResNet()
model_dict, _ = paddle.fluid.dygraph.load_dygraph(DY_STATE_DICT_SAVE_PATH)
resnet.set_dict(model_dict)
resnet.eval()
pred_res = resnet(
paddle.to_tensor(
data=data, dtype=None, place=None, stop_gradient=True))
ret = pred_res.numpy()
paddle.enable_static()
return ret
def predict_static(data):
exe = paddle.static.Executor(place)
[inference_program, feed_target_names,
fetch_targets] = paddle.io.load_inference_model(
MODEL_SAVE_PATH,
executor=exe,
params_filename=paddle.fluid.dygraph.io.VARIABLE_FILENAME)
pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data},
fetch_list=fetch_targets)
return pred_res[0]
def predict_dygraph_jit(data):
paddle.disable_static(place)
resnet = paddle.fluid.dygraph.jit.load(MODEL_SAVE_PATH)
resnet.eval()
pred_res = resnet(data)
ret = pred_res.numpy()
paddle.enable_static()
return ret
def predict_analysis_inference(data):
output = PredictorTools(MODEL_SAVE_PATH,
paddle.fluid.dygraph.io.VARIABLE_FILENAME, [data])
out = output()
return out
class TestResnet(unittest.TestCase):
def train(self, to_static):
program_translator.enable(to_static)
return train(to_static)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image)
st_pre = predict_static(image)
dy_jit_pre = predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image)
self.assertTrue(
np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
self.assertTrue(
np.allclose(dy_jit_pre, st_pre),
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
self.assertTrue(
np.allclose(predictor_pre, st_pre),
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre,
st_pre))
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
dygraph_loss))
self.verify_predict()
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
train(to_static=True)
finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册