未验证 提交 8e876f9e 编写于 作者: W Wang Meng 提交者: GitHub

Merge pull request #601 from will-am/fix_seresnext

Refine SE-ResNeXt
import os
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import reader
......@@ -100,7 +99,11 @@ def SE_ResNeXt(input, class_dim, infer=False):
return out
def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
def train(learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model'):
class_dim = 1000
image_shape = [3, 224, 224]
......@@ -129,6 +132,9 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
......@@ -145,16 +151,18 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
test_accuracy.reset(exe)
for data in test_reader():
out, acc = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost] + test_accuracy.metrics)
loss, acc = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost] + test_accuracy.metrics)
test_pass_acc = test_accuracy.eval(exe)
print("End pass {0}, train_acc {1}, test_acc {2}".format(
pass_id, pass_acc, test_pass_acc))
model_path = os.path.join(model_save_dir, str(pass_id))
fluid.io.save_inference_model(model_path, ['image'], [out], exe)
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
train(learning_rate=0.1, batch_size=8, num_passes=100)
train(learning_rate=0.1, batch_size=8, num_passes=100, init_model=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册