Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d4e8d707
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4e8d707
编写于
11月 05, 2018
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'origin/develop' into fix/sequence_pad
test=develop
上级
7f3c6ea4
daed473d
变更
42
隐藏空白更改
内联
并排
Showing
42 changed file
with
1836 addition
and
120 deletion
+1836
-120
paddle/fluid/API.spec
paddle/fluid/API.spec
+4
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-2
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+11
-0
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/sequential_execution_pass.cc
paddle/fluid/framework/details/sequential_execution_pass.cc
+109
-0
paddle/fluid/framework/details/sequential_execution_pass.h
paddle/fluid/framework/details/sequential_execution_pass.h
+34
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
...e/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
+3
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
+58
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
+34
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
...e/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
+123
-0
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+3
-0
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+54
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+3
-0
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
+132
-0
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+203
-0
paddle/fluid/operators/grid_sampler_op.h
paddle/fluid/operators/grid_sampler_op.h
+322
-0
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+14
-8
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+30
-25
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+4
-4
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+6
-2
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+29
-0
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+8
-6
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+6
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+166
-21
paddle/fluid/operators/spp_op.h
paddle/fluid/operators/spp_op.h
+5
-3
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+8
-3
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+7
-0
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+18
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+174
-21
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+3
-1
python/paddle/fluid/tests/unittests/test_dist_save_load.py
python/paddle/fluid/tests/unittests/test_dist_save_load.py
+1
-0
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
+123
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+9
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
...fluid/tests/unittests/test_parallel_executor_seresnext.py
+40
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
...uid/tests/unittests/test_parallel_executor_transformer.py
+2
-0
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+27
-8
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+28
-8
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+23
-1
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
paddle/fluid/API.spec
浏览文件 @
d4e8d707
...
...
@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
...
...
@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
], varargs=None, keywords=None, defaults=(False, -100
))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
, 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False
))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
...
...
@@ -178,6 +178,7 @@ paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], var
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
d4e8d707
...
...
@@ -35,13 +35,15 @@ if(WITH_GPU)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
endif
()
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
if
(
WITH_GPU
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
sequential_execution_pass
)
else
()
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
sequential_execution_pass
)
endif
()
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
d4e8d707
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
...
...
@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
if
(
strategy_
.
enable_sequential_execution_
)
{
AppendPass
(
"sequential_execution_pass"
);
}
// Add a graph viz pass to record a graph.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
...
...
@@ -110,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"sequential_execution_pass"
)
{
pass
->
Erase
(
kAllOpDescs
);
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
...
...
@@ -125,3 +135,4 @@ USE_PASS(multi_batch_merge_pass);
USE_PASS
(
multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
sequential_execution_pass
);
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
d4e8d707
...
...
@@ -69,6 +69,8 @@ struct BuildStrategy {
bool
enable_data_balance_
{
false
};
bool
enable_sequential_execution_
{
false
};
bool
fuse_broadcast_op_
{
false
};
// User normally doesn't need to call this API.
...
...
paddle/fluid/framework/details/sequential_execution_pass.cc
0 → 100644
浏览文件 @
d4e8d707
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
bool
IsSameOpDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
std
::
unique_ptr
<
ir
::
Graph
>
SequentialExecutionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static
std
::
unordered_set
<
std
::
string
>
skip_dist_ops
{
"send"
,
"recv"
,
"send_barrier"
,
"fetch_barrier"
};
auto
&
ops
=
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
std
::
vector
<
ir
::
Node
*>
op_node_list
;
op_node_list
.
reserve
(
ops
.
size
());
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
std
::
unordered_set
<
ir
::
Node
*>
ready_ops
;
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
std
::
unordered_set
<
ir
::
Node
*>
preceding_ops
;
for
(
auto
*
in
:
node
->
inputs
)
{
PADDLE_ENFORCE
(
in
->
IsVar
(),
"Preceding Node of Op Nodes must be Var Node"
);
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
preceding_ops
.
insert
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
insert
(
node
);
}
op_deps
[
node
]
=
preceding_ops
.
size
();
if
(
preceding_ops
.
empty
())
{
ready_ops
.
insert
(
node
);
}
}
for
(
auto
*
op_desc
:
ops
)
{
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ready_ops
)
{
if
(
IsSameOpDesc
(
op_desc
,
node
->
Op
()))
{
PADDLE_ENFORCE
(
found_node
==
nullptr
,
"Found multiple op_desc in graph: %s"
,
op_desc
->
Type
());
found_node
=
node
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
found_node
,
"Cannot find op_desc in graph: %s"
,
op_desc
->
Type
());
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
.
at
(
pending_op
)
==
0
)
{
ready_ops
.
insert
(
pending_op
);
}
}
ready_ops
.
erase
(
found_node
);
if
(
skip_dist_ops
.
count
(
op_desc
->
Type
())
==
0
)
{
op_node_list
.
push_back
(
found_node
);
}
}
for
(
size_t
i
=
1
;
i
<
op_node_list
.
size
();
++
i
)
{
auto
*
dep_var
=
graph
->
CreateControlDepVar
();
op_node_list
[
i
]
->
inputs
.
push_back
(
dep_var
);
op_node_list
[
i
-
1
]
->
outputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
op_node_list
[
i
]);
dep_var
->
inputs
.
push_back
(
op_node_list
[
i
-
1
]);
VLOG
(
10
)
<<
"Add dependencies between "
<<
op_node_list
[
i
-
1
]
->
Name
()
<<
" and "
<<
op_node_list
[
i
]
->
Name
();
}
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
sequential_execution_pass
,
paddle
::
framework
::
details
::
SequentialExecutionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/sequential_execution_pass.h
0 → 100644
浏览文件 @
d4e8d707
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
d4e8d707
...
...
@@ -41,6 +41,7 @@ pass_library(conv_bn_fuse_pass inference)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
pass_library
(
conv_bias_mkldnn_fuse_pass inference
)
pass_library
(
conv_relu_mkldnn_fuse_pass inference
)
pass_library
(
conv_elementwise_add_mkldnn_fuse_pass inference
)
...
...
@@ -59,6 +60,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
if
(
WITH_MKLDNN
)
cc_test
(
test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
endif
()
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
浏览文件 @
d4e8d707
...
...
@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase {
virtual
~
ConvReLUFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
浏览文件 @
d4e8d707
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
// a->OP0->b
...
...
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
0 → 100644
浏览文件 @
d4e8d707
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std
::
unique_ptr
<
ir
::
Graph
>
DepthwiseConvMKLDNNPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"depthwise_conv_mkldnn_pass"
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
pattern
->
NewNode
(
"depthwise_conv"
)
->
assert_is_op
(
"depthwise_conv2d"
)
->
assert_op_attr
(
"use_mkldnn"
,
true
);
int
found_depthwise_conv_mkldnn_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle DepthwiseConvMKLDNN fuse"
;
GET_NODE
(
depthwise_conv
,
(
*
pattern
));
depthwise_conv
->
Op
()
->
SetType
(
"conv2d"
);
found_depthwise_conv_mkldnn_count
++
;
};
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_depthwise_conv_mkldnn_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
0 → 100644
浏览文件 @
d4e8d707
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DepthwiseConvMKLDNNPass
:
public
FusePassBase
{
public:
virtual
~
DepthwiseConvMKLDNNPass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
0 → 100644
浏览文件 @
d4e8d707
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
bool
use_mkldnn
=
false
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
op
->
SetAttr
(
"name"
,
name
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
2
]});
op
->
SetOutput
(
"Out"
,
outputs
);
}
// (a, weights, bias)->depthwise conv mkldnn->b
// (b, weights2, bias2)->depthwise conv no mkldnn->c
// (c, weights3, bias3)->conv mkldnn->d
// (d, weights3, bias3)->conv no mkldnn->e
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"weights"
,
"bias"
,
"weights2"
,
"bias2"
,
"weights3"
,
"bias3"
,
"weights4"
,
"bias4"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
if
(
v
==
"weights"
||
v
==
"bias"
||
v
==
"weights2"
||
v
==
"bias2"
||
v
==
"weights3"
||
v
==
"bias3"
||
v
==
"weights4"
||
v
==
"bias4"
)
{
var
->
SetPersistable
(
true
);
}
}
// depthwise conv with MKL-DNN
SetOp
(
&
prog
,
"depthwise_conv2d"
,
"conv1"
,
std
::
vector
<
std
::
string
>
({
"a"
,
"weights"
,
"bias"
}),
std
::
vector
<
std
::
string
>
({
"b"
}),
true
);
// depthwise conv without MKL-DNN
SetOp
(
&
prog
,
"depthwise_conv2d"
,
"conv2"
,
std
::
vector
<
std
::
string
>
({
"b"
,
"weights2"
,
"bias2"
}),
std
::
vector
<
std
::
string
>
({
"c"
}),
false
);
// conv with MKL-DNN
SetOp
(
&
prog
,
"conv2d"
,
"conv3"
,
std
::
vector
<
std
::
string
>
({
"c"
,
"weights3"
,
"bias3"
}),
std
::
vector
<
std
::
string
>
({
"d"
}),
true
);
// conv without MKL-dNN
SetOp
(
&
prog
,
"conv2d"
,
"conv4"
,
std
::
vector
<
std
::
string
>
({
"d"
,
"weights4"
,
"bias4"
}),
std
::
vector
<
std
::
string
>
({
"e"
}),
false
);
return
prog
;
}
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"depthwise_conv_mkldnn_pass"
);
struct
counters
{
int
mkldnn_depthwise_conv_nodes
;
int
other_depthwise_conv_nodes
;
int
mkldnn_conv_nodes
;
int
other_conv_nodes
;
};
counters
before
{
1
,
1
,
1
,
1
};
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
// initialize counters before loop
counters
after
{
0
,
0
,
0
,
0
};
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
Type
()
==
"conv2d"
)
{
if
(
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)))
after
.
mkldnn_conv_nodes
++
;
else
after
.
other_conv_nodes
++
;
}
else
if
(
op
->
Type
()
==
"depthwise_conv2d"
)
{
if
(
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)))
after
.
mkldnn_depthwise_conv_nodes
++
;
else
after
.
other_depthwise_conv_nodes
++
;
}
}
}
EXPECT_EQ
(
after
.
other_depthwise_conv_nodes
,
before
.
other_depthwise_conv_nodes
);
EXPECT_EQ
(
after
.
other_conv_nodes
,
before
.
other_conv_nodes
);
EXPECT_EQ
(
after
.
mkldnn_depthwise_conv_nodes
,
before
.
mkldnn_depthwise_conv_nodes
-
1
);
EXPECT_EQ
(
after
.
mkldnn_conv_nodes
,
before
.
mkldnn_conv_nodes
+
1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
depthwise_conv_mkldnn_pass
);
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
d4e8d707
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
// a->OP0->b
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
d4e8d707
...
...
@@ -23,8 +23,62 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
void
CheckProgram
(
const
ProgramDesc
&
program
)
{
std
::
map
<
int
,
bool
>
visit
;
#define _INT(role) static_cast<int>(role)
for
(
size_t
i
=
0
;
i
<
program
.
Size
();
++
i
)
{
for
(
OpDesc
*
op
:
program
.
Block
(
i
).
AllOps
())
{
// For backward compatibility, some program doesn't have role added.
if
(
!
op
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
continue
;
int
role_id
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
visit
[
role_id
]
=
true
;
switch
(
role_id
)
{
case
_INT
(
OpRole
::
kForward
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
==
visit
.
end
(),
"Cannot add forward operator before backward operator."
);
break
;
case
_INT
(
OpRole
::
kBackward
):
case
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add backward operator before optimize operator."
);
break
;
case
_INT
(
OpRole
::
kForward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
))
==
visit
.
end
(),
"Cannot add backward|loss operator before "
"forward|loss operator."
);
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add backward operator before optimize operator."
);
break
;
case
_INT
(
OpRole
::
kOptimize
):
case
_INT
(
OpRole
::
kOptimize
)
|
_INT
(
OpRole
::
kLRSched
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
(),
"Optimize operators must follow backward operator."
);
break
;
case
_INT
(
OpRole
::
kLRSched
):
case
_INT
(
OpRole
::
kDist
):
case
_INT
(
OpRole
::
kRPC
):
case
_INT
(
OpRole
::
kNotSpecified
):
break
;
default:
LOG
(
FATAL
)
<<
"Unknown operator role. Don't add new role because "
"you don't know what you are doing."
;
}
}
}
#undef _INT
}
}
// namespace
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
CheckProgram
(
program_
);
// Make the nodes id start from 0.
Node
::
ResetId
();
auto
var_nodes
=
InitFromProgram
(
program_
);
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
d4e8d707
...
...
@@ -79,6 +79,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
#ifdef PADDLE_WITH_MKLDNN
"depthwise_conv_mkldnn_pass"
,
//
"conv_bias_mkldnn_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_elementwise_add_mkldnn_fuse_pass"
,
//
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
d4e8d707
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
...
...
@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
op
->
SetOutput
(
"Xs"
,
outputs
);
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
}
TEST
(
DataFlowGraph
,
Build_IR_Graph
)
{
...
...
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
0 → 100644
浏览文件 @
d4e8d707
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
DataLayout
=
platform
::
DataLayout
;
using
ScopedSpatialTransformerDescriptor
=
platform
::
ScopedSpatialTransformerDescriptor
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
class
CUDNNGridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
int
n
=
input
->
dims
()[
0
];
int
c
=
input
->
dims
()[
1
];
int
h
=
input
->
dims
()[
2
];
int
w
=
input
->
dims
()[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
ScopedSpatialTransformerDescriptor
st_desc
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_desc
=
st_desc
.
descriptor
<
T
>
(
4
,
size
);
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
output_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerForward
(
handle
,
cudnn_st_desc
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_output_desc
,
output_data
));
}
};
template
<
typename
T
>
class
CUDNNGridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
auto
output_grad_dims
=
output_grad
->
dims
();
const
int
n
=
output_grad_dims
[
0
];
const
int
c
=
output_grad_dims
[
1
];
const
int
h
=
output_grad_dims
[
2
];
const
int
w
=
output_grad_dims
[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
ScopedSpatialTransformerDescriptor
st_dest
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_dest
=
st_dest
.
descriptor
<
T
>
(
4
,
size
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
output_grad_dims
,
ctx
.
GetPlace
());
T
*
grid_grad_data
=
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
input_grad_desc
;
ScopedTensorDescriptor
output_grad_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_input_grad_desc
=
input_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input_grad
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_grad_desc
=
output_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output_grad
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerBackward
(
handle
,
cudnn_st_dest
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_input_grad_desc
,
input_grad_data
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_output_grad_desc
,
output_grad_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
grid_grad_data
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
grid_sampler
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
double
>
);
REGISTER_OP_KERNEL
(
grid_sampler_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
double
>
);
paddle/fluid/operators/grid_sampler_op.cc
0 → 100644
浏览文件 @
d4e8d707
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
GridSampleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grid"
),
"Input(Grid) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output(Output) of GridSampleOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
4
,
"Input(X) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
.
size
()
==
4
,
"Input(Grid) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
[
3
]
==
2
,
"Input(Grid) dims[3] should be 2."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
0
],
x_dims
[
0
],
"Input(X) and Input(Grid) dims[0] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
1
],
x_dims
[
2
],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
2
],
x_dims
[
3
],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."
);
ctx
->
SetOutputDim
(
"Output"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
"Output"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]"
);
AddInput
(
"Grid"
,
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention"
);
AddOutput
(
"Output"
,
"(Tensor) Output tensor with shape [N, C, H, W]"
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default true) Only used in cudnn kernel, need install cudnn"
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC"
);
}
};
class
GridSampleOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
input_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Grid"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Grid"
),
grid_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"grid_sampler_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"Grid"
,
Input
(
"Grid"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
OutputGrad
(
"Output"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Grid"
),
InputGrad
(
"Grid"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
grid_sampler
,
ops
::
GridSampleOp
,
ops
::
GridSampleOpMaker
,
ops
::
GridSampleGradMaker
);
REGISTER_OPERATOR
(
grid_sampler_grad
,
ops
::
GridSampleOpGrad
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler_grad
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/grid_sampler_op.h
0 → 100644
浏览文件 @
d4e8d707
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Array3
=
Eigen
::
DSizes
<
int64_t
,
3
>
;
using
Array4
=
Eigen
::
DSizes
<
int64_t
,
4
>
;
template
<
typename
T
>
static
inline
bool
isInBound
(
T
x
,
T
y
,
T
x_max
,
T
y_max
)
{
if
(
x
<
0
||
x
>
x_max
||
y
<
0
||
y
>
y_max
)
{
return
false
;
}
return
true
;
}
template
<
typename
T
>
static
void
CalcGridLocations
(
const
platform
::
CPUDeviceContext
&
ctx
,
const
Tensor
&
grid
,
Tensor
*
x_w
,
Tensor
*
x_e
,
Tensor
*
y_n
,
Tensor
*
y_s
,
Tensor
*
d_w
,
Tensor
*
d_e
,
Tensor
*
d_n
,
Tensor
*
d_s
)
{
auto
&
place
=
*
ctx
.
eigen_device
();
const
int
n
=
grid
.
dims
()[
0
];
const
int
h
=
grid
.
dims
()[
1
];
const
int
w
=
grid
.
dims
()[
2
];
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor
grid_x
,
grid_y
;
T
*
grid_x_data
=
grid_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
T
*
grid_y_data
=
grid_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
const
T
*
grid_data
=
grid
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_x_data
[
i
]
=
grid_data
[
2
*
i
];
grid_y_data
[
i
]
=
grid_data
[(
2
*
i
)
+
1
];
}
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
ones_t
=
EigenTensor
<
T
,
3
>::
From
(
ones
).
setConstant
(
1.0
);
// scale grid to [0, h-1/w-1]
auto
grid_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_x
);
auto
grid_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_y
);
grid_x_t
.
device
(
place
)
=
0.5
*
((
grid_x_t
+
ones_t
)
*
x_max
);
grid_y_t
.
device
(
place
)
=
0.5
*
((
grid_y_t
+
ones_t
)
*
y_max
);
// calculate coords of 4 corner points
x_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
x_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
x_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_w
);
auto
x_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_e
);
auto
y_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_n
);
auto
y_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_s
);
x_w_t
.
device
(
place
)
=
grid_x_t
.
floor
();
x_e_t
.
device
(
place
)
=
x_w_t
+
ones_t
;
y_n_t
.
device
(
place
)
=
grid_y_t
.
floor
();
y_s_t
.
device
(
place
)
=
y_n_t
+
ones_t
;
// calculate distances to 4 sides
d_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_s
);
d_w_t
.
device
(
place
)
=
grid_x_t
-
x_w_t
;
d_e_t
.
device
(
place
)
=
x_e_t
-
grid_x_t
;
d_n_t
.
device
(
place
)
=
grid_y_t
-
y_n_t
;
d_s_t
.
device
(
place
)
=
y_s_t
-
grid_y_t
;
}
template
<
typename
T
>
static
void
GetGridPointValue
(
const
Tensor
&
input
,
Tensor
*
output
,
const
Tensor
&
x
,
const
Tensor
&
y
)
{
const
int
n
=
input
.
dims
()[
0
];
const
int
c
=
input
.
dims
()[
1
];
const
int
h
=
input
.
dims
()[
2
];
const
int
w
=
input
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
).
setConstant
((
T
)
0
);
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))));
}
}
}
}
}
}
template
<
typename
T
>
static
void
GatherOutputGradToInputGrad
(
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
const
Tensor
&
x
,
const
Tensor
&
y
,
const
Tensor
&
d1
,
const
Tensor
&
d2
)
{
const
int
n
=
output_grad
.
dims
()[
0
];
const
int
c
=
output_grad
.
dims
()[
1
];
const
int
h
=
output_grad
.
dims
()[
2
];
const
int
w
=
output_grad
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
d1_t
=
EigenTensor
<
T
,
3
>::
From
(
d1
);
auto
d2_t
=
EigenTensor
<
T
,
3
>::
From
(
d2
);
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
input_grad_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))))
+=
output_grad_t
(
i
,
j
,
k
,
l
)
*
d1_t
(
i
,
k
,
l
)
*
d2_t
(
i
,
k
,
l
);
}
}
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
// calc locations and distances of 4 corner points
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
output
,
static_cast
<
T
>
(
0
));
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
d_w_scaled_t
=
d_w_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_e_scaled_t
=
d_e_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_n_scaled_t
=
d_n_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_s_scaled_t
=
d_s_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
// bilinear interpolaetion by 4 corner points
output_t
.
device
(
place
)
=
v_wn_t
*
d_e_scaled_t
*
d_s_scaled_t
+
v_en_t
*
d_w_scaled_t
*
d_s_scaled_t
+
v_ws_t
*
d_e_scaled_t
*
d_n_scaled_t
+
v_es_t
*
d_w_scaled_t
*
d_n_scaled_t
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
input_grad
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
grid_grad
,
static_cast
<
T
>
(
0
));
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_n
,
d_e
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_s
,
d_e
,
d_n
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_n
,
d_w
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_s
,
d_w
,
d_n
);
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output_grad
);
Tensor
grid_grad_x
,
grid_grad_y
;
grid_grad_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
grid_grad_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
grid_grad_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_x
).
setConstant
(
0.0
);
auto
grid_grad_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_y
).
setConstant
(
0.0
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
grid_grad_x_t
(
i
,
k
,
l
)
+=
((
v_en_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_s_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_ws_t
(
i
,
j
,
k
,
l
))
*
d_n_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
grid_grad_y_t
(
i
,
k
,
l
)
+=
((
v_ws_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_e_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_en_t
(
i
,
j
,
k
,
l
))
*
d_w_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
}
}
}
}
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
grid_grad_x_t
=
grid_grad_x_t
*
(
x_max
/
(
T
)
2
);
grid_grad_y_t
=
grid_grad_y_t
*
(
y_max
/
(
T
)
2
);
// gather grid_grad [x, y] in 3rd Dim
T
*
grid_grad_data
=
grid_grad
->
data
<
T
>
();
T
*
grid_grad_x_data
=
grid_grad_x
.
data
<
T
>
();
T
*
grid_grad_y_data
=
grid_grad_y
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_grad_data
[
2
*
i
]
=
grid_grad_x_data
[
i
];
grid_grad_data
[
2
*
i
+
1
]
=
grid_grad_y_data
[
i
];
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/pooling.cc
浏览文件 @
d4e8d707
...
...
@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
...
...
@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
...
...
@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
output_idx
]
=
ele
;
}
...
...
@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
paddle/fluid/operators/math/pooling.cu
浏览文件 @
d4e8d707
...
...
@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
bool
exclusive
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
...
...
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
PoolProcess
pool_process
,
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
...
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
output_data
);
stride_width
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
}
};
...
...
@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
pool_process
,
exclusive
,
input_grad_data
);
}
};
...
...
@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
double
>
;
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
...
...
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
...
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
}
};
...
...
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
padding_width
,
pool_process
,
exclusive
,
input_grad_data
);
}
};
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
d4e8d707
...
...
@@ -89,7 +89,7 @@ class Pool2dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
...
@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -123,7 +123,7 @@ class Pool3dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
...
@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
d4e8d707
...
...
@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
...
@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
if
(
pooling_type
==
"max"
)
{
pooling_mode
=
PoolingMode
::
kMaximum
;
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
...
@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
Tensor
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
...
@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
pooling_mode
=
PoolingMode
::
kMaximum
;
}
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
d4e8d707
...
...
@@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() {
"operator."
"If global_pooling = true, paddings and ksize will be ignored."
)
.
SetDefault
({
0
,
0
});
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
...
...
@@ -236,6 +242,23 @@ Example:
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
For exclusive = true:
$$
hstart = i * strides[0] - paddings[0]
hend = hstart + ksize[0]
wstart = j * strides[1] - paddings[1]
wend = wstart + ksize[1]
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]}
$$
For exclusive = false:
$$
hstart = max(0, i * strides[0] - paddings[0])
hend = min(H, hstart + ksize[0])
wstart = max(0, j * strides[1] - paddings[1])
wend = min(W, wstart + ksize[1])
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
)DOC"
);
}
...
...
@@ -283,6 +306,12 @@ void Pool3dOpMaker::Make() {
"If global_pooling = true, ksize and paddings will be ignored."
)
.
SetDefault
({
0
,
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"use_cudnn"
,
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
d4e8d707
...
...
@@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
paddings
[
i
]
=
0
;
...
...
@@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
...
...
@@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
break
;
case
3
:
{
...
...
@@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel<T> {
pool3d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
@@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
...
...
@@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool2d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool2d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
break
;
case
3
:
{
...
...
@@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool3d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool3d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
d4e8d707
...
...
@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"numeric_stable_mode"
,
"(bool, default: false), A flag to indicate whether to use more "
"numerically stable algorithm. This flag is only valid when "
"soft_label is false and GPU is used."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"ignore_index"
,
"(int, default -100), Specifies a target value that is ignored and"
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
d4e8d707
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -117,8 +118,8 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
static
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -141,9 +142,10 @@ __global__ void RowReductionForMax(const T* logits_data, T* max_data,
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
static
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -153,24 +155,34 @@ __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
while
(
beg_
idx
<
end_idx
)
{
softmax
[
beg_idx
]
=
logits_data
[
beg_
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
beg_
idx
]);
beg_
idx
+=
BlockDim
;
auto
idx
=
beg_idx
+
BlockDim
;
while
(
idx
<
end_idx
)
{
softmax
[
idx
]
=
logits_data
[
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
idx
]);
idx
+=
BlockDim
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
real_log
(
diff_max_sum
);
if
(
!
CalculateLogSoftmax
)
return
;
__syncthreads
();
diff_max_sum
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
}
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -194,11 +206,134 @@ __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
}
template
<
typename
T
>
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
struct
HardLabelSoftmaxWithCrossEntropyFunctor
{
public:
HardLabelSoftmaxWithCrossEntropyFunctor
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
])
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
};
template
<
typename
T
>
struct
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
{
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
,
int
ignore_idx
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
),
ignore_idx_
(
ignore_idx
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
]
||
col_idx
==
ignore_idx_
)
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
int
ignore_idx_
;
};
template
<
typename
T
>
static
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
batch_size
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
HardLabelSoftmaxWithCrossEntropy
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
const
int64_t
*
labels_data
,
T
*
loss_data
,
T
*
softmax_data
,
int
batch_size
,
int
feature_size
,
int
ignore_idx
)
{
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
feature_size
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
feature_size
)));
auto
stream
=
ctx
.
stream
();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size, \
ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
} \
} break
switch
(
block_dim
)
{
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
256
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
128
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
64
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
32
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
16
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
batch_size
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
),
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
...
...
@@ -237,7 +372,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
,
stream
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
)
,
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
...
...
@@ -272,11 +407,21 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
batch_size
,
feature_size
,
context
.
cuda_device_context
().
stream
());
}
else
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
if
(
!
context
.
Attr
<
bool
>
(
"numeric_stable_mode"
))
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
}
else
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
auto
*
logits_data
=
logits
->
data
<
T
>
();
auto
*
labels_data
=
labels
->
data
<
int64_t
>
();
HardLabelSoftmaxWithCrossEntropy
<
T
>
(
context
.
cuda_device_context
(),
logits_data
,
labels_data
,
loss_data
,
softmax_data
,
batch_size
,
feature_size
,
ignore_index
);
}
}
}
};
...
...
paddle/fluid/operators/spp_op.h
浏览文件 @
d4e8d707
...
...
@@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel<T> {
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
MaxPool
<
T
>
,
T
>
pool_forward
;
math
::
MaxPool
<
T
>
max_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
max_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
max_process
,
true
,
&
out_level
);
}
else
if
(
pooling_type
==
"avg"
)
{
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
AvgPool
<
T
>
,
T
>
pool_forward
;
math
::
AvgPool
<
T
>
avg_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
avg_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
avg_process
,
true
,
&
out_level
);
}
// flatten pooling output shape
int
output_flatten_w
=
in_x
->
dims
()[
1
]
*
bins
*
bins
;
...
...
@@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
math
::
AvgPoolGrad
<
T
>
avg_process
;
pool_backward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
*&
out_level
,
*&
outgrad_level
,
kernel_size
,
strides
,
paddings
,
avg_process
,
in_x_grad
);
paddings
,
avg_process
,
true
,
in_x_grad
);
}
}
}
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
d4e8d707
...
...
@@ -76,8 +76,9 @@ enum class DataLayout { // Not use
enum
class
PoolingMode
{
kMaximum
,
kAverage
,
kMaximumDeterministic
,
kAverageExclusive
,
kAverageInclusive
,
};
#if CUDNN_VERSION < 6000
...
...
@@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
...
...
@@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX_DETERMINISTIC
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
d4e8d707
...
...
@@ -821,6 +821,13 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_data_balance_
=
b
;
})
// FIXME(chengudo): enable_data_balance seems not important
.
def_property
(
"enable_sequential_execution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_sequential_execution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_sequential_execution_
=
b
;
})
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
d4e8d707
...
...
@@ -1586,8 +1586,7 @@ class DynamicRNN(object):
self
.
lod_rank_table
=
None
self
.
max_seq_len
=
None
self
.
step_idx
=
None
self
.
zero_idx
=
fill_constant
(
shape
=
[
1
],
value
=
0
,
dtype
=
'int64'
,
force_cpu
=
True
)
self
.
zero_idx
=
None
self
.
mem_dict
=
dict
()
self
.
output_array
=
[]
self
.
outputs
=
[]
...
...
@@ -1792,6 +1791,7 @@ class DynamicRNN(object):
"""
self
.
_assert_in_rnn_block_
(
'memory'
)
self
.
_init_zero_idx_
()
if
init
is
not
None
:
if
not
isinstance
(
init
,
Variable
):
raise
TypeError
(
...
...
@@ -1905,6 +1905,22 @@ class DynamicRNN(object):
array_write
(
x
=
each
,
i
=
self
.
step_idx
,
array
=
outside_array
)
self
.
output_array
.
append
(
outside_array
)
def
_init_zero_idx_
(
self
):
if
self
.
zero_idx
is
None
:
parent_block
=
self
.
_parent_block_
()
self
.
zero_idx
=
parent_block
.
create_var
(
name
=
unique_name
.
generate
(
'zero_idx'
),
dtype
=
'int64'
)
parent_block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
self
.
zero_idx
]},
attrs
=
{
'shape'
:
[
1
],
'dtype'
:
self
.
zero_idx
.
dtype
,
'value'
:
float
(
0
),
'force_cpu'
:
True
})
def
_parent_block_
(
self
):
prog
=
self
.
helper
.
main_program
parent_idx
=
prog
.
current_block
().
parent_idx
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
d4e8d707
...
...
@@ -158,6 +158,7 @@ __all__ = [
'sequence_reverse'
,
'affine_channel'
,
'hash'
,
'grid_sampler'
,
'log_loss'
,
'add_position_encoding'
,
]
...
...
@@ -2101,7 +2102,8 @@ def pool2d(input,
global_pooling
=
False
,
use_cudnn
=
True
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
${comment}
...
...
@@ -2115,11 +2117,13 @@ def pool2d(input,
pool_type: ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
global_pooling: ${global_pooling_comment}
use_cudnn: ${use_cudnn_comment}
ceil_mode: ${ceil_mode_comment}
global_pooling
(bool)
: ${global_pooling_comment}
use_cudnn
(bool)
: ${use_cudnn_comment}
ceil_mode
(bool)
: ${ceil_mode_comment}
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: The pooling result.
...
...
@@ -2177,7 +2181,8 @@ def pool2d(input,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
return
pool_out
...
...
@@ -2191,7 +2196,8 @@ def pool3d(input,
global_pooling
=
False
,
use_cudnn
=
True
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
This function adds the operator for pooling in 3-dimensions, using the
pooling configurations mentioned in input parameters.
...
...
@@ -2207,6 +2213,8 @@ def pool3d(input,
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: output of pool3d layer.
...
...
@@ -2245,7 +2253,8 @@ def pool3d(input,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
return
pool_out
...
...
@@ -4713,7 +4722,8 @@ def multiplex(inputs, index):
def
softmax_with_cross_entropy
(
logits
,
label
,
soft_label
=
False
,
ignore_index
=-
100
):
ignore_index
=-
100
,
numeric_stable_mode
=
False
):
"""
**Softmax With Cross Entropy Operator.**
...
...
@@ -4747,6 +4757,18 @@ def softmax_with_cross_entropy(logits,
\\
left(
\\
text{logit}_i -
\\
log
\\
left(
\\
sum_{i=0}^{K}
\\
exp(
\\
text{logit}_i)
\\
right)
\\
right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by:
.. math::
max_j =
\\
max_{i=0}^{K}{
\\
text{logit}_i}
log
\\
_max
\\
_sum_j =
\\
log
\\
sum_{i=0}^{K}
\\
exp(logit_i - max_j)
softmax_j =
\\
exp(logit_j - max_j - {log
\\
_max
\\
_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
...
...
@@ -4758,6 +4780,13 @@ def softmax_with_cross_entropy(logits,
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: False
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
...
...
@@ -4780,8 +4809,11 @@ def softmax_with_cross_entropy(logits,
'Label'
:
label
},
outputs
=
{
'Softmax'
:
softmax
,
'Loss'
:
loss
},
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
})
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
,
'numeric_stable_mode'
:
numeric_stable_mode
})
return
loss
...
...
@@ -7712,19 +7744,59 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
def
hash
(
input
,
hash_size
,
num_hash
=
1
,
name
=
None
):
"""
hash the input
Args:
input (Variable): The input variable which is a one-hot word.
hash_size (int): The space size for hash algorithm.
Hash the input to an integer whose value is less than the given hash size.
The hash algorithm we used was xxHash - Extremely fast hash algorithm
(https://github.com/Cyan4973/xxHash/tree/v0.6.5)
A simple example as below:
.. code-block:: text
Given:
# shape [2, 2]
input.data = [
[[1], [2]],
[[3], [4]],
]
input.lod = [[0, 2]]
hash_size = 10000
num_hash = 4
Then:
Hash op will take all number in input's 2nd dimension as hash algorithm's
input for each time. Each input will be hashed for 4 times, and get an
array whose length is 4. Each value in the array ranges from 0 to 9999.
# shape [2, 4]
output.data = [
[[9662], [9217], [1129], [8487]],
[[8310], [1327], [1654], [4567]],
]
output.lod = [[0, 2]]
Args:
input (Variable): The input variable which is a one-hot word. The
dimensions of the input variable must be 2.
hash_size (int): The space size for hash algorithm. The output value
will keep in the range:math:`[0, hash_size - 1]`.
num_hash (int): The times of hash, default 1.
name (str, default None): The name of this layer.
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, len(word_dict))
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, num_hash=4, hash_size=1000)
"""
helper
=
LayerHelper
(
'hash'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
...
...
@@ -7738,6 +7810,87 @@ def hash(input, hash_size, num_hash=1, name=None):
return
out
@
templatedoc
()
def
grid_sampler
(
x
,
grid
,
name
=
None
):
"""
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper
=
LayerHelper
(
"grid_sampler"
,
**
locals
())
if
not
isinstance
(
x
,
Variable
):
return
ValueError
(
"The x should be a Variable"
)
if
not
isinstance
(
grid
,
Variable
):
return
ValueError
(
"The grid should be a Variable"
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
ipts
=
{
'X'
:
x
,
'Grid'
:
grid
}
helper
.
append_op
(
type
=
'grid_sampler'
,
inputs
=
ipts
,
outputs
=
{
'Output'
:
out
})
return
out
def
log_loss
(
input
,
label
,
epsilon
=
1e-4
,
name
=
None
):
"""
**Negative Log Loss Layer**
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
d4e8d707
...
...
@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce
=
False
,
fuse_elewise_add_act_ops
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
use_fast_executor
=
False
):
use_fast_executor
=
False
,
enable_sequential_execution
=
False
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
...
...
@@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
...
...
python/paddle/fluid/tests/unittests/test_dist_save_load.py
浏览文件 @
d4e8d707
...
...
@@ -72,6 +72,7 @@ class TestDistSaveLoadDense2x2(TestDistBase):
self
.
assertAlmostEqual
(
local_np
.
all
(),
train1_np
.
all
(),
delta
=
delta
)
self
.
assertAlmostEqual
(
train0_np
.
all
(),
train1_np
.
all
(),
delta
=
delta
)
@
unittest
.
skip
(
reason
=
"CI fail"
)
def
test_dist
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
...
...
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
0 → 100644
浏览文件 @
d4e8d707
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
AffineGrid
(
theta
,
size
):
n
=
size
[
0
]
h
=
size
[
2
]
w
=
size
[
3
]
h_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
h
)[
np
.
newaxis
,
:],
w
,
axis
=
0
).
T
[:,
:,
np
.
newaxis
]
w_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
w
)[
np
.
newaxis
,
:],
h
,
axis
=
0
)[:,
:,
np
.
newaxis
]
grid
=
np
.
concatenate
(
[
w_idx
,
h_idx
,
np
.
ones
([
h
,
w
,
1
])],
axis
=
2
)
# h * w * 3
grid
=
np
.
repeat
(
grid
[
np
.
newaxis
,
:],
size
[
0
],
axis
=
0
)
# n * h * w *3
ret
=
np
.
zeros
([
n
,
h
*
w
,
2
])
theta
=
theta
.
transpose
([
0
,
2
,
1
])
for
i
in
range
(
len
(
theta
)):
ret
[
i
]
=
np
.
dot
(
grid
[
i
].
reshape
([
h
*
w
,
3
]),
theta
[
i
])
return
ret
.
reshape
([
n
,
h
,
w
,
2
]).
astype
(
"float32"
)
def
getGridPointValue
(
data
,
x
,
y
):
data_shape
=
data
.
shape
N
=
data_shape
[
0
]
H
=
data_shape
[
2
]
W
=
data_shape
[
3
]
out
=
np
.
zeros
(
data_shape
,
dtype
=
'float'
)
for
i
in
range
(
N
):
for
j
in
range
(
H
):
for
k
in
range
(
W
):
if
y
[
i
,
j
,
k
]
<
0
or
y
[
i
,
j
,
k
]
>
H
-
1
or
x
[
i
,
j
,
k
]
<
0
or
x
[
i
,
j
,
k
]
>
W
-
1
:
out
[
i
,
:,
j
,
k
]
=
0
else
:
out
[
i
,
:,
j
,
k
]
=
data
[
i
,
:,
y
[
i
,
j
,
k
],
x
[
i
,
j
,
k
]]
return
out
def
GridSampler
(
data
,
grid
):
dims
=
data
.
shape
N
=
dims
[
0
]
C
=
dims
[
1
]
H
=
dims
[
2
]
W
=
dims
[
3
]
x
=
grid
[:,
:,
:,
0
]
y
=
grid
[:,
:,
:,
1
]
y_max
=
H
-
1
x_max
=
W
-
1
x
=
0.5
*
((
x
.
astype
(
'float32'
)
+
1.0
)
*
x_max
)
y
=
0.5
*
((
y
.
astype
(
'float32'
)
+
1.0
)
*
y_max
)
x0
=
np
.
floor
(
x
).
astype
(
'int32'
)
x1
=
x0
+
1
y0
=
np
.
floor
(
y
).
astype
(
'int32'
)
y1
=
y0
+
1
wa
=
np
.
tile
(((
x1
-
x
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wb
=
np
.
tile
(((
x1
-
x
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wc
=
np
.
tile
(((
x
-
x0
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wd
=
np
.
tile
(((
x
-
x0
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
va
=
getGridPointValue
(
data
,
x0
,
y0
)
vb
=
getGridPointValue
(
data
,
x0
,
y1
)
vc
=
getGridPointValue
(
data
,
x1
,
y0
)
vd
=
getGridPointValue
(
data
,
x1
,
y1
)
out
=
(
wa
*
va
+
wb
*
vb
+
wc
*
vc
+
wd
*
vd
).
astype
(
'float32'
)
return
out
class
TestGridSamplerOp
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
'grid_sampler'
x
=
np
.
random
.
randint
(
0
,
255
,
self
.
x_shape
).
astype
(
'float32'
)
theta
=
np
.
zeros
(
self
.
theta_shape
).
astype
(
'float32'
)
for
i
in
range
(
self
.
theta_shape
[
0
]):
for
j
in
range
(
2
):
for
k
in
range
(
3
):
theta
[
i
,
j
,
k
]
=
np
.
random
.
rand
(
1
)[
0
]
grid
=
AffineGrid
(
theta
,
self
.
x_shape
)
self
.
inputs
=
{
'X'
:
x
,
'Grid'
:
grid
}
self
.
attrs
=
{
'use_cudnn'
:
True
}
self
.
outputs
=
{
'Output'
:
GridSampler
(
x
,
grid
)}
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-3
)
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Grid'
],
'Output'
,
max_relative_error
=
0.61
)
def
initTestCase
(
self
):
self
.
x_shape
=
(
2
,
5
,
7
,
3
)
self
.
grid_shape
=
(
2
,
7
,
3
,
2
)
self
.
theta_shape
=
(
2
,
2
,
3
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
d4e8d707
...
...
@@ -865,6 +865,15 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_grid_sampler
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
3
,
5
,
7
],
dtype
=
'float32'
)
grid
=
layers
.
data
(
name
=
'grid'
,
shape
=
[
5
,
7
,
2
],
dtype
=
'float32'
)
out
=
layers
.
grid_sampler
(
x
,
grid
)
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_affine_grid
(
self
):
program
=
Program
()
with
program_guard
(
program
):
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
浏览文件 @
d4e8d707
...
...
@@ -232,6 +232,46 @@ class TestResnet(TestParallelExecutorBase):
for
loss
in
zip
(
all_reduce_last_loss
,
reduce_last_loss
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
if
not
use_cuda
:
return
all_reduce_first_loss_seq
,
all_reduce_last_loss_seq
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
iter
=
iter
,
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
False
,
optimizer
=
optimizer
,
enable_sequential_execution
=
True
)
reduce_first_loss_seq
,
reduce_last_loss_seq
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
iter
=
iter
,
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
True
,
optimizer
=
optimizer
,
enable_sequential_execution
=
True
)
for
loss
in
zip
(
all_reduce_first_loss
,
all_reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
all_reduce_last_loss
,
all_reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
reduce_first_loss
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
reduce_last_loss
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
all_reduce_first_loss_seq
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
all_reduce_last_loss_seq
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
def
_check_resnet_convergence
(
self
,
model
,
use_cuda
=
True
,
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
浏览文件 @
d4e8d707
...
...
@@ -173,6 +173,8 @@ class TestTransformer(TestParallelExecutorBase):
def
test_main
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
enable_sequential_execution
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
,
iter
=
5
)
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
d4e8d707
...
...
@@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x,
c_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
r_start
:
r_end
,
c_start
:
c_end
]
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
(
(
r_end
-
r_start
)
*
(
c_end
-
c_start
))
field_size
=
((
r_end
-
r_start
)
*
(
c_end
-
c_start
))
if
exclusive
\
else
(
ksize
[
0
]
*
ksize
[
1
])
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
field_size
return
out
...
...
@@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool2D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
output
=
self
.
pool2D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
...
...
@@ -106,7 +110,9 @@ class TestPool2d_Op(OpTest):
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
self
.
outputs
=
{
'Out'
:
output
}
...
...
@@ -150,6 +156,9 @@ class TestPool2d_Op(OpTest):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool2d_Op
):
def
init_test_case
(
self
):
...
...
@@ -322,5 +331,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
d4e8d707
...
...
@@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
...
...
@@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
...
...
@@ -85,8 +87,10 @@ def avg_pool3D_forward_naive(x,
w_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
d_start
:
d_end
,
h_start
:
h_end
,
w_start
:
w_end
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
(
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
))
field_size
=
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
)
\
if
exclusive
else
ksize
[
0
]
*
ksize
[
1
]
*
ksize
[
2
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
field_size
return
out
...
...
@@ -100,13 +104,14 @@ class TestPool3d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool3D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
output
=
self
.
pool3D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
...
...
@@ -117,7 +122,9 @@ class TestPool3d_Op(OpTest):
'global_pooling'
:
self
.
global_pool
,
'use_cudnn'
:
self
.
use_cudnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
self
.
outputs
=
{
'Out'
:
output
}
...
...
@@ -161,6 +168,9 @@ class TestPool3d_Op(OpTest):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool3d_Op
):
def
init_test_case
(
self
):
...
...
@@ -333,5 +343,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
d4e8d707
...
...
@@ -26,7 +26,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -46,6 +50,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -54,6 +59,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOpNoCudnn
(
TestSoftmaxWithCrossEntropyOp
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
class
TestSoftmaxWithCrossEntropyOp2
(
OpTest
):
"""
Test softmax with cross entropy operator with soft labels.
...
...
@@ -93,7 +103,11 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
Test softmax with cross entropy operator with ignore_index.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -114,7 +128,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
,
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -123,5 +140,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOp3NoCudnn
(
TestSoftmaxWithCrossEntropyOp3
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
d4e8d707
...
...
@@ -27,7 +27,7 @@ def _get_version_detail(idx):
if re.match('@TAG_VERSION_REGEX@', '@PADDLE_VERSION@'):
version_details = '@PADDLE_VERSION@'.split('.')
if len(version_details)
=
= 3:
if len(version_details)
>
= 3:
return version_details[idx]
return 0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录