Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1c591c39
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1c591c39
编写于
10月 23, 2018
作者:
J
jerrywgz
提交者:
GitHub
10月 23, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into fix_rpn_target_assign_op
上级
f06c6193
96e9b658
变更
46
展开全部
隐藏空白更改
内联
并排
Showing
46 changed file
with
1978 addition
and
553 deletion
+1978
-553
cmake/generic.cmake
cmake/generic.cmake
+7
-0
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+2
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-0
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
...uid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+154
-0
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
...luid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
+38
-0
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...mework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+247
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+85
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+72
-0
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
+101
-0
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
+38
-0
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+5
-11
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+6
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+15
-13
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
...le/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
+7
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+36
-15
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+8
-3
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+131
-119
paddle/fluid/operators/detection/generate_proposals_op.cu
paddle/fluid/operators/detection/generate_proposals_op.cu
+90
-76
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+7
-7
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+2
-2
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc
+229
-0
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h
+42
-0
paddle/fluid/operators/gather.h
paddle/fluid/operators/gather.h
+2
-4
paddle/fluid/operators/math/fc_compute.h
paddle/fluid/operators/math/fc_compute.h
+15
-9
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+6
-0
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+88
-0
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+57
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+10
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+3
-0
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+9
-0
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+10
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/layer_helper.py
python/paddle/fluid/layer_helper.py
+12
-3
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+17
-16
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+33
-23
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+1
-1
python/paddle/fluid/layers/layer_function_generator.py
python/paddle/fluid/layers/layer_function_generator.py
+5
-3
python/paddle/fluid/layers/metric_op.py
python/paddle/fluid/layers/metric_op.py
+5
-5
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+186
-168
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+17
-14
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+2
-2
python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py
...uid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py
+94
-0
python/paddle/fluid/tests/unittests/test_seq_conv.py
python/paddle/fluid/tests/unittests/test_seq_conv.py
+49
-50
python/paddle/fluid/tests/unittests/test_slice_var.py
python/paddle/fluid/tests/unittests/test_slice_var.py
+0
-1
python/paddle/fluid/transpiler/inference_transpiler.py
python/paddle/fluid/transpiler/inference_transpiler.py
+28
-6
未找到文件。
cmake/generic.cmake
浏览文件 @
1c591c39
...
...
@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
add_dependencies
(
${
TARGET_NAME
}
mklml
)
target_link_libraries
(
${
TARGET_NAME
}
"-L
${
MKLML_LIB_DIR
}
-liomp5 -Wl,--as-needed"
)
endif
()
# remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
if
(
"
${
cc_library_DEPS
}
;"
MATCHES
"python;"
)
list
(
REMOVE_ITEM cc_library_DEPS python
)
add_dependencies
(
${
TARGET_NAME
}
python
)
target_link_libraries
(
${
TARGET_NAME
}
"-Wl,-undefined,dynamic_lookup"
)
endif
()
target_link_libraries
(
${
TARGET_NAME
}
${
cc_library_DEPS
}
)
add_dependencies
(
${
TARGET_NAME
}
${
cc_library_DEPS
}
)
endif
()
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
1c591c39
...
...
@@ -49,6 +49,8 @@ struct VarHandleBase {
void
AddOutput
(
OpHandleBase
*
out
,
ir
::
Node
*
node
)
{
if
(
pending_ops_
.
find
(
out
)
==
pending_ops_
.
end
())
{
PADDLE_ENFORCE
(
out
!=
nullptr
,
"The output of %s should not be nullptr"
,
this
->
Node
()
->
Name
());
pending_ops_
.
insert
(
out
);
node_
->
outputs
.
push_back
(
node
);
}
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
1c591c39
...
...
@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library
(
fc_gru_fuse_pass inference
)
pass_library
(
seq_concat_fc_fuse_pass inference
)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
conv_bias_mkldnn_fuse_pass inference
)
...
...
@@ -44,6 +45,9 @@ if(WITH_MKLDNN)
endif
()
cc_library
(
fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector
)
if
(
WITH_MKLDNN
)
pass_library
(
conv_elementwise_add_mkldnn_fuse_pass inference
)
endif
()
set
(
GLOB_PASS_LIB
${
PASS_LIBRARY
}
CACHE INTERNAL
"Global PASS library"
)
...
...
@@ -57,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
if
(
WITH_MKLDNN
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
endif
()
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
0 → 100644
浏览文件 @
1c591c39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <utility>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void
CorrectGraphEdges
(
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
auto
from_in_inputs
=
std
::
find
(
std
::
begin
(
node
.
inputs
),
std
::
end
(
node
.
inputs
),
from
);
if
(
from_in_inputs
!=
std
::
end
(
node
.
inputs
))
{
IR_NODE_LINK_TO
(
to
,
(
&
node
));
auto
inputs
=
node
.
Op
()
->
Inputs
();
using
input_type
=
VariableNameMap
::
value_type
;
std
::
for_each
(
std
::
begin
(
inputs
),
std
::
end
(
inputs
),
[
from
,
to
,
&
node
](
const
input_type
&
i
)
->
void
{
auto
param_names
=
i
.
second
;
auto
pi
=
std
::
find
(
std
::
begin
(
param_names
),
std
::
end
(
param_names
),
from
->
Name
());
if
(
pi
!=
std
::
end
(
param_names
))
{
node
.
Op
()
->
SetInput
(
i
.
first
,
{
to
->
Name
()});
}
});
}
}
}
}
// namespace
using
graph_ptr
=
std
::
unique_ptr
<
ir
::
Graph
>
;
graph_ptr
ConvElementwiseAddMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope_
};
auto
conv_output
=
conv_pattern
();
patterns
::
ElementwiseAdd
elementwise_add_pattern
{
pattern
,
name_scope_
};
elementwise_add_pattern
(
conv_output
);
conv_output
->
AsIntermediate
();
auto
conv_op_has_bias
=
[](
const
Node
&
conv_op
)
->
std
::
pair
<
bool
,
Node
*>
{
auto
bias_input_names
=
conv_op
.
Op
()
->
Inputs
();
auto
bias_it
=
bias_input_names
.
find
(
"Bias"
);
if
(
bias_it
!=
std
::
end
(
bias_input_names
))
{
bool
has_bias
=
!
bias_it
->
second
.
empty
();
if
(
has_bias
)
{
auto
conv_bias_names
=
bias_it
->
second
;
auto
conv_bias_names_it
=
std
::
find_if
(
std
::
begin
(
conv_op
.
inputs
),
std
::
end
(
conv_op
.
inputs
),
[
&
conv_bias_names
](
Node
*
n
)
->
bool
{
return
n
->
Name
()
==
conv_bias_names
[
0
];
});
return
std
::
make_pair
(
has_bias
,
*
conv_bias_names_it
);
}
}
return
std
::
make_pair
(
false
,
nullptr
);
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_op
,
elementwise_add_op
,
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_x
,
elementwise_add_x
,
elementwise_add_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add_out
,
elementwise_add_out
,
elementwise_add_pattern
);
if
(
FindFuseOption
(
*
conv_op
,
*
elementwise_add_op
)
!=
FUSE_MKLDNN
)
return
;
OpDesc
op_desc
;
op_desc
.
SetType
(
"conv2d"
);
op_desc
.
SetInput
(
"Input"
,
{
conv_input
->
Name
()});
op_desc
.
SetInput
(
"Filter"
,
{
conv_filter
->
Name
()});
op_desc
.
SetInput
(
"ResidualData"
,
{
elementwise_add_x
->
Name
()});
op_desc
.
SetOutput
(
"Output"
,
{
conv_output
->
Name
()});
bool
has_bias
;
Node
*
conv_bias
;
std
::
tie
(
has_bias
,
conv_bias
)
=
conv_op_has_bias
(
*
conv_op
);
if
(
has_bias
)
{
op_desc
.
SetInput
(
"Bias"
,
{
conv_bias
->
Name
()});
}
for
(
const
auto
&
attr
:
conv_op
->
Op
()
->
GetAttrMap
())
{
op_desc
.
SetAttr
(
attr
.
first
,
attr
.
second
);
}
op_desc
.
SetAttr
(
"fuse_residual_connection"
,
true
);
auto
fused_conv_op
=
g
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
conv_input
,
fused_conv_op
);
IR_NODE_LINK_TO
(
conv_filter
,
fused_conv_op
);
IR_NODE_LINK_TO
(
elementwise_add_x
,
fused_conv_op
);
IR_NODE_LINK_TO
(
fused_conv_op
,
conv_output
);
if
(
has_bias
)
{
IR_NODE_LINK_TO
(
conv_bias
,
fused_conv_op
);
}
CorrectGraphEdges
(
g
,
elementwise_add_out
,
conv_output
);
GraphSafeRemoveNodes
(
g
,
{
elementwise_add_out
,
conv_op
,
elementwise_add_op
});
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAddMKLDNNFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h
0 → 100644
浏览文件 @
1c591c39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ConvElementwiseAddMKLDNNFusePass
:
public
FusePassBase
{
public:
virtual
~
ConvElementwiseAddMKLDNNFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"residual_connections_fuse_pass"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
0 → 100644
浏览文件 @
1c591c39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
constexpr
int
nodes_removed
=
3
;
constexpr
int
nodes_added
=
1
;
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>&
inputs
,
const
std
::
pair
<
std
::
string
,
std
::
string
>&
output
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetAttr
(
"use_mkldnn"
,
true
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
struct
IsReachable
{
using
func
=
std
::
function
<
bool
(
const
std
::
string
&
,
const
std
::
string
&
)
>
;
auto
operator
()(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
->
func
{
auto
find_node
=
[](
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
if
(
name
==
node
.
Name
())
{
return
&
node
;
}
}
return
nullptr
;
};
return
[
&
](
std
::
string
from
,
const
std
::
string
to
)
->
bool
{
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
visited
[
node
.
Name
()]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
return
false
;
for
(
auto
n
:
cur
->
outputs
)
{
if
(
n
->
Name
()
==
to
)
return
true
;
if
(
!
visited
[
n
->
Name
()])
{
visited
[
n
->
Name
()]
=
true
;
queue
.
push_back
(
n
->
Name
());
}
}
}
return
false
;
};
}
};
void
AssertOpsCount
(
const
std
::
unique_ptr
<
ir
::
Graph
>&
graph
)
{
int
conv_count
=
0
;
int
elementwise_add_count
=
0
;
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"conv2d"
)
{
++
conv_count
;
}
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"elementwise_add"
)
{
++
elementwise_add_count
;
}
}
EXPECT_EQ
(
conv_count
,
1
);
EXPECT_EQ
(
elementwise_add_count
,
0
);
}
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
const
std
::
vector
<
std
::
string
>&
persistent_vars
)
{
ProgramDesc
prog
;
auto
add_var_to_prog
=
[
&
prog
](
const
std
::
string
&
var_name
)
->
VarDesc
*
{
auto
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
var_name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
return
var
;
};
for
(
const
auto
&
v
:
transient_vars
)
{
add_var_to_prog
(
v
);
}
for
(
const
auto
&
v
:
persistent_vars
)
{
auto
var
=
add_var_to_prog
(
v
);
var
->
SetPersistable
(
true
);
}
return
prog
;
}
}
// namespace
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionWithElementwiseAddRelu
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
},
{
"bias"
,
"weights"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
AssertOpsCount
(
graph
);
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionWithElementwiseAddReluNoBias
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{
"Out"
,
"e"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"relu"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
AssertOpsCount
(
graph
);
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionElementwiseAdd
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
},
{
"bias"
,
"weights"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"b"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"b"
},
{
"Y"
,
"c"
}},
{
"Out"
,
"d"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_FALSE
(
is_reachable
(
graph
)(
"a"
,
"d"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
AssertOpsCount
(
graph
);
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
SigmoidConvolutionAddElementwiseRelu
)
{
auto
prog
=
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
},
{
"bias"
,
"weights"
});
SetOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{
"Out"
,
"b"
});
SetOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{
"Output"
,
"c"
});
SetOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"d"
}},
{
"Out"
,
"e"
});
SetOp
(
&
prog
,
"relu"
,
{{
"X"
,
"e"
}},
{
"Out"
,
"f"
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
IsReachable
is_reachable
;
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"conv_elementwise_add_mkldnn_fuse_pass"
);
int
original_nodes_num
=
graph
->
Nodes
().
size
();
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
is_reachable
(
graph
)(
"a"
,
"f"
));
EXPECT_EQ
(
original_nodes_num
-
nodes_removed
+
nodes_added
,
current_nodes_num
);
AssertOpsCount
(
graph
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
);
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
1c591c39
...
...
@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()(
return
relu_out_var
;
}
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
seqconv_input
)
{
// Create Operators
seqconv_input
->
assert_is_op_input
(
"sequence_conv"
,
"X"
);
auto
*
seqconv_op
=
pattern
->
NewNode
(
seqconv_repr
())
->
assert_is_op
(
"sequence_conv"
)
->
assert_op_attr
<
bool
>
(
"paddingTrainable"
,
false
)
->
assert_op_attr
<
int
>
(
"contextStride"
,
1
);
auto
*
eltadd_op
=
pattern
->
NewNode
(
eltadd_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
relu_op
=
pattern
->
NewNode
(
relu_repr
())
->
assert_is_op
(
"relu"
);
// Create variables
// Filter
auto
*
seqconv_weight_var
=
pattern
->
NewNode
(
seqconv_weight_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"sequence_conv"
,
"Filter"
);
// Bias
auto
*
eltadd_bias_var
=
pattern
->
NewNode
(
eltadd_bias_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
seqconv_out_var
=
pattern
->
NewNode
(
seqconv_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"sequence_conv"
)
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_out_var
=
pattern
->
NewNode
(
eltadd_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"elementwise_add"
)
->
assert_is_only_input_of_op
(
"relu"
);
// output
auto
*
relu_out_var
=
pattern
->
NewNode
(
relu_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"relu"
);
seqconv_op
->
LinksFrom
({
seqconv_input
,
seqconv_weight_var
})
.
LinksTo
({
seqconv_out_var
});
eltadd_op
->
LinksFrom
({
seqconv_out_var
,
eltadd_bias_var
})
.
LinksTo
({
eltadd_out_var
});
relu_op
->
LinksFrom
({
eltadd_out_var
}).
LinksTo
({
relu_out_var
});
return
relu_out_var
;
}
PDNode
*
patterns
::
FC
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
x
,
bool
with_bias
)
{
// Create shared nodes.
...
...
@@ -999,6 +1044,46 @@ PDNode *patterns::ConvBias::operator()(
return
eltwise_out_var
;
}
PDNode
*
patterns
::
Conv
::
operator
()()
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
input_var
=
pattern
->
NewNode
(
conv_input_repr
())
->
AsInput
()
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
auto
filter_var
=
pattern
->
NewNode
(
conv_filter_repr
())
->
AsInput
()
->
assert_is_op_input
(
"conv2d"
,
"Filter"
);
auto
output_var
=
pattern
->
NewNode
(
conv_output_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"conv2d"
,
"Output"
);
conv_op
->
LinksFrom
({
input_var
,
filter_var
});
conv_op
->
LinksTo
({
output_var
});
return
output_var
;
}
PDNode
*
patterns
::
ElementwiseAdd
::
operator
()(
PDNode
*
x_var
)
{
auto
elementwise_add_op
=
pattern
->
NewNode
(
elementwise_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
x_var
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
y_var
=
pattern
->
NewNode
(
elementwise_add_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
out_var
=
pattern
->
NewNode
(
elementwise_add_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
elementwise_add_op
->
LinksFrom
({
x_var
,
y_var
});
elementwise_add_op
->
LinksTo
({
out_var
});
return
out_var
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
1c591c39
...
...
@@ -128,6 +128,15 @@ struct PDNode {
const
std
::
unordered_set
<
std
::
string
>&
op_types
,
const
std
::
string
&
argument
,
int
nth
);
template
<
typename
T
>
PDNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
asserts_
.
emplace_back
([
=
](
Node
*
x
)
{
return
x
&&
x
->
IsOp
()
&&
x
->
Op
()
->
HasAttr
(
attr_name
)
&&
boost
::
get
<
T
>
(
x
->
Op
()
->
GetAttr
(
attr_name
))
==
attr
;
});
return
this
;
}
private:
PDNode
(
PDPattern
*
pattern
,
const
std
::
string
&
name
=
""
,
Type
type
=
Type
::
kVar
)
...
...
@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE
(
relu_out
);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
// seqconv_input, seqconv_weight,
// seqconv_out, seqconv,
// elementwise_add_bias, elementwise_add_out, elementwise_add
// relu_out, relu
struct
SeqConvEltAddRelu
:
public
PatternBase
{
SeqConvEltAddRelu
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"seqconv_eltadd_relu"
)
{}
PDNode
*
operator
()(
PDNode
*
seqconv_input
);
// declare operator node's name
PATTERN_DECL_NODE
(
seqconv
);
PATTERN_DECL_NODE
(
eltadd
);
PATTERN_DECL_NODE
(
relu
);
// declare variable node's name
PATTERN_DECL_NODE
(
seqconv_weight
);
PATTERN_DECL_NODE
(
seqconv_out
);
PATTERN_DECL_NODE
(
eltadd_bias
);
PATTERN_DECL_NODE
(
eltadd_out
);
PATTERN_DECL_NODE
(
relu_out
);
};
// FC with bias
// op: mul + elementwise_add
// named nodes:
...
...
@@ -599,6 +633,44 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE
(
eltwise_bias
);
PATTERN_DECL_NODE
(
eltwise_out
);
};
// Convolution op
// Forward pass for convolution.
// conv_input, conv_bias and conv_filter are inputs.
// conv_output is a result of the operator.
// residual_data is data used by skip connection.
// If residual connection fusion is on, the formula is:
// conv_output = conv_op(conv_filter, conv_input, conv_bias)
// + conv_residual_data
// If the fusion is off, conv_residual_data is not added.
struct
Conv
:
public
PatternBase
{
Conv
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"convolution"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_input
);
PATTERN_DECL_NODE
(
conv_filter
);
PATTERN_DECL_NODE
(
conv_residual_data
);
PATTERN_DECL_NODE
(
conv_output
);
};
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct
ElementwiseAdd
:
public
PatternBase
{
ElementwiseAdd
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add"
)
{}
PDNode
*
operator
()(
PDNode
*
x_var
);
PATTERN_DECL_NODE
(
elementwise_add_op
);
PATTERN_DECL_NODE
(
elementwise_add_x
);
PATTERN_DECL_NODE
(
elementwise_add_y
);
PATTERN_DECL_NODE
(
elementwise_add_out
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc
0 → 100644
浏览文件 @
1c591c39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"X"
))
->
assert_is_op_input
(
"sequence_conv"
)
->
assert_var_not_persistable
();
patterns
::
SeqConvEltAddRelu
fuse_pattern
(
pattern
,
name_scope
);
fuse_pattern
(
x
);
// Create New OpDesc
auto
fuse_creator
=
[
&
](
Node
*
seqconv
,
Node
*
input
,
Node
*
seqconv_weight
,
Node
*
eltadd_bias
,
Node
*
relu_out
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_seqconv_eltadd_relu"
);
op_desc
.
SetInput
(
"X"
,
{
input
->
Name
()});
op_desc
.
SetInput
(
"Filter"
,
{
seqconv_weight
->
Name
()});
op_desc
.
SetInput
(
"Bias"
,
{
eltadd_bias
->
Name
()});
op_desc
.
SetAttr
(
"contextLength"
,
seqconv
->
Op
()
->
GetAttr
(
"contextLength"
));
op_desc
.
SetAttr
(
"contextStart"
,
seqconv
->
Op
()
->
GetAttr
(
"contextStart"
));
op_desc
.
SetAttr
(
"contextStride"
,
seqconv
->
Op
()
->
GetAttr
(
"contextStride"
));
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
const
std
::
string
ColMat
=
patterns
::
UniqueKey
(
"SeqConvColMat"
);
op_desc
.
SetOutput
(
"ColMat"
,
{
ColMat
});
op_desc
.
SetOutput
(
"Out"
,
{
relu_out
->
Name
()});
scope
->
Var
(
ColMat
)
->
GetMutable
<
LoDTensor
>
();
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
input
,
op
);
IR_NODE_LINK_TO
(
seqconv_weight
,
op
);
IR_NODE_LINK_TO
(
eltadd_bias
,
op
);
IR_NODE_LINK_TO
(
op
,
relu_out
);
return
op
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle SeqConv EltAdd Relu fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv
,
seqconv
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv_weight
,
seqconv_weight
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
seqconv_out
,
seqconv_out
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd
,
eltadd
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_bias
,
eltadd_bias
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
eltadd_out
,
eltadd_out
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
relu
,
relu
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
relu_out
,
relu_out
,
fuse_pattern
);
fuse_creator
(
seqconv
,
subgraph
.
at
(
x
),
seqconv_weight
,
eltadd_bias
,
relu_out
);
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
{
seqconv
,
seqconv_out
,
eltadd
,
eltadd_out
,
relu
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
SeqConvEltAddReluFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
());
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
seqconv_eltadd_relu_fuse_pass
,
paddle
::
framework
::
ir
::
SeqConvEltAddReluFusePass
);
paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h
0 → 100644
浏览文件 @
1c591c39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
SeqConvEltAddReluFusePass
:
public
FusePassBase
{
public:
virtual
~
SeqConvEltAddReluFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"seqconv_eltadd_relu_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/op_desc.cc
浏览文件 @
1c591c39
...
...
@@ -515,20 +515,14 @@ void OpDesc::InferShape(const BlockDesc &block) const {
}
void
OpDesc
::
InferVarType
(
BlockDesc
*
block
)
const
{
// There are a few places that var type can be set.
// When VarDesc is created, default set to LOD_TENSOR.
// When output variable is created, default is defaut set to LOD_TENSOR.
// We limit here to be the only place that operator defines its customized
// var type inference. Hence, we don't do any "default" setting here.
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
());
if
(
info
.
infer_var_type_
)
{
info
.
infer_var_type_
(
*
this
,
block
);
}
else
{
// all output type is LoDTensor by default
VLOG
(
10
)
<<
this
->
Type
()
<<
" has not registered InferVarType. Set output variables to "
"LOD_TENSOR"
;
for
(
auto
&
out_pair
:
this
->
outputs_
)
{
for
(
auto
&
out_var_name
:
out_pair
.
second
)
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
1c591c39
...
...
@@ -299,6 +299,12 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
ParallelExecutor
::~
ParallelExecutor
()
{
const
auto
dev_ctxs
=
platform
::
DeviceContextPool
::
Instance
().
GetAllDeviceContexts
();
for
(
auto
&
dev_ctx
:
dev_ctxs
)
{
dev_ctx
->
Wait
();
}
if
(
member_
->
own_local_scope_
)
{
for
(
size_t
i
=
1
;
i
<
member_
->
local_scopes_
.
size
();
++
i
)
{
Scope
*
local_scope
=
member_
->
local_scopes_
[
i
];
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
1c591c39
...
...
@@ -67,20 +67,22 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion.
const
std
::
vector
<
std
::
string
>
all_ir_passes_
{{
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"seqconv_eltadd_relu_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
"mul_gru_fuse_pass"
,
//
"seq_concat_fc_fuse_pass"
,
//
"fc_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
#ifdef PADDLE_WITH_MKLDNN
"conv_bias_mkldnn_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_bias_mkldnn_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_elementwise_add_mkldnn_fuse_pass"
,
//
#endif
}};
...
...
paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc
浏览文件 @
1c591c39
...
...
@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
SetConfig
(
&
cfg
);
int
num_ops
;
auto
predictor
=
CreatePaddlePredictor
<
AnalysisConfig
>
(
cfg
);
GetFuseStatis
(
predictor
.
get
(),
&
num_ops
);
auto
fuse_statis
=
GetFuseStatis
(
predictor
.
get
(),
&
num_ops
);
ASSERT_TRUE
(
fuse_statis
.
count
(
"fc_fuse"
));
ASSERT_TRUE
(
fuse_statis
.
count
(
"seqconv_eltadd_relu_fuse"
));
EXPECT_EQ
(
fuse_statis
.
at
(
"fc_fuse"
),
2
);
EXPECT_EQ
(
fuse_statis
.
at
(
"seqconv_eltadd_relu_fuse"
),
6
);
EXPECT_EQ
(
num_ops
,
32
);
}
// Compare result of NativeConfig and AnalysisConfig
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
1c591c39
...
...
@@ -86,7 +86,7 @@ function(op_library TARGET)
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
foreach
(
windows_unsupport_op
"nccl_op"
"gen_nccl_id_op"
"warpctc_op"
"hierarchical_sigmoid_op"
"crf_decoding_op"
"select_op"
"lstmp_op"
"gru_op"
"fusion_gru_op"
"lstm_op"
"fusion_lstm_op"
"cumsum_op"
"channel_send_op"
"channel_create_op"
"channel_close_op"
"channel_recv_op"
)
"fusion_seqconv_eltadd_relu_op"
"channel_send_op"
"channel_create_op"
"channel_close_op"
"channel_recv_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
windows_unsupport_op
}
"
)
return
()
endif
()
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
1c591c39
...
...
@@ -300,10 +300,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
dilations
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"dilations"
);
bool
fuse_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_relu"
);
bool
fuse_
eltwise
=
ctx
.
Attr
<
bool
>
(
"fuse_eltwise
"
);
bool
fuse_
residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection
"
);
int
groups
=
ctx
.
Attr
<
int
>
(
"groups"
);
// TODO: add support for dilation
// TODO
(tpatejko)
: add support for dilation
PADDLE_ENFORCE
(
dilations
.
size
()
==
2
&&
dilations
[
0
]
==
1
&&
dilations
[
1
]
==
1
,
"dilation in convolution is not implemented yet"
);
...
...
@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
memory
::
format
::
x
);
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
bias_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
,
fuse_relu
,
fuse_
eltwise
);
fuse_relu
,
fuse_
residual_conn
);
}
else
{
conv_pd
=
ConvFwdPrimitiveDesc
(
src_md
,
weights_md
,
dst_md
,
strides
,
paddings
,
mkldnn_engine
,
fuse_relu
,
fuse_
eltwise
);
mkldnn_engine
,
fuse_relu
,
fuse_
residual_conn
);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx
.
SetBlob
(
key_conv_pd
,
conv_pd
);
...
...
@@ -386,8 +386,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
user_weights_memory_p
=
handler
.
AcquireWeightsMemory
(
user_weights_md
,
to_void_cast
<
T
>
(
filter_data
));
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
T
*
output_data
=
nullptr
;
if
(
fuse_residual_conn
)
{
auto
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
auto
residual_param_data
=
residual_param
->
data
<
T
>
();
PADDLE_ENFORCE
(
residual_param_data
!=
nullptr
,
"Provide data if you want MKLDNN conv+elementwise_add fusion"
);
PADDLE_ENFORCE_EQ
(
output
->
dims
(),
residual_param
->
dims
(),
"Output and elementwise parameter need to have the "
"same dimension sizes"
);
output
->
ShareDataWith
(
*
residual_param
);
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
else
{
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
}
// create reorder primitive if the input format is not the preferred one
auto
src_memory_p
=
handler
.
AcquireSrcMemoryFromPrimitive
(
user_src_memory_p
,
pipeline
);
...
...
@@ -424,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
mkldnn
::
primitive_attr
CreatePostOps
(
bool
fuse_relu
,
bool
fuse_
eltwise
)
const
{
bool
fuse_
residual_conn
)
const
{
mkldnn
::
primitive_attr
conv_attr
;
mkldnn
::
post_ops
post_operations
;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if
(
fuse_eltwise
)
{
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if
(
fuse_residual_conn
)
{
post_operations
.
append_sum
(
1.0
f
);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
...
...
@@ -452,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
memory
::
desc
&
dst
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
,
const
bool
fuse_relu
,
const
bool
fuse_
eltwise
)
const
{
const
bool
fuse_
residual_conn
)
const
{
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
memory
::
dims
padding_dims
=
{
paddings
[
0
],
paddings
[
1
]};
...
...
@@ -461,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst
,
stride_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
fuse_eltwise
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
fuse_residual_conn
);
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
conv_desc
,
conv_attr
,
engine
);
...
...
@@ -476,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
const
mkldnn
::
engine
&
engine
,
const
bool
fuse_relu
,
const
bool
fuse_
eltwise
)
const
{
const
bool
fuse_
residual_conn
)
const
{
memory
::
dims
stride_dims
=
{
strides
[
0
],
strides
[
1
]};
memory
::
dims
padding_dims
=
{
paddings
[
0
],
paddings
[
1
]};
...
...
@@ -485,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias
,
dst
,
stride_dims
,
padding_dims
,
padding_dims
,
mkldnn
::
padding_kind
::
zero
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
fuse_eltwise
);
mkldnn
::
primitive_attr
conv_attr
=
CreatePostOps
(
fuse_relu
,
fuse_residual_conn
);
auto
p_conv_pd
=
new
mkldnn
::
convolution_forward
::
primitive_desc
(
conv_desc
,
conv_attr
,
engine
);
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
1c591c39
...
...
@@ -132,6 +132,11 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."
)
.
Reuse
(
"Input"
);
AddInput
(
"ResidualData"
,
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion."
)
.
AsDispensable
();
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
...
...
@@ -164,10 +169,10 @@ void Conv2DOpMaker::Make() {
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"fuse_relu"
,
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"fuse_
eltwise
"
,
AddAttr
<
bool
>
(
"fuse_
residual_connection
"
,
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is
connected via skip connection
"
"
to a previous layer
."
)
"whenever convolution output is
as an input to residual
"
"
connection
."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"data_format"
,
...
...
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
1c591c39
...
...
@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/
framework/var_type
.h"
#include "paddle/fluid/
operators/detail/safe_ref
.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
@@ -25,21 +27,17 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
struct
AppendProposalsFunctor
{
LoDTensor
*
out_
;
int64_t
offset_
;
Tensor
*
to_add_
;
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
AppendProposalsFunctor
(
LoDTensor
*
out
,
int64_t
offset
,
Tensor
*
to_add
)
:
out_
(
out
),
offset_
(
offset
),
to_add_
(
to_add
)
{}
template
<
typename
T
>
void
apply
()
const
{
auto
*
out_data
=
out_
->
data
<
T
>
();
auto
*
to_add_data
=
to_add_
->
data
<
T
>
();
memcpy
(
out_data
+
offset_
,
to_add_data
,
to_add_
->
numel
()
*
sizeof
(
T
));
}
};
static
void
AppendProposals
(
Tensor
*
dst
,
int64_t
offset
,
const
Tensor
&
src
)
{
auto
*
out_data
=
dst
->
data
<
void
>
();
auto
*
to_add_data
=
src
.
data
<
void
>
();
size_t
size_of_t
=
framework
::
SizeOfType
(
src
.
type
());
offset
*=
size_of_t
;
std
::
memcpy
(
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
out_data
)
+
offset
),
to_add_data
,
src
.
numel
()
*
size_of_t
);
}
class
GenerateProposalsOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
};
template
<
class
T
>
void
BoxCoder
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
all_anchors
,
Tensor
*
bbox_deltas
,
Tensor
*
variances
,
Tensor
*
proposals
)
{
static
inline
void
BoxCoder
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
all_anchors
,
Tensor
*
bbox_deltas
,
Tensor
*
variances
,
Tensor
*
proposals
)
{
T
*
proposals_data
=
proposals
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
row
=
all_anchors
->
dims
()[
0
];
...
...
@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
2
]
*
bbox_deltas_data
[
i
*
len
+
2
],
std
::
log
(
1000.0
/
16.0
)
))
*
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
3
]
*
bbox_deltas_data
[
i
*
len
+
3
],
std
::
log
(
1000.0
/
16.0
)
))
*
kBBoxClipDefault
))
*
anchor_height
;
}
else
{
bbox_center_x
=
...
...
@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y
=
bbox_deltas_data
[
i
*
len
+
1
]
*
anchor_height
+
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
2
],
std
::
log
(
1000.0
/
16.0
)
))
*
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
3
],
std
::
log
(
1000.0
/
16.0
)
))
*
kBBoxClipDefault
))
*
anchor_height
;
}
...
...
@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
}
template
<
class
T
>
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
im_info
,
Tensor
*
boxes
)
{
static
inline
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
im_info
,
Tensor
*
boxes
)
{
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
zero
(
0
);
for
(
int64_t
i
=
0
;
i
<
boxes
->
numel
();
++
i
)
{
if
(
i
%
4
==
0
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
0.0
f
);
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
zero
);
}
else
if
(
i
%
4
==
1
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
0.0
f
);
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
zero
);
}
else
if
(
i
%
4
==
2
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
0.0
f
);
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
zero
);
}
else
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
0.0
f
);
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
zero
);
}
}
}
template
<
class
T
>
void
FilterBoxes
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
boxes
,
float
min_size
,
const
Tensor
&
im_info
,
Tensor
*
keep
)
{
static
inline
void
FilterBoxes
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
boxes
,
float
min_size
,
const
Tensor
&
im_info
,
Tensor
*
keep
)
{
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
im_scale
=
im_info_data
[
2
];
...
...
@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
keep
->
Resize
({
keep_len
});
}
bool
SortScorePairDescend
(
const
std
::
pair
<
float
,
int
>
&
pair1
,
const
std
::
pair
<
float
,
int
>
&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
}
template
<
class
T
>
void
GetMaxScoreIndex
(
const
std
::
vector
<
T
>
&
scores
,
std
::
vector
<
std
::
pair
<
T
,
int
>>
*
sorted_indices
)
{
static
inline
std
::
vector
<
std
::
pair
<
T
,
int
>>
GetSortedScoreIndex
(
const
std
::
vector
<
T
>
&
scores
)
{
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
;
sorted_indices
.
reserve
(
scores
.
size
());
for
(
size_t
i
=
0
;
i
<
scores
.
size
();
++
i
)
{
sorted_indices
->
push_back
(
std
::
make_pair
(
scores
[
i
],
i
)
);
sorted_indices
.
emplace_back
(
scores
[
i
],
i
);
}
// Sort the score pair according to the scores in descending order
std
::
stable_sort
(
sorted_indices
->
begin
(),
sorted_indices
->
end
(),
SortScorePairDescend
);
std
::
stable_sort
(
sorted_indices
.
begin
(),
sorted_indices
.
end
(),
[](
const
std
::
pair
<
T
,
int
>
&
a
,
const
std
::
pair
<
T
,
int
>
&
b
)
{
return
a
.
first
<
b
.
first
;
});
return
sorted_indices
;
}
template
<
class
T
>
T
BBoxArea
(
const
T
*
box
,
const
bool
normalized
)
{
static
inline
T
BBoxArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
...
...
@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) {
}
template
<
class
T
>
T
JaccardOverlap
(
const
T
*
box1
,
const
T
*
box2
,
const
bool
normalized
)
{
static
inline
T
JaccardOverlap
(
const
T
*
box1
,
const
T
*
box2
,
bool
normalized
)
{
if
(
box2
[
0
]
>
box1
[
2
]
||
box2
[
2
]
<
box1
[
0
]
||
box2
[
1
]
>
box1
[
3
]
||
box2
[
3
]
<
box1
[
1
])
{
return
static_cast
<
T
>
(
0.
);
...
...
@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const
T
inter_ymin
=
std
::
max
(
box1
[
1
],
box2
[
1
]);
const
T
inter_xmax
=
std
::
min
(
box1
[
2
],
box2
[
2
]);
const
T
inter_ymax
=
std
::
min
(
box1
[
3
],
box2
[
3
]);
const
T
inter_w
=
std
::
max
(
0.0
f
,
inter_xmax
-
inter_xmin
+
1
);
const
T
inter_h
=
std
::
max
(
0.0
f
,
inter_ymax
-
inter_ymin
+
1
);
const
T
inter_w
=
std
::
max
(
T
(
0
)
,
inter_xmax
-
inter_xmin
+
1
);
const
T
inter_h
=
std
::
max
(
T
(
0
)
,
inter_ymax
-
inter_ymin
+
1
);
const
T
inter_area
=
inter_w
*
inter_h
;
const
T
bbox1_area
=
BBoxArea
<
T
>
(
box1
,
normalized
);
const
T
bbox2_area
=
BBoxArea
<
T
>
(
box2
,
normalized
);
...
...
@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
}
}
template
<
typename
T
>
static
inline
Tensor
VectorToTensor
(
const
std
::
vector
<
T
>
&
selected_indices
,
int
selected_num
)
{
Tensor
keep_nms
;
keep_nms
.
Resize
({
selected_num
});
auto
*
keep_data
=
keep_nms
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
selected_num
;
++
i
)
{
keep_data
[
i
]
=
selected_indices
[
i
];
}
return
keep_nms
;
}
template
<
class
T
>
Tensor
NMS
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
bbox
,
Tensor
*
scores
,
const
T
nms_threshold
,
const
float
eta
)
{
static
inline
Tensor
NMS
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
bbox
,
Tensor
*
scores
,
T
nms_threshold
,
float
eta
)
{
PADDLE_ENFORCE_NOT_NULL
(
bbox
);
int64_t
num_boxes
=
bbox
->
dims
()[
0
];
// 4: [xmin ymin xmax ymax]
...
...
@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
std
::
vector
<
T
>
scores_data
(
num_boxes
);
std
::
copy_n
(
scores
->
data
<
T
>
(),
num_boxes
,
scores_data
.
begin
());
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
;
GetMaxScoreIndex
<
T
>
(
scores_data
,
&
sorted_indices
);
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
=
GetSortedScoreIndex
<
T
>
(
scores_data
);
std
::
vector
<
int
>
selected_indices
;
int
selected_num
=
0
;
T
adaptive_threshold
=
nms_threshold
;
const
T
*
bbox_data
=
bbox
->
data
<
T
>
();
bool
flag
;
while
(
sorted_indices
.
size
()
!=
0
)
{
int
idx
=
sorted_indices
.
front
().
second
;
flag
=
true
;
for
(
size_t
k
=
0
;
k
<
selected_indices
.
size
();
++
k
)
{
int
idx
=
sorted_indices
.
back
().
second
;
bool
flag
=
true
;
for
(
int
kept_idx
:
selected_indices
)
{
if
(
flag
)
{
const
int
kept_idx
=
selected_indices
[
k
];
T
overlap
=
JaccardOverlap
<
T
>
(
bbox_data
+
idx
*
box_size
,
bbox_data
+
kept_idx
*
box_size
,
false
);
flag
=
(
overlap
<=
adaptive_threshold
);
...
...
@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
}
if
(
flag
)
{
selected_indices
.
push_back
(
idx
);
selected_num
++
;
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
begin
());
sorted_indices
.
erase
(
sorted_indices
.
end
());
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
}
Tensor
keep_nms
;
keep_nms
.
Resize
({
selected_num
});
int
*
keep_data
=
keep_nms
.
mutable_data
<
int
>
(
ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
selected_num
;
++
i
)
{
keep_data
[
i
]
=
selected_indices
[
i
];
}
return
keep_nms
;
return
VectorToTensor
(
selected_indices
,
selected_num
);
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
>
class
GenerateProposalsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
scores
=
context
.
Input
<
Tensor
>
(
"Scores"
);
auto
*
bbox_deltas
=
context
.
Input
<
Tensor
>
(
"BboxDeltas"
);
auto
*
im_info
=
context
.
Input
<
Tensor
>
(
"ImInfo"
);
auto
*
anchors
=
context
.
Input
<
Tensor
>
(
"Anchors"
);
auto
*
variances
=
context
.
Input
<
Tensor
>
(
"Variances"
);
auto
anchors
=
detail
::
Ref
(
context
.
Input
<
Tensor
>
(
"Anchors"
),
"Cannot find input Anchors(%s) in scope"
,
context
.
Inputs
(
"Anchors"
)[
0
]);
auto
variances
=
detail
::
Ref
(
context
.
Input
<
Tensor
>
(
"Variances"
),
"Cannot find input Variances(%s) in scope"
,
context
.
Inputs
(
"Variances"
)[
0
]);
auto
*
rpn_rois
=
context
.
Output
<
LoDTensor
>
(
"RpnRois"
);
auto
*
rpn_roi_probs
=
context
.
Output
<
LoDTensor
>
(
"RpnRoiProbs"
);
...
...
@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
float
min_size
=
context
.
Attr
<
float
>
(
"min_size"
);
float
eta
=
context
.
Attr
<
float
>
(
"eta"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
auto
scores_dim
=
scores
->
dims
();
auto
&
scores_dim
=
scores
->
dims
();
int64_t
num
=
scores_dim
[
0
];
int64_t
c_score
=
scores_dim
[
1
];
int64_t
h_score
=
scores_dim
[
2
];
int64_t
w_score
=
scores_dim
[
3
];
auto
bbox_dim
=
bbox_deltas
->
dims
();
auto
&
bbox_dim
=
bbox_deltas
->
dims
();
int64_t
c_bbox
=
bbox_dim
[
1
];
int64_t
h_bbox
=
bbox_dim
[
2
];
int64_t
w_bbox
=
bbox_dim
[
3
];
...
...
@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_swap
.
mutable_data
<
T
>
({
num
,
h_score
,
w_score
,
c_score
},
dev_ctx
.
GetPlace
());
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans
;
math
::
Transpose
<
platform
::
CPU
DeviceContext
,
T
,
4
>
trans
;
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
trans
(
dev_ctx
,
*
bbox_deltas
,
&
bbox_deltas_swap
,
axis
);
trans
(
dev_ctx
,
*
scores
,
&
scores_swap
,
axis
);
framework
::
LoD
lod
;
std
::
vector
<
size_t
>
lod0
(
1
,
0
);
Tensor
*
anchor
=
const_cast
<
framework
::
Tensor
*>
(
anchors
)
;
anchor
->
Resize
({
anchors
->
numel
()
/
4
,
4
}
);
Tensor
*
var
=
const_cast
<
framework
::
Tensor
*>
(
variances
);
var
->
Resize
({
var
->
numel
()
/
4
,
4
});
lod
.
resize
(
1
);
auto
&
lod0
=
lod
[
0
]
;
lod0
.
push_back
(
0
);
anchors
.
Resize
({
anchors
.
numel
()
/
4
,
4
}
);
var
iances
.
Resize
({
variances
.
numel
()
/
4
,
4
});
int64_t
num_proposals
=
0
;
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
...
...
@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice
.
Resize
({
h_score
*
w_score
*
c_score
,
1
});
std
::
pair
<
Tensor
,
Tensor
>
tensor_pair
=
ProposalForOneImage
(
dev_ctx
,
im_info_slice
,
*
anchor
,
*
var
,
ProposalForOneImage
(
dev_ctx
,
im_info_slice
,
anchors
,
variances
,
bbox_deltas_slice
,
scores_slice
,
pre_nms_top_n
,
post_nms_top_n
,
nms_thresh
,
min_size
,
eta
);
Tensor
proposals
=
tensor_pair
.
first
;
Tensor
scores
=
tensor_pair
.
second
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
rpn_rois
->
type
()),
AppendProposalsFunctor
(
rpn_rois
,
4
*
num_proposals
,
&
proposals
));
framework
::
VisitDataType
(
framework
::
ToDataType
(
rpn_roi_probs
->
type
()),
AppendProposalsFunctor
(
rpn_roi_probs
,
num_proposals
,
&
scores
));
Tensor
&
proposals
=
tensor_pair
.
first
;
Tensor
&
scores
=
tensor_pair
.
second
;
AppendProposals
(
rpn_rois
,
4
*
num_proposals
,
proposals
);
AppendProposals
(
rpn_roi_probs
,
num_proposals
,
scores
);
num_proposals
+=
proposals
.
dims
()[
0
];
lod0
.
emplace
_back
(
num_proposals
);
lod0
.
push
_back
(
num_proposals
);
}
lod
.
emplace_back
(
lod0
);
rpn_rois
->
set_lod
(
lod
);
rpn_roi_probs
->
set_lod
(
lod
);
rpn_rois
->
Resize
({
num_proposals
,
4
});
...
...
@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
}
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
DeviceContext
&
ctx
,
const
Tensor
&
im_info_slice
,
const
platform
::
CPU
DeviceContext
&
ctx
,
const
Tensor
&
im_info_slice
,
const
Tensor
&
anchors
,
const
Tensor
&
variances
,
const
Tensor
&
bbox_deltas_slice
,
// [M, 4]
const
Tensor
&
scores_slice
,
// [N, 1]
...
...
@@ -392,10 +394,9 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
scores_slice
.
numel
();
++
i
)
{
index
[
i
]
=
i
;
}
std
::
function
<
bool
(
const
int64_t
&
,
const
int64_t
&
)
>
compare
=
[
scores_data
](
const
int64_t
&
i
,
const
int64_t
&
j
)
{
return
scores_data
[
i
]
>
scores_data
[
j
];
};
auto
compare
=
[
scores_data
](
const
int64_t
&
i
,
const
int64_t
&
j
)
{
return
scores_data
[
i
]
>
scores_data
[
j
];
};
if
(
pre_nms_top_n
<=
0
||
pre_nms_top_n
>=
scores_slice
.
numel
())
{
std
::
sort
(
index
,
index
+
scores_slice
.
numel
(),
compare
);
...
...
@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
class
GenerateProposalsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Scores"
,
"The scores of anchors should be foreground."
);
AddInput
(
"BboxDeltas"
,
"bbox_deltas."
);
AddInput
(
"ImInfo"
,
"Information for image reshape."
);
AddInput
(
"Anchors"
,
"All anchors."
);
AddInput
(
"Variances"
,
" variances"
);
AddOutput
(
"RpnRois"
,
"Anchors."
);
AddOutput
(
"RpnRoiProbs"
,
"Anchors."
);
AddAttr
<
int
>
(
"pre_nms_topN"
,
"pre_nms_topN"
);
AddAttr
<
int
>
(
"post_nms_topN"
,
"post_nms_topN"
);
AddAttr
<
float
>
(
"nms_thresh"
,
"nms_thres"
);
AddAttr
<
float
>
(
"min_size"
,
"min size"
);
AddInput
(
"Scores"
,
"(Tensor) The scores from conv is in shape (N, A, H, W), "
"N is batch size, A is number of anchors, "
"H and W are height and width of the feature map"
);
AddInput
(
"BboxDeltas"
,
"(Tensor) Bounding box deltas from conv is in "
"shape (N, 4*A, H, W)."
);
AddInput
(
"ImInfo"
,
"(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, scale)"
);
AddInput
(
"Anchors"
,
"(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4)."
);
AddInput
(
"Variances"
,
"(Tensor) Bounding box variances with same shape as `Anchors`."
);
AddOutput
(
"RpnRois"
,
"(LoDTensor), Output proposals with shape (rois_num, 4)."
);
AddOutput
(
"RpnRoiProbs"
,
"(LoDTensor) Scores of proposals with shape (rois_num, 1)."
);
AddAttr
<
int
>
(
"pre_nms_topN"
,
"Number of top scoring RPN proposals to keep before "
"applying NMS."
);
AddAttr
<
int
>
(
"post_nms_topN"
,
"Number of top scoring RPN proposals to keep after "
"applying NMS"
);
AddAttr
<
float
>
(
"nms_thresh"
,
"NMS threshold used on RPN proposals."
);
AddAttr
<
float
>
(
"min_size"
,
"Proposal height and width both need to be greater "
"than this min_size."
);
AddAttr
<
float
>
(
"eta"
,
"The parameter for adaptive NMS."
);
AddComment
(
R"DOC(
Generate Proposals OP
This operator proposes rois according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals
could be used to train detection net.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
This operator Generate bounding box proposals for Faster RCNN.
The propoasls are generated for a list of images based on image
score 'Scores', bounding box regression result 'BboxDeltas' as
well as predefined bounding box shapes 'anchors'. Greedy
non-maximum suppression is applied to generate the final bounding
boxes.
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC"
);
}
};
...
...
@@ -490,6 +503,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
generate_proposals
,
ops
::
GenerateProposalsOp
,
ops
::
GenerateProposalsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
generate_proposals
,
ops
::
GenerateProposalsKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
generate_proposals
,
ops
::
GenerateProposalsKernel
<
float
>
,
ops
::
GenerateProposalsKernel
<
double
>
);
paddle/fluid/operators/detection/generate_proposals_op.cu
浏览文件 @
1c591c39
...
...
@@ -16,10 +16,13 @@ limitations under the License. */
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,36 +39,38 @@ namespace {
int
const
kThreadsPerBlock
=
sizeof
(
uint64_t
)
*
8
;
template
<
typename
T
>
__global__
void
RangeInitKernel
(
const
T
start
,
const
T
delta
,
const
int
size
,
T
*
out
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
size
)
{
out
[
i
]
=
start
+
i
*
delta
;
}
}
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
__device__
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
template
<
typename
T
>
void
SortDescending
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
value
,
Tensor
*
value_out
,
Tensor
*
index_out
)
{
int
num
=
value
.
numel
();
static
void
SortDescending
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
value
,
Tensor
*
value_out
,
Tensor
*
index_out
)
{
int
num
=
static_cast
<
int
>
(
value
.
numel
());
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
int
block
=
512
;
auto
stream
=
ctx
.
stream
(
);
RangeInitKernel
<<<
DIVUP
(
num
,
block
),
block
,
0
,
stream
>>>
(
0
,
1
,
num
,
idx_in
);
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
ctx
,
num
)
;
for_range
(
RangeInitFunctor
{
0
,
1
,
idx_in
}
);
int
*
idx_out
=
index_out
->
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
const
T
*
keys_in
=
value
.
data
<
T
>
();
T
*
keys_out
=
value_out
->
mutable_data
<
T
>
({
num
},
ctx
.
GetPlace
());
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
nullptr
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
// Allocate temporary storage
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
void
*
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
...
...
@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
}
template
<
typename
T
>
__device__
__forceinline__
T
Min
(
T
x
,
T
y
)
{
return
x
<
y
?
x
:
y
;
}
template
<
typename
T
>
__device__
__forceinline__
T
Max
(
T
x
,
T
y
)
{
return
x
>
y
?
x
:
y
;
}
template
<
typename
T
>
__global__
void
BoxDecodeAndClipKernel
(
const
T
*
anchor
,
const
T
*
deltas
,
const
T
*
var
,
const
int
*
index
,
const
T
*
im_info
,
const
int
num
,
T
*
proposals
)
{
T
kBBoxClipDefault
=
log
(
1000.0
/
16.0
);
CUDA_1D_KERNEL_LOOP
(
i
,
num
)
{
struct
BoxDecodeAndClipFunctor
{
const
T
*
anchor
;
const
T
*
deltas
;
const
T
*
var
;
const
int
*
index
;
const
T
*
im_info
;
T
*
proposals
;
BoxDecodeAndClipFunctor
(
const
T
*
anchor
,
const
T
*
deltas
,
const
T
*
var
,
const
int
*
index
,
const
T
*
im_info
,
T
*
proposals
)
:
anchor
(
anchor
),
deltas
(
deltas
),
var
(
var
),
index
(
index
),
im_info
(
im_info
),
proposals
(
proposals
)
{}
T
bbox_clip_default
{
static_cast
<
T
>
(
kBBoxClipDefault
)};
__device__
void
operator
()(
size_t
i
)
{
int
k
=
index
[
i
]
*
4
;
T
axmin
=
anchor
[
k
];
T
aymin
=
anchor
[
k
+
1
];
...
...
@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T
dxmax
=
deltas
[
k
+
2
];
T
dymax
=
deltas
[
k
+
3
];
T
d_cx
=
0.
,
d_cy
=
0.
,
d_w
=
0.
,
d_h
=
0.
;
T
d_cx
,
d_cy
,
d_w
,
d_h
;
if
(
var
)
{
d_cx
=
cx
+
dxmin
*
w
*
var
[
k
];
d_cy
=
cy
+
dymin
*
h
*
var
[
k
+
1
];
d_w
=
exp
(
Min
<
T
>
(
dxmax
*
var
[
k
+
2
],
kBBoxClipD
efault
))
*
w
;
d_h
=
exp
(
Min
<
T
>
(
dymax
*
var
[
k
+
3
],
kBBoxClipD
efault
))
*
h
;
d_w
=
exp
(
Min
(
dxmax
*
var
[
k
+
2
],
bbox_clip_d
efault
))
*
w
;
d_h
=
exp
(
Min
(
dymax
*
var
[
k
+
3
],
bbox_clip_d
efault
))
*
h
;
}
else
{
d_cx
=
cx
+
dxmin
*
w
;
d_cy
=
cy
+
dymin
*
h
;
d_w
=
exp
(
Min
<
T
>
(
dxmax
,
kBBoxClipD
efault
))
*
w
;
d_h
=
exp
(
Min
<
T
>
(
dymax
,
kBBoxClipD
efault
))
*
h
;
d_w
=
exp
(
Min
(
dxmax
,
bbox_clip_d
efault
))
*
w
;
d_h
=
exp
(
Min
(
dymax
,
bbox_clip_d
efault
))
*
h
;
}
T
oxmin
=
d_cx
-
d_w
*
0.5
;
...
...
@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T
oxmax
=
d_cx
+
d_w
*
0.5
-
1.
;
T
oymax
=
d_cy
+
d_h
*
0.5
-
1.
;
proposals
[
i
*
4
]
=
Max
<
T
>
(
Min
<
T
>
(
oxmin
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
1
]
=
Max
<
T
>
(
Min
<
T
>
(
oymin
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
+
2
]
=
Max
<
T
>
(
Min
<
T
>
(
oxmax
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
3
]
=
Max
<
T
>
(
Min
<
T
>
(
oymax
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
]
=
Max
(
Min
(
oxmin
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
1
]
=
Max
(
Min
(
oymin
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
+
2
]
=
Max
(
Min
(
oxmax
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
3
]
=
Max
(
Min
(
oymax
,
im_info
[
0
]
-
1.
),
0.
);
}
}
__device__
__forceinline__
T
Min
(
T
a
,
T
b
)
const
{
return
a
>
b
?
b
:
a
;
}
__device__
__forceinline__
T
Max
(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
,
int
BlockSize
>
__global__
void
FilterBBoxes
(
const
T
*
bboxes
,
const
T
*
im_info
,
const
T
min_size
,
const
int
num
,
int
*
keep_
num
,
int
*
keep
)
{
static
__global__
void
FilterBBoxes
(
const
T
*
bboxes
,
const
T
*
im_info
,
const
T
min_size
,
const
int
num
,
int
*
keep_num
,
int
*
keep
)
{
T
im_h
=
im_info
[
0
];
T
im_w
=
im_info
[
1
];
T
im_scale
=
im_info
[
2
];
...
...
@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
}
}
__device__
inline
float
IoU
(
const
float
*
a
,
const
float
*
b
)
{
static
__device__
inline
float
IoU
(
const
float
*
a
,
const
float
*
b
)
{
float
left
=
max
(
a
[
0
],
b
[
0
]),
right
=
min
(
a
[
2
],
b
[
2
]);
float
top
=
max
(
a
[
1
],
b
[
1
]),
bottom
=
min
(
a
[
3
],
b
[
3
]);
float
width
=
max
(
right
-
left
+
1
,
0.
f
),
height
=
max
(
bottom
-
top
+
1
,
0.
f
);
...
...
@@ -191,8 +205,9 @@ __device__ inline float IoU(const float *a, const float *b) {
return
inter_s
/
(
s_a
+
s_b
-
inter_s
);
}
__global__
void
NMSKernel
(
const
int
n_boxes
,
const
float
nms_overlap_thresh
,
const
float
*
dev_boxes
,
uint64_t
*
dev_mask
)
{
static
__global__
void
NMSKernel
(
const
int
n_boxes
,
const
float
nms_overlap_thresh
,
const
float
*
dev_boxes
,
uint64_t
*
dev_mask
)
{
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
...
...
@@ -234,9 +249,9 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
}
template
<
typename
T
>
void
NMS
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
proposals
,
const
Tensor
&
sorted_indices
,
const
T
nms_threshold
,
Tensor
*
keep_out
)
{
static
void
NMS
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
proposals
,
const
Tensor
&
sorted_indices
,
const
T
nms_threshold
,
Tensor
*
keep_out
)
{
int
boxes_num
=
proposals
.
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
boxes_num
,
sorted_indices
.
dims
()[
0
]);
...
...
@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const
T
*
boxes
=
proposals
.
data
<
T
>
();
auto
place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
int
size_bytes
=
boxes_num
*
col_blocks
*
sizeof
(
uint64_t
);
uint64_t
*
d_mask
=
reinterpret_cast
<
uint64_t
*>
(
memory
::
Alloc
(
place
,
size_bytes
));
NMSKernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_threshold
,
boxes
,
d_mask
);
uint64_t
*
h_mask
=
reinterpret_cast
<
uint64_t
*>
(
memory
::
Alloc
(
platform
::
CPUPlace
(),
size_bytes
));
memory
::
Copy
(
platform
::
CPUPlace
(),
h_mask
,
place
,
d_mask
,
size_bytes
,
0
);
framework
::
Vector
<
uint64_t
>
mask
(
boxes_num
*
col_blocks
);
NMSKernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_threshold
,
boxes
,
mask
.
CUDAMutableData
(
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
())));
std
::
vector
<
uint64_t
>
remv
(
col_blocks
);
memset
(
&
remv
[
0
],
0
,
sizeof
(
uint64_t
)
*
col_blocks
);
...
...
@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
if
(
!
(
remv
[
nblock
]
&
(
1ULL
<<
inblock
)))
{
++
num_to_keep
;
keep_vec
.
push_back
(
i
);
uint64_t
*
p
=
&
h_
mask
[
0
]
+
i
*
col_blocks
;
uint64_t
*
p
=
&
mask
[
0
]
+
i
*
col_blocks
;
for
(
int
j
=
nblock
;
j
<
col_blocks
;
j
++
)
{
remv
[
j
]
|=
p
[
j
];
}
...
...
@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
int
*
keep
=
keep_out
->
mutable_data
<
int
>
({
num_to_keep
},
ctx
.
GetPlace
());
memory
::
Copy
(
place
,
keep
,
platform
::
CPUPlace
(),
keep_vec
.
data
(),
sizeof
(
int
)
*
num_to_keep
,
0
);
memory
::
Free
(
place
,
d_mask
);
memory
::
Free
(
platform
::
CPUPlace
(),
h_mask
);
}
template
<
typename
T
>
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
st
atic
st
d
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
im_info
,
const
Tensor
&
anchors
,
const
Tensor
&
variances
,
const
Tensor
&
bbox_deltas
,
// [M, 4]
...
...
@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
// 2. box decode and clipping
Tensor
proposals
;
proposals
.
mutable_data
<
T
>
({
pre_nms_num
,
4
},
ctx
.
GetPlace
());
int
block
=
512
;
auto
stream
=
ctx
.
stream
();
BoxDecodeAndClipKernel
<
T
><<<
DIVUP
(
pre_nms_num
,
block
),
block
,
0
,
stream
>>>
(
anchors
.
data
<
T
>
(),
bbox_deltas
.
data
<
T
>
(),
variances
.
data
<
T
>
(),
index_sort
.
data
<
int
>
(),
im_info
.
data
<
T
>
(),
pre_nms_num
,
proposals
.
data
<
T
>
());
{
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
ctx
,
pre_nms_num
);
for_range
(
BoxDecodeAndClipFunctor
<
T
>
{
anchors
.
data
<
T
>
(),
bbox_deltas
.
data
<
T
>
(),
variances
.
data
<
T
>
(),
index_sort
.
data
<
int
>
(),
im_info
.
data
<
T
>
(),
proposals
.
data
<
T
>
()});
}
// 3. filter
Tensor
keep_index
,
keep_num_t
;
keep_index
.
mutable_data
<
int
>
({
pre_nms_num
},
ctx
.
GetPlace
());
keep_num_t
.
mutable_data
<
int
>
({
1
},
ctx
.
GetPlace
());
min_size
=
std
::
max
(
min_size
,
1.0
f
);
auto
stream
=
ctx
.
stream
();
FilterBBoxes
<
T
,
512
><<<
1
,
512
,
0
,
stream
>>>
(
proposals
.
data
<
T
>
(),
im_info
.
data
<
T
>
(),
min_size
,
pre_nms_num
,
keep_num_t
.
data
<
int
>
(),
keep_index
.
data
<
int
>
());
...
...
@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto
*
scores
=
context
.
Input
<
Tensor
>
(
"Scores"
);
auto
*
bbox_deltas
=
context
.
Input
<
Tensor
>
(
"BboxDeltas"
);
auto
*
im_info
=
context
.
Input
<
Tensor
>
(
"ImInfo"
);
auto
*
anchors
=
context
.
Input
<
Tensor
>
(
"Anchors"
);
auto
*
variances
=
context
.
Input
<
Tensor
>
(
"Variances"
);
auto
anchors
=
detail
::
Ref
(
context
.
Input
<
Tensor
>
(
"Anchors"
),
"Cannot find input Anchors(%s) in scope"
,
context
.
Inputs
(
"Anchors"
)[
0
]);
auto
variances
=
detail
::
Ref
(
context
.
Input
<
Tensor
>
(
"Variances"
),
"Cannot find input Variances(%s) in scope"
,
context
.
Inputs
(
"Variances"
)[
0
]);
auto
*
rpn_rois
=
context
.
Output
<
LoDTensor
>
(
"RpnRois"
);
auto
*
rpn_roi_probs
=
context
.
Output
<
LoDTensor
>
(
"RpnRoiProbs"
);
...
...
@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
trans
(
dev_ctx
,
*
bbox_deltas
,
&
bbox_deltas_swap
,
axis
);
trans
(
dev_ctx
,
*
scores
,
&
scores_swap
,
axis
);
Tensor
*
anchor
=
const_cast
<
framework
::
Tensor
*>
(
anchors
);
anchor
->
Resize
({
anchors
->
numel
()
/
4
,
4
});
Tensor
*
var
=
const_cast
<
framework
::
Tensor
*>
(
variances
);
var
->
Resize
({
var
->
numel
()
/
4
,
4
});
anchors
.
Resize
({
anchors
.
numel
()
/
4
,
4
});
variances
.
Resize
({
variances
.
numel
()
/
4
,
4
});
rpn_rois
->
mutable_data
<
T
>
({
bbox_deltas
->
numel
()
/
4
,
4
},
context
.
GetPlace
());
...
...
@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice
.
Resize
({
h_score
*
w_score
*
c_score
,
1
});
std
::
pair
<
Tensor
,
Tensor
>
box_score_pair
=
ProposalForOneImage
<
T
>
(
dev_ctx
,
im_info_slice
,
*
anchor
,
*
var
,
ProposalForOneImage
<
T
>
(
dev_ctx
,
im_info_slice
,
anchors
,
variances
,
bbox_deltas_slice
,
scores_slice
,
pre_nms_top_n
,
post_nms_top_n
,
nms_thresh
,
min_size
,
eta
);
Tensor
proposals
=
box_score_pair
.
first
;
Tensor
scores
=
box_score_pair
.
second
;
Tensor
&
proposals
=
box_score_pair
.
first
;
Tensor
&
scores
=
box_score_pair
.
second
;
memory
::
Copy
(
place
,
rpn_rois_data
+
num_proposals
*
4
,
place
,
proposals
.
data
<
T
>
(),
sizeof
(
T
)
*
proposals
.
numel
(),
0
);
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
1c591c39
...
...
@@ -86,7 +86,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context
s
->
response_call_back_
=
nullptr
;
platform
::
RecordEvent
record_event
(
method
,
p_ctx
);
platform
::
Record
RPC
Event
record_event
(
method
,
p_ctx
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
);
...
...
@@ -143,7 +143,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordEvent
record_event
(
method
,
p_ctx
);
platform
::
Record
RPC
Event
record_event
(
method
,
p_ctx
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/GetVariable"
,
buf
,
&
cq_
);
...
...
@@ -191,7 +191,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordEvent
record_event
(
method
,
p_ctx
);
platform
::
Record
RPC
Event
record_event
(
method
,
p_ctx
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/PrefetchVariable"
,
req
,
...
...
@@ -221,7 +221,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
BATCH_BARRIER_MESSAGE
);
platform
::
RecordEvent
record_event
(
method
,
nullptr
);
platform
::
Record
RPC
Event
record_event
(
method
,
nullptr
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
@@ -246,7 +246,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
platform
::
RecordEvent
record_event
(
method
,
nullptr
);
platform
::
Record
RPC
Event
record_event
(
method
,
nullptr
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
@@ -271,7 +271,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
COMPLETE_MESSAGE
);
platform
::
RecordEvent
record_event
(
method
,
nullptr
);
platform
::
Record
RPC
Event
record_event
(
method
,
nullptr
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
@@ -301,7 +301,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
req
.
set_varname
(
CHECKPOINT_SAVE_MESSAGE
);
req
.
set_out_varname
(
dir
);
platform
::
RecordEvent
record_event
(
method
,
nullptr
);
platform
::
Record
RPC
Event
record_event
(
method
,
nullptr
);
auto
rpc
=
s
->
stub_
->
AsyncCheckpointNotify
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
1c591c39
...
...
@@ -36,7 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
platform
::
RecordEvent
record_event
(
"serial"
,
&
ctx
);
platform
::
Record
RPC
Event
record_event
(
"serial"
,
&
ctx
);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback
destroy_callback
=
[](
void
*
backing
)
{};
...
...
@@ -148,7 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
platform
::
RecordEvent
record_event
(
"deserial"
,
&
ctx
);
platform
::
Record
RPC
Event
record_event
(
"deserial"
,
&
ctx
);
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
...
...
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.cc
0 → 100644
浏览文件 @
1c591c39
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h"
#include <algorithm> // for min, max
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
namespace
paddle
{
namespace
operators
{
void
FusionSeqConvEltAddReluOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FusionSeqConvEltAddReluOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Filter"
),
"Input(Filter) of FusionSeqConvEltAddReluOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of FusionSeqConvEltAddReluOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FusionSeqConvEltAddReluOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ColMat"
),
"Output(ColMat) of FusionSeqConvEltAddReluOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Filter"
);
int
context_length
=
ctx
->
Attrs
().
Get
<
int
>
(
"contextLength"
);
PADDLE_ENFORCE
(
ctx
->
Attrs
().
Get
<
int
>
(
"contextStride"
)
==
1
,
"Currently, FusionSeqConvEltAddReluOp only supports contextStride=1."
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
2
&&
w_dims
.
size
()
==
2
,
"Input(X, Filter) should be 2-D tensor."
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
2
&&
w_dims
.
size
()
==
2
,
"Input(X, Filter) should be 2-D tensor."
);
PADDLE_ENFORCE
(
w_dims
[
0
]
==
context_length
*
x_dims
[
1
],
"Filter's height should be context_length * "
"input_hidden_size ."
);
PADDLE_ENFORCE_GT
(
context_length
+
ctx
->
Attrs
().
Get
<
int
>
(
"contextStart"
),
0
,
"contextStart size should be smaller than contextLength."
);
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
w_dims
[
1
]});
ctx
->
SetOutputDim
(
"ColMat"
,
{
x_dims
[
0
],
w_dims
[
0
]});
ctx
->
ShareLoD
(
"X"
,
"Out"
);
}
framework
::
OpKernelType
FusionSeqConvEltAddReluOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
void
FusionSeqConvEltAddReluOpMaker
::
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x."
);
// PaddingData only support false yet, should be ensured at pass.
AddInput
(
"Filter"
,
"(Tensor) same as the input(Filter) of sequence conv op is an "
"learnable parameter."
"This is a tensor with shape (K, N), where K is the "
"context_length * dim size of x, N is the output feature size."
);
AddInput
(
"Bias"
,
"(Tensor) the learnable weights. shape (1, N), where N is the "
"output feature size"
);
AddOutput
(
"Out"
,
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, N), where, T is the "
"total time steps in this mini-batch, N is the output feature size."
);
AddOutput
(
"ColMat"
,
"(Tensor) (T, K), where T is where T is the "
"total time steps in this mini-batch, K is height of Filter"
)
.
AsIntermediate
();
AddAttr
<
int
>
(
"contextLength"
,
"(int) the contextLength of FusionSeqConvEltAddReluOp is the "
"height of the convolution kernel."
)
.
GreaterThan
(
0
);
AddAttr
<
int
>
(
"contextStart"
,
"(int, default:0) the contextStart of FusionSeqConvEltAddReluOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative. The negative number "
"means to pad contextStart time-steps of zeros or learnable "
"parameters at the beginning of each instance. The positive "
"number means to skip contextStart time-steps of each "
"instance."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"contextStride"
,
"(int, default:1) the contextStride of FusionSeqConvEltAddReluOp "
"represents the stride length of convolution kernel. "
"Currently, FusionSeqConvEltAddReluOp only supports"
"contextStride=1."
)
.
SetDefault
(
1
)
.
GreaterThan
(
0
);
AddComment
(
R"DOC(
Fusion Sequence Conv and ElementwiseAdd Operator.
)DOC"
);
}
template
<
typename
T
>
class
FusionSeqConvEltAddReluKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
w
=
ctx
.
Input
<
Tensor
>
(
"Filter"
);
auto
*
b
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
y
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
col
=
ctx
.
Output
<
Tensor
>
(
"ColMat"
);
auto
x_lod
=
x
->
lod
();
auto
x_dims
=
x
->
dims
();
auto
w_dims
=
w
->
dims
();
PADDLE_ENFORCE_EQ
(
b
->
numel
(),
w_dims
[
1
],
"bias size should be equal to output feature size."
);
PADDLE_ENFORCE_EQ
(
x_lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
w_data
=
w
->
data
<
T
>
();
const
T
*
b_data
=
b
->
data
<
T
>
();
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
col_data
=
col
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
context_start
=
ctx
.
Attr
<
int
>
(
"contextStart"
);
int
context_length
=
ctx
.
Attr
<
int
>
(
"contextLength"
);
int
up_pad
=
std
::
max
(
0
,
-
context_start
);
int
down_pad
=
std
::
max
(
0
,
context_start
+
context_length
-
1
);
// im2col
int
src_mat_w
=
static_cast
<
int
>
(
x_dims
[
1
]);
int
src_mat_w_sz
=
src_mat_w
*
sizeof
(
T
);
int
col_mat_w
=
static_cast
<
int
>
(
w_dims
[
0
]);
int
col_mat_w_sz
=
col_mat_w
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
x_lod
[
0
].
size
())
-
1
;
++
i
)
{
int
st
=
x_lod
[
0
][
i
];
int
ed
=
x_lod
[
0
][
i
+
1
];
const
T
*
src_data
=
x_data
+
st
*
src_mat_w
;
T
*
dst_data
=
col_data
+
st
*
col_mat_w
;
int
seq_len
=
ed
-
st
;
if
(
seq_len
>
up_pad
+
down_pad
)
{
// zero all up_pad and fill data
std
::
memset
(
dst_data
,
0
,
up_pad
*
col_mat_w_sz
);
dst_data
=
dst_data
+
up_pad
*
src_mat_w
;
int
copy_size
=
col_mat_w_sz
-
up_pad
*
src_mat_w_sz
;
for
(
int
j
=
0
;
j
<
up_pad
;
++
j
)
{
// blas.VCOPY?
std
::
memcpy
(
dst_data
,
src_data
,
copy_size
);
dst_data
+=
(
col_mat_w
-
src_mat_w
);
copy_size
+=
src_mat_w_sz
;
}
// fill data
for
(
int
j
=
0
;
j
<
seq_len
-
up_pad
-
down_pad
;
++
j
)
{
std
::
memcpy
(
dst_data
,
src_data
,
copy_size
);
dst_data
+=
col_mat_w
;
src_data
+=
src_mat_w
;
}
// zero all down_pad and fill data
std
::
memset
(
dst_data
,
0
,
down_pad
*
col_mat_w_sz
);
copy_size
-=
src_mat_w_sz
;
for
(
int
j
=
0
;
j
<
down_pad
;
++
j
)
{
std
::
memcpy
(
dst_data
,
src_data
,
copy_size
);
dst_data
+=
col_mat_w
;
src_data
+=
src_mat_w
;
copy_size
-=
src_mat_w_sz
;
}
}
else
{
PADDLE_ENFORCE_GE
(
context_length
,
up_pad
+
down_pad
+
1
);
std
::
memset
(
dst_data
,
0
,
seq_len
*
col_mat_w_sz
);
dst_data
=
dst_data
+
up_pad
*
src_mat_w
;
int
zero_sz
=
up_pad
*
src_mat_w_sz
;
int
cur_src_sz
=
seq_len
*
src_mat_w_sz
;
for
(
int
j
=
0
;
j
<
std
::
min
(
up_pad
,
seq_len
);
++
j
)
{
int
copy_size
=
std
::
min
(
cur_src_sz
,
col_mat_w_sz
-
zero_sz
);
std
::
memcpy
(
dst_data
,
src_data
,
copy_size
);
dst_data
+=
(
col_mat_w
-
src_mat_w
);
zero_sz
-=
src_mat_w_sz
;
}
// from bottom
dst_data
=
col_data
+
ed
*
col_mat_w
;
src_data
=
x_data
+
st
*
src_mat_w
;
zero_sz
=
down_pad
*
src_mat_w_sz
;
for
(
int
j
=
1
;
j
<=
std
::
min
(
down_pad
,
seq_len
);
++
j
)
{
int
copy_size
=
std
::
min
(
cur_src_sz
,
col_mat_w_sz
-
zero_sz
);
std
::
memcpy
(
dst_data
-
(
zero_sz
+
copy_size
)
/
sizeof
(
T
),
src_data
+
std
::
max
(
seq_len
-
j
-
up_pad
,
0
)
*
src_mat_w
,
copy_size
);
dst_data
-=
col_mat_w
;
zero_sz
-=
src_mat_w_sz
;
}
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
w_dims
[
1
],
w_dims
[
0
],
col_data
,
w_data
,
y_data
,
b_data
,
true
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fusion_seqconv_eltadd_relu
,
ops
::
FusionSeqConvEltAddReluOp
,
ops
::
FusionSeqConvEltAddReluOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
fusion_seqconv_eltadd_relu
,
ops
::
FusionSeqConvEltAddReluKernel
<
float
>
,
ops
::
FusionSeqConvEltAddReluKernel
<
double
>
);
paddle/fluid/operators/fusion_seqconv_eltadd_relu_op.h
0 → 100644
浏览文件 @
1c591c39
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
class
FusionSeqConvEltAddReluOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusionSeqConvEltAddReluOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/gather.h
浏览文件 @
1c591c39
...
...
@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
// check index of shape 1-D
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
);
int
index_size
=
index
.
dims
()[
0
];
int
64_t
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
output_dims
[
0
]
=
index_size
;
const
T
*
p_src
=
src
.
data
<
T
>
();
const
int
*
p_index
=
index
.
data
<
int
>
();
...
...
@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const
size_t
slice_bytes
=
slice_size
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
index_size
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
index_size
;
++
i
)
{
int
index_
=
p_index
[
i
];
memcpy
(
p_output
+
i
*
slice_size
,
p_src
+
index_
*
slice_size
,
slice_bytes
);
}
...
...
paddle/fluid/operators/math/fc_compute.h
浏览文件 @
1c591c39
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
DECLARE_int32
(
paddle_num_threads
);
...
...
@@ -30,20 +31,25 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
if
(
B
==
NULL
)
{
return
;
}
if
(
relu
)
{
const
auto
&
vaddrelu
=
jitkernel
::
KernelPool
::
Instance
()
.
template
Get
<
jitkernel
::
VAddReluKernel
<
T
>
>
(
N
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
T
*
dst
=
Y
+
i
*
N
;
vaddrelu
->
Compute
(
B
,
dst
,
dst
);
}
}
else
{
const
auto
&
vadd
=
jitkernel
::
KernelPool
::
Instance
()
.
template
Get
<
jitkernel
::
VAddKernel
<
T
>
>
(
N
);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
blas
.
AXPY
(
N
,
static_cast
<
T
>
(
1
),
B
,
Y
+
i
*
N
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
T
*
dst
=
Y
+
i
*
N
;
vadd
->
Compute
(
B
,
dst
,
dst
);
}
}
if
(
!
relu
)
{
return
;
}
// TODO(TJ): fuse relu
LOG
(
FATAL
)
<<
"Not implemented!"
;
}
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
1c591c39
...
...
@@ -86,6 +86,12 @@ class VAddBiasKernel : public Kernel {
virtual
void
Compute
(
const
T
a
,
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
class
VAddReluKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
};
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
public:
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
1c591c39
...
...
@@ -378,11 +378,99 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void
Compute
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
/* VAddRelu JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
z
[
i
]
=
z
[
i
]
>
0
?
z
[
i
]
:
0
;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx = _mm256_loadu_ps(x); \
__m256 tmpy = _mm256_loadu_ps(y); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
_mm256_storeu_ps(z, tmpy); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(y); \
tmp0 = _mm256_add_ps(tmp0, tmp1); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
tmp1 = _mm256_add_ps(tmp1, tmp2); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(z, tmp0); \
_mm256_storeu_ps(z + 8, tmp1); \
}
#define INTRI_COMMON_FLOAT(isa, block) \
template <> \
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
: VAddReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
} \
template <> \
void VAddReluKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmpx = _mm256_loadu_ps(x + i); \
__m256 tmpy = _mm256_loadu_ps(y + i); \
tmpy = _mm256_add_ps(tmpx, tmpy); \
tmpy = _mm256_max_ps(tmpy, zeros); \
_mm256_storeu_ps(z + i, tmpy); \
} \
for (int i = this->end_; i < this->num_; ++i) { \
z[i] = x[i] + y[i]; \
z[i] = z[i] > 0 ? z[i] : 0; \
} \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI_COMMON_FLOAT
(
jit
::
avx
,
kGT16
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI_COMMON_FLOAT
(
jit
::
avx2
,
kGT16
);
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI_COMMON_FLOAT
(
jit
::
avx512f
,
kGT16
);
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
1c591c39
...
...
@@ -712,6 +712,63 @@ TEST(JitKernel, vadd) {
}
}
void
vaddrelu_ref
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
+
y
[
i
];
z
[
i
]
=
z
[
i
]
>
0
?
z
[
i
]
:
0
;
}
}
void
vaddrelu_better
(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
vadd
->
Compute
(
x
,
y
,
z
);
vrelu
->
Compute
(
z
,
z
);
}
TEST
(
JitKernel
,
vaddrelu
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
RandomVec
<
float
>
(
d
,
y
.
data
());
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddReluKernel
<
float
>
>
(
d
);
const
auto
&
vadd
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddKernel
<
float
>
>
(
d
);
const
auto
&
vrelu
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VReluKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vadd_ref
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
trefe
=
GetCurrentUS
();
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vaddrelu_better
(
vadd
,
vrelu
,
x_data
,
y_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, better takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
TEST
(
JitKernel
,
pool
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
const
int
frame_size
=
4
;
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
1c591c39
...
...
@@ -35,6 +35,16 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
return
it
->
second
.
get
();
}
const
std
::
vector
<
const
DeviceContext
*>
DeviceContextPool
::
GetAllDeviceContexts
()
const
{
std
::
vector
<
const
DeviceContext
*>
all_device_ctx
;
all_device_ctx
.
reserve
(
device_contexts_
.
size
());
for
(
auto
&
dev_ctx
:
device_contexts_
)
{
all_device_ctx
.
emplace_back
(
dev_ctx
.
second
.
get
());
}
return
all_device_ctx
;
}
DeviceContextPool
::
DeviceContextPool
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
);
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
1c591c39
...
...
@@ -217,6 +217,9 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
/*! \brief Return all the device contexts. */
const
std
::
vector
<
const
DeviceContext
*>
GetAllDeviceContexts
()
const
;
template
<
typename
Place
>
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*
GetByPlace
(
const
Place
&
place
)
{
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
1c591c39
...
...
@@ -30,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool
(
enable_rpc_profiler
,
false
,
"Enable rpc profiler or not."
);
namespace
paddle
{
namespace
platform
{
...
...
@@ -193,6 +195,13 @@ RecordEvent::~RecordEvent() {
PopEvent
(
name_
,
dev_ctx_
);
}
RecordRPCEvent
::
RecordRPCEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
)
{
if
(
FLAGS_enable_rpc_profiler
)
{
event_
.
reset
(
new
platform
::
RecordEvent
(
name
,
dev_ctx
));
}
}
RecordBlock
::
RecordBlock
(
int
block_id
)
:
is_enabled_
(
false
),
start_ns_
(
PosixInNsec
())
{
std
::
lock_guard
<
std
::
mutex
>
l
(
profiler_mu
);
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
1c591c39
...
...
@@ -87,6 +87,16 @@ struct RecordEvent {
std
::
string
full_name_
;
};
class
RecordRPCEvent
{
public:
// dev_ctx can be set to nullptr if device is cpu.
RecordRPCEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
);
~
RecordRPCEvent
()
{}
private:
std
::
unique_ptr
<
RecordEvent
>
event_
;
};
struct
RecordBlock
{
explicit
RecordBlock
(
int
block_id
);
~
RecordBlock
();
...
...
python/paddle/fluid/__init__.py
浏览文件 @
1c591c39
...
...
@@ -120,6 +120,7 @@ def __bootstrap__():
read_env_flags
.
append
(
'rpc_deadline'
)
read_env_flags
.
append
(
'rpc_server_profile_period'
)
read_env_flags
.
append
(
'rpc_server_profile_path'
)
read_env_flags
.
append
(
'enable_rpc_profiler'
)
if
core
.
is_compiled_with_cuda
():
read_env_flags
+=
[
...
...
python/paddle/fluid/layer_helper.py
浏览文件 @
1c591c39
...
...
@@ -324,10 +324,19 @@ class LayerHelper(object):
raise
ValueError
(
"no Parameter name %s found"
%
name
)
return
param
def
create_tmp_variable
(
self
,
dtype
,
stop_gradient
=
False
):
def
create_variable_for_type_inference
(
self
,
dtype
,
stop_gradient
=
False
):
"""Create a temporary variable that should be type inferred layer.
Note:
The default type will be set to LOD_TENSOR. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
return
self
.
main_program
.
current_block
().
create_var
(
name
=
unique_name
.
generate
(
"."
.
join
([
self
.
name
,
'tmp'
])),
dtype
=
dtype
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
stop_gradient
)
...
...
@@ -388,7 +397,7 @@ class LayerHelper(object):
b
=
self
.
create_parameter
(
attr
=
bias_attr
,
shape
=
size
,
dtype
=
input_var
.
dtype
,
is_bias
=
True
)
tmp
=
self
.
create_
tmp_variabl
e
(
dtype
=
input_var
.
dtype
)
tmp
=
self
.
create_
variable_for_type_inferenc
e
(
dtype
=
input_var
.
dtype
)
self
.
append_op
(
type
=
'elementwise_add'
,
inputs
=
{
'X'
:
[
input_var
],
...
...
@@ -414,7 +423,7 @@ class LayerHelper(object):
tmp
=
input_var
# NOTE(dzhwinter): some activation support inplace compution.
if
not
core
.
IsInplace
(
act_type
):
tmp
=
self
.
create_
tmp_variabl
e
(
dtype
=
input_var
.
dtype
)
tmp
=
self
.
create_
variable_for_type_inferenc
e
(
dtype
=
input_var
.
dtype
)
self
.
append_op
(
type
=
act_type
,
inputs
=
{
"X"
:
[
input_var
]},
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
1c591c39
...
...
@@ -80,8 +80,8 @@ def split_lod_tensor(input, mask, level=0):
"""
helper
=
LayerHelper
(
'split_lod_tensor'
,
**
locals
())
out_true
=
helper
.
create_
tmp_variabl
e
(
dtype
=
input
.
dtype
)
out_false
=
helper
.
create_
tmp_variabl
e
(
dtype
=
input
.
dtype
)
out_true
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
input
.
dtype
)
out_false
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
type
=
'split_lod_tensor'
,
inputs
=
{
...
...
@@ -131,7 +131,7 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0):
in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
"""
helper
=
LayerHelper
(
'merge_lod_tensor'
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
in_true
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
in_true
.
dtype
)
helper
.
append_op
(
type
=
'merge_lod_tensor'
,
inputs
=
{
'X'
:
x
,
...
...
@@ -524,7 +524,7 @@ class StaticRNN(object):
if
not
isinstance
(
o
,
Variable
):
raise
TypeError
(
"step output takes a Variable"
)
tmp_o
=
self
.
helper
.
create_
tmp_variabl
e
(
dtype
=
o
.
dtype
)
tmp_o
=
self
.
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
o
.
dtype
)
self
.
helper
.
append_op
(
type
=
'rnn_memory_helper'
,
inputs
=
{
'X'
:
[
o
]},
...
...
@@ -606,7 +606,8 @@ class StaticRNN(object):
pre_memories
.
append
(
mem
.
pre_mem
.
name
)
mem_var
=
rnn_block
.
var
(
mem
.
mem
.
name
)
assert
isinstance
(
mem_var
,
Variable
)
new_mem
=
self
.
helper
.
create_tmp_variable
(
dtype
=
mem_var
.
dtype
)
new_mem
=
self
.
helper
.
create_variable_for_type_inference
(
dtype
=
mem_var
.
dtype
)
rnn_block
.
append_op
(
type
=
'rnn_memory_helper'
,
...
...
@@ -813,7 +814,7 @@ def max_sequence_len(rank_table):
${out_comment}.
"""
helper
=
LayerHelper
(
"max_seqence_len"
,
**
locals
())
res
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"int64"
)
res
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"int64"
)
helper
.
append_op
(
type
=
"max_sequence_len"
,
inputs
=
{
"RankTable"
:
rank_table
},
...
...
@@ -884,7 +885,7 @@ def array_to_lod_tensor(x, table):
lod_tensor = fluid.layers.array_to_lod_tensor(array, table)
"""
helper
=
LayerHelper
(
"array_to_lod_tensor"
,
**
locals
())
tmp
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
tmp
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"array_to_lod_tensor"
,
inputs
=
{
'X'
:
x
,
...
...
@@ -915,7 +916,7 @@ def increment(x, value=1.0, in_place=True):
"""
helper
=
LayerHelper
(
"increment"
,
**
locals
())
if
not
in_place
:
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
else
:
out
=
x
helper
.
append_op
(
...
...
@@ -1012,7 +1013,7 @@ def less_than(x, y, force_cpu=None, cond=None, **ignored):
"""
helper
=
LayerHelper
(
"less_than"
,
**
locals
())
if
cond
is
None
:
cond
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'bool'
)
cond
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'bool'
)
cond
.
stop_gradient
=
True
attrs
=
dict
()
...
...
@@ -1051,7 +1052,7 @@ def equal(x, y, cond=None, **ignored):
"""
helper
=
LayerHelper
(
"equal"
,
**
locals
())
if
cond
is
None
:
cond
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'bool'
)
cond
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'bool'
)
cond
.
stop_gradient
=
True
helper
.
append_op
(
...
...
@@ -1098,7 +1099,7 @@ def array_read(array, i):
array
,
Variable
)
or
array
.
type
!=
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
:
raise
TypeError
(
"array should be tensor array vairable"
)
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
array
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
array
.
dtype
)
helper
.
append_op
(
type
=
'read_from_array'
,
inputs
=
{
'X'
:
[
array
],
...
...
@@ -1133,7 +1134,7 @@ def shrink_memory(x, i, table):
usage.
"""
helper
=
LayerHelper
(
'shrink_memory'
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'shrink_rnn_memory'
,
inputs
=
{
'X'
:
[
x
],
...
...
@@ -1170,7 +1171,7 @@ def array_length(array):
"""
helper
=
LayerHelper
(
'array_length'
,
**
locals
())
tmp
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'int64'
)
tmp
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'int64'
)
tmp
.
stop_gradient
=
True
helper
.
append_op
(
type
=
'lod_array_length'
,
inputs
=
{
'X'
:
[
array
]},
outputs
=
{
'Out'
:
[
tmp
]})
...
...
@@ -1590,7 +1591,7 @@ class DynamicRNN(object):
self
.
mem_dict
=
dict
()
self
.
output_array
=
[]
self
.
outputs
=
[]
self
.
cond
=
self
.
helper
.
create_
tmp_variabl
e
(
dtype
=
'bool'
)
self
.
cond
=
self
.
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'bool'
)
self
.
cond
.
stop_gradient
=
False
self
.
while_op
=
While
(
self
.
cond
)
self
.
input_array
=
[]
...
...
@@ -1924,7 +1925,7 @@ def reorder_lod_tensor_by_rank(x, rank_table):
helper
.
is_instance
(
'x'
,
Variable
)
helper
.
is_instance
(
'rank_table'
,
Variable
)
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'reorder_lod_tensor_by_rank'
,
inputs
=
{
'X'
:
[
x
],
...
...
@@ -1958,7 +1959,7 @@ def is_empty(x, cond=None, **ignored):
"""
helper
=
LayerHelper
(
"is_empty"
,
**
locals
())
if
cond
is
None
:
cond
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'bool'
)
cond
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'bool'
)
cond
.
stop_gradient
=
True
elif
not
isinstance
(
cond
,
Variable
):
raise
TypeError
(
"cond takes a variable"
)
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
1c591c39
...
...
@@ -287,7 +287,8 @@ def detection_output(loc,
scores
=
nn
.
reshape
(
x
=
scores
,
shape
=
compile_shape
,
actual_shape
=
run_shape
)
scores
=
nn
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
])
scores
.
stop_gradient
=
True
nmsed_outs
=
helper
.
create_tmp_variable
(
dtype
=
decoded_box
.
dtype
)
nmsed_outs
=
helper
.
create_variable_for_type_inference
(
dtype
=
decoded_box
.
dtype
)
helper
.
append_op
(
type
=
"multiclass_nms"
,
inputs
=
{
'Scores'
:
scores
,
...
...
@@ -319,7 +320,7 @@ def iou_similarity(x, y, name=None):
"""
helper
=
LayerHelper
(
"iou_similarity"
,
**
locals
())
if
name
is
None
:
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
...
...
@@ -356,7 +357,8 @@ def box_coder(prior_box,
helper
=
LayerHelper
(
"box_coder"
,
**
locals
())
if
name
is
None
:
output_box
=
helper
.
create_tmp_variable
(
dtype
=
prior_box
.
dtype
)
output_box
=
helper
.
create_variable_for_type_inference
(
dtype
=
prior_box
.
dtype
)
else
:
output_box
=
helper
.
create_variable
(
name
=
name
,
dtype
=
prior_box
.
dtype
,
persistable
=
False
)
...
...
@@ -387,7 +389,7 @@ def polygon_box_transform(input, name=None):
"""
helper
=
LayerHelper
(
"polygon_box_transform"
,
**
locals
())
if
name
is
None
:
output
=
helper
.
create_
tmp_variabl
e
(
dtype
=
input
.
dtype
)
output
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
input
.
dtype
)
else
:
output
=
helper
.
create_variable
(
name
=
name
,
dtype
=
prior_box
.
input
,
persistable
=
False
)
...
...
@@ -455,7 +457,7 @@ def detection_map(detect_res,
helper
=
LayerHelper
(
"detection_map"
,
**
locals
())
def
__create_var
(
type
):
return
helper
.
create_
tmp_variabl
e
(
dtype
=
type
)
return
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
type
)
map_out
=
__create_var
(
'float32'
)
accum_pos_count_out
=
out_states
[
0
]
if
out_states
else
__create_var
(
'int32'
)
...
...
@@ -562,8 +564,9 @@ def bipartite_match(dist_matrix,
>>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou)
"""
helper
=
LayerHelper
(
'bipartite_match'
,
**
locals
())
match_indices
=
helper
.
create_tmp_variable
(
dtype
=
'int32'
)
match_distance
=
helper
.
create_tmp_variable
(
dtype
=
dist_matrix
.
dtype
)
match_indices
=
helper
.
create_variable_for_type_inference
(
dtype
=
'int32'
)
match_distance
=
helper
.
create_variable_for_type_inference
(
dtype
=
dist_matrix
.
dtype
)
helper
.
append_op
(
type
=
'bipartite_match'
,
inputs
=
{
'DistMat'
:
dist_matrix
},
...
...
@@ -649,8 +652,8 @@ def target_assign(input,
gt, matched_indices, mismatch_value=0)
"""
helper
=
LayerHelper
(
'target_assign'
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
input
.
dtype
)
out_weight
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'float32'
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
input
.
dtype
)
out_weight
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'float32'
)
helper
.
append_op
(
type
=
'target_assign'
,
inputs
=
{
...
...
@@ -821,9 +824,10 @@ def ssd_loss(location,
conf_loss
=
nn
.
reshape
(
x
=
conf_loss
,
shape
=
(
num
,
num_prior
),
actual_shape
=
actual_shape
)
conf_loss
.
stop_gradient
=
True
neg_indices
=
helper
.
create_
tmp_variabl
e
(
dtype
=
'int32'
)
neg_indices
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
'int32'
)
dtype
=
matched_indices
.
dtype
updated_matched_indices
=
helper
.
create_tmp_variable
(
dtype
=
dtype
)
updated_matched_indices
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'mine_hard_examples'
,
inputs
=
{
...
...
@@ -1003,8 +1007,8 @@ def prior_box(input,
max_sizes
=
[
max_sizes
]
attrs
[
'max_sizes'
]
=
max_sizes
box
=
helper
.
create_
tmp_variabl
e
(
dtype
)
var
=
helper
.
create_
tmp_variabl
e
(
dtype
)
box
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
)
var
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
)
helper
.
append_op
(
type
=
"prior_box"
,
inputs
=
{
"Input"
:
input
,
...
...
@@ -1342,8 +1346,8 @@ def anchor_generator(input,
'offset'
:
offset
}
anchor
=
helper
.
create_
tmp_variabl
e
(
dtype
)
var
=
helper
.
create_
tmp_variabl
e
(
dtype
)
anchor
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
)
var
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
)
helper
.
append_op
(
type
=
"anchor_generator"
,
inputs
=
{
"Input"
:
input
},
...
...
@@ -1389,7 +1393,7 @@ def roi_perspective_transform(input,
"""
helper
=
LayerHelper
(
'roi_perspective_transform'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
out
=
helper
.
create_
tmp_variabl
e
(
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
)
helper
.
append_op
(
type
=
"roi_perspective_transform"
,
inputs
=
{
"X"
:
input
,
...
...
@@ -1423,11 +1427,15 @@ def generate_proposal_labels(rpn_rois,
helper
=
LayerHelper
(
'generate_proposal_labels'
,
**
locals
())
rois
=
helper
.
create_tmp_variable
(
dtype
=
rpn_rois
.
dtype
)
labels_int32
=
helper
.
create_tmp_variable
(
dtype
=
gt_classes
.
dtype
)
bbox_targets
=
helper
.
create_tmp_variable
(
dtype
=
rpn_rois
.
dtype
)
bbox_inside_weights
=
helper
.
create_tmp_variable
(
dtype
=
rpn_rois
.
dtype
)
bbox_outside_weights
=
helper
.
create_tmp_variable
(
dtype
=
rpn_rois
.
dtype
)
rois
=
helper
.
create_variable_for_type_inference
(
dtype
=
rpn_rois
.
dtype
)
labels_int32
=
helper
.
create_variable_for_type_inference
(
dtype
=
gt_classes
.
dtype
)
bbox_targets
=
helper
.
create_variable_for_type_inference
(
dtype
=
rpn_rois
.
dtype
)
bbox_inside_weights
=
helper
.
create_variable_for_type_inference
(
dtype
=
rpn_rois
.
dtype
)
bbox_outside_weights
=
helper
.
create_variable_for_type_inference
(
dtype
=
rpn_rois
.
dtype
)
helper
.
append_op
(
type
=
"generate_proposal_labels"
,
...
...
@@ -1509,8 +1517,10 @@ def generate_proposals(scores,
"""
helper
=
LayerHelper
(
'generate_proposals'
,
**
locals
())
rpn_rois
=
helper
.
create_tmp_variable
(
dtype
=
bbox_deltas
.
dtype
)
rpn_roi_probs
=
helper
.
create_tmp_variable
(
dtype
=
scores
.
dtype
)
rpn_rois
=
helper
.
create_variable_for_type_inference
(
dtype
=
bbox_deltas
.
dtype
)
rpn_roi_probs
=
helper
.
create_variable_for_type_inference
(
dtype
=
scores
.
dtype
)
helper
.
append_op
(
type
=
"generate_proposals"
,
inputs
=
{
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
1c591c39
...
...
@@ -954,7 +954,7 @@ def read_file(reader):
"""
helper
=
LayerHelper
(
'read_file'
)
out
=
[
helper
.
create_
tmp_variabl
e
(
helper
.
create_
variable_for_type_inferenc
e
(
stop_gradient
=
True
,
dtype
=
'float32'
)
for
_
in
range
(
len
(
reader
.
desc
.
shapes
()))
]
...
...
python/paddle/fluid/layers/layer_function_generator.py
浏览文件 @
1c591c39
...
...
@@ -202,10 +202,12 @@ def generate_layer_fn(op_type):
out_var
=
out
[
0
]
if
(
isinstance
(
out
,
list
)
or
isinstance
(
out
,
tuple
))
else
out
else
:
out_var
=
helper
.
create_
tmp_variabl
e
(
dtype
=
dtype
)
out_var
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
dtype
)
outputs
[
o_name
]
=
[
out_var
]
for
name
in
intermediate_output_names
:
outputs
[
name
]
=
[
helper
.
create_tmp_variable
(
dtype
=
dtype
)]
outputs
[
name
]
=
[
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
]
helper
.
append_op
(
type
=
op_type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
kwargs
)
return
helper
.
append_activation
(
out_var
)
...
...
@@ -229,7 +231,7 @@ def generate_layer_fn_noattr(op_type):
def
func
(
x
,
name
=
None
):
helper
=
LayerHelper
(
op_type
,
**
locals
())
output
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
output
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
output
})
return
output
...
...
python/paddle/fluid/layers/metric_op.py
浏览文件 @
1c591c39
...
...
@@ -58,11 +58,11 @@ def accuracy(input, label, k=1, correct=None, total=None):
"""
helper
=
LayerHelper
(
"accuracy"
,
**
locals
())
topk_out
,
topk_indices
=
nn
.
topk
(
input
,
k
=
k
)
acc_out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"float32"
)
acc_out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"float32"
)
if
correct
is
None
:
correct
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"int64"
)
correct
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"int64"
)
if
total
is
None
:
total
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"int64"
)
total
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"int64"
)
helper
.
append_op
(
type
=
"accuracy"
,
inputs
=
{
...
...
@@ -124,8 +124,8 @@ def auc(input,
auc_out=fluid.layers.auc(input=prediction, label=label)
"""
helper
=
LayerHelper
(
"auc"
,
**
locals
())
auc_out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"float64"
)
batch_auc_out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
"float64"
)
auc_out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"float64"
)
batch_auc_out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
"float64"
)
# make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
1c591c39
此差异已折叠。
点击以展开。
python/paddle/fluid/layers/tensor.py
浏览文件 @
1c591c39
...
...
@@ -152,7 +152,7 @@ def cast(x, dtype):
result = fluid.layers.cast(x=data, dtype='float64')
"""
helper
=
LayerHelper
(
'cast'
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
[
x
]},
...
...
@@ -184,7 +184,7 @@ def concat(input, axis=0, name=None):
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
"""
helper
=
LayerHelper
(
'concat'
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
helper
.
input_dtype
())
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'concat'
,
inputs
=
{
'X'
:
input
},
...
...
@@ -221,7 +221,8 @@ def sums(input, out=None):
"""
helper
=
LayerHelper
(
'sum'
,
**
locals
())
if
out
is
None
:
out
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'sum'
,
inputs
=
{
'X'
:
input
},
...
...
@@ -252,7 +253,7 @@ def assign(input, output=None):
"""
helper
=
LayerHelper
(
'assign'
,
**
locals
())
if
output
is
None
:
output
=
helper
.
create_
tmp_variabl
e
(
dtype
=
input
.
dtype
)
output
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
input
.
dtype
)
if
isinstance
(
input
,
Variable
):
helper
.
append_op
(
type
=
'assign'
,
inputs
=
{
'X'
:
[
input
]},
outputs
=
{
'Out'
:
[
output
]})
...
...
@@ -311,7 +312,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
helper
=
LayerHelper
(
"fill_constant"
,
**
locals
())
if
out
is
None
:
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
...
...
@@ -358,7 +359,7 @@ def fill_constant_batch_size_like(input,
${out_comment}.
"""
helper
=
LayerHelper
(
"fill_constant_batch_size_like"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'fill_constant_batch_size_like'
,
inputs
=
{
'Input'
:
input
},
...
...
@@ -396,7 +397,7 @@ def argmin(x, axis=0):
out = fluid.layers.argmin(x=in, axis=-1)
"""
helper
=
LayerHelper
(
"arg_min"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
VarDesc
.
VarType
.
INT64
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
VarDesc
.
VarType
.
INT64
)
helper
.
append_op
(
type
=
'arg_min'
,
inputs
=
{
'X'
:
x
},
...
...
@@ -427,7 +428,7 @@ def argmax(x, axis=0):
out = fluid.layers.argmax(x=in, axis=-1)
"""
helper
=
LayerHelper
(
"arg_max"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
VarDesc
.
VarType
.
INT64
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
VarDesc
.
VarType
.
INT64
)
helper
.
append_op
(
type
=
'arg_max'
,
inputs
=
{
'X'
:
x
},
...
...
@@ -477,8 +478,10 @@ def argsort(input, axis=-1, name=None):
out, indices = fluid.layers.argsort(input, axis=0)
"""
helper
=
LayerHelper
(
"argsort"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
,
stop_gradient
=
True
)
ids
=
helper
.
create_tmp_variable
(
VarDesc
.
VarType
.
INT64
,
stop_gradient
=
True
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
,
stop_gradient
=
True
)
ids
=
helper
.
create_variable_for_type_inference
(
VarDesc
.
VarType
.
INT64
,
stop_gradient
=
True
)
helper
.
append_op
(
type
=
'argsort'
,
inputs
=
{
'X'
:
input
},
...
...
@@ -562,7 +565,7 @@ def reverse(x, axis):
if
isinstance
(
axis
,
int
):
axis
=
[
axis
]
helper
=
LayerHelper
(
"reverse"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'reverse'
,
inputs
=
{
'Input'
:
x
},
...
...
@@ -654,7 +657,7 @@ def has_inf(x):
Variable: The tensor variable storing the output, only a bool value.
"""
helper
=
LayerHelper
(
"isinf"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"isinf"
,
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
})
return
out
...
...
@@ -670,7 +673,7 @@ def has_nan(x):
Variable: The tensor variable storing the output, only a bool value.
"""
helper
=
LayerHelper
(
"isnan"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"isnan"
,
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
})
return
out
...
...
@@ -687,6 +690,6 @@ def isfinite(x):
Variable: The tensor variable storing the output, contains a bool value.
"""
helper
=
LayerHelper
(
"isfinite"
,
**
locals
())
out
=
helper
.
create_
tmp_variabl
e
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_
variable_for_type_inferenc
e
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
"isfinite"
,
inputs
=
{
"X"
:
x
},
outputs
=
{
"Out"
:
out
})
return
out
python/paddle/fluid/regularizer.py
浏览文件 @
1c591c39
...
...
@@ -151,7 +151,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
decay
=
block
.
create_var
(
dtype
=
"float32"
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
block
.
append_op
(
type
=
'extract_rows'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
idx
})
block
.
append_op
(
...
...
@@ -228,7 +228,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
decay
=
block
.
create_var
(
dtype
=
"float32"
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
block
.
append_op
(
type
=
'extract_rows'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
idx
})
block
.
append_op
(
...
...
python/paddle/fluid/tests/unittests/test_fusion_seqconv_eltadd_relu_op.py
0 → 100644
浏览文件 @
1c591c39
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
random
from
op_test
import
OpTest
from
test_seq_conv
import
seqconv
class
TestSeqConvEltAddRelu
(
OpTest
):
def
set_conf
(
self
):
pass
def
setUp
(
self
):
self
.
op_type
=
'fusion_seqconv_eltadd_relu'
self
.
lod
=
[[
6
,
4
]]
self
.
in_fea_size
=
16
self
.
out_fea_size
=
8
self
.
context_length
=
4
self
.
context_stride
=
1
self
.
context_start
=
0
self
.
set_conf
()
assert
self
.
context_stride
==
1
T
=
sum
(
self
.
lod
[
0
])
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
T
,
self
.
in_fea_size
]).
astype
(
'float32'
)
w
=
np
.
random
.
uniform
(
-
1
,
1
,
[
self
.
in_fea_size
*
self
.
context_length
,
self
.
out_fea_size
]).
astype
(
'float32'
)
b
=
np
.
random
.
uniform
(
-
2
,
1
,
[
1
,
self
.
out_fea_size
]).
astype
(
'float32'
)
out
=
seqconv
(
x
,
self
.
lod
,
w
,
self
.
context_length
,
self
.
context_start
)
out
=
np
.
maximum
(
out
+
b
,
0
)
self
.
inputs
=
{
'X'
:
(
x
,
self
.
lod
),
'Filter'
:
w
,
'Bias'
:
b
}
self
.
attrs
=
{
'contextStart'
:
self
.
context_start
,
'contextLength'
:
self
.
context_length
,
'contextStride'
:
self
.
context_stride
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestSeqConvEltAddReluBS1
(
TestSeqConvEltAddRelu
):
def
set_conf
(
self
):
self
.
lod
=
[[
10
]]
class
TestSeqConvEltAddReluBS1Case2
(
TestSeqConvEltAddRelu
):
def
set_conf
(
self
):
self
.
lod
=
[[
2
]]
class
TestSeqConvEltAddReluCase1
(
TestSeqConvEltAddRelu
):
def
set_conf
(
self
):
self
.
lod
=
[[
3
,
5
,
1
,
6
]]
self
.
context_length
=
3
self
.
context_start
=
-
2
class
TestSeqConvEltAddReluCase2
(
TestSeqConvEltAddRelu
):
def
set_conf
(
self
):
self
.
lod
=
[[
10
,
1
,
2
,
4
,
1
,
5
,
6
]]
self
.
in_fea_size
=
2
self
.
context_length
=
4
self
.
context_start
=
-
1
class
TestSeqConvEltAddReluCase3
(
TestSeqConvEltAddRelu
):
def
set_conf
(
self
):
self
.
lod
=
[[
10
,
1
,
2
,
4
,
1
,
5
,
6
]]
self
.
context_length
=
5
self
.
context_start
=
-
4
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_seq_conv.py
浏览文件 @
1c591c39
...
...
@@ -20,6 +20,53 @@ import random
from
op_test
import
OpTest
def
seqconv
(
x
,
lod
,
filter
,
context_length
,
context_start
,
padding_trainable
=
False
,
padding_data
=
None
):
[
T
,
M
]
=
x
.
shape
col
=
np
.
zeros
((
T
,
context_length
*
M
)).
astype
(
'float32'
)
offset
=
[
0
]
for
seq_len
in
lod
[
0
]:
offset
.
append
(
offset
[
-
1
]
+
seq_len
)
begin_pad
=
np
.
max
([
0
,
-
context_start
])
for
i
in
range
(
len
(
offset
)
-
1
):
for
j
in
range
(
context_length
):
in_begin
=
offset
[
i
]
+
context_start
+
j
in_end
=
offset
[
i
+
1
]
+
context_start
+
j
out_begin
=
offset
[
i
]
out_end
=
offset
[
i
+
1
]
if
in_begin
<
offset
[
i
]:
pad_size
=
np
.
min
(
[
offset
[
i
]
-
in_begin
,
offset
[
i
+
1
]
-
offset
[
i
]])
if
padding_trainable
:
sub_w
=
padding_data
[
j
:
j
+
pad_size
,
:]
col
[
offset
[
i
]:
offset
[
i
]
+
pad_size
,
j
*
M
:(
j
+
1
)
*
M
]
=
sub_w
out_begin
=
offset
[
i
]
+
pad_size
in_begin
=
offset
[
i
]
if
in_end
>
offset
[
i
+
1
]:
pad_size
=
np
.
min
(
[
in_end
-
offset
[
i
+
1
],
offset
[
i
+
1
]
-
offset
[
i
]])
if
padding_trainable
:
sub_w
=
padding_data
[
begin_pad
+
context_start
+
j
-
pad_size
:
begin_pad
+
context_start
+
j
,
:]
col
[
offset
[
i
+
1
]
-
pad_size
:
offset
[
i
+
1
],
j
*
M
:(
j
+
1
)
*
M
]
=
sub_w
in_end
=
offset
[
i
+
1
]
out_end
=
offset
[
i
+
1
]
-
pad_size
if
in_end
<=
in_begin
:
continue
in_sub
=
x
[
in_begin
:
in_end
,
:]
col
[
out_begin
:
out_end
,
j
*
M
:(
j
+
1
)
*
M
]
+=
in_sub
return
np
.
dot
(
col
,
filter
)
class
TestSeqProject
(
OpTest
):
def
setUp
(
self
):
self
.
init_test_case
()
...
...
@@ -66,57 +113,9 @@ class TestSeqProject(OpTest):
'paddingTrainable'
:
self
.
padding_trainable
,
'contextStride'
:
self
.
context_stride
}
out
=
np
.
zeros
(
(
self
.
input_size
[
0
],
self
.
output_represention
)).
astype
(
'float32'
)
out
=
seqconv
(
x
,
self
.
lod
,
w
,
self
.
context_length
,
self
.
context_start
,
self
.
padding_trainable
,
self
.
pad_data
)
self
.
outputs
=
{
'Out'
:
out
}
self
.
compute
()
def
compute
(
self
):
x
,
lod
=
self
.
inputs
[
'X'
]
filter
=
self
.
inputs
[
'Filter'
]
pading_data
=
self
.
pad_data
out
=
np
.
zeros
((
self
.
input_size
[
0
],
self
.
context_length
*
self
.
input_size
[
1
])).
astype
(
'float32'
)
offset
=
[
0
]
for
seq_len
in
lod
[
0
]:
offset
.
append
(
offset
[
-
1
]
+
seq_len
)
begin_pad
=
np
.
max
([
0
,
-
self
.
context_start
])
for
i
in
range
(
len
(
offset
)
-
1
):
for
j
in
range
(
self
.
context_length
):
in_begin
=
offset
[
i
]
+
self
.
context_start
+
j
in_end
=
offset
[
i
+
1
]
+
self
.
context_start
+
j
out_begin
=
offset
[
i
]
out_end
=
offset
[
i
+
1
]
if
in_begin
<
offset
[
i
]:
pad_size
=
np
.
min
(
[
offset
[
i
]
-
in_begin
,
offset
[
i
+
1
]
-
offset
[
i
]])
if
self
.
padding_trainable
:
sub_w
=
pading_data
[
j
:
j
+
pad_size
,
:]
out
[
offset
[
i
]:
offset
[
i
]
+
pad_size
,
j
*
self
.
input_size
[
1
]:(
j
+
1
)
*
self
.
input_size
[
1
]]
=
sub_w
out_begin
=
offset
[
i
]
+
pad_size
in_begin
=
offset
[
i
]
if
in_end
>
offset
[
i
+
1
]:
pad_size
=
np
.
min
(
[
in_end
-
offset
[
i
+
1
],
offset
[
i
+
1
]
-
offset
[
i
]])
if
self
.
padding_trainable
:
sub_w
=
pading_data
[
begin_pad
+
self
.
context_start
+
j
-
pad_size
:
begin_pad
+
self
.
context_start
+
j
,
:]
out
[
offset
[
i
+
1
]
-
pad_size
:
offset
[
i
+
1
],
j
*
self
.
input_size
[
1
]:(
j
+
1
)
*
self
.
input_size
[
1
]]
=
sub_w
in_end
=
offset
[
i
+
1
]
out_end
=
offset
[
i
+
1
]
-
pad_size
if
in_end
<=
in_begin
:
continue
in_sub
=
x
[
in_begin
:
in_end
,
:]
out
[
out_begin
:
out_end
,
j
*
self
.
input_size
[
1
]:(
j
+
1
)
*
self
.
input_size
[
1
]]
+=
in_sub
np
.
dot
(
out
,
filter
,
out
=
self
.
outputs
[
'Out'
])
def
test_check_output
(
self
):
self
.
check_output
()
...
...
python/paddle/fluid/tests/unittests/test_slice_var.py
浏览文件 @
1c591c39
...
...
@@ -30,7 +30,6 @@ class TestSliceVar(unittest.TestCase):
var
=
program
.
global_block
().
create_var
(
name
=
str
(
random
.
randint
(
10000
,
99999
)),
persistable
=
True
,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape
=
shape
)
var_list
.
append
(
var
)
blocks
=
slice_variable
(
var_list
,
10
,
min_size
)
...
...
python/paddle/fluid/transpiler/inference_transpiler.py
浏览文件 @
1c591c39
...
...
@@ -74,7 +74,7 @@ class InferenceTranspiler(object):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_
eltwise
' attribute to convolution OP and replacing its output
'fuse_
residual_connection
' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
...
...
@@ -92,7 +92,8 @@ class InferenceTranspiler(object):
if
current_op
.
type
in
[
'conv2d'
]:
next_op
=
self
.
block
.
ops
[
i
+
1
]
if
next_op
.
type
==
'elementwise_add'
:
self
.
_fuse_conv_eltwise
(
current_op
,
next_op
)
self
.
_fuse_conv_eltwise
(
i
,
current_op
,
next_op
)
self
.
block
.
_remove_op
(
i
+
1
)
# Remove old conv
self
.
block
.
_remove_op
(
i
+
1
)
# Remove elementwise_add
i
=
i
+
1
self
.
_adjust_input
()
...
...
@@ -444,7 +445,7 @@ class InferenceTranspiler(object):
outputs
=
{
"Output"
:
out_var
},
attrs
=
attrs
)
def
_fuse_conv_eltwise
(
self
,
conv_op
,
eltwise_op
):
def
_fuse_conv_eltwise
(
self
,
index
,
conv_op
,
eltwise_op
):
'''
fuse the conv op with elementwise_add
...
...
@@ -454,9 +455,30 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator
'''
conv_op
.
_set_attr
(
"fuse_eltwise"
,
True
)
self
.
input_map
[
conv_op
.
output
(
"Output"
)[
0
]]
=
eltwise_op
.
input
(
"Y"
)[
0
]
self
.
input_map
[
eltwise_op
.
output
(
"Out"
)[
0
]]
=
eltwise_op
.
input
(
"Y"
)[
0
]
eltwise_input
=
"X"
if
eltwise_op
.
input
(
"X"
)[
0
]
==
conv_op
.
output
(
"Output"
)[
0
]:
eltwise_input
=
"Y"
residual_var
=
self
.
block
.
vars
[
eltwise_op
.
input
(
eltwise_input
)[
0
]]
out_var
=
self
.
block
.
vars
[
eltwise_op
.
output
(
"Out"
)[
0
]]
filter_var
=
self
.
block
.
vars
[
conv_op
.
input
(
"Filter"
)[
0
]]
in_var
=
self
.
block
.
vars
[
conv_op
.
input
(
"Input"
)[
0
]]
bias_var
=
self
.
block
.
vars
[
conv_op
.
input
(
"Bias"
)[
0
]]
conv_op
.
_set_attr
(
"fuse_residual_connection"
,
True
)
attrs
=
{
name
:
conv_op
.
attr
(
name
)
for
name
in
conv_op
.
attr_names
}
self
.
block
.
_insert_op
(
index
,
type
=
"conv2d"
,
inputs
=
{
"Input"
:
in_var
,
"Filter"
:
filter_var
,
"Bias"
:
bias_var
,
"ResidualData"
:
residual_var
},
outputs
=
{
"Output"
:
out_var
},
attrs
=
attrs
)
def
_adjust_input
(
self
):
for
i
in
range
(
len
(
self
.
block
.
ops
)):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录