未验证 提交 c2034c50 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #147 from littletomatodonkey/fix_effnet

fix EfficientNet
mode: 'train' mode: 'train'
ARCHITECTURE: ARCHITECTURE:
name: "EfficientNetB0" name: "EfficientNetB0"
drop_connect_rate: 0.1 params:
padding_type : "SAME" is_test: False
padding_type : "SAME"
override_params:
drop_connect_rate: 0.1
pretrained_model: "" pretrained_model: ""
model_save_dir: "./output/" model_save_dir: "./output/"
classes_num: 1000 classes_num: 1000
......
...@@ -103,7 +103,8 @@ def create_model(architecture, image, classes_num, is_train): ...@@ -103,7 +103,8 @@ def create_model(architecture, image, classes_num, is_train):
""" """
name = architecture["name"] name = architecture["name"]
params = architecture.get("params", {}) params = architecture.get("params", {})
params['is_test'] = not is_train if "is_test" in params:
params['is_test'] = not is_train
model = architectures.__dict__[name](**params) model = architectures.__dict__[name](**params)
out = model.net(input=image, class_dim=classes_num) out = model.net(input=image, class_dim=classes_num)
return out return out
...@@ -418,7 +419,6 @@ def run(dataloader, ...@@ -418,7 +419,6 @@ def run(dataloader,
Returns: Returns:
""" """
print(fetchs)
fetch_list = [f[0] for f in fetchs.values()] fetch_list = [f[0] for f in fetchs.values()]
metric_list = [f[1] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()]
for m in metric_list: for m in metric_list:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册