提交 44acc751 编写于 作者: T tensor-tang

add python interface for mkldnn_conv

上级 c39b771a
...@@ -2054,18 +2054,27 @@ class ConvLayerBase(LayerBase): ...@@ -2054,18 +2054,27 @@ class ConvLayerBase(LayerBase):
if num_filters is not None: if num_filters is not None:
self.config.num_filters = num_filters self.config.num_filters = num_filters
use_mkldnn = int(g_command_config_args.get("use_mkldnn", 0))
use_gpu = int(g_command_config_args.get("use_gpu", 0)) use_gpu = int(g_command_config_args.get("use_gpu", 0))
parallel_nn = int(g_command_config_args.get("parallel_nn", 0)) parallel_nn = int(g_command_config_args.get("parallel_nn", 0))
# Automatically select cudnn_type for GPU and exconv for CPU # Automatically select cudnn_type for GPU, exconv for CPU
# and mkldnn_conv for MKLDNN
# if set type=conv, but still reserve the way user specify # if set type=conv, but still reserve the way user specify
# exconv or cudnn_conv manually. # exconv, mkldnn_conv or cudnn_conv manually.
if self.layer_type == "cudnn_conv": if self.layer_type == "cudnn_conv":
config_assert(use_gpu, "cudnn_conv only support GPU") config_assert(use_gpu, "cudnn_conv only support GPU")
if self.layer_type == "mkldnn_conv":
config_assert(use_mkldnn, "mkldnn_conv only support MKLDNN")
if (use_gpu == 1 and self.layer_type != "exconv" and if (use_gpu == 1 and self.layer_type != "exconv" and
self.layer_type != "mkldnn_conv" and
(parallel_nn == 0 or self.config.device > -1)): (parallel_nn == 0 or self.config.device > -1)):
self.layer_type = "cudnn_conv" self.layer_type = "cudnn_conv"
else:
if (use_mkldnn == 1):
self.layer_type = "mkldnn_conv"
else: else:
self.layer_type = "exconv" self.layer_type = "exconv"
# need to specify layer in config # need to specify layer in config
...@@ -2099,6 +2108,11 @@ class ConvLayer(ConvLayerBase): ...@@ -2099,6 +2108,11 @@ class ConvLayer(ConvLayerBase):
layer_type = 'exconv' layer_type = 'exconv'
@config_layer('mkldnn_conv')
class ConvLayer(ConvLayerBase):
layer_type = 'mkldnn_conv'
@config_layer('cudnn_conv') @config_layer('cudnn_conv')
class ConvLayer(ConvLayerBase): class ConvLayer(ConvLayerBase):
layer_type = 'cudnn_conv' layer_type = 'cudnn_conv'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册