Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
9b5a2831
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9b5a2831
编写于
12月 11, 2018
作者:
T
TensorFlower Gardener
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #24086 from Intel-tensorflow:nhasabni/fusedconv
PiperOrigin-RevId: 225099426
上级
5269f8ac
7250f853
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
544 addition
and
19 deletion
+544
-19
tensorflow/core/graph/mkl_layout_pass.cc
tensorflow/core/graph/mkl_layout_pass.cc
+63
-3
tensorflow/core/graph/mkl_layout_pass_test.cc
tensorflow/core/graph/mkl_layout_pass_test.cc
+127
-1
tensorflow/core/kernels/mkl_conv_ops.cc
tensorflow/core/kernels/mkl_conv_ops.cc
+74
-5
tensorflow/core/kernels/mkl_fused_ops_test.cc
tensorflow/core/kernels/mkl_fused_ops_test.cc
+253
-10
tensorflow/core/ops/mkl_nn_ops.cc
tensorflow/core/ops/mkl_nn_ops.cc
+27
-0
未找到文件。
tensorflow/core/graph/mkl_layout_pass.cc
浏览文件 @
9b5a2831
...
...
@@ -260,6 +260,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_
.
conv3d_grad_filter
=
"Conv3DBackpropFilterV2"
;
csinfo_
.
fused_batch_norm
=
"FusedBatchNorm"
;
csinfo_
.
fused_batch_norm_grad
=
"FusedBatchNormGrad"
;
csinfo_
.
fused_conv2d
=
"_FusedConv2D"
;
csinfo_
.
identity
=
"Identity"
;
csinfo_
.
lrn
=
"LRN"
;
csinfo_
.
lrn_grad
=
"LRNGrad"
;
...
...
@@ -274,6 +275,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_
.
mkl_conv2d_with_bias
=
"_MklConv2DWithBias"
;
csinfo_
.
mkl_conv2d_grad_filter_with_bias
=
"_MklConv2DBackpropFilterWithBias"
;
csinfo_
.
mkl_fused_conv2d
=
"_MklFusedConv2D"
;
csinfo_
.
mkl_pad_with_conv2d
=
"_MklPadWithConv2D"
;
csinfo_
.
pad
=
"Pad"
;
csinfo_
.
pad_with_conv2d
=
"__MklDummyPadWithConv2D"
;
...
...
@@ -380,6 +382,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{
csinfo_
.
fused_batch_norm_grad
,
mkl_op_registry
::
GetMklOpName
(
csinfo_
.
fused_batch_norm_grad
),
CopyAttrsFusedBatchNorm
,
AlwaysRewrite
});
rinfo_
.
push_back
({
csinfo_
.
fused_conv2d
,
csinfo_
.
mkl_fused_conv2d
,
CopyAttrsFusedConv2D
,
FusedConv2DRewrite
});
rinfo_
.
push_back
({
csinfo_
.
identity
,
mkl_op_registry
::
GetMklOpName
(
csinfo_
.
identity
),
CopyAttrsDataType
,
AlwaysRewrite
});
...
...
@@ -665,6 +669,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string
conv3d_grad_filter
;
string
fused_batch_norm
;
string
fused_batch_norm_grad
;
string
fused_conv2d
;
string
identity
;
string
lrn
;
string
lrn_grad
;
...
...
@@ -679,6 +684,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string
mkl_conv2d_grad_filter
;
string
mkl_conv2d_grad_filter_with_bias
;
string
mkl_conv2d_with_bias
;
string
mkl_fused_conv2d
;
string
mkl_pad_with_conv2d
;
string
mul
;
string
pad
;
...
...
@@ -1174,6 +1180,23 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return
false
;
}
static
bool
FusedConv2DRewrite
(
const
Node
*
n
)
{
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
// it includes those we support.
DataType
T
;
if
(
!
GetNodeAttr
(
n
->
def
(),
"T"
,
&
T
).
ok
()
||
!
mkl_op_registry
::
IsMklOp
(
csinfo_
.
mkl_fused_conv2d
,
T
))
{
return
false
;
}
std
::
vector
<
string
>
fused_ops
;
TF_CHECK_OK
(
GetNodeAttr
(
n
->
def
(),
"fused_ops"
,
&
fused_ops
));
return
(
fused_ops
==
std
::
vector
<
string
>
{
"BiasAdd"
}
||
fused_ops
==
std
::
vector
<
string
>
{
"Relu"
}
||
fused_ops
==
std
::
vector
<
string
>
{
"BiasAdd"
,
"Relu"
});
}
// Rewrites input node to a new node specified by its matching rewrite info.
//
// Method first searches matching rewrite info for input node and then
...
...
@@ -1335,6 +1358,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
bool
change_format
=
false
);
static
void
CopyAttrsFusedBatchNorm
(
const
Node
*
orig_node
,
NodeBuilder
*
nb
,
bool
change_format
=
false
);
static
void
CopyAttrsFusedConv2D
(
const
Node
*
orig_node
,
NodeBuilder
*
nb
,
bool
change_format
=
false
);
static
void
CopyAttrsLRN
(
const
Node
*
orig_node
,
NodeBuilder
*
nb
,
bool
change_format
=
false
);
static
void
CopyAttrsPadWithConv2D
(
const
Node
*
orig_node
,
NodeBuilder
*
nb
,
...
...
@@ -1554,12 +1579,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
CHECK_NOTNULL
(
filter_node
);
// Now check which nodes receive from filter_node. Filter feeds as
// 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
// 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
// _MklFusedConv2D.
for
(
const
Edge
*
e
:
filter_node
->
out_edges
())
{
if
((
e
->
dst
()
->
type_string
()
==
csinfo_
.
mkl_conv2d
||
// add check for mkl_pad_with_conv2d
e
->
dst
()
->
type_string
()
==
csinfo_
.
mkl_pad_with_conv2d
||
e
->
dst
()
->
type_string
()
==
csinfo_
.
mkl_conv2d_with_bias
)
&&
e
->
dst
()
->
type_string
()
==
csinfo_
.
mkl_conv2d_with_bias
||
e
->
dst
()
->
type_string
()
==
csinfo_
.
mkl_fused_conv2d
)
&&
e
->
dst_input
()
==
kConv2DFilterInputSlotIdx
/* filter is 2nd input of Conv2D and _MklConv2D. */
)
{
if
(
conv2d_node
!=
nullptr
)
{
...
...
@@ -2234,6 +2260,39 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
nb
->
Attr
(
"is_training"
,
is_training
);
}
void
MklLayoutRewritePass
::
CopyAttrsFusedConv2D
(
const
Node
*
orig_node
,
NodeBuilder
*
nb
,
bool
change_format
)
{
DataType
T
;
int
num_args
;
float
epsilon
;
string
data_format
;
string
padding
;
std
::
vector
<
int32
>
strides
;
std
::
vector
<
int32
>
dilations
;
std
::
vector
<
string
>
fused_ops
;
// Get all attributes from old node.
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"T"
,
&
T
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"num_args"
,
&
num_args
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"strides"
,
&
strides
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"padding"
,
&
padding
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"data_format"
,
&
data_format
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"dilations"
,
&
dilations
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"fused_ops"
,
&
fused_ops
));
TF_CHECK_OK
(
GetNodeAttr
(
orig_node
->
def
(),
"epsilon"
,
&
epsilon
));
// Add attributes to new node.
nb
->
Attr
(
"T"
,
T
);
nb
->
Attr
(
"num_args"
,
num_args
);
nb
->
Attr
(
"strides"
,
strides
);
nb
->
Attr
(
"padding"
,
padding
);
nb
->
Attr
(
"data_format"
,
data_format
);
nb
->
Attr
(
"dilations"
,
dilations
);
nb
->
Attr
(
"fused_ops"
,
fused_ops
);
nb
->
Attr
(
"epsilon"
,
epsilon
);
}
//////////////////////////////////////////////////////////////////////////
// Helper functions related to node merge pass
//////////////////////////////////////////////////////////////////////////
...
...
@@ -2881,6 +2940,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
if
(
n
->
type_string
()
!=
csinfo_
.
conv2d_with_bias
&&
n
->
type_string
()
!=
csinfo_
.
pad_with_conv2d
&&
n
->
type_string
()
!=
csinfo_
.
conv2d_grad_filter_with_bias
&&
n
->
type_string
()
!=
csinfo_
.
fused_conv2d
&&
!
mkl_op_registry
::
IsMklOp
(
mkl_op_registry
::
GetMklOpName
(
n
->
type_string
()),
T
))
{
return
nullptr
;
...
...
tensorflow/core/graph/mkl_layout_pass_test.cc
浏览文件 @
9b5a2831
...
...
@@ -133,6 +133,7 @@ REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP
(
"InputList"
).
Output
(
"o: N * float"
).
Attr
(
"N: int"
).
SetIsStateful
();
REGISTER_OP
(
"HalfInput"
).
Output
(
"o: half"
).
SetIsStateful
();
REGISTER_OP
(
"Int32Input"
).
Output
(
"o: int32"
).
SetIsStateful
();
REGISTER_OP
(
"DoubleInput"
).
Output
(
"o: double"
).
SetIsStateful
();
REGISTER_OP
(
"_MklInput"
).
Output
(
"o: uint8"
).
SetIsStateful
();
REGISTER_OP
(
"_MklInput2"
)
.
Output
(
"o: uint8"
)
...
...
@@ -142,7 +143,7 @@ REGISTER_OP("Output2").Input("i: float").Input("i1: float").SetIsStateful();
REGISTER_OP
(
"Output"
).
Input
(
"i: float"
).
SetIsStateful
();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge opti
i
mization
// Unit tests related to node merge optimization
/////////////////////////////////////////////////////////////////////
TEST_F
(
MklLayoutPassTest
,
Basic
)
{
...
...
@@ -1096,6 +1097,131 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
"A->C;B->C:1;B->D;C->D:1"
);
}
// Rewrite test for _FusedConv2D Op with BiasAdd fusion
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_FusedConv2D_Positive1
)
{
InitGraph
(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }"
);
EXPECT_EQ
(
DoMklLayoutOptimizationPass
(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"
);
}
// Rewrite test for _FusedConv2D Op with Relu fusion
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_FusedConv2D_Positive2
)
{
InitGraph
(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }"
);
EXPECT_EQ
(
DoMklLayoutOptimizationPass
(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"
);
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu fusion
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_FusedConv2D_Positive3
)
{
InitGraph
(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }"
);
EXPECT_EQ
(
DoMklLayoutOptimizationPass
(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"
);
}
// Rewrite test for _FusedConv2D Op with unsupported fusion
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_FusedConv2D_Negative1
)
{
InitGraph
(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }"
);
EXPECT_EQ
(
DoMklLayoutOptimizationPass
(),
"A(Input);B(Input);C(Input);D(_FusedConv2D);E(Zeta)|A->D;"
"B->D:1;C->D:2;C->E:1;D->E"
);
}
// Rewrite test for _FusedConv2D Op with unsupported type
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_FusedConv2D_Negative2
)
{
InitGraph
(
"node { name: 'A' op: 'DoubleInput'}"
"node { name: 'B' op: 'DoubleInput'}"
"node { name: 'C' op: 'DoubleInput'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_DOUBLE } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_DOUBLE } }"
" input: ['D', 'C'] }"
);
EXPECT_EQ
(
DoMklLayoutOptimizationPass
(),
"A(DoubleInput);B(DoubleInput);C(DoubleInput);"
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"
);
}
TEST_F
(
MklLayoutPassTest
,
NodeRewrite_Conv2DGradFilter_Positive
)
{
InitGraph
(
"node { name: 'A' op: 'Input'}"
...
...
tensorflow/core/kernels/mkl_conv_ops.cc
浏览文件 @
9b5a2831
...
...
@@ -1022,7 +1022,7 @@ class MklConvOp : public OpKernel {
// get a conv2d fwd from primitive pool
MklConvFwdPrimitive
<
float
,
Tinput
,
Tfilter
,
Tbias
,
Ttemp_output
>*
conv_fwd
=
nullptr
;
if
(
biasEnabled
)
{
if
(
fuse_biasadd_
)
{
memory
::
dims
bias_dims
=
{};
conv_utl
.
GetBiasSizeInMklOrder
(
kInputIndex_Bias
,
&
bias_dims
);
MklConvFwdParams
convFwdDims
(
src_dims
,
filter_dims
,
bias_dims
,
...
...
@@ -1094,7 +1094,7 @@ class MklConvOp : public OpKernel {
}
// execute convolution
if
(
biasEnabled
)
{
if
(
fuse_biasadd_
)
{
const
Tensor
&
bias_tensor
=
MklGetInput
(
context
,
kInputIndex_Bias
);
Tbias
*
bias_data
=
this
->
GetBiasHandle
(
context
,
conv_fwd_pd
,
bias_tensor
);
...
...
@@ -1154,6 +1154,12 @@ class MklConvOp : public OpKernel {
}
protected:
void
set_fuse_biasadd
(
bool
fuse_biasadd
)
{
fuse_biasadd_
=
fuse_biasadd
;
}
void
set_fuse_relu
(
bool
fuse_relu
)
{
fuse_relu_
=
fuse_relu
;
}
// This method is for the base class MklConvOp, which handles the
// floating point implementation of Conv. The quantized conv implementations
// will use overidden versions of this method.
virtual
void
ExtendConvFwdParams
(
OpKernelContext
*
context
,
MklConvFwdParams
&
params
)
{
// Create a string from data types of input, filter, bias, and output.
...
...
@@ -1161,6 +1167,11 @@ class MklConvOp : public OpKernel {
params
.
dtypes
.
append
(
typeid
(
Tfilter
).
name
());
params
.
dtypes
.
append
(
typeid
(
Tbias
).
name
());
params
.
dtypes
.
append
(
typeid
(
Toutput
).
name
());
// Add fusions as post ops
// Note: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking fuse_biasadd_ flag.
if
(
fuse_relu_
)
params
.
post_op_params
.
push_back
({
"relu"
,
{
1.0
,
0.0
,
0.0
}});
}
virtual
Tbias
*
GetBiasHandle
(
...
...
@@ -1168,7 +1179,7 @@ class MklConvOp : public OpKernel {
std
::
shared_ptr
<
mkldnn
::
convolution_forward
::
primitive_desc
>&
conv2d_fwd_pd
,
const
Tensor
&
bias_tensor
)
{
if
(
biasEnabled
)
{
if
(
fuse_biasadd_
)
{
return
static_cast
<
Tbias
*>
(
const_cast
<
Tbias
*>
(
bias_tensor
.
flat
<
Tbias
>
().
data
()));
}
else
{
...
...
@@ -1214,6 +1225,11 @@ class MklConvOp : public OpKernel {
std
::
vector
<
int32
>
dilations_
;
Padding
padding_
;
TensorFormat
data_format_
;
// Initialize to values the template is instantiated with
bool
fuse_biasadd_
=
biasEnabled
;
bool
fuse_relu_
=
false
;
const
int
kInputIndex_Src
=
0
,
kInputIndex_Filter
=
1
,
kInputIndex_Bias
=
2
;
const
int
kInputIndex_Pad
=
2
;
const
int
kOutputIndex_Dst
=
0
,
kOutputIndex_Filter
=
1
;
...
...
@@ -1267,12 +1283,12 @@ class MklConvOp : public OpKernel {
// Create convolution primitive and add it to net.
std
::
vector
<
primitive
>
net
;
if
(
bias
)
{
DCHECK
(
biasEnabled
);
DCHECK
(
fuse_biasadd_
);
net
.
push_back
(
convolution_forward
(
conv_prim_desc
,
src
->
GetOpMem
(),
filter
->
GetOpMem
(),
bias
->
GetOpMem
(),
output
->
GetOpMem
()));
}
else
{
DCHECK
(
!
biasEnabled
);
DCHECK
(
!
fuse_biasadd_
);
net
.
push_back
(
convolution_forward
(
conv_prim_desc
,
src
->
GetOpMem
(),
filter
->
GetOpMem
(),
output
->
GetOpMem
()));
...
...
@@ -1282,6 +1298,49 @@ class MklConvOp : public OpKernel {
}
};
// Base class for fused convolution forward operations
template
<
typename
Device
,
typename
Tinput
,
typename
Tfilter
,
typename
Tbias
,
typename
Toutput
,
typename
Ttemp_output
>
class
MklFusedConvOp
:
public
MklConvOp
<
Device
,
Tinput
,
Tfilter
,
Tbias
,
Toutput
,
Ttemp_output
,
int32
,
false
,
false
>
{
public:
explicit
MklFusedConvOp
(
OpKernelConstruction
*
context
)
:
MklConvOp
<
Device
,
Tinput
,
Tfilter
,
Tbias
,
Toutput
,
Ttemp_output
,
int32
,
false
,
false
>
(
context
)
{
// Since we came here through the registration of _MklFusedConv2D, get
// all information from 'fused_ops' and 'num_args'
std
::
vector
<
string
>
fused_ops
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"fused_ops"
,
&
fused_ops
));
int
num_args
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"num_args"
,
&
num_args
));
OP_REQUIRES
(
context
,
!
fused_ops
.
empty
(),
errors
::
InvalidArgument
(
"Fused Conv2D must have at least one fused op."
));
if
(
fused_ops
==
std
::
vector
<
string
>
{
"BiasAdd"
})
{
this
->
set_fuse_biasadd
(
true
);
OP_REQUIRES
(
context
,
num_args
==
1
,
errors
::
InvalidArgument
(
"Fused Conv2D must have one extra argument: bias."
));
}
else
if
(
fused_ops
==
std
::
vector
<
string
>
{
"Relu"
})
{
this
->
set_fuse_relu
(
true
);
}
else
if
(
fused_ops
==
std
::
vector
<
string
>
{
"BiasAdd"
,
"Relu"
})
{
this
->
set_fuse_biasadd
(
true
);
this
->
set_fuse_relu
(
true
);
OP_REQUIRES
(
context
,
num_args
==
1
,
errors
::
InvalidArgument
(
"Fused Conv2D must have one extra argument: bias."
));
}
else
{
OP_REQUIRES
(
context
,
false
,
errors
::
Unimplemented
(
"Fusion is not implemented: ["
,
str_util
::
Join
(
fused_ops
,
","
),
"]"
));
}
}
virtual
~
MklFusedConvOp
()
{}
};
// We create new class for each verison of Quantized Convolution and inherit
// from the FP32 version of the base class
template
<
typename
Device
,
typename
Tbias
,
typename
Toutput
,
...
...
@@ -1881,6 +1940,16 @@ REGISTER_KERNEL_BUILDER(
TF_CALL_float
(
REGISTER_MKL_CPU_2D
);
#define REGISTER_MKL_CPU_2D_FUSED(T) \
REGISTER_KERNEL_BUILDER(Name("_MklFusedConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklFusedConvOp<CPUDevice, T, T, T, T, T>);
// We check the fused_ops attributes to decide if bias is enabled or not.
TF_CALL_float
(
REGISTER_MKL_CPU_2D_FUSED
);
// Register 3D operations
#define REGISTER_MKL_CPU_3D(T) \
REGISTER_KERNEL_BUILDER( \
...
...
tensorflow/core/kernels/mkl_fused_ops_test.cc
浏览文件 @
9b5a2831
...
...
@@ -32,17 +32,17 @@ limitations under the License.
namespace
tensorflow
{
// Helper class for converting MKL te
sn
ors to TF tensors and comparing to
// Helper class for converting MKL te
ns
ors to TF tensors and comparing to
// expected values
static
const
uint8
dummy_tensor
[]
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
static
const
TensorShape
dummy_shape
({
8
});
template
<
typename
T
>
class
ConvMklToTF
:
public
OpsTestBase
{
public:
template
<
typename
T
>
void
ConvertAndCompare
(
DataType
dtype
,
const
Tensor
&
first
,
const
Tensor
&
second
,
const
Tensor
&
expected
)
{
void
PerformConversion
(
DataType
dtype
,
const
Tensor
&
tensor
,
const
Tensor
&
mkl_meta_tensor
,
Tensor
*
output
)
{
// Create an MKL to TF conversion node and execute it
TF_EXPECT_OK
(
NodeDefBuilder
(
"mkl_to_tf_op"
,
"_MklToTf"
)
.
Input
(
FakeInput
(
dtype
))
// Input
...
...
@@ -51,16 +51,259 @@ class ConvMklToTF : public OpsTestBase {
.
Attr
(
"_kernel"
,
"MklOp"
)
.
Finalize
(
node_def
()));
TF_EXPECT_OK
(
InitOp
());
AddInputFromArray
<
T
>
(
first
.
shape
(),
first
.
flat
<
T
>
());
AddInputFromArray
<
uint8
>
(
second
.
shape
(),
second
.
flat
<
uint8
>
());
AddInputFromArray
<
T
>
(
tensor
.
shape
(),
tensor
.
flat
<
T
>
());
AddInputFromArray
<
uint8
>
(
mkl_meta_tensor
.
shape
(),
mkl_meta_tensor
.
flat
<
uint8
>
());
TF_ASSERT_OK
(
RunOpKernel
());
const
Tensor
&
output
=
*
GetOutput
(
0
);
*
output
=
*
GetOutput
(
0
);
}
void
ConvertAndCompare
(
DataType
dtype
,
const
Tensor
&
tensor
,
const
Tensor
&
mkl_meta_tensor
,
const
Tensor
&
expected
)
{
Tensor
output
;
PerformConversion
(
dtype
,
tensor
,
mkl_meta_tensor
,
&
output
);
test
::
ExpectTensorNear
<
T
>
(
expected
,
output
,
1e-5
);
}
void
TestBody
()
{};
void
TestBody
()
{}
};
// Testing MKL's fused convolution ops
template
<
typename
T
>
class
MklFusedConv2DOpTest
:
public
OpsTestBase
{
protected:
static
constexpr
int
kDepth
=
3
;
static
constexpr
int
kImageWidth
=
32
;
static
constexpr
int
kImageHeight
=
32
;
static
constexpr
int
kImageBatchCount
=
8
;
using
BiasAddGraphRunner
=
std
::
function
<
void
(
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
out
)
>
;
// Runs a Tensorflow graph defined by the root scope, and fetches the result
// of 'fetch' node into the output Tensor.
void
RunAndFetch
(
const
tensorflow
::
Scope
&
root
,
const
string
&
fetch
,
Tensor
*
output
)
{
tensorflow
::
GraphDef
graph
;
TF_ASSERT_OK
(
root
.
ToGraphDef
(
&
graph
));
std
::
unique_ptr
<
tensorflow
::
Session
>
session
(
tensorflow
::
NewSession
(
tensorflow
::
SessionOptions
()));
TF_ASSERT_OK
(
session
->
Create
(
graph
));
std
::
vector
<
Tensor
>
unfused_tensors
;
TF_ASSERT_OK
(
session
->
Run
({},
{
fetch
},
{},
&
unfused_tensors
));
*
output
=
unfused_tensors
[
0
];
}
void
RunConv2DWithBias
(
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
output
,
int
stride
=
1
)
{
auto
root
=
tensorflow
::
Scope
::
NewRootScope
();
auto
conv
=
ops
::
Conv2D
(
root
.
WithOpName
(
"conv"
),
ops
::
Const
(
root
.
WithOpName
(
"input"
),
Input
::
Initializer
(
input_data
)),
ops
::
Const
(
root
.
WithOpName
(
"filter"
),
Input
::
Initializer
(
filter_data
)),
{
1
,
stride
,
stride
,
1
},
"SAME"
);
auto
with_bias
=
ops
::
BiasAdd
(
root
.
WithOpName
(
"with_bias"
),
conv
,
ops
::
Const
(
root
.
WithOpName
(
"bias"
),
Input
::
Initializer
(
bias_data
)));
RunAndFetch
(
root
,
"with_bias"
,
output
);
}
void
RunConv2DWithBiasAndRelu
(
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
output
,
int
stride
=
1
)
{
auto
root
=
tensorflow
::
Scope
::
NewRootScope
();
auto
conv
=
ops
::
Conv2D
(
root
.
WithOpName
(
"conv"
),
ops
::
Const
(
root
.
WithOpName
(
"input"
),
Input
::
Initializer
(
input_data
)),
ops
::
Const
(
root
.
WithOpName
(
"filter"
),
Input
::
Initializer
(
filter_data
)),
{
1
,
stride
,
stride
,
1
},
"SAME"
);
auto
with_bias
=
ops
::
BiasAdd
(
root
.
WithOpName
(
"with_bias"
),
conv
,
ops
::
Const
(
root
.
WithOpName
(
"bias"
),
Input
::
Initializer
(
bias_data
)));
auto
with_relu
=
ops
::
Relu
(
root
.
WithOpName
(
"with_relu"
),
with_bias
);
RunAndFetch
(
root
,
"with_relu"
,
output
);
}
void
RunMklFusedConv2DOp
(
const
Tensor
&
image
,
const
Tensor
&
filter
,
const
std
::
vector
<
Tensor
>&
args
,
const
std
::
vector
<
string
>&
fused_ops
,
Tensor
*
output
,
int
stride
=
1
)
{
DataType
dtype
=
DataTypeToEnum
<
T
>::
v
();
int
num_args
=
static_cast
<
int
>
(
args
.
size
());
TF_EXPECT_OK
(
NodeDefBuilder
(
"fused_conv_op"
,
"_MklFusedConv2D"
)
.
Input
(
FakeInput
(
dtype
))
.
Input
(
FakeInput
(
dtype
))
.
Attr
(
"num_args"
,
num_args
)
.
Input
(
FakeInput
(
num_args
,
dtype
))
.
Input
(
FakeInput
(
DT_UINT8
))
.
Input
(
FakeInput
(
DT_UINT8
))
.
Input
(
FakeInput
(
num_args
,
DT_UINT8
))
.
Attr
(
"T"
,
dtype
)
.
Attr
(
"strides"
,
{
1
,
stride
,
stride
,
1
})
.
Attr
(
"padding"
,
"SAME"
)
.
Attr
(
"fused_ops"
,
fused_ops
)
.
Attr
(
"_kernel"
,
"MklOp"
)
.
Finalize
(
node_def
()));
TF_EXPECT_OK
(
InitOp
());
AddInputFromArray
<
T
>
(
image
.
shape
(),
image
.
flat
<
T
>
());
AddInputFromArray
<
T
>
(
filter
.
shape
(),
filter
.
flat
<
T
>
());
for
(
const
Tensor
&
arg
:
args
)
AddInputFromArray
<
T
>
(
arg
.
shape
(),
arg
.
flat
<
T
>
());
AddInputFromArray
<
uint8
>
(
dummy_shape
,
dummy_tensor
);
AddInputFromArray
<
uint8
>
(
dummy_shape
,
dummy_tensor
);
for
(
const
Tensor
&
arg
:
args
)
AddInputFromArray
<
uint8
>
(
dummy_shape
,
dummy_tensor
);
TF_ASSERT_OK
(
RunOpKernel
());
// Compare output to expected results
const
Tensor
&
output_tensor
=
*
GetOutput
(
0
);
// Index 2 will need to be changed if the number of outputs produced
// by MklConv2D change.
const
Tensor
&
output_meta_tensor
=
*
GetOutput
(
2
);
ConvMklToTF
<
T
>
conv_comp
;
conv_comp
.
PerformConversion
(
dtype
,
output_tensor
,
output_meta_tensor
,
output
);
}
void
VerifyBiasAddTensorsNear
(
int
depth
,
int
image_width
,
int
image_height
,
int
image_batch_count
,
int
filter_size
,
int
filter_count
,
const
BiasAddGraphRunner
&
run_default
,
const
BiasAddGraphRunner
&
run_fused
)
{
DataType
dtype
=
DataTypeToEnum
<
T
>::
v
();
Tensor
image
(
dtype
,
{
image_batch_count
,
image_height
,
image_width
,
depth
});
image
.
flat
<
T
>
()
=
image
.
flat
<
T
>
().
setRandom
();
Tensor
filter
(
dtype
,
{
filter_size
,
filter_size
,
depth
,
filter_count
});
filter
.
flat
<
T
>
()
=
filter
.
flat
<
T
>
().
setRandom
();
const
int
bias_size
=
filter_count
;
Tensor
bias
(
dtype
,
{
bias_size
});
bias
.
flat
<
T
>
()
=
bias
.
flat
<
T
>
().
setRandom
();
Tensor
conv_2d
;
Tensor
fused_conv_2d
;
run_default
(
image
,
filter
,
bias
,
&
conv_2d
);
run_fused
(
image
,
filter
,
bias
,
&
fused_conv_2d
);
ASSERT_EQ
(
conv_2d
.
dtype
(),
fused_conv_2d
.
dtype
());
ASSERT_EQ
(
conv_2d
.
shape
(),
fused_conv_2d
.
shape
());
test
::
ExpectClose
(
conv_2d
,
fused_conv_2d
);
}
// Verifies that computing Conv2D+BiasAdd in a graph is identical to
// FusedConv2D.
void
VerifyConv2DWithBias
(
int
filter_size
,
int
filter_count
,
int
depth
=
kDepth
,
int
image_width
=
kImageWidth
,
int
image_height
=
kImageHeight
,
int
image_batch_count
=
kImageBatchCount
)
{
const
BiasAddGraphRunner
run_default
=
[
this
](
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
out
)
{
RunConv2DWithBias
(
input_data
,
filter_data
,
bias_data
,
out
);
};
const
BiasAddGraphRunner
run_fused
=
[
this
](
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
out
)
{
RunMklFusedConv2DOp
(
input_data
,
filter_data
,
{
bias_data
},
{
"BiasAdd"
},
out
);
};
VerifyBiasAddTensorsNear
(
depth
,
image_width
,
image_height
,
image_batch_count
,
filter_size
,
filter_count
,
run_default
,
run_fused
);
}
// Verifies that computing Conv2D+BiasAdd+Relu in a graph is identical to
// FusedConv2D.
void
VerifyConv2DWithBiasAndRelu
(
int
filter_size
,
int
filter_count
,
int
depth
=
kDepth
,
int
image_width
=
kImageWidth
,
int
image_height
=
kImageHeight
,
int
image_batch_count
=
kImageBatchCount
)
{
const
BiasAddGraphRunner
run_default
=
[
this
](
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
out
)
{
RunConv2DWithBiasAndRelu
(
input_data
,
filter_data
,
bias_data
,
out
);
};
const
BiasAddGraphRunner
run_fused
=
[
this
](
const
Tensor
&
input_data
,
const
Tensor
&
filter_data
,
const
Tensor
&
bias_data
,
Tensor
*
out
)
{
RunMklFusedConv2DOp
(
input_data
,
filter_data
,
{
bias_data
},
{
"BiasAdd"
,
"Relu"
},
out
);
};
VerifyBiasAddTensorsNear
(
depth
,
image_width
,
image_height
,
image_batch_count
,
filter_size
,
filter_count
,
run_default
,
run_fused
);
}
};
template
<
typename
T
>
class
MklFusedConv2DWithBiasOpTest
:
public
MklFusedConv2DOpTest
<
T
>
{};
TYPED_TEST_CASE_P
(
MklFusedConv2DWithBiasOpTest
);
// -------------------------------------------------------------------------- //
// Conv2D + BiasAdd + {Relu} //
// -------------------------------------------------------------------------- //
TYPED_TEST_P
(
MklFusedConv2DWithBiasOpTest
,
OneByOneConvolution
)
{
const
int
filter_size
=
1
;
const
int
filter_count
=
12
;
this
->
VerifyConv2DWithBias
(
filter_size
,
filter_count
);
}
TYPED_TEST_P
(
MklFusedConv2DWithBiasOpTest
,
SpatialConvolution
)
{
const
int
filter_size
=
3
;
const
int
filter_count
=
12
;
this
->
VerifyConv2DWithBias
(
filter_size
,
filter_count
);
}
TYPED_TEST_P
(
MklFusedConv2DWithBiasOpTest
,
OneByOneConvolutionAndRelu
)
{
const
int
filter_size
=
1
;
const
int
filter_count
=
12
;
this
->
VerifyConv2DWithBiasAndRelu
(
filter_size
,
filter_count
);
}
TYPED_TEST_P
(
MklFusedConv2DWithBiasOpTest
,
SpatialConvolutionAndRelu
)
{
const
int
filter_size
=
3
;
const
int
filter_count
=
12
;
this
->
VerifyConv2DWithBiasAndRelu
(
filter_size
,
filter_count
);
}
REGISTER_TYPED_TEST_CASE_P
(
MklFusedConv2DWithBiasOpTest
,
//
OneByOneConvolution
,
//
SpatialConvolution
,
//
OneByOneConvolutionAndRelu
,
//
SpatialConvolutionAndRelu
);
using
MklFusedBiasAddDataTypes
=
::
testing
::
Types
<
float
>
;
INSTANTIATE_TYPED_TEST_CASE_P
(
Test
,
MklFusedConv2DWithBiasOpTest
,
MklFusedBiasAddDataTypes
);
// Testing fusion of pad and convolution
class
FusedPadConvOpTest
:
public
OpsTestBase
{
...
...
@@ -98,8 +341,8 @@ class FusedPadConvOpTest : public OpsTestBase {
// Compare output to expected results
const
Tensor
&
first
=
*
GetOutput
(
0
);
const
Tensor
&
second
=
*
GetOutput
(
2
);
ConvMklToTF
conv_comp
;
conv_comp
.
ConvertAndCompare
<
T
>
(
dtype
,
first
,
second
,
expected
);
ConvMklToTF
<
T
>
conv_comp
;
conv_comp
.
ConvertAndCompare
(
dtype
,
first
,
second
,
expected
);
}
};
...
...
tensorflow/core/ops/mkl_nn_ops.cc
浏览文件 @
9b5a2831
...
...
@@ -32,6 +32,33 @@ using shape_inference::DimensionHandle;
using
shape_inference
::
InferenceContext
;
using
shape_inference
::
ShapeHandle
;
REGISTER_OP
(
"_MklFusedConv2D"
)
.
Input
(
"input: T"
)
.
Input
(
"filter: T"
)
.
Input
(
"args: num_args * T"
)
.
Input
(
"mkl_input: uint8"
)
.
Input
(
"mkl_filter: uint8"
)
.
Input
(
"mkl_args: num_args * uint8"
)
.
Output
(
"output: T"
)
.
Output
(
"filter_output: T"
)
.
Output
(
"mkl_output: uint8"
)
.
Output
(
"mkl_filter_output: uint8"
)
.
Attr
(
"T: {float}"
)
.
Attr
(
"num_args: int >= 0"
)
.
Attr
(
"strides: list(int)"
)
.
Attr
(
GetPaddingAttrString
())
.
Attr
(
GetConvnetDataFormatAttrString
())
.
Attr
(
"dilations: list(int) = [1, 1, 1, 1]"
)
.
Attr
(
"fused_ops: list(string) = []"
)
// Attributes for the FusedBatchNorm ------------------------------------ //
.
Attr
(
"epsilon: float = 0.0001"
)
// ---------------------------------------------------------------------- //
.
SetShapeFn
(
shape_inference
::
Conv2DShape
)
.
Doc
(
R"doc(
*NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer
is expected to create these operators.
)doc"
);
REGISTER_OP
(
"_MklQuantizedMaxPool"
)
.
Input
(
"input: T"
)
.
Input
(
"min_input: float"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录