提交 f13d725a 编写于 作者: T tensor-tang

add mkldnn_lrn python interface and add it to simple net

上级 343b1a96
......@@ -51,6 +51,8 @@ tmp = img_pool_layer(input=tmp,
padding=1,
pool_type=MaxPooling())
tmp = img_cmrnorm_layer(input=tmp, size=5, scale=0.0001, power=0.75)
tmp = fc_layer(input=tmp,
size=channels,
bias_attr=False,
......
......@@ -2287,11 +2287,17 @@ class Conv3DLayer(Conv3DLayerBase):
class NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(NormLayer, self).__init__(name, 'norm', 0, inputs=inputs, **xargs)
use_mkldnn = bool(int(g_command_config_args.get("use_mkldnn", 0)))
use_mkldnn = True if use_mkldnn and self.inputs[
0].norm.norm_type == 'cmrnorm-projection' else False
self.config.type = 'mkldnn_lrn' if use_mkldnn else self.config.type
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
norm_conf = self.config.inputs[input_index].norm_conf
parse_norm(self.inputs[input_index].norm, input_layer.name,
norm_conf)
norm_conf.scale = self.inputs[
input_index].norm.scale if use_mkldnn else norm_conf.scale
self.set_cnn_layer(name, norm_conf.output_y, norm_conf.output_x,
norm_conf.channels, False)
if norm_conf.norm_type == "cross-channel-norm":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册