未验证 提交 f0c08fc0 编写于 作者: H hysunflower 提交者: GitHub

add_max_iter_for_resnet (#4617)

上级 0952a2de
...@@ -25,6 +25,7 @@ from paddle.fluid import framework ...@@ -25,6 +25,7 @@ from paddle.fluid import framework
import math import math
import sys import sys
import time
IMAGENET1000 = 1281167 IMAGENET1000 = 1281167
base_lr = 0.1 base_lr = 0.1
...@@ -45,6 +46,9 @@ def parse_args(): ...@@ -45,6 +46,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
"-b", "--batch_size", default=32, type=int, help="set epoch") "-b", "--batch_size", default=32, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce") parser.add_argument("--ce", action="store_true", help="run ce")
# NOTE:used in benchmark
parser.add_argument("--max_iter", default=0, type=int, help="the max iters to train, used in benchmark")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -310,6 +314,9 @@ def train_resnet(): ...@@ -310,6 +314,9 @@ def train_resnet():
#file_name = './model/epoch_0.npz' #file_name = './model/epoch_0.npz'
#model_data = np.load( file_name ) #model_data = np.load( file_name )
#NOTE: used in benchmark
total_batch_num = 0
for eop in range(epoch): for eop in range(epoch):
resnet.train() resnet.train()
...@@ -325,6 +332,12 @@ def train_resnet(): ...@@ -325,6 +332,12 @@ def train_resnet():
print("load finished") print("load finished")
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
#NOTE: used in benchmark
if args.max_iter and total_batch_num == args.max_iter:
return
batch_start = time.time()
dy_x_data = np.array( dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32') [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
if len(np.array([x[1] if len(np.array([x[1]
...@@ -356,15 +369,18 @@ def train_resnet(): ...@@ -356,15 +369,18 @@ def train_resnet():
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
resnet.clear_gradients() resnet.clear_gradients()
batch_end = time.time()
train_batch_cost = batch_end - batch_start
total_loss += dy_out total_loss += dy_out
total_acc1 += acc_top1.numpy() total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy() total_acc5 += acc_top5.numpy()
total_sample += 1 total_sample += 1
total_batch_num = total_batch_num + 1 #this is for benchmark
#print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out)) #print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
if batch_id % 10 == 0: if batch_id % 10 == 0:
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \ print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f, batch cost: %.5f" % \
( eop, batch_id, total_loss / total_sample, \ ( eop, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample)) total_acc1 / total_sample, total_acc5 / total_sample, train_batch_cost))
if args.ce: if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample)) print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册