未验证 提交 16c01ffb 编写于 作者: D dyning 提交者: GitHub

Merge pull request #74 from WuHaobo/googlenet

fix mixup while using GoogleNet
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import sys import sys
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -36,7 +37,7 @@ def check_version(): ...@@ -36,7 +37,7 @@ def check_version():
try: try:
fluid.require_version('1.7.0') fluid.require_version('1.7.0')
except Exception as e: except Exception:
logger.error(err) logger.error(err)
sys.exit(1) sys.exit(1)
...@@ -62,7 +63,8 @@ def check_architecture(architecture): ...@@ -62,7 +63,8 @@ def check_architecture(architecture):
assert isinstance(architecture, dict), \ assert isinstance(architecture, dict), \
("the type of architecture({}) should be dict". format(architecture)) ("the type of architecture({}) should be dict". format(architecture))
assert "name" in architecture, \ assert "name" in architecture, \
("name must be in the architecture keys, just contains: {}". format(architecture.keys())) ("name must be in the architecture keys, just contains: {}". format(
architecture.keys()))
similar_names = similar_architectures(architecture["name"], similar_names = similar_architectures(architecture["name"],
get_architectures()) get_architectures())
...@@ -83,7 +85,8 @@ def check_mix(architecture, use_mix=False): ...@@ -83,7 +85,8 @@ def check_mix(architecture, use_mix=False):
err = "Cannot use mix processing in GoogLeNet, " \ err = "Cannot use mix processing in GoogLeNet, " \
"please set use_mix = False." "please set use_mix = False."
try: try:
if architecture["name"] == "GoogLeNet": assert use_mix == False if architecture["name"] == "GoogLeNet":
assert use_mix is not True
except AssertionError: except AssertionError:
logger.error(err) logger.error(err)
sys.exit(1) sys.exit(1)
......
...@@ -100,7 +100,7 @@ def check_config(config): ...@@ -100,7 +100,7 @@ def check_config(config):
architecture = config.get('ARCHITECTURE') architecture = config.get('ARCHITECTURE')
check.check_architecture(architecture) check.check_architecture(architecture)
use_mix = config.get('use_mix') use_mix = config.get('use_mix', False)
check.check_mix(architecture, use_mix) check.check_mix(architecture, use_mix)
classes_num = config.get('classes_num') classes_num = config.get('classes_num')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册