提交 eff17a68 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #3049 from lcy-seso/fix_v2_bachnorm_parse

enable v2 automatically sets using cudnn-batch norm.
......@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase):
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Also based on cudnn version.
use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \
((not parallel_nn) or self.config.device > -1) and \
cudnn_version >= 4007
((not parallel_nn) or self.config.device > -1)
self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm"
super(BatchNormLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs)
......
......@@ -34,6 +34,7 @@ import minibatch
import plot
import image
import model
import paddle.trainer.config_parser as cp
__all__ = [
'optimizer',
......@@ -58,6 +59,8 @@ __all__ = [
'model',
]
cp.begin_parse()
def init(**kwargs):
import py_paddle.swig_paddle as api
......@@ -73,6 +76,11 @@ def init(**kwargs):
for key in args_dict.keys():
args.append('--%s=%s' % (key, str(args_dict[key])))
if 'use_gpu' in kwargs:
cp.g_command_config_args['use_gpu'] = kwargs['use_gpu']
assert 'parallel_nn' not in kwargs, ("currently 'parallel_nn' is not "
"supported in v2 APIs.")
api.initPaddle(*args)
......
......@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None):
def get_layer(name):
return config_base.__layer_map__.get(name)
cp.begin_parse()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册