提交 f08bbbdf 编写于 作者: W wangmeng28

Add load pretrained model in SE-ResNeXt

上级 16347d3e
......@@ -100,7 +100,11 @@ def SE_ResNeXt(input, class_dim, infer=False):
return out
def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
def train(learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model'):
class_dim = 1000
image_shape = [3, 224, 224]
......@@ -129,6 +133,9 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_params(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
......@@ -157,4 +164,4 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
if __name__ == '__main__':
train(learning_rate=0.1, batch_size=8, num_passes=100)
train(learning_rate=0.1, batch_size=8, num_passes=100, init_model=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册