diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index a9e1d6d2e06d56f837690ec95fa8f8d41a90725f..7c32eb0069f4075d72cd4c3654c83e3d5c98fb1c 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2286,8 +2286,15 @@ class NormLayer(LayerBase): @config_layer('pool') class PoolLayer(LayerBase): + layer_type = 'pool' + def __init__(self, name, inputs, ceil_mode=True, **xargs): - super(PoolLayer, self).__init__(name, 'pool', 0, inputs=inputs, **xargs) + use_mkldnn = int(g_command_config_args.get("use_mkldnn", 0)) + if self.layer_type == "mkldnn_pool": + config_assert(use_mkldnn, "mkldnn_pool only support MKLDNN") + self.layer_type = 'mkldnn_pool' if use_mkldnn else 'pool' + super(PoolLayer, self).__init__( + name, self.layer_type, 0, inputs=inputs, **xargs) for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) pool_conf = self.config.inputs[input_index].pool_conf @@ -2297,6 +2304,11 @@ class PoolLayer(LayerBase): pool_conf.channels) +@config_layer('mkldnn_pool') +class MKLDNNPoolLayer(PoolLayer): + layer_type = 'mkldnn_pool' + + @config_layer('pool3d') class Pool3DLayer(LayerBase): def __init__(self, name, inputs, ceil_mode=True, **xargs):