未验证 提交 020d7573 编写于 作者: H Hongyu Liu 提交者: GitHub

Add resnet eval mode (#2336)

* add ptb lm; test=develop

* add dynamic ocr recognition; test=develop

* fix renet bug; test=develop

* fix ocr save model; test=develop

* fix format; test=develop
上级 225db53d
......@@ -20,12 +20,37 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
batch_size = 8
epoch = 10
from paddle.fluid import framework
import math
import sys
batch_size = 32
epoch = 120
IMAGENET1000 = 1281167
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
def optimizer_setting():
return fluid.optimizer.SGD(learning_rate=0.01)
total_images = IMAGENET1000
step = int(math.ceil(float(total_images) / batch_size))
epochs = [30, 60, 90]
bd = [step * e for e in epochs]
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
return optimizer
class ConvBNLayer(fluid.dygraph.Layer):
......@@ -186,19 +211,81 @@ class ResNet(fluid.dygraph.Layer):
return y
def eval(model, data):
model.eval()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
for batch_id, data in enumerate(data()):
dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
if len(np.array([x[1] for x in data]).astype('int64')) != batch_size:
continue
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label._stop_gradient = True
out = model(img)
#loss = fluid.layers.cross_entropy(input=out, label=label)
#avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
#dy_out = avg_loss.numpy()
#total_loss += dy_out
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1
# print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
if batch_id % 10 == 0:
print("test | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
( batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
print("final eval loss %0.3f acc1 %0.3f acc5 %0.3f" % \
(total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
def train_resnet():
with fluid.dygraph.guard():
resnet = ResNet("resnet")
optimizer = optimizer_setting()
train_reader = paddle.batch(
paddle.dataset.flowers.train(),
batch_size=batch_size)
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)
#file_name = './model/epoch_0.npz'
#model_data = np.load( file_name )
for eop in range(epoch):
resnet.train()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
#dict_state = resnet.state_dict()
#resnet.load_dict( model_data )
print("load finished")
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
if len(np.array([x[1] for x in data]).astype('int64')) != batch_size:
if len(np.array([x[1]
for x in data]).astype('int64')) != batch_size:
continue
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
......@@ -211,13 +298,33 @@ def train_resnet():
loss = fluid.layers.cross_entropy(input=out, label=label)
avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
dy_out = avg_loss.numpy()
avg_loss.backward()
optimizer.minimize(avg_loss)
resnet.clear_gradients()
print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
framework._dygraph_tracer_._clear_ops()
total_loss += dy_out
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1
#print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
if batch_id % 10 == 0:
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
( eop, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
(eop, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
resnet.eval()
eval(resnet, test_reader)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册