Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fe4cd502
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fe4cd502
编写于
11月 06, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into optimize-thread-pool
test=develop
上级
ac415c00
d6a6a130
变更
120
隐藏空白更改
内联
并排
Showing
120 changed file
with
4363 addition
and
470 deletion
+4363
-470
CMakeLists.txt
CMakeLists.txt
+0
-1
cmake/configure.cmake
cmake/configure.cmake
+1
-5
cmake/simd.cmake
cmake/simd.cmake
+3
-1
paddle/fluid/API.spec
paddle/fluid/API.spec
+5
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-2
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+11
-0
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+5
-8
paddle/fluid/framework/details/sequential_execution_pass.cc
paddle/fluid/framework/details/sequential_execution_pass.cc
+109
-0
paddle/fluid/framework/details/sequential_execution_pass.h
paddle/fluid/framework/details/sequential_execution_pass.h
+34
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+3
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
+2
-1
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
...e/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
+3
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
+58
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
+34
-0
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
...e/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
+123
-0
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
+3
-0
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+59
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+17
-4
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+1
-1
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+6
-0
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+3
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+11
-16
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+1
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+3
-0
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+10
-32
paddle/fluid/inference/api/api_impl_tester.cc
paddle/fluid/inference/api/api_impl_tester.cc
+8
-6
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
...luid/inference/api/api_tensorrt_subgraph_engine_tester.cc
+2
-2
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+1
-1
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
+2
-6
paddle/fluid/inference/test.cmake
paddle/fluid/inference/test.cmake
+31
-0
paddle/fluid/inference/tests/api/CMakeLists.txt
paddle/fluid/inference/tests/api/CMakeLists.txt
+0
-14
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
+112
-0
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+231
-0
paddle/fluid/operators/affine_grid_op.h
paddle/fluid/operators/affine_grid_op.h
+174
-0
paddle/fluid/operators/checkpoint_notify_op.cc
paddle/fluid/operators/checkpoint_notify_op.cc
+3
-1
paddle/fluid/operators/conv_mkldnn_op.cc
paddle/fluid/operators/conv_mkldnn_op.cc
+52
-12
paddle/fluid/operators/delete_var_op.cc
paddle/fluid/operators/delete_var_op.cc
+7
-1
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+6
-2
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+5
-3
paddle/fluid/operators/distributed/grpc_serde.h
paddle/fluid/operators/distributed/grpc_serde.h
+3
-2
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+9
-4
paddle/fluid/operators/distributed/grpc_variable_response.cc
paddle/fluid/operators/distributed/grpc_variable_response.cc
+8
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+1
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+17
-0
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+18
-2
paddle/fluid/operators/distributed/rpc_client.cc
paddle/fluid/operators/distributed/rpc_client.cc
+1
-0
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+6
-3
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+2
-2
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+1
-0
paddle/fluid/operators/distributed/variable_response.h
paddle/fluid/operators/distributed/variable_response.h
+2
-0
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+3
-1
paddle/fluid/operators/gen_nccl_id_op.cc
paddle/fluid/operators/gen_nccl_id_op.cc
+1
-1
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
+132
-0
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+203
-0
paddle/fluid/operators/grid_sampler_op.h
paddle/fluid/operators/grid_sampler_op.h
+322
-0
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+30
-15
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+12
-0
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+2
-2
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+73
-0
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+60
-0
paddle/fluid/operators/math/jit_gen.cc
paddle/fluid/operators/math/jit_gen.cc
+90
-0
paddle/fluid/operators/math/jit_gen.h
paddle/fluid/operators/math/jit_gen.h
+80
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+2
-1
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+68
-54
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+1
-1
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+3
-3
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+93
-32
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+20
-20
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+7
-7
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+14
-8
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+30
-25
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+4
-4
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+6
-2
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+29
-0
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+8
-6
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+3
-1
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+13
-0
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+3
-1
paddle/fluid/operators/ref_by_trainer_id_op.cc
paddle/fluid/operators/ref_by_trainer_id_op.cc
+79
-0
paddle/fluid/operators/ref_by_trainer_id_op.cu.cc
paddle/fluid/operators/ref_by_trainer_id_op.cu.cc
+26
-0
paddle/fluid/operators/ref_by_trainer_id_op.h
paddle/fluid/operators/ref_by_trainer_id_op.h
+49
-0
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+3
-1
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+3
-1
paddle/fluid/operators/sign_op.cc
paddle/fluid/operators/sign_op.cc
+2
-1
paddle/fluid/operators/sign_op.cu
paddle/fluid/operators/sign_op.cu
+5
-1
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+6
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+166
-21
paddle/fluid/operators/spp_op.h
paddle/fluid/operators/spp_op.h
+5
-3
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+1
-1
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+30
-3
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+45
-38
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+37
-9
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+23
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+7
-0
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+0
-2
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+6
-2
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+18
-2
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+1
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+339
-35
python/paddle/fluid/tests/unittests/dist_save_load.py
python/paddle/fluid/tests/unittests/dist_save_load.py
+174
-0
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+3
-1
python/paddle/fluid/tests/unittests/test_affine_grid_op.py
python/paddle/fluid/tests/unittests/test_affine_grid_op.py
+79
-0
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+11
-5
python/paddle/fluid/tests/unittests/test_dist_mnist.py
python/paddle/fluid/tests/unittests/test_dist_mnist.py
+9
-0
python/paddle/fluid/tests/unittests/test_dist_save_load.py
python/paddle/fluid/tests/unittests/test_dist_save_load.py
+90
-0
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
+123
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+25
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
...fluid/tests/unittests/test_parallel_executor_seresnext.py
+40
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
...uid/tests/unittests/test_parallel_executor_transformer.py
+2
-0
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+27
-8
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+28
-8
python/paddle/fluid/tests/unittests/test_py_reader_lod_level_share.py
...e/fluid/tests/unittests/test_py_reader_lod_level_share.py
+43
-0
python/paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
...paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
+130
-0
python/paddle/fluid/tests/unittests/test_ref_by_trainer_id_op.py
...paddle/fluid/tests/unittests/test_ref_by_trainer_id_op.py
+36
-0
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+23
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+113
-6
python/paddle/fluid/transpiler/inference_transpiler.py
python/paddle/fluid/transpiler/inference_transpiler.py
+28
-0
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -62,7 +62,6 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option
(
USE_EIGEN_FOR_BLAS
"Use matrix multiplication in Eigen"
OFF
)
option
(
EIGEN_USE_THREADS
"Compile with multi-threaded Eigen"
OFF
)
option
(
WITH_ARM_FP16
"Use half precision support on armv8.2-a cpu"
OFF
)
option
(
WITH_FAST_BUNDLE_TEST
"Bundle tests that can be run in a single process together to reduce launch overhead"
OFF
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
REPLACE_ENFORCE_GLOG
"Replace PADDLE_ENFORCE with glog/CHECK for better debug."
OFF
)
option
(
WITH_ANAKIN
"Compile with Anakin library"
OFF
)
...
...
cmake/configure.cmake
浏览文件 @
fe4cd502
...
...
@@ -50,11 +50,7 @@ if(NOT WITH_PROFILER)
endif
(
NOT WITH_PROFILER
)
if
(
NOT CMAKE_CROSSCOMPILING
)
if
(
WITH_AVX AND AVX512F_FOUND
)
set
(
SIMD_FLAG
${
AVX512F_FLAG
}
)
elseif
(
WITH_AVX AND AVX2_FOUND
)
set
(
SIMD_FLAG
${
AVX2_FLAG
}
)
elseif
(
WITH_AVX AND AVX_FOUND
)
if
(
WITH_AVX AND AVX_FOUND
)
set
(
SIMD_FLAG
${
AVX_FLAG
}
)
elseif
(
SSE3_FOUND
)
set
(
SIMD_FLAG
${
SSE3_FLAG
}
)
...
...
cmake/simd.cmake
浏览文件 @
fe4cd502
...
...
@@ -89,7 +89,9 @@ CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h>
int main()
{
__m512i a = _mm512_undefined_epi32();
__m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4,
13, -5, 6, -7, 9, 2, -6, 3);
__m512i result = _mm512_abs_epi32 (a);
return 0;
}"
AVX512F_FOUND
)
...
...
paddle/fluid/API.spec
浏览文件 @
fe4cd502
...
...
@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
...
...
@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
], varargs=None, keywords=None, defaults=(False, -100
))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'
, 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False
))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
...
...
@@ -174,9 +174,11 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -35,13 +35,15 @@ if(WITH_GPU)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
endif
()
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle
)
if
(
WITH_GPU
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
sequential_execution_pass
)
else
()
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
sequential_execution_pass
)
endif
()
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
fe4cd502
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
...
...
@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit
ParallelExecutorPassBuilder
(
const
BuildStrategy
&
strategy
)
:
ir
::
PassBuilder
(),
strategy_
(
strategy
)
{
if
(
strategy_
.
enable_sequential_execution_
)
{
AppendPass
(
"sequential_execution_pass"
);
}
// Add a graph viz pass to record a graph.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
AppendPass
(
"graph_viz_pass"
);
...
...
@@ -110,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
}
else
if
(
pass
->
Type
()
==
"sequential_execution_pass"
)
{
pass
->
Erase
(
kAllOpDescs
);
pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
}
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
...
...
@@ -125,3 +135,4 @@ USE_PASS(multi_batch_merge_pass);
USE_PASS
(
multi_devices_pass
);
USE_PASS
(
multi_devices_check_pass
);
USE_PASS
(
multi_devices_print_pass
);
USE_PASS
(
sequential_execution_pass
);
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
fe4cd502
...
...
@@ -69,6 +69,8 @@ struct BuildStrategy {
bool
enable_data_balance_
{
false
};
bool
enable_sequential_execution_
{
false
};
bool
fuse_broadcast_op_
{
false
};
// User normally doesn't need to call this API.
...
...
paddle/fluid/framework/details/rpc_op_handle.cc
浏览文件 @
fe4cd502
...
...
@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_
(
place
)
{}
void
RPCOpHandle
::
RunImpl
()
{
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if
(
ir
::
IsControlDepVar
(
*
in
->
Node
()))
{
// HACK
if
(
ir
::
IsControlDepVar
(
*
in
->
Node
()))
{
continue
;
}
if
(
in
->
GeneratedOp
())
{
in
->
GeneratedOp
()
->
RecordWaitEventOnCtx
(
dev_ctxes_
.
at
(
p
));
}
}
auto
&
tmp_scope
=
local_scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_
->
Run
(
*
tmp_scope
,
place_
);
this
->
RunAndRecordEvent
([
this
]
{
op_
->
Run
(
*
local_scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
}
);
}
std
::
string
RPCOpHandle
::
Name
()
const
{
return
name_
;
}
...
...
paddle/fluid/framework/details/sequential_execution_pass.cc
0 → 100644
浏览文件 @
fe4cd502
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
bool
IsSameOpDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
std
::
unique_ptr
<
ir
::
Graph
>
SequentialExecutionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static
std
::
unordered_set
<
std
::
string
>
skip_dist_ops
{
"send"
,
"recv"
,
"send_barrier"
,
"fetch_barrier"
};
auto
&
ops
=
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kAllOpDescs
);
std
::
vector
<
ir
::
Node
*>
op_node_list
;
op_node_list
.
reserve
(
ops
.
size
());
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
std
::
unordered_set
<
ir
::
Node
*>
ready_ops
;
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
std
::
unordered_set
<
ir
::
Node
*>
preceding_ops
;
for
(
auto
*
in
:
node
->
inputs
)
{
PADDLE_ENFORCE
(
in
->
IsVar
(),
"Preceding Node of Op Nodes must be Var Node"
);
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
preceding_ops
.
insert
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
insert
(
node
);
}
op_deps
[
node
]
=
preceding_ops
.
size
();
if
(
preceding_ops
.
empty
())
{
ready_ops
.
insert
(
node
);
}
}
for
(
auto
*
op_desc
:
ops
)
{
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ready_ops
)
{
if
(
IsSameOpDesc
(
op_desc
,
node
->
Op
()))
{
PADDLE_ENFORCE
(
found_node
==
nullptr
,
"Found multiple op_desc in graph: %s"
,
op_desc
->
Type
());
found_node
=
node
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
found_node
,
"Cannot find op_desc in graph: %s"
,
op_desc
->
Type
());
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
.
at
(
pending_op
)
==
0
)
{
ready_ops
.
insert
(
pending_op
);
}
}
ready_ops
.
erase
(
found_node
);
if
(
skip_dist_ops
.
count
(
op_desc
->
Type
())
==
0
)
{
op_node_list
.
push_back
(
found_node
);
}
}
for
(
size_t
i
=
1
;
i
<
op_node_list
.
size
();
++
i
)
{
auto
*
dep_var
=
graph
->
CreateControlDepVar
();
op_node_list
[
i
]
->
inputs
.
push_back
(
dep_var
);
op_node_list
[
i
-
1
]
->
outputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
op_node_list
[
i
]);
dep_var
->
inputs
.
push_back
(
op_node_list
[
i
-
1
]);
VLOG
(
10
)
<<
"Add dependencies between "
<<
op_node_list
[
i
-
1
]
->
Name
()
<<
" and "
<<
op_node_list
[
i
]
->
Name
();
}
return
graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
sequential_execution_pass
,
paddle
::
framework
::
details
::
SequentialExecutionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllOpDescs
);
paddle/fluid/framework/details/sequential_execution_pass.h
0 → 100644
浏览文件 @
fe4cd502
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
constexpr
char
kAllOpDescs
[]
=
"all_op_descs"
;
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/executor.cc
浏览文件 @
fe4cd502
...
...
@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
void
Executor
::
Close
()
{
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
()
::
paddle
::
operators
::
distributed
::
GRPCClient
>
(
0
)
->
SendComplete
();
#endif
}
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -41,6 +41,7 @@ pass_library(conv_bn_fuse_pass inference)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
pass_library
(
conv_bias_mkldnn_fuse_pass inference
)
pass_library
(
conv_relu_mkldnn_fuse_pass inference
)
pass_library
(
conv_elementwise_add_mkldnn_fuse_pass inference
)
...
...
@@ -59,6 +60,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test
(
test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector
)
cc_test
(
test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto
)
if
(
WITH_MKLDNN
)
cc_test
(
test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass
)
cc_test
(
test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
endif
()
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h
浏览文件 @
fe4cd502
...
...
@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase {
virtual
~
ConvReLUFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc
浏览文件 @
fe4cd502
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
// a->OP0->b
...
...
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std
::
unique_ptr
<
ir
::
Graph
>
DepthwiseConvMKLDNNPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
graph
.
get
());
FusePassBase
::
Init
(
"depthwise_conv_mkldnn_pass"
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
pattern
->
NewNode
(
"depthwise_conv"
)
->
assert_is_op
(
"depthwise_conv2d"
)
->
assert_op_attr
(
"use_mkldnn"
,
true
);
int
found_depthwise_conv_mkldnn_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
3
)
<<
"handle DepthwiseConvMKLDNN fuse"
;
GET_NODE
(
depthwise_conv
,
(
*
pattern
));
depthwise_conv
->
Op
()
->
SetType
(
"conv2d"
);
found_depthwise_conv_mkldnn_count
++
;
};
gpd
(
graph
.
get
(),
handler
);
AddStatis
(
found_depthwise_conv_mkldnn_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DepthwiseConvMKLDNNPass
:
public
FusePassBase
{
public:
virtual
~
DepthwiseConvMKLDNNPass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass_tester.cc
0 → 100644
浏览文件 @
fe4cd502
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
bool
use_mkldnn
=
false
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
op
->
SetAttr
(
"name"
,
name
);
op
->
SetInput
(
"Input"
,
{
inputs
[
0
]});
op
->
SetInput
(
"Filter"
,
{
inputs
[
1
]});
op
->
SetInput
(
"Bias"
,
{
inputs
[
2
]});
op
->
SetOutput
(
"Out"
,
outputs
);
}
// (a, weights, bias)->depthwise conv mkldnn->b
// (b, weights2, bias2)->depthwise conv no mkldnn->c
// (c, weights3, bias3)->conv mkldnn->d
// (d, weights3, bias3)->conv no mkldnn->e
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
{
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"weights"
,
"bias"
,
"weights2"
,
"bias2"
,
"weights3"
,
"bias3"
,
"weights4"
,
"bias4"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
if
(
v
==
"weights"
||
v
==
"bias"
||
v
==
"weights2"
||
v
==
"bias2"
||
v
==
"weights3"
||
v
==
"bias3"
||
v
==
"weights4"
||
v
==
"bias4"
)
{
var
->
SetPersistable
(
true
);
}
}
// depthwise conv with MKL-DNN
SetOp
(
&
prog
,
"depthwise_conv2d"
,
"conv1"
,
std
::
vector
<
std
::
string
>
({
"a"
,
"weights"
,
"bias"
}),
std
::
vector
<
std
::
string
>
({
"b"
}),
true
);
// depthwise conv without MKL-DNN
SetOp
(
&
prog
,
"depthwise_conv2d"
,
"conv2"
,
std
::
vector
<
std
::
string
>
({
"b"
,
"weights2"
,
"bias2"
}),
std
::
vector
<
std
::
string
>
({
"c"
}),
false
);
// conv with MKL-DNN
SetOp
(
&
prog
,
"conv2d"
,
"conv3"
,
std
::
vector
<
std
::
string
>
({
"c"
,
"weights3"
,
"bias3"
}),
std
::
vector
<
std
::
string
>
({
"d"
}),
true
);
// conv without MKL-dNN
SetOp
(
&
prog
,
"conv2d"
,
"conv4"
,
std
::
vector
<
std
::
string
>
({
"d"
,
"weights4"
,
"bias4"
}),
std
::
vector
<
std
::
string
>
({
"e"
}),
false
);
return
prog
;
}
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"depthwise_conv_mkldnn_pass"
);
struct
counters
{
int
mkldnn_depthwise_conv_nodes
;
int
other_depthwise_conv_nodes
;
int
mkldnn_conv_nodes
;
int
other_conv_nodes
;
};
counters
before
{
1
,
1
,
1
,
1
};
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
// initialize counters before loop
counters
after
{
0
,
0
,
0
,
0
};
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
Type
()
==
"conv2d"
)
{
if
(
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)))
after
.
mkldnn_conv_nodes
++
;
else
after
.
other_conv_nodes
++
;
}
else
if
(
op
->
Type
()
==
"depthwise_conv2d"
)
{
if
(
boost
::
get
<
bool
>
(
op
->
GetAttr
(
"use_mkldnn"
)))
after
.
mkldnn_depthwise_conv_nodes
++
;
else
after
.
other_depthwise_conv_nodes
++
;
}
}
}
EXPECT_EQ
(
after
.
other_depthwise_conv_nodes
,
before
.
other_depthwise_conv_nodes
);
EXPECT_EQ
(
after
.
other_conv_nodes
,
before
.
other_conv_nodes
);
EXPECT_EQ
(
after
.
mkldnn_depthwise_conv_nodes
,
before
.
mkldnn_depthwise_conv_nodes
-
1
);
EXPECT_EQ
(
after
.
mkldnn_conv_nodes
,
before
.
mkldnn_conv_nodes
+
1
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
depthwise_conv_mkldnn_pass
);
paddle/fluid/framework/ir/fc_fuse_pass_tester.cc
浏览文件 @
fe4cd502
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op
->
SetInput
(
"X"
,
inputs
);
}
op
->
SetOutput
(
"Out"
,
outputs
);
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
}
// a->OP0->b
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
fe4cd502
...
...
@@ -23,8 +23,67 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
void
CheckProgram
(
const
ProgramDesc
&
program
)
{
#define _INT(role) static_cast<int>(role)
std
::
map
<
int
,
bool
>
visit
;
for
(
OpDesc
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// For backward compatibility, some program doesn't have role added.
if
(
!
op
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
continue
;
int
role_id
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
visit
[
role_id
]
=
true
;
switch
(
role_id
)
{
case
_INT
(
OpRole
::
kForward
):
if
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add backward operator before forward operator %s."
<<
op
->
Type
();
}
break
;
case
_INT
(
OpRole
::
kBackward
):
case
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add backward operator %s after optimize operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kForward
)
|
_INT
(
OpRole
::
kLoss
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
))
==
visit
.
end
(),
"Cannot add backward|loss operator before "
"forward|loss operator %s."
,
op
->
Type
());
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add forward|loss operator %s after optimize operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kOptimize
):
case
_INT
(
OpRole
::
kOptimize
)
|
_INT
(
OpRole
::
kLRSched
):
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
(),
"Optimize operators %s must follow backward operator."
,
op
->
Type
());
break
;
case
_INT
(
OpRole
::
kLRSched
):
case
_INT
(
OpRole
::
kDist
):
case
_INT
(
OpRole
::
kRPC
):
case
_INT
(
OpRole
::
kNotSpecified
):
break
;
default:
LOG
(
FATAL
)
<<
"Unknown operator role. Don't add new role because "
"you don't know what you are doing."
;
}
}
#undef _INT
}
}
// namespace
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
CheckProgram
(
program_
);
// Make the nodes id start from 0.
Node
::
ResetId
();
auto
var_nodes
=
InitFromProgram
(
program_
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
fe4cd502
...
...
@@ -259,6 +259,15 @@ GraphPatternDetector::DetectPatterns() {
return
result
;
}
bool
GraphItemCMP
(
const
std
::
pair
<
PDNode
*
,
Node
*>
&
a
,
const
std
::
pair
<
PDNode
*
,
Node
*>
&
b
)
{
if
(
a
.
first
!=
b
.
first
)
{
return
a
.
first
<
b
.
first
;
}
else
{
return
a
.
second
<
b
.
second
;
}
}
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
// see https://github.com/PaddlePaddle/Paddle/issues/13550
void
GraphPatternDetector
::
UniquePatterns
(
...
...
@@ -267,12 +276,16 @@ void GraphPatternDetector::UniquePatterns(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
result
;
std
::
unordered_set
<
size_t
>
set
;
std
::
hash
<
std
::
string
>
hasher
;
for
(
auto
&
g
:
*
subgraphs
)
{
size_t
key
=
0
;
for
(
auto
&
item
:
g
)
{
key
^=
std
::
hash
<
void
*>
{}(
item
.
first
);
key
^=
std
::
hash
<
void
*>
{}(
item
.
second
);
// Sort the items in the sub-graph, and transform to a string key.
std
::
vector
<
std
::
pair
<
PDNode
*
,
Node
*>>
sorted_keys
(
g
.
begin
(),
g
.
end
());
std
::
sort
(
sorted_keys
.
begin
(),
sorted_keys
.
end
(),
GraphItemCMP
);
std
::
stringstream
ss
;
for
(
auto
&
item
:
sorted_keys
)
{
ss
<<
item
.
first
<<
":"
<<
item
.
second
;
}
auto
key
=
hasher
(
ss
.
str
());
if
(
!
set
.
count
(
key
))
{
result
.
emplace_back
(
g
);
set
.
insert
(
key
);
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
fe4cd502
...
...
@@ -418,7 +418,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE_EQ
(
new_lod
.
size
(),
lod
.
size
());
for
(
size_t
j
=
0
;
j
<
lod
.
size
();
++
j
)
{
auto
&
sub_lod
=
new_lod
[
j
];
auto
&
offset
=
sub_lod
.
back
();
size_t
offset
=
sub_lod
.
back
();
for
(
size_t
k
=
1
;
k
<
lod
[
j
].
size
();
++
k
)
{
sub_lod
.
push_back
(
lod
[
j
][
k
]
+
offset
);
}
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
fe4cd502
...
...
@@ -153,6 +153,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
nullptr
);
}
else
if
(
platform
::
is_cuda_pinned_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_pinned_place
=
boost
::
get
<
platform
::
CUDAPinnedPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_pinned_place
,
src_ptr
,
size
,
nullptr
);
}
#endif
}
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
fe4cd502
if
(
WITH_TESTING
)
include
(
test.cmake
)
# some generic cmake funtion for inference
endif
()
# analysis and tensorrt must be added before creating static library,
# otherwise, there would be undefined reference to them in static library.
add_subdirectory
(
analysis
)
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -20,22 +20,17 @@ cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid
)
function
(
inference_analysis_test TARGET
)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS ARGS EXTRA_DEPS
)
cmake_parse_arguments
(
analysis_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
mem_opt
""
)
if
(
WITH_GPU
)
set
(
mem_opt
"--fraction_of_gpu_memory_to_use=0.5"
)
endif
()
cc_test
(
${
TARGET
}
SRCS
"
${
analysis_test_SRCS
}
"
DEPS analysis pass
${
GLOB_PASS_LIB
}
${
analysis_test_EXTRA_DEPS
}
ARGS --inference_model_dir=
${
PYTHON_TESTS_DIR
}
/book/word2vec.inference.model
${
mem_opt
}
${
analysis_test_ARGS
}
)
set_tests_properties
(
${
TARGET
}
PROPERTIES DEPENDS test_word2vec
)
endif
(
WITH_TESTING
)
function
(
inference_analysis_test TARGET
)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS ARGS EXTRA_DEPS
)
cmake_parse_arguments
(
analysis_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
inference_base_test
(
${
TARGET
}
SRCS
${
analysis_test_SRCS
}
DEPS analysis pass
${
GLOB_PASS_LIB
}
${
analysis_test_EXTRA_DEPS
}
ARGS --inference_model_dir=
${
WORD2VEC_MODEL_DIR
}
${
analysis_test_ARGS
}
)
endif
()
endfunction
(
inference_analysis_test
)
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api
)
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
fe4cd502
...
...
@@ -79,6 +79,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
#ifdef PADDLE_WITH_MKLDNN
"depthwise_conv_mkldnn_pass"
,
//
"conv_bias_mkldnn_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_elementwise_add_mkldnn_fuse_pass"
,
//
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
fe4cd502
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
...
...
@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
op
->
SetType
(
type
);
op
->
SetInput
(
"Xs"
,
inputs
);
op
->
SetOutput
(
"Xs"
,
outputs
);
op
->
SetAttr
(
framework
::
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
}
TEST
(
DataFlowGraph
,
Build_IR_Graph
)
{
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -17,39 +17,12 @@ if(APPLE)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-error=pessimizing-move"
)
endif
(
APPLE
)
set
(
inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor
${
GLOB_PASS_LIB
}
)
set
(
inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor
${
GLOB_PASS_LIB
}
)
if
(
WITH_GPU AND TENSORRT_FOUND
)
set
(
inference_deps
${
inference_deps
}
paddle_inference_tensorrt_subgraph_engine analysis_predictor
)
endif
()
function
(
inference_api_test TARGET_NAME
)
if
(
WITH_TESTING
)
set
(
options
""
)
set
(
oneValueArgs SRC
)
set
(
multiValueArgs ARGS
)
cmake_parse_arguments
(
inference_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
if
(
WITH_GPU
)
cc_test
(
${
TARGET_NAME
}
SRCS
${
inference_test_SRC
}
DEPS
"
${
inference_deps
}
"
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/ --fraction_of_gpu_memory_to_use=0.15
)
else
()
cc_test
(
${
TARGET_NAME
}
SRCS
${
inference_test_SRC
}
DEPS
"
${
inference_deps
}
"
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
)
endif
()
if
(
inference_test_ARGS
)
set_tests_properties
(
${
TARGET_NAME
}
PROPERTIES DEPENDS
"
${
inference_test_ARGS
}
"
)
endif
()
endif
(
WITH_TESTING
)
endfunction
(
inference_api_test
)
cc_library
(
reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope
)
cc_library
(
analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor
)
...
...
@@ -59,8 +32,11 @@ cc_test(test_paddle_inference_api
SRCS api_tester.cc
DEPS paddle_inference_api
)
inference_api_test
(
test_api_impl SRC api_impl_tester.cc
ARGS test_word2vec test_image_classification
)
if
(
WITH_TESTING
)
inference_base_test
(
test_api_impl SRCS api_impl_tester.cc DEPS
${
inference_deps
}
ARGS --word2vec_dirname=
${
WORD2VEC_MODEL_DIR
}
--book_dirname=
${
PYTHON_TESTS_DIR
}
/book
)
set_tests_properties
(
test_api_impl PROPERTIES DEPENDS test_image_classification
)
endif
()
cc_test
(
test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor
${
inference_deps
}
paddle_inference_api
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book
)
...
...
@@ -68,8 +44,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
cc_library
(
paddle_inference_tensorrt_subgraph_engine
SRCS api_tensorrt_subgraph_engine.cc
DEPS paddle_inference_api analysis tensorrt_engine paddle_inference_api paddle_fluid_api tensorrt_converter zero_copy_tensor_dummy
)
inference_api_test
(
test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_engine_tester.cc ARGS test_word2vec
)
if
(
WITH_TESTING
)
inference_base_test
(
test_api_tensorrt_subgraph_engine SRCS api_tensorrt_subgraph_engine_tester.cc DEPS
${
inference_deps
}
ARGS --dirname=
${
WORD2VEC_MODEL_DIR
}
)
endif
()
endif
()
if
(
WITH_ANAKIN AND WITH_MKL
)
# only needed in CI
...
...
paddle/fluid/inference/api/api_impl_tester.cc
浏览文件 @
fe4cd502
...
...
@@ -22,12 +22,14 @@ limitations under the License. */
#include "paddle/fluid/inference/tests/test_helper.h"
#ifdef __clang__
#define ACC_DIFF 4e-
2
#define ACC_DIFF 4e-
3
#else
#define ACC_DIFF 1e-
2
#define ACC_DIFF 1e-
3
#endif
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
word2vec_dirname
,
""
,
"Directory of the word2vec inference model."
);
DEFINE_string
(
book_dirname
,
""
,
"Directory of the book inference model."
);
namespace
paddle
{
...
...
@@ -49,7 +51,7 @@ PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
NativeConfig
GetConfig
()
{
NativeConfig
config
;
config
.
model_dir
=
FLAGS_
dirname
+
"/word2vec.inference.model"
;
config
.
model_dir
=
FLAGS_
word2vec_dirname
;
LOG
(
INFO
)
<<
"dirname "
<<
config
.
model_dir
;
config
.
fraction_of_gpu_memory
=
0.15
;
#ifdef PADDLE_WITH_CUDA
...
...
@@ -116,7 +118,7 @@ void MainImageClassification(bool use_gpu) {
NativeConfig
config
=
GetConfig
();
config
.
use_gpu
=
use_gpu
;
config
.
model_dir
=
FLAGS_dirname
+
"/image_classification_resnet.inference.model"
;
FLAGS_
book_
dirname
+
"/image_classification_resnet.inference.model"
;
const
bool
is_combined
=
false
;
std
::
vector
<
std
::
vector
<
int64_t
>>
feed_target_shapes
=
...
...
@@ -220,7 +222,7 @@ void MainThreadsImageClassification(bool use_gpu) {
NativeConfig
config
=
GetConfig
();
config
.
use_gpu
=
use_gpu
;
config
.
model_dir
=
FLAGS_dirname
+
"/image_classification_resnet.inference.model"
;
FLAGS_
book_
dirname
+
"/image_classification_resnet.inference.model"
;
auto
main_predictor
=
CreatePaddlePredictor
<
NativeConfig
>
(
config
);
std
::
vector
<
framework
::
LoDTensor
>
jobs
(
num_jobs
);
...
...
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
浏览文件 @
fe4cd502
...
...
@@ -29,13 +29,13 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
//# 1. Create PaddlePredictor with a config.
NativeConfig
config0
;
config0
.
model_dir
=
FLAGS_dirname
+
"word2vec.inference.model"
;
config0
.
model_dir
=
FLAGS_dirname
;
config0
.
use_gpu
=
true
;
config0
.
fraction_of_gpu_memory
=
0.3
;
config0
.
device
=
0
;
MixedRTConfig
config1
;
config1
.
model_dir
=
FLAGS_dirname
+
"word2vec.inference.model"
;
config1
.
model_dir
=
FLAGS_dirname
;
config1
.
use_gpu
=
true
;
config1
.
fraction_of_gpu_memory
=
0.3
;
config1
.
device
=
0
;
...
...
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
fe4cd502
...
...
@@ -62,7 +62,7 @@ for WITH_STATIC_LIB in ON OFF; do
-DWITH_GPU
=
$TEST_GPU_CPU
\
-DWITH_STATIC_LIB
=
$WITH_STATIC_LIB
make
-j
word2vec_model
=
$
{
PADDLE_ROOT
}
'/build/python/paddle/fluid/tests/book
/word2vec.inference.model'
word2vec_model
=
$
DATA_DIR
'/word2vec
/word2vec.inference.model'
if
[
-d
$word2vec_model
]
;
then
for
use_gpu
in
$use_gpu_list
;
do
./simple_on_word2vec
\
...
...
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
浏览文件 @
fe4cd502
...
...
@@ -70,12 +70,8 @@ void Main(bool use_gpu) {
// The outputs' buffers are in CPU memory.
for
(
size_t
i
=
0
;
i
<
std
::
min
(
static_cast
<
size_t
>
(
5
),
num_elements
);
i
++
)
{
// Here will result random fail, for that the model is trained by CI, the
// train phase is not stable, so the result will be random.
// TODO(Superjomn) will restore after the model is upload.
// CHECK_NEAR(static_cast<float*>(outputs.front().data.data())[i],
// result[i],
// 0.001);
CHECK_NEAR
(
static_cast
<
float
*>
(
outputs
.
front
().
data
.
data
())[
i
],
result
[
i
],
0.001
);
}
}
}
...
...
paddle/fluid/inference/test.cmake
0 → 100644
浏览文件 @
fe4cd502
set
(
INFERENCE_URL
"http://paddle-inference-dist.cdn.bcebos.com"
CACHE STRING
"inference download url"
)
set
(
INFERENCE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
function
(
inference_download install_dir url filename
)
message
(
STATUS
"Download inference test stuff from
${
url
}
/
${
filename
}
"
)
execute_process
(
COMMAND bash -c
"mkdir -p
${
install_dir
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
install_dir
}
&& wget -q
${
url
}
/
${
filename
}
"
)
message
(
STATUS
"finish downloading
${
filename
}
"
)
endfunction
()
function
(
inference_download_and_uncompress install_dir url filename
)
inference_download
(
${
install_dir
}
${
url
}
${
filename
}
)
execute_process
(
COMMAND bash -c
"cd
${
install_dir
}
&& tar xzf
${
filename
}
"
)
endfunction
()
set
(
WORD2VEC_INSTALL_DIR
"
${
INFERENCE_DEMO_INSTALL_DIR
}
/word2vec"
)
if
(
NOT EXISTS
${
WORD2VEC_INSTALL_DIR
}
)
inference_download_and_uncompress
(
${
WORD2VEC_INSTALL_DIR
}
${
INFERENCE_URL
}
"word2vec.inference.model.tar.gz"
)
endif
()
set
(
WORD2VEC_MODEL_DIR
"
${
WORD2VEC_INSTALL_DIR
}
/word2vec.inference.model"
)
function
(
inference_base_test TARGET
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS ARGS DEPS
)
cmake_parse_arguments
(
base_test
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
if
(
WITH_GPU
)
set
(
mem_opt
"--fraction_of_gpu_memory_to_use=0.5"
)
endif
()
cc_test
(
${
TARGET
}
SRCS
${
base_test_SRCS
}
DEPS
${
base_test_DEPS
}
ARGS
${
mem_opt
}
${
base_test_ARGS
}
)
endfunction
()
paddle/fluid/inference/tests/api/CMakeLists.txt
浏览文件 @
fe4cd502
set
(
INFERENCE_URL
"http://paddle-inference-dist.cdn.bcebos.com"
)
set
(
INFERENCE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
set
(
INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
)
function
(
inference_download install_dir url filename
)
message
(
STATUS
"Download inference test stuff from
${
url
}
/
${
filename
}
"
)
execute_process
(
COMMAND bash -c
"mkdir -p
${
install_dir
}
"
)
execute_process
(
COMMAND bash -c
"cd
${
install_dir
}
&& wget -q
${
url
}
/
${
filename
}
"
)
message
(
STATUS
"finish downloading
${
filename
}
"
)
endfunction
()
function
(
inference_download_and_uncompress install_dir url filename
)
inference_download
(
${
install_dir
}
${
url
}
${
filename
}
)
execute_process
(
COMMAND bash -c
"cd
${
install_dir
}
&& tar xzf
${
filename
}
"
)
endfunction
()
function
(
download_model_and_data install_dir model_name data_name
)
if
(
NOT EXISTS
${
install_dir
}
)
...
...
paddle/fluid/operators/affine_grid_cudnn_op.cu.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
ScopedSpatialTransformerDescriptor
=
platform
::
ScopedSpatialTransformerDescriptor
;
template
<
typename
T
>
class
CUDNNAffineGridOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
theta
=
ctx
.
Input
<
Tensor
>
(
"Theta"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
const
T
*
theta_data
=
theta
->
data
<
T
>
();
int
n
=
theta
->
dims
()[
0
];
auto
size_attr
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"output_shape"
);
Tensor
h_sizes
;
int
*
h_size_data
;
if
(
size_attr
.
size
()
==
0
)
{
auto
*
output_shape
=
ctx
.
Input
<
Tensor
>
(
"OutputShape"
);
framework
::
TensorCopy
(
*
output_shape
,
platform
::
CPUPlace
(),
&
h_sizes
);
h_size_data
=
h_sizes
.
data
<
int
>
();
}
else
{
h_size_data
=
h_sizes
.
mutable_data
<
int
>
({
4
},
platform
::
CPUPlace
());
h_size_data
[
0
]
=
n
;
h_size_data
[
1
]
=
size_attr
[
1
];
h_size_data
[
2
]
=
size_attr
[
2
];
h_size_data
[
3
]
=
size_attr
[
3
];
}
T
*
output_data
=
output
->
mutable_data
<
T
>
(
{
n
,
h_size_data
[
2
],
h_size_data
[
3
],
2
},
ctx
.
GetPlace
());
ScopedSpatialTransformerDescriptor
st_desc
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_desc
=
st_desc
.
descriptor
<
T
>
(
4
,
h_size_data
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfGridGeneratorForward
(
handle
,
cudnn_st_desc
,
theta_data
,
output_data
));
}
};
template
<
typename
T
>
class
CUDNNAffineGridGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
theta_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Theta"
));
int
n
=
output_grad
->
dims
()[
0
];
auto
size_attr
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"output_shape"
);
Tensor
h_sizes
;
int
*
h_size_data
;
if
(
size_attr
.
size
()
==
0
)
{
auto
*
output_shape
=
ctx
.
Input
<
Tensor
>
(
"OutputShape"
);
framework
::
TensorCopy
(
*
output_shape
,
platform
::
CPUPlace
(),
&
h_sizes
);
h_size_data
=
h_sizes
.
data
<
int
>
();
}
else
{
h_size_data
=
h_sizes
.
mutable_data
<
int
>
({
4
},
platform
::
CPUPlace
());
h_size_data
[
0
]
=
n
;
h_size_data
[
1
]
=
size_attr
[
1
];
h_size_data
[
2
]
=
size_attr
[
2
];
h_size_data
[
3
]
=
size_attr
[
3
];
}
ScopedSpatialTransformerDescriptor
st_desc
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_desc
=
st_desc
.
descriptor
<
T
>
(
4
,
h_size_data
);
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
theta_grad_data
=
theta_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfGridGeneratorBackward
(
handle
,
cudnn_st_desc
,
output_grad_data
,
theta_grad_data
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
affine_grid
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNAffineGridOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNAffineGridOpKernel
<
double
>
);
REGISTER_OP_KERNEL
(
affine_grid_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNAffineGridGradOpKernel
<
double
>
);
paddle/fluid/operators/affine_grid_op.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
struct
Linspace
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
T
start
,
T
end
,
int
count
,
framework
::
Tensor
*
numbers
,
const
framework
::
ExecutionContext
&
ctx
)
{
T
*
number_data
=
numbers
->
mutable_data
<
T
>
({
count
},
platform
::
CPUPlace
());
T
slice
=
(
end
-
start
)
/
(
T
)(
count
-
1
);
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
number_data
[
i
]
=
start
+
(
T
)
i
*
slice
;
}
}
};
class
AffineGridOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Theta"
),
"Input(Theta) of AffineGridOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output(Output) of AffineGridOp should not be null."
);
auto
theta_dims
=
ctx
->
GetInputDim
(
"Theta"
);
PADDLE_ENFORCE
(
theta_dims
.
size
()
==
3
,
"AffineGrid's Input(Theta) should be 3-D tensor."
);
auto
output_shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"output_shape"
);
if
(
output_shape
.
size
()
==
0
)
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"OutputShape"
),
"Input(OutputShape) of AffineGridOp should not be null if "
"attr(output_shape) is not configured."
);
auto
output_shape_dims
=
ctx
->
GetInputDim
(
"OutputShape"
);
PADDLE_ENFORCE
(
output_shape_dims
.
size
()
==
1
,
"AffineGrid's Input(OutputShape) should be 1-D tensor."
);
}
else
{
PADDLE_ENFORCE
(
output_shape
.
size
()
==
4
,
"The size of attr(output_shape) should be 4."
);
}
PADDLE_ENFORCE
(
theta_dims
[
1
]
==
2
,
"Input(theta) dims[1] should be 2."
);
PADDLE_ENFORCE
(
theta_dims
[
2
]
==
3
,
"Input(theta) dims[2] should be 3."
);
// N * H * W * 2
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
({
theta_dims
[
0
],
-
1
,
-
1
,
2
}));
ctx
->
ShareLoD
(
"Theta"
,
"Output"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
auto
data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
());
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library
);
}
};
class
AffineGridOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Theta"
,
"(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. "
"It is used to transform coordinate (x_0, y_0) to coordinate (x_1, "
"y_1)."
);
AddInput
(
"OutputShape"
,
"(Tensor) The shape of target image with format [N, C, H, W]."
)
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor) Output Tensor with shape [N, H, W, 2]."
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
.
SetDefault
(
true
);
AddAttr
<
std
::
vector
<
int
>>
(
"output_shape"
,
"The target output image shape with format [N, C, H, W]."
)
.
SetDefault
(
std
::
vector
<
int
>
());
AddComment
(
R"DOC(
It generates a grid of (x,y) coordinates using the parameters of the
affine transformation that correspond to a set of points where the input
feature map should be sampled to produce the transformed output feature map.
Given:
Theta = [[[x_11, x_12, x_13]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
OutputShape = [2, 3, 5, 5]
Step 1:
Generate relative coordinates according to OutputShape.
The values of relative coordinates are in the interval between -1 and 1.
The shape of the relative coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5]
[ 1. 1. 1. 1. 1. ]]
[[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
[ 0.5 -1. 1. ]
[ 1. -1. 1. ]
[-1. -0.5 1. ]
[-0.5 -0.5 1. ]
[ 0. -0.5 1. ]
[ 0.5 -0.5 1. ]
[ 1. -0.5 1. ]
[-1. 0. 1. ]
[-0.5 0. 1. ]
[ 0. 0. 1. ]
[ 0.5 0. 1. ]
[ 1. 0. 1. ]
[-1. 0.5 1. ]
[-0.5 0.5 1. ]
[ 0. 0.5 1. ]
[ 0.5 0.5 1. ]
[ 1. 0.5 1. ]
[-1. 1. 1. ]
[-0.5 1. 1. ]
[ 0. 1. 1. ]
[ 0.5 1. 1. ]
[ 1. 1. 1. ]]
Step3:
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
)DOC"
);
}
};
class
AffineGridOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
theta_dims
=
ctx
->
GetInputDim
(
"Theta"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Theta"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Theta"
),
theta_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
AffineGridGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"affine_grid_grad"
);
op
->
SetInput
(
"Theta"
,
Input
(
"Theta"
));
op
->
SetInput
(
"OutputShape"
,
Input
(
"OutputShape"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
OutputGrad
(
"Output"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"Theta"
),
InputGrad
(
"Theta"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
affine_grid
,
ops
::
AffineGridOp
,
ops
::
AffineGridOpMaker
,
ops
::
AffineGridGradMaker
);
REGISTER_OPERATOR
(
affine_grid_grad
,
ops
::
AffineGridOpGrad
);
REGISTER_OP_CPU_KERNEL
(
affine_grid
,
ops
::
AffineGridOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
AffineGridOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
affine_grid_grad
,
ops
::
AffineGridGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
AffineGridGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/affine_grid_op.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Array1
=
Eigen
::
DSizes
<
int64_t
,
1
>
;
using
Array2
=
Eigen
::
DSizes
<
int64_t
,
2
>
;
using
Array3
=
Eigen
::
DSizes
<
int64_t
,
3
>
;
using
Array4
=
Eigen
::
DSizes
<
int64_t
,
4
>
;
/**
*Return a tensor with evenly spaced numbers over a specified interval.
*/
template
<
typename
DeviceContext
,
typename
T
>
struct
Linspace
{
void
operator
()(
T
start
,
T
end
,
int
count
,
framework
::
Tensor
*
numbers
,
const
framework
::
ExecutionContext
&
ctx
);
};
template
<
typename
DeviceContext
,
typename
T
>
inline
void
GetIdxMap
(
int
n
,
int
h
,
int
w
,
Tensor
*
grid
,
const
framework
::
ExecutionContext
&
ctx
)
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
grid
->
mutable_data
<
T
>
({
n
,
h
,
w
,
3
},
ctx
.
GetPlace
());
auto
grid_t
=
EigenTensor
<
T
,
4
>::
From
(
*
grid
);
// Get indexes of height with shape [height, width, 1]
Tensor
h_idx
;
Linspace
<
DeviceContext
,
T
>
linspace
;
linspace
((
T
)
-
1
,
(
T
)
1
,
h
,
&
h_idx
,
ctx
);
auto
h_idx_t
=
EigenTensor
<
T
,
1
>::
From
(
h_idx
);
// Get indexes of width with shape [height, width, 1]
Tensor
w_idx
;
linspace
((
T
)
-
1
,
(
T
)
1
,
w
,
&
w_idx
,
ctx
);
auto
w_idx_t
=
EigenTensor
<
T
,
1
>::
From
(
w_idx
);
// Get constant ones tensor with shape [height, width, 1]
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
h
,
w
,
1
},
ctx
.
GetPlace
());
auto
ones_t
=
EigenTensor
<
T
,
3
>::
From
(
ones
).
setConstant
((
T
)
1
);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor
w_idx_map
;
w_idx_map
.
mutable_data
<
T
>
({
h
,
w
,
1
},
ctx
.
GetPlace
());
auto
w_idx_map_t
=
EigenTensor
<
T
,
3
>::
From
(
w_idx_map
);
Tensor
h_idx_map
;
h_idx_map
.
mutable_data
<
T
>
({
h
,
w
,
1
},
ctx
.
GetPlace
());
auto
h_idx_map_t
=
EigenTensor
<
T
,
3
>::
From
(
h_idx_map
);
Tensor
w_h_idx_map
;
w_h_idx_map
.
mutable_data
<
T
>
({
h
,
w
,
2
},
ctx
.
GetPlace
());
auto
w_h_idx_map_t
=
EigenTensor
<
T
,
3
>::
From
(
w_h_idx_map
);
Tensor
w_h_one_idx_map
;
w_h_one_idx_map
.
mutable_data
<
T
>
({
h
,
w
,
3
},
ctx
.
GetPlace
());
auto
w_h_one_idx_map_t
=
EigenTensor
<
T
,
3
>::
From
(
w_h_one_idx_map
);
w_idx_map_t
.
device
(
place
)
=
w_idx_t
.
reshape
(
Array2
(
1
,
w
))
.
broadcast
(
Array2
(
h
,
1
))
.
reshape
(
Array3
(
h
,
w
,
1
));
h_idx_map_t
.
device
(
place
)
=
h_idx_t
.
reshape
(
Array2
(
1
,
h
))
.
broadcast
(
Array2
(
w
,
1
))
.
shuffle
(
Array2
(
1
,
0
))
.
reshape
(
Array3
(
h
,
w
,
1
));
w_h_idx_map_t
.
device
(
place
)
=
w_idx_map_t
.
concatenate
(
h_idx_map_t
,
2
);
w_h_one_idx_map_t
.
device
(
place
)
=
w_h_idx_map_t
.
concatenate
(
ones_t
,
2
);
grid_t
.
device
(
place
)
=
w_h_one_idx_map_t
.
reshape
(
Array4
(
1
,
h
,
w
,
3
))
.
broadcast
(
Array4
(
n
,
1
,
1
,
1
));
}
template
<
typename
DeviceContext
,
typename
T
>
class
AffineGridOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
theta
=
ctx
.
Input
<
Tensor
>
(
"Theta"
);
int
n
=
theta
->
dims
()[
0
];
auto
size_attr
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"output_shape"
);
int
h
=
0
;
int
w
=
0
;
if
(
size_attr
.
size
()
==
0
)
{
auto
*
output_shape
=
ctx
.
Input
<
Tensor
>
(
"OutputShape"
);
Tensor
h_sizes
;
framework
::
TensorCopy
(
*
output_shape
,
platform
::
CPUPlace
(),
&
h_sizes
);
const
int
*
h_size_data
=
h_sizes
.
data
<
int
>
();
h
=
h_size_data
[
2
];
w
=
h_size_data
[
3
];
}
else
{
h
=
size_attr
[
2
];
w
=
size_attr
[
3
];
}
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
output
,
static_cast
<
T
>
(
0
));
Tensor
grid
;
GetIdxMap
<
DeviceContext
,
T
>
(
n
,
h
,
w
,
&
grid
,
ctx
);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
Tensor
sliced_grid
=
grid
.
Slice
(
i
,
i
+
1
).
Resize
({
h
*
w
,
3
});
Tensor
sliced_theta
=
theta
->
Slice
(
i
,
i
+
1
).
Resize
({
2
,
3
});
Tensor
sliced_out
=
output
->
Slice
(
i
,
i
+
1
).
Resize
({
h
*
w
,
2
});
blas
.
MatMul
(
sliced_grid
,
false
,
sliced_theta
,
true
,
T
(
1
),
&
sliced_out
,
T
(
0
));
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
AffineGridGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
theta_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Theta"
));
int
n
=
output_grad
->
dims
()[
0
];
auto
size_attr
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"output_shape"
);
int
h
=
0
;
int
w
=
0
;
if
(
size_attr
.
size
()
==
0
)
{
auto
*
output_shape
=
ctx
.
Input
<
Tensor
>
(
"OutputShape"
);
Tensor
h_sizes
;
framework
::
TensorCopy
(
*
output_shape
,
platform
::
CPUPlace
(),
&
h_sizes
);
const
int
*
h_size_data
=
h_sizes
.
data
<
int
>
();
h
=
h_size_data
[
2
];
w
=
h_size_data
[
3
];
}
else
{
h
=
size_attr
[
2
];
w
=
size_attr
[
3
];
}
theta_grad
->
mutable_data
<
T
>
({
n
,
2
,
3
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
theta_grad
,
static_cast
<
T
>
(
0
));
Tensor
grid
;
GetIdxMap
<
DeviceContext
,
T
>
(
n
,
h
,
w
,
&
grid
,
ctx
);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
Tensor
sliced_grid
=
grid
.
Slice
(
i
,
i
+
1
).
Resize
({
h
*
w
,
3
});
Tensor
sliced_out_grad
=
output_grad
->
Slice
(
i
,
i
+
1
).
Resize
({
h
*
w
,
2
});
Tensor
sliced_theta_grad
=
theta_grad
->
Slice
(
i
,
i
+
1
).
Resize
({
2
,
3
});
blas
.
MatMul
(
sliced_out_grad
,
true
,
sliced_grid
,
false
,
T
(
1
),
&
sliced_theta_grad
,
T
(
0
));
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/checkpoint_notify_op.cc
浏览文件 @
fe4cd502
...
...
@@ -38,9 +38,10 @@ class CheckpointNotifyOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
std
::
string
dir
=
Attr
<
std
::
string
>
(
"dir"
);
std
::
string
lookup_table_name
=
Attr
<
std
::
string
>
(
"lookup_table"
);
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
i
++
)
{
auto
lookup_table_save_dir
=
string
::
Sprintf
(
"%s/%s_%d"
,
dir
,
lookup_table_name
,
i
);
...
...
@@ -63,6 +64,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"dir"
,
"(string, default '') indicate the folder checkpoint will use"
);
AddAttr
<
std
::
string
>
(
"lookup_table"
,
"(string, default '') the lookup table name"
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
CheckpointNotify operator
...
...
paddle/fluid/operators/conv_mkldnn_op.cc
浏览文件 @
fe4cd502
...
...
@@ -15,6 +15,8 @@
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +59,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return
conv_pd_
->
dst_primitive_desc
().
get_size
();
}
mkldnn
::
memory
::
format
GetDstFormat
()
const
{
return
static_cast
<
mkldnn
::
memory
::
format
>
(
conv_pd_
->
dst_primitive_desc
().
desc
().
data
.
format
);
}
size_t
GetDiffWeightsMemorySize
()
const
{
return
conv_bwd_weights_pd_
->
diff_weights_primitive_desc
().
get_size
();
}
...
...
@@ -108,6 +115,20 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
"@data-weights_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireResidualDataMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
return
this
->
AcquireMemory
(
md
,
ptr
,
"@user_residual_data_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromResidualDataMemory
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
user_residual_memory_p
,
void
*
dst_ptr
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
// NOLINT
return
this
->
AcquireMemory
(
user_residual_memory_p
,
this
->
AcquireDstMemoryFromPrimitive
(
dst_ptr
),
"@residual_data_mem_p"
,
pipeline
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemoryFromDataPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
...
...
@@ -386,7 +407,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
user_weights_memory_p
=
handler
.
AcquireWeightsMemory
(
user_weights_md
,
to_void_cast
<
T
>
(
filter_data
));
T
*
output_data
=
nullptr
;
// create reorder primitive if the input format is not the preferred one
auto
src_memory_p
=
handler
.
AcquireSrcMemoryFromPrimitive
(
user_src_memory_p
,
pipeline
);
auto
weights_memory_p
=
handler
.
AcquireWeightsMemoryFromPrimitive
(
user_weights_memory_p
,
pipeline
,
is_test
);
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
;
if
(
fuse_residual_conn
)
{
auto
residual_param
=
ctx
.
Input
<
Tensor
>
(
"ResidualData"
);
...
...
@@ -399,21 +426,34 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the "
"same dimension sizes"
);
output
->
ShareDataWith
(
*
residual_param
);
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
residual_param
->
format
()
!=
handler
.
GetDstFormat
())
{
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
auto
residual_data_tz
=
paddle
::
framework
::
vectorize2int
(
residual_param
->
dims
());
auto
residual_data_type
=
paddle
::
framework
::
ToMKLDNNDataType
(
residual_param
->
type
());
auto
user_residual_md
=
platform
::
MKLDNNMemDesc
(
residual_data_tz
,
residual_data_type
,
residual_param
->
format
());
auto
user_residual_memory_p
=
handler
.
AcquireResidualDataMemory
(
user_residual_md
,
to_void_cast
<
T
>
(
residual_param_data
));
dst_memory_p
=
handler
.
AcquireDstMemoryFromResidualDataMemory
(
user_residual_memory_p
,
to_void_cast
<
T
>
(
output_data
),
pipeline
);
}
else
{
output
->
ShareDataWith
(
*
residual_param
);
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
output_data
));
}
}
else
{
output_data
=
auto
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
handler
.
GetDstMemorySize
());
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
output_data
));
}
// create reorder primitive if the input format is not the preferred one
auto
src_memory_p
=
handler
.
AcquireSrcMemoryFromPrimitive
(
user_src_memory_p
,
pipeline
);
auto
weights_memory_p
=
handler
.
AcquireWeightsMemoryFromPrimitive
(
user_weights_memory_p
,
pipeline
,
is_test
);
auto
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
output_data
));
// create convolution op primitive
std
::
shared_ptr
<
mkldnn
::
convolution_forward
>
conv_p
;
if
(
bias
)
{
...
...
paddle/fluid/operators/delete_var_op.cc
浏览文件 @
fe4cd502
...
...
@@ -32,6 +32,11 @@ class DeleteVarOp : public framework::OperatorBase {
}
};
class
DeleteVarOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
class
DeleteVarOpInfoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -48,4 +53,5 @@ It should not be configured by users directly.
REGISTER_OPERATOR
(
delete_var
,
paddle
::
operators
::
DeleteVarOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
DeleteVarOpInfoMaker
);
paddle
::
operators
::
DeleteVarOpInfoMaker
,
paddle
::
operators
::
DeleteVarOpShapeInference
);
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
fe4cd502
...
...
@@ -79,7 +79,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
req
);
SerializeToByteBuffer
(
var_name_val
,
var
,
*
p_ctx
,
&
req
,
""
,
trainer_id_
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
...
...
@@ -105,7 +105,10 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void
ProcGetResponse
(
const
VarHandle
&
var_h
,
const
::
grpc
::
ByteBuffer
&
ret_msg
)
{
framework
::
Variable
*
outvar
=
nullptr
;
DeserializeFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
(),
var_h
.
scope
(),
&
outvar
);
// get response's trainer_id is not used
int
trainer_id
;
DeserializeFromByteBuffer
(
ret_msg
,
*
var_h
.
ctx
(),
var_h
.
scope
(),
&
outvar
,
&
trainer_id
);
}
template
<
typename
T
>
...
...
@@ -135,6 +138,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
fe4cd502
...
...
@@ -34,8 +34,8 @@ namespace distributed {
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
)
{
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
const
int
trainer_id
)
{
platform
::
RecordRPCEvent
record_event
(
"serial"
,
&
ctx
);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
...
...
@@ -45,6 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
size_t
payload_size
;
request
.
set_varname
(
name
);
request
.
set_trainer_id
(
trainer_id
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
...
...
@@ -147,11 +148,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
)
{
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
platform
::
RecordRPCEvent
record_event
(
"deserial"
,
&
ctx
);
operators
::
distributed
::
GRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
msg
)
==
0
,
"parse bytebuffer to tensor error!"
);
*
var
=
resp
.
GetVar
();
*
trainer_id
=
resp
.
GetTrainerId
();
}
}
// namespace distributed
...
...
paddle/fluid/operators/distributed/grpc_serde.h
浏览文件 @
fe4cd502
...
...
@@ -38,12 +38,13 @@ typedef void (*DestroyCallback)(void*);
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
());
const
std
::
string
&
out_varname
=
std
::
string
(),
const
int
trainer_id
=
0
);
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
);
framework
::
Variable
**
var
,
int
*
trainer_id
);
}
// namespace distributed
}
// namespace operators
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
fe4cd502
...
...
@@ -102,9 +102,10 @@ class RequestSend final : public RequestBase {
auto
scope
=
request_
->
GetMutableLocalScope
();
auto
invar
=
request_
->
GetVar
();
int
trainer_id
=
request_
->
GetTrainerId
();
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
Finish
(
reply_
,
&
responder_
);
}
...
...
@@ -133,13 +134,14 @@ class RequestGet final : public RequestBase {
void
Process
()
override
{
// proc request.
std
::
string
varname
=
request_
.
varname
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
varname
;
auto
scope
=
request_handler_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
if
(
outvar
)
{
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
...
...
@@ -179,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestPrefetch, in_var_name: "
<<
in_var_name
<<
" out_var_name: "
<<
out_var_name
;
...
...
@@ -187,7 +190,8 @@ class RequestPrefetch final : public RequestBase {
// out var must be created in local scope!
framework
::
Variable
*
outvar
=
scope
->
Var
(
out_var_name
);
request_handler_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
out_var_name
);
request_handler_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_var_name
);
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
...
...
@@ -225,12 +229,13 @@ class RequestCheckpointNotify final : public RequestBase {
std
::
string
checkpoint_notify
=
request_
->
Varname
();
std
::
string
checkpoint_dir
=
request_
->
OutVarname
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
;
request_handler_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
checkpoint_dir
);
trainer_id
,
checkpoint_dir
);
Finish
(
reply_
,
&
responder_
);
}
...
...
paddle/fluid/operators/distributed/grpc_variable_response.cc
浏览文件 @
fe4cd502
...
...
@@ -293,6 +293,14 @@ int GRPCVariableResponse::Parse(Source* source) {
}
break
;
}
case
sendrecv
::
VariableMessage
::
kTrainerIdFieldNumber
:
{
uint64_t
trainer_id
=
0
;
if
(
!
input
.
ReadVarint64
(
&
trainer_id
))
{
return
tag
;
}
meta_
.
set_trainer_id
(
trainer_id
);
break
;
}
default:
{
// Unknown tag, return unknown error.
return
-
1
;
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
fe4cd502
...
...
@@ -190,6 +190,7 @@ class RequestHandler {
// }
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
=
0
;
protected:
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
fe4cd502
...
...
@@ -36,6 +36,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
...
...
@@ -76,6 +77,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
sync_mode_
)
{
...
...
@@ -88,6 +90,19 @@ bool RequestGetHandler::Handle(const std::string& varname,
}
}
else
{
if
(
varname
!=
FETCH_BARRIER_MESSAGE
&&
varname
!=
COMPLETE_MESSAGE
)
{
if
(
enable_dc_asgd_
)
{
// NOTE: the format is determined by distributed_transpiler.py
std
::
string
param_bak_name
=
string
::
Sprintf
(
"%s.trainer_%d_bak"
,
varname
,
trainer_id
);
VLOG
(
3
)
<<
"getting "
<<
param_bak_name
<<
" trainer_id "
<<
trainer_id
;
auto
var
=
scope_
->
FindVar
(
varname
);
auto
t_orig
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
param_bak
=
scope_
->
Var
(
param_bak_name
);
auto
t
=
param_bak
->
GetMutable
<
framework
::
LoDTensor
>
();
t
->
mutable_data
(
dev_ctx_
->
GetPlace
(),
t_orig
.
type
());
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
...
...
@@ -98,6 +113,7 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
...
...
@@ -113,6 +129,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework
::
Scope
*
scope
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
PADDLE_ENFORCE
(
checkpoint_notify_id
!=
-
1
,
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
fe4cd502
...
...
@@ -36,20 +36,34 @@ namespace distributed {
class
RequestSendHandler
final
:
public
RequestHandler
{
public:
explicit
RequestSendHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
explicit
RequestSendHandler
(
bool
sync_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
override
;
private:
bool
enable_dc_asgd_
;
};
class
RequestGetHandler
final
:
public
RequestHandler
{
public:
explicit
RequestGetHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
explicit
RequestGetHandler
(
bool
sync_mode
,
bool
enable_dc_asgd
=
false
)
:
RequestHandler
(
sync_mode
)
{
enable_dc_asgd_
=
enable_dc_asgd
;
}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
override
;
private:
bool
enable_dc_asgd_
;
};
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
...
...
@@ -58,6 +72,7 @@ class RequestPrefetchHandler final : public RequestHandler {
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
override
;
};
...
...
@@ -70,6 +85,7 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual
~
RequestCheckpointHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
override
;
private:
...
...
paddle/fluid/operators/distributed/rpc_client.cc
浏览文件 @
fe4cd502
...
...
@@ -24,6 +24,7 @@ namespace distributed {
std
::
once_flag
RPCClient
::
init_flag_
;
std
::
unique_ptr
<
RPCClient
>
RPCClient
::
rpc_client_
(
nullptr
);
int
RPCClient
::
trainer_id_
=
0
;
}
// namespace distributed
}
// namespace operators
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
fe4cd502
...
...
@@ -72,14 +72,15 @@ class RPCClient {
virtual
bool
Wait
()
=
0
;
template
<
typename
T
>
static
RPCClient
*
GetInstance
()
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
<
T
>
);
static
RPCClient
*
GetInstance
(
int
trainer_id
)
{
std
::
call_once
(
init_flag_
,
&
RPCClient
::
Init
<
T
>
,
trainer_id
);
return
rpc_client_
.
get
();
}
// Init is called by GetInstance.
template
<
typename
T
>
static
void
Init
()
{
static
void
Init
(
int
trainer_id
)
{
trainer_id_
=
trainer_id
;
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
T
());
rpc_client_
->
InitImpl
();
...
...
@@ -88,6 +89,8 @@ class RPCClient {
protected:
virtual
void
InitImpl
()
{}
// each trainer have exact one trainer id, it should be static
static
int
trainer_id_
;
private:
static
std
::
once_flag
init_flag_
;
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
fe4cd502
...
...
@@ -125,7 +125,7 @@ TEST(PREFETCH, CPU) {
g_req_handler
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
true
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
1
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequestPrefetch
);
g_rpc_service
->
WaitServerReady
();
...
...
@@ -165,7 +165,7 @@ TEST(COMPLETE, CPU) {
g_req_handler
.
reset
(
new
distributed
::
RequestSendHandler
(
true
));
g_rpc_service
.
reset
(
new
RPCSERVER_T
(
"127.0.0.1:0"
,
2
));
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
PADDLE_ENFORCE
(
client
!=
nullptr
);
std
::
thread
server_thread
(
StartServer
,
distributed
::
kRequestSend
);
g_rpc_service
->
WaitServerReady
();
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
fe4cd502
...
...
@@ -79,6 +79,7 @@ message VariableMessage {
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2.
int64 profile = 11;
int64 trainer_id = 12;
}
message VoidMessage {}
paddle/fluid/operators/distributed/variable_response.h
浏览文件 @
fe4cd502
...
...
@@ -92,6 +92,8 @@ class VariableResponse {
return
scope_
->
FindVar
(
meta_
.
varname
());
}
int
GetTrainerId
()
{
return
static_cast
<
int
>
(
meta_
.
trainer_id
());
}
protected:
bool
ReadRaw
(
::
google
::
protobuf
::
io
::
CodedInputStream
*
input
,
const
platform
::
DeviceContext
&
dev_ctx
,
platform
::
Place
place
,
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
fe4cd502
...
...
@@ -37,7 +37,8 @@ class FetchBarrierOp : public framework::OperatorBase {
const
platform
::
Place
&
place
)
const
override
{
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
PADDLE_ENFORCE
(
rpc_client
->
Wait
(),
"internal error in RPCClient"
);
...
...
@@ -61,6 +62,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC"
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
...
...
paddle/fluid/operators/gen_nccl_id_op.cc
浏览文件 @
fe4cd502
...
...
@@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl id to "
<<
ep
;
...
...
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
DataLayout
=
platform
::
DataLayout
;
using
ScopedSpatialTransformerDescriptor
=
platform
::
ScopedSpatialTransformerDescriptor
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
class
CUDNNGridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
int
n
=
input
->
dims
()[
0
];
int
c
=
input
->
dims
()[
1
];
int
h
=
input
->
dims
()[
2
];
int
w
=
input
->
dims
()[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
ScopedSpatialTransformerDescriptor
st_desc
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_desc
=
st_desc
.
descriptor
<
T
>
(
4
,
size
);
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
output_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerForward
(
handle
,
cudnn_st_desc
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_output_desc
,
output_data
));
}
};
template
<
typename
T
>
class
CUDNNGridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
auto
output_grad_dims
=
output_grad
->
dims
();
const
int
n
=
output_grad_dims
[
0
];
const
int
c
=
output_grad_dims
[
1
];
const
int
h
=
output_grad_dims
[
2
];
const
int
w
=
output_grad_dims
[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
ScopedSpatialTransformerDescriptor
st_dest
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_dest
=
st_dest
.
descriptor
<
T
>
(
4
,
size
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
output_grad_dims
,
ctx
.
GetPlace
());
T
*
grid_grad_data
=
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
input_grad_desc
;
ScopedTensorDescriptor
output_grad_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_input_grad_desc
=
input_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input_grad
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_grad_desc
=
output_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output_grad
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerBackward
(
handle
,
cudnn_st_dest
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_input_grad_desc
,
input_grad_data
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_output_grad_desc
,
output_grad_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
grid_grad_data
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
grid_sampler
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
double
>
);
REGISTER_OP_KERNEL
(
grid_sampler_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
double
>
);
paddle/fluid/operators/grid_sampler_op.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
GridSampleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grid"
),
"Input(Grid) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output(Output) of GridSampleOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
4
,
"Input(X) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
.
size
()
==
4
,
"Input(Grid) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
[
3
]
==
2
,
"Input(Grid) dims[3] should be 2."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
0
],
x_dims
[
0
],
"Input(X) and Input(Grid) dims[0] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
1
],
x_dims
[
2
],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
2
],
x_dims
[
3
],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."
);
ctx
->
SetOutputDim
(
"Output"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
"Output"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]"
);
AddInput
(
"Grid"
,
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention"
);
AddOutput
(
"Output"
,
"(Tensor) Output tensor with shape [N, C, H, W]"
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default true) Only used in cudnn kernel, need install cudnn"
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC"
);
}
};
class
GridSampleOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
input_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Grid"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Grid"
),
grid_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"grid_sampler_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"Grid"
,
Input
(
"Grid"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
OutputGrad
(
"Output"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Grid"
),
InputGrad
(
"Grid"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
grid_sampler
,
ops
::
GridSampleOp
,
ops
::
GridSampleOpMaker
,
ops
::
GridSampleGradMaker
);
REGISTER_OPERATOR
(
grid_sampler_grad
,
ops
::
GridSampleOpGrad
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler_grad
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/grid_sampler_op.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Array3
=
Eigen
::
DSizes
<
int64_t
,
3
>
;
using
Array4
=
Eigen
::
DSizes
<
int64_t
,
4
>
;
template
<
typename
T
>
static
inline
bool
isInBound
(
T
x
,
T
y
,
T
x_max
,
T
y_max
)
{
if
(
x
<
0
||
x
>
x_max
||
y
<
0
||
y
>
y_max
)
{
return
false
;
}
return
true
;
}
template
<
typename
T
>
static
void
CalcGridLocations
(
const
platform
::
CPUDeviceContext
&
ctx
,
const
Tensor
&
grid
,
Tensor
*
x_w
,
Tensor
*
x_e
,
Tensor
*
y_n
,
Tensor
*
y_s
,
Tensor
*
d_w
,
Tensor
*
d_e
,
Tensor
*
d_n
,
Tensor
*
d_s
)
{
auto
&
place
=
*
ctx
.
eigen_device
();
const
int
n
=
grid
.
dims
()[
0
];
const
int
h
=
grid
.
dims
()[
1
];
const
int
w
=
grid
.
dims
()[
2
];
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor
grid_x
,
grid_y
;
T
*
grid_x_data
=
grid_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
T
*
grid_y_data
=
grid_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
const
T
*
grid_data
=
grid
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_x_data
[
i
]
=
grid_data
[
2
*
i
];
grid_y_data
[
i
]
=
grid_data
[(
2
*
i
)
+
1
];
}
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
ones_t
=
EigenTensor
<
T
,
3
>::
From
(
ones
).
setConstant
(
1.0
);
// scale grid to [0, h-1/w-1]
auto
grid_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_x
);
auto
grid_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_y
);
grid_x_t
.
device
(
place
)
=
0.5
*
((
grid_x_t
+
ones_t
)
*
x_max
);
grid_y_t
.
device
(
place
)
=
0.5
*
((
grid_y_t
+
ones_t
)
*
y_max
);
// calculate coords of 4 corner points
x_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
x_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
x_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_w
);
auto
x_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_e
);
auto
y_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_n
);
auto
y_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_s
);
x_w_t
.
device
(
place
)
=
grid_x_t
.
floor
();
x_e_t
.
device
(
place
)
=
x_w_t
+
ones_t
;
y_n_t
.
device
(
place
)
=
grid_y_t
.
floor
();
y_s_t
.
device
(
place
)
=
y_n_t
+
ones_t
;
// calculate distances to 4 sides
d_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_s
);
d_w_t
.
device
(
place
)
=
grid_x_t
-
x_w_t
;
d_e_t
.
device
(
place
)
=
x_e_t
-
grid_x_t
;
d_n_t
.
device
(
place
)
=
grid_y_t
-
y_n_t
;
d_s_t
.
device
(
place
)
=
y_s_t
-
grid_y_t
;
}
template
<
typename
T
>
static
void
GetGridPointValue
(
const
Tensor
&
input
,
Tensor
*
output
,
const
Tensor
&
x
,
const
Tensor
&
y
)
{
const
int
n
=
input
.
dims
()[
0
];
const
int
c
=
input
.
dims
()[
1
];
const
int
h
=
input
.
dims
()[
2
];
const
int
w
=
input
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
).
setConstant
((
T
)
0
);
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))));
}
}
}
}
}
}
template
<
typename
T
>
static
void
GatherOutputGradToInputGrad
(
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
const
Tensor
&
x
,
const
Tensor
&
y
,
const
Tensor
&
d1
,
const
Tensor
&
d2
)
{
const
int
n
=
output_grad
.
dims
()[
0
];
const
int
c
=
output_grad
.
dims
()[
1
];
const
int
h
=
output_grad
.
dims
()[
2
];
const
int
w
=
output_grad
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
d1_t
=
EigenTensor
<
T
,
3
>::
From
(
d1
);
auto
d2_t
=
EigenTensor
<
T
,
3
>::
From
(
d2
);
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
input_grad_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))))
+=
output_grad_t
(
i
,
j
,
k
,
l
)
*
d1_t
(
i
,
k
,
l
)
*
d2_t
(
i
,
k
,
l
);
}
}
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
// calc locations and distances of 4 corner points
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
output
,
static_cast
<
T
>
(
0
));
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
d_w_scaled_t
=
d_w_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_e_scaled_t
=
d_e_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_n_scaled_t
=
d_n_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_s_scaled_t
=
d_s_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
// bilinear interpolaetion by 4 corner points
output_t
.
device
(
place
)
=
v_wn_t
*
d_e_scaled_t
*
d_s_scaled_t
+
v_en_t
*
d_w_scaled_t
*
d_s_scaled_t
+
v_ws_t
*
d_e_scaled_t
*
d_n_scaled_t
+
v_es_t
*
d_w_scaled_t
*
d_n_scaled_t
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
input_grad
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
grid_grad
,
static_cast
<
T
>
(
0
));
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_n
,
d_e
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_s
,
d_e
,
d_n
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_n
,
d_w
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_s
,
d_w
,
d_n
);
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output_grad
);
Tensor
grid_grad_x
,
grid_grad_y
;
grid_grad_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
grid_grad_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
grid_grad_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_x
).
setConstant
(
0.0
);
auto
grid_grad_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_y
).
setConstant
(
0.0
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
grid_grad_x_t
(
i
,
k
,
l
)
+=
((
v_en_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_s_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_ws_t
(
i
,
j
,
k
,
l
))
*
d_n_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
grid_grad_y_t
(
i
,
k
,
l
)
+=
((
v_ws_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_e_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_en_t
(
i
,
j
,
k
,
l
))
*
d_w_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
}
}
}
}
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
grid_grad_x_t
=
grid_grad_x_t
*
(
x_max
/
(
T
)
2
);
grid_grad_y_t
=
grid_grad_y_t
*
(
y_max
/
(
T
)
2
);
// gather grid_grad [x, y] in 3rd Dim
T
*
grid_grad_data
=
grid_grad
->
data
<
T
>
();
T
*
grid_grad_x_data
=
grid_grad_x
.
data
<
T
>
();
T
*
grid_grad_y_data
=
grid_grad_y
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_grad_data
[
2
*
i
]
=
grid_grad_x_data
[
i
];
grid_grad_data
[
2
*
i
+
1
]
=
grid_grad_y_data
[
i
];
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
fe4cd502
...
...
@@ -218,23 +218,26 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
)
const
{
VLOG
(
2
)
<<
"RunAsyncLoop"
;
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_block_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_to_block_id_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_to_block_id"
);
for
(
const
auto
&
grad_and_id
:
grad_to_block_id_str
)
{
DoubleFindMap
<
std
::
string
,
int32_t
>
grad_to_block_id
;
auto
append_block_maps
=
[](
DoubleFindMap
<
std
::
string
,
int32_t
>
*
out_map
,
const
std
::
string
&
grad_and_id
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_and_id
,
':'
,
&
pieces
);
VLOG
(
3
)
<<
"after split,
grad
= "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
VLOG
(
3
)
<<
"after split,
key
= "
<<
pieces
[
0
]
<<
", id="
<<
pieces
[
1
];
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_block_id
.
count
(
pieces
[
0
]),
0
);
PADDLE_ENFORCE_EQ
(
out_map
->
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_block_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
(
*
out_map
)[
pieces
[
0
]]
=
block_id
;
};
for
(
const
auto
&
grad_and_id
:
grad_to_block_id_str
)
{
append_block_maps
(
&
grad_to_block_id
,
grad_and_id
);
}
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
...
...
@@ -244,15 +247,22 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list
.
push_back
(
blkid
);
}
auto
optimize_prepared
=
executor
->
Prepare
(
*
program
,
block_list
);
// execute global block if needed
if
(
block_list
[
0
]
==
1
&&
id_to_grad
.
count
(
1
)
==
0
)
{
// execute global block if needed, block id 1 in the program is global
// block if it's not bind to a grad var for it's update.
if
(
block_list
[
0
]
==
1
&&
grad_to_block_id
.
find_value
(
static_cast
<
int32_t
>
(
1
))
==
grad_to_block_id
.
end
())
{
executor
->
RunPreparedContext
(
optimize_prepared
[
0
].
get
(),
recv_scope
);
}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared_ctx
;
grad_to_prepared_ctx
,
param_to_prepared_ctx
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared_ctx
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
auto
blkid
=
block_list
[
i
];
auto
it
=
grad_to_block_id
.
find_value
(
blkid
);
if
(
it
!=
grad_to_block_id
.
end
())
{
grad_to_prepared_ctx
[
it
->
first
]
=
optimize_prepared
[
i
];
}
}
request_send_handler_
->
SetGradToPreparedCtx
(
&
grad_to_prepared_ctx
);
...
...
@@ -315,6 +325,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
bool
dc_sgd
=
Attr
<
bool
>
(
"dc_asgd"
);
auto
fan_in
=
Attr
<
int
>
(
"Fanin"
);
auto
inputs
=
Inputs
(
"X"
);
...
...
@@ -328,8 +339,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
RPCSERVER_T
(
endpoint
,
fan_in
));
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
));
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
));
request_send_handler_
.
reset
(
new
distributed
::
RequestSendHandler
(
sync_mode
,
dc_sgd
));
request_get_handler_
.
reset
(
new
distributed
::
RequestGetHandler
(
sync_mode
,
dc_sgd
));
request_prefetch_handler_
.
reset
(
new
distributed
::
RequestPrefetchHandler
(
sync_mode
));
request_checkpoint_handler_
.
reset
(
new
distributed
::
RequestCheckpointHandler
(
...
...
@@ -443,6 +456,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
).
SetDefault
(
true
);
AddAttr
<
bool
>
(
"dc_asgd"
,
"set to true will enable DC-ASGD training."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
framework
::
BlockDesc
*>>
(
kOptimizeBlocks
,
"Optimize blocks to run on server side."
)
.
SetDefault
({});
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
fe4cd502
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
...
...
@@ -37,6 +38,17 @@ constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
template
<
class
TKey
,
class
TValue
>
class
DoubleFindMap
:
public
std
::
unordered_map
<
TKey
,
TValue
>
{
public:
typename
std
::
unordered_map
<
TKey
,
TValue
>::
iterator
find_value
(
TValue
v
)
{
return
std
::
find_if
(
this
->
begin
(),
this
->
end
(),
[
&
v
](
const
std
::
pair
<
const
std
::
string
,
int
>
p
)
{
return
p
.
second
==
v
;
});
}
};
class
ListenAndServOp
:
public
framework
::
OperatorBase
{
public:
ListenAndServOp
(
const
std
::
string
&
type
,
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
fe4cd502
...
...
@@ -76,6 +76,6 @@ endif()
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_and_split
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_library
(
jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas
)
SRCS jit_kernel.cc jit_
gen.cc jit_code.cc jit_
kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
DEPS cpu_info cblas
gflags enforce
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/jit_code.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
namespace
platform
::
jit
;
// NOLINT
bool
VMulJitCode
::
init
(
int
d
)
{
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return
MayIUse
(
avx
);
}
void
VMulJitCode
::
generate
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmulss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_code.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/jit_gen.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
class
VMulJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
VMulJitCode
);
explicit
VMulJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
void
JitCode
::
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
void
JitCode
::
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
JitCode
::
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
filename
<<
"paddle_jitcode_"
<<
name
()
<<
"."
<<
counter
<<
".bin"
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
getSize
());
fout
.
close
();
}
}
}
Xbyak
::
Address
JitCode
::
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
)
{
int
scale
=
0
;
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
class
JitCode
:
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{}
virtual
~
JitCode
()
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
generate
()
=
0
;
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
this
->
generate
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
}
return
reinterpret_cast
<
const
FUNC
>
(
code
);
}
DISABLE_COPY_AND_ASSIGN
(
JitCode
);
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
void
preCode
();
void
postCode
();
void
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
;
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
fe4cd502
...
...
@@ -39,6 +39,7 @@ class Kernel {
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
// TODO(TJ): below members should be deprecated.
int
num_
{
0
};
int
end_
{
0
};
int
rest_
{
0
};
...
...
@@ -64,7 +65,7 @@ class KernelPool {
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
v
irtual
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
fe4cd502
...
...
@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
...
...
@@ -27,65 +30,77 @@ namespace paddle {
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
template
<
typename
T
>
void
VMulRefer
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VMulMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMulMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
#endif
/* VMUL JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
Compute
(
const
T
*
x
,
const
T
*
y
,
T
*
z
)
const
override
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
z
[
i
]
=
x
[
i
]
*
y
[
i
];
}
static
inline
std
::
string
name
(
int
d
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
};
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VMulJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VMulKernelImpl<float, isa, block>::Compute( \
const float* x, const float* y, float* z) const { \
platform::dynload::vsMul(this->num_, x, y, z); \
if
(
useMKL
(
d
))
{
this
->
Compute
=
VMulMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VMulRefer
<
T
>
;
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VMulKernelImpl<double, isa, block>::Compute( \
const double* x, const double* y, double* z) const { \
platform::dynload::vdMul(this->num_, x, y, z); \
}
private:
std
::
unique_ptr
<
gen
::
VMulJitCode
>
jitcode_
{
nullptr
};
};
FOR_EACH_ISA
(
MKL_FLOAT
,
kGT16
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
template
<
>
bool
VMulKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VMulJitCode
::
init
(
d
);
}
#define INTRI8_FLOAT(isa) \
template <> \
void VMulKernelImpl<float, isa, kEQ8>::Compute( \
const float* x, const float* y, float* z) const { \
__m256 tmpx, tmpy; \
tmpx = _mm256_loadu_ps(x); \
tmpy = _mm256_loadu_ps(y); \
tmpx = _mm256_mul_ps(tmpx, tmpy); \
_mm256_storeu_ps(z, tmpx); \
}
template
<
>
bool
VMulKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
jit
::
MayIUse
(
jit
::
avx512f
)
&&
d
>
512
;
}
// avx > for > mkl
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
template
<
>
bool
VMulKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
/* VADD JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -465,13 +480,12 @@ INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#undef INTRI16_FLOAT
#undef INTRI_COMMON_FLOAT
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddb
,
VAddBiasKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
fe4cd502
...
...
@@ -288,7 +288,7 @@ INTRIAVX512_FLOAT(kGT16);
#undef INIT_ALPHA
#undef UPDATE_ALPHA
REGISTER_JITKERNEL
(
crf_decode
,
CRFDecodeKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
crf_decode
,
CRFDecodeKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
fe4cd502
...
...
@@ -250,7 +250,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef MKL_FLOAT
#undef MKL_DOUBLE
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vexp
,
VExpKernel
);
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -396,7 +396,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vsigmoid
,
VSigmoidKernel
);
/* VTanh JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
@@ -531,7 +531,7 @@ INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2);
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
REGISTER_JITKERNEL
_DEPRECATED
(
vtanh
,
VTanhKernel
);
#undef JITKERNEL_NEW_ACT_IMPL
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
fe4cd502
...
...
@@ -21,8 +21,71 @@ namespace operators {
namespace
math
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, JITKERNEL_DECLARE, \
JITKERNEL_FIND_KEY, JITKERNEL_IMPL)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < AVX_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \
...
...
@@ -47,44 +110,42 @@ namespace jit = platform::jit;
SEARCH_BLOCK(macro_, ker, dtype, jit::isa_any); \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
#define JITKERNEL_NEW_IMPL
_DEPRECATED
(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>(
\
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE(ker_key, ker_class, ker_dtype, dtype_key, \
marco_declare, macro_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, JITKERNEL_DECLARE, \
JITKERNEL_KEY, JITKERNEL_NEW_IMPL)
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_declare, macro_key, \
macro_impl) \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, float, f, marco_declare, macro_key, \
macro_impl); \
JITKERNEL_WITH_DTYPE(ker_key, ker_class, double, d, marco_declare, \
macro_key, macro_impl)
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED)
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(jit::avx512f, block); \
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
fe4cd502
...
...
@@ -179,23 +179,23 @@ class LSTMKernelImpl : public LSTMKernel<T> {
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
private:
...
...
@@ -289,36 +289,36 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
/* get fgated and igated*/
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
private:
...
...
@@ -352,8 +352,8 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
act_cell, d)); \
}
REGISTER_JITKERNEL_ARGS
(
lstm
,
LSTMKernel
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_KEY_LSTM
,
JITKERNEL_NEW_LSTM_IMPL
);
REGISTER_JITKERNEL_ARGS
_DEPRECATED
(
lstm
,
LSTMKernel
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_KEY_LSTM
,
JITKERNEL_NEW_LSTM_IMPL
);
#undef INTRI8_FLOAT
#undef JITKERNEL_DECLARE_LSTM
...
...
@@ -378,13 +378,13 @@ class GRUKernelImpl : public GRUKernel<T> {
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
(
gates
,
gates
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
...
...
@@ -472,8 +472,8 @@ INTRI8_FLOAT(jit::avx512f);
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
REGISTER_JITKERNEL_ARGS
_DEPRECATED
(
gru
,
GRUKernel
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_KEY_GRU
,
JITKERNEL_NEW_GRU_IMPL
);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
fe4cd502
...
...
@@ -369,12 +369,12 @@ void lstm_ctht_better(
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
(
gates
,
gates
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
(
ct
,
gates
+
d2
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
TEST
(
JitKernel
,
lstm
)
{
...
...
@@ -578,7 +578,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) {
TEST
(
JitKernel
,
vmul
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
for
(
int
d
:
{
7
,
8
,
15
,
16
,
20
,
30
,
256
,
512
,
1000
,
1024
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
...
...
@@ -616,7 +616,7 @@ TEST(JitKernel, vmul) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -800,8 +800,8 @@ TEST(JitKernel, pool) {
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf4"
);
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
jit
4"
);
EXPECT_EQ
(
pvmul_f
,
pvmul_from_key
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
5
"
);
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulf
jit
"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
paddle/fluid/operators/math/pooling.cc
浏览文件 @
fe4cd502
...
...
@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
...
...
@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
...
...
@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
output_idx
]
=
ele
;
}
...
...
@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
paddle/fluid/operators/math/pooling.cu
浏览文件 @
fe4cd502
...
...
@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
bool
exclusive
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
...
...
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
PoolProcess
pool_process
,
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
...
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
output_data
);
stride_width
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
}
};
...
...
@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
pool_process
,
exclusive
,
input_grad_data
);
}
};
...
...
@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
double
>
;
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
bool
exclusive
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
}
...
...
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
...
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
}
};
...
...
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
padding_width
,
pool_process
,
exclusive
,
input_grad_data
);
}
};
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
fe4cd502
...
...
@@ -89,7 +89,7 @@ class Pool2dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
...
@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
@@ -123,7 +123,7 @@ class Pool3dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
...
@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
DeviceContext
,
class
T
>
...
...
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
fe4cd502
...
...
@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
...
@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
if
(
pooling_type
==
"max"
)
{
pooling_mode
=
PoolingMode
::
kMaximum
;
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
...
@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
Tensor
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
...
@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
pooling_mode
=
PoolingMode
::
kMaximum
;
}
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
fe4cd502
...
...
@@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() {
"operator."
"If global_pooling = true, paddings and ksize will be ignored."
)
.
SetDefault
({
0
,
0
});
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
...
...
@@ -236,6 +242,23 @@ Example:
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
For exclusive = true:
$$
hstart = i * strides[0] - paddings[0]
hend = hstart + ksize[0]
wstart = j * strides[1] - paddings[1]
wend = wstart + ksize[1]
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]}
$$
For exclusive = false:
$$
hstart = max(0, i * strides[0] - paddings[0])
hend = min(H, hstart + ksize[0])
wstart = max(0, j * strides[1] - paddings[1])
wend = min(W, wstart + ksize[1])
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
)DOC"
);
}
...
...
@@ -283,6 +306,12 @@ void Pool3dOpMaker::Make() {
"If global_pooling = true, ksize and paddings will be ignored."
)
.
SetDefault
({
0
,
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"use_cudnn"
,
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
fe4cd502
...
...
@@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
paddings
[
i
]
=
0
;
...
...
@@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
...
...
@@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
break
;
case
3
:
{
...
...
@@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel<T> {
pool3d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
@@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
...
...
@@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool2d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool2d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
break
;
case
3
:
{
...
...
@@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool3d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool3d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
fe4cd502
...
...
@@ -42,7 +42,8 @@ class PrefetchOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
...
...
@@ -69,6 +70,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) result "
"to be fetched from parameter server"
)
.
AsDuplicable
();
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
fe4cd502
...
...
@@ -33,6 +33,19 @@ class ReadInferShape : public framework::InferShapeBase {
reader_dims
.
size
(),
out_names
.
size
(),
"The reader's dim number doesn't match the output number."
);
ctx
->
SetOutputsDim
(
"Out"
,
reader_dims
);
if
(
!
ctx
->
IsRuntime
())
{
auto
in_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"Reader"
)[
0
]);
auto
in_lod_levels
=
in_desc
->
GetLoDLevels
();
auto
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
in_lod_levels
.
size
(),
out_var_ptrs
.
size
(),
"LoDLevels of Input(Reader) must be the same as the "
"number of Outputs(Out)."
);
for
(
size_t
i
=
0
;
i
<
out_var_ptrs
.
size
();
++
i
)
{
auto
*
out_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
]);
out_desc
->
SetLoDLevel
(
in_lod_levels
[
i
]);
}
}
}
};
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
fe4cd502
...
...
@@ -42,7 +42,8 @@ class RecvOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
i
++
)
{
...
...
@@ -73,6 +74,7 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"sync_mode"
,
"(int, default 0)"
"sync recv or async recv."
)
...
...
paddle/fluid/operators/ref_by_trainer_id_op.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
class
RefByTrainerIdOp
:
public
framework
::
OperatorWithKernel
{
public:
RefByTrainerIdOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"X"
),
"Input(X) of RefByTrainerIdOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"TrainerId"
),
"Input(TrainerId) of RefByTrainerIdOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of RefByTrainerIdOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"TrainerId"
).
size
(),
1
,
"TrainerId should be a scalar."
);
// Out's shape is determined at runtime.
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
GetPlace
());
}
};
class
RefByTrainerIdOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input tensor list."
).
AsDuplicable
();
AddInput
(
"TrainerId"
,
"(Tensor) Scalar int, the trainer id runtime value."
);
AddOutput
(
"Out"
,
"(Tensor) Return one tensor reference of X[trainer_id]"
);
AddComment
(
R"DOC(
**RefByTrainerId operator**
Return a reference of a tensor, using trainer_id as the index to find from the input.
$$Out = X[TrainerId]$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
ref_by_trainer_id
,
ops
::
RefByTrainerIdOp
,
ops
::
RefByTrainerIdOpMaker
);
REGISTER_OP_CPU_KERNEL
(
ref_by_trainer_id
,
ops
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/ref_by_trainer_id_op.cu.cc
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL
(
ref_by_trainer_id
,
paddle
::
operators
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
RefByTrainerIdKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/ref_by_trainer_id_op.h
0 → 100644
浏览文件 @
fe4cd502
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
RefByTrainerIdKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
in_list
=
context
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
*
trainer_id_t
=
context
.
Input
<
framework
::
Tensor
>
(
"TrainerId"
);
int64_t
trainer_id
;
auto
*
trainer_id_data
=
trainer_id_t
->
data
<
int64_t
>
();
if
(
platform
::
is_gpu_place
(
context
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
auto
stream
=
context
.
cuda_device_context
().
stream
();
memory
::
Copy
<>
(
platform
::
CPUPlace
(),
&
trainer_id
,
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
()),
trainer_id_data
,
sizeof
(
int64_t
),
stream
);
#endif
}
else
{
trainer_id
=
*
trainer_id_data
;
}
printf
(
"after get trainer_id %lu
\n
"
,
trainer_id
);
PADDLE_ENFORCE_LT
(
trainer_id
,
in_list
.
size
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
ShareDataWith
(
*
(
in_list
[
trainer_id
]));
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
fe4cd502
...
...
@@ -39,7 +39,8 @@ class SendBarrierOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
eps
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
VLOG
(
3
)
<<
"SendBarrierOp sync"
;
...
...
@@ -67,6 +68,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC"
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"endpoints"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to."
)
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
fe4cd502
...
...
@@ -44,7 +44,8 @@ class SendOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
...
...
@@ -79,6 +80,7 @@ This operator will send variables to listen_and_serve op at the parameter server
"(int, default 0)"
"sync send or async send."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
...
...
paddle/fluid/operators/sign_op.cc
浏览文件 @
fe4cd502
...
...
@@ -67,4 +67,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
sign
,
ops
::
SignOp
,
ops
::
SignOpMaker
<
float
>
,
ops
::
SignGradMaker
);
REGISTER_OP_CPU_KERNEL
(
sign
,
ops
::
SignKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
sign
,
ops
::
SignKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SignKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/sign_op.cu
浏览文件 @
fe4cd502
...
...
@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sign_op.h"
#include "paddle/fluid/platform/float16.h"
REGISTER_OP_CUDA_KERNEL
(
sign
,
paddle
::
operators
::
SignKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle
::
operators
::
SignKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
SignKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
SignKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
);
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
fe4cd502
...
...
@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"numeric_stable_mode"
,
"(bool, default: false), A flag to indicate whether to use more "
"numerically stable algorithm. This flag is only valid when "
"soft_label is false and GPU is used."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"ignore_index"
,
"(int, default -100), Specifies a target value that is ignored and"
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
fe4cd502
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -117,8 +118,8 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
static
__global__
void
RowReductionForMax
(
const
T
*
logits_data
,
T
*
max_data
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -141,9 +142,10 @@ __global__ void RowReductionForMax(const T* logits_data, T* max_data,
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
template
<
typename
T
,
int
BlockDim
,
bool
CalculateLogSoftmax
=
false
>
static
__global__
void
RowReductionForDiffMaxSum
(
const
T
*
logits_data
,
T
*
max_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -153,24 +155,34 @@ __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
softmax
[
beg_idx
]
=
logits_data
[
beg_idx
]
-
block_max
;
T
diff_max_sum
=
real_exp
(
softmax
[
beg_idx
]);
beg_idx
+=
BlockDim
;
while
(
beg_
idx
<
end_idx
)
{
softmax
[
beg_idx
]
=
logits_data
[
beg_
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
beg_
idx
]);
beg_
idx
+=
BlockDim
;
auto
idx
=
beg_idx
+
BlockDim
;
while
(
idx
<
end_idx
)
{
softmax
[
idx
]
=
logits_data
[
idx
]
-
block_max
;
diff_max_sum
+=
real_exp
(
softmax
[
idx
]);
idx
+=
BlockDim
;
}
diff_max_sum
=
BlockReduce
<
T
,
BlockDim
>
(
temp_storage
).
Reduce
(
diff_max_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
real_log
(
diff_max_sum
);
if
(
!
CalculateLogSoftmax
)
return
;
__syncthreads
();
diff_max_sum
=
max_data
[
blockIdx
.
x
];
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
while
(
beg_idx
<
end_idx
)
{
softmax
[
beg_idx
]
-=
diff_max_sum
;
beg_idx
+=
BlockDim
;
}
if
(
threadIdx
.
x
==
0
)
max_data
[
blockIdx
.
x
]
=
0
;
}
// Make sure that BlockDim <= feature_size
template
<
typename
T
,
int
BlockDim
>
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
static
__global__
void
RowReductionForSoftmaxAndCrossEntropy
(
const
T
*
logits_data
,
const
T
*
labels_data
,
T
*
loss_data
,
T
*
softmax
,
int
feature_size
)
{
__shared__
BlockReduceTempStorage
<
T
,
BlockDim
>
temp_storage
;
auto
beg_idx
=
feature_size
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -194,11 +206,134 @@ __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
}
template
<
typename
T
>
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
struct
HardLabelSoftmaxWithCrossEntropyFunctor
{
public:
HardLabelSoftmaxWithCrossEntropyFunctor
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
])
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
};
template
<
typename
T
>
struct
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
{
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx
(
const
T
*
logits
,
const
int64_t
*
labels
,
T
*
loss
,
T
*
log_softmax
,
int
feature_size
,
int
ignore_idx
)
:
logits_
(
logits
),
labels_
(
labels
),
loss_
(
loss
),
log_softmax_
(
log_softmax
),
feature_size_
(
feature_size
),
ignore_idx_
(
ignore_idx
)
{}
__device__
void
operator
()(
int
idx
)
const
{
auto
row_idx
=
idx
/
feature_size_
;
auto
col_idx
=
idx
%
feature_size_
;
if
(
col_idx
!=
labels_
[
row_idx
]
||
col_idx
==
ignore_idx_
)
{
log_softmax_
[
idx
]
=
real_exp
(
log_softmax_
[
idx
]);
}
else
{
auto
softmax
=
log_softmax_
[
idx
];
log_softmax_
[
idx
]
=
real_exp
(
softmax
);
loss_
[
row_idx
]
=
-
softmax
;
}
}
private:
const
T
*
logits_
;
const
int64_t
*
labels_
;
T
*
loss_
;
T
*
log_softmax_
;
int
feature_size_
;
int
ignore_idx_
;
};
template
<
typename
T
>
static
__global__
void
SetSoftmaxToOneWhenFeatureSizeIsOne
(
T
*
out
,
int
batch_size
)
{
auto
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
batch_size
)
out
[
idx
]
=
static_cast
<
T
>
(
1
);
}
template
<
typename
T
>
static
void
HardLabelSoftmaxWithCrossEntropy
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
logits_data
,
const
int64_t
*
labels_data
,
T
*
loss_data
,
T
*
softmax_data
,
int
batch_size
,
int
feature_size
,
int
ignore_idx
)
{
constexpr
int
kMaxBlockDim
=
512
;
int
block_dim
=
feature_size
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
feature_size
)));
auto
stream
=
ctx
.
stream
();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size, \
ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
} \
} break
switch
(
block_dim
)
{
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
512
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
256
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
128
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
64
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
32
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
16
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
8
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
4
);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
(
2
);
case
1
:
SetSoftmaxToOneWhenFeatureSizeIsOne
<<<
(
batch_size
+
kMaxBlockDim
-
1
)
/
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
),
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
break
;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template
<
typename
T
>
static
void
SoftmaxWithCrossEntropyFusedKernel
(
const
T
*
logits_data
,
const
T
*
labels_data
,
...
...
@@ -237,7 +372,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
kMaxBlockDim
,
kMaxBlockDim
,
0
,
stream
>>>
(
softmax_data
,
batch_size
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
,
stream
);
cudaMemsetAsync
(
loss_data
,
0
,
batch_size
*
sizeof
(
T
)
,
stream
);
break
;
default:
PADDLE_THROW
(
"BlockDim must be 2^n in softmax_with_cross_entropy_op"
);
...
...
@@ -272,11 +407,21 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
logits_data
,
labels_data
,
softmax_data
,
loss_data
,
batch_size
,
feature_size
,
context
.
cuda_device_context
().
stream
());
}
else
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
if
(
!
context
.
Attr
<
bool
>
(
"numeric_stable_mode"
))
{
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
}
else
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
auto
*
logits_data
=
logits
->
data
<
T
>
();
auto
*
labels_data
=
labels
->
data
<
int64_t
>
();
HardLabelSoftmaxWithCrossEntropy
<
T
>
(
context
.
cuda_device_context
(),
logits_data
,
labels_data
,
loss_data
,
softmax_data
,
batch_size
,
feature_size
,
ignore_index
);
}
}
}
};
...
...
paddle/fluid/operators/spp_op.h
浏览文件 @
fe4cd502
...
...
@@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel<T> {
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
MaxPool
<
T
>
,
T
>
pool_forward
;
math
::
MaxPool
<
T
>
max_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
max_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
max_process
,
true
,
&
out_level
);
}
else
if
(
pooling_type
==
"avg"
)
{
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
AvgPool
<
T
>
,
T
>
pool_forward
;
math
::
AvgPool
<
T
>
avg_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
avg_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
avg_process
,
true
,
&
out_level
);
}
// flatten pooling output shape
int
output_flatten_w
=
in_x
->
dims
()[
1
]
*
bins
*
bins
;
...
...
@@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
math
::
AvgPoolGrad
<
T
>
avg_process
;
pool_backward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
*&
out_level
,
*&
outgrad_level
,
kernel_size
,
strides
,
paddings
,
avg_process
,
in_x_grad
);
paddings
,
avg_process
,
true
,
in_x_grad
);
}
}
}
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
fe4cd502
...
...
@@ -92,7 +92,7 @@ TEST(SendNcclId, RPCServer) {
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
();
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
fe4cd502
...
...
@@ -76,8 +76,9 @@ enum class DataLayout { // Not use
enum
class
PoolingMode
{
kMaximum
,
kAverage
,
kMaximumDeterministic
,
kAverageExclusive
,
kAverageInclusive
,
};
#if CUDNN_VERSION < 6000
...
...
@@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
...
...
@@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX_DETERMINISTIC
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
default:
...
...
@@ -341,6 +346,28 @@ class ScopedPoolingDescriptor {
DISABLE_COPY_AND_ASSIGN
(
ScopedPoolingDescriptor
);
};
class
ScopedSpatialTransformerDescriptor
{
public:
ScopedSpatialTransformerDescriptor
()
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreateSpatialTransformerDescriptor
(
&
desc_
));
}
~
ScopedSpatialTransformerDescriptor
()
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroySpatialTransformerDescriptor
(
desc_
));
}
template
<
typename
T
>
inline
cudnnSpatialTransformerDescriptor_t
descriptor
(
const
int
nbDims
,
const
int
dimA
[])
{
PADDLE_ENFORCE
(
dynload
::
cudnnSetSpatialTransformerNdDescriptor
(
desc_
,
CUDNN_SAMPLER_BILINEAR
,
CudnnDataType
<
T
>::
type
,
nbDims
,
dimA
));
return
desc_
;
}
private:
cudnnSpatialTransformerDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedSpatialTransformerDescriptor
);
};
inline
bool
CanCUDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
)
{
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
use_cudnn
&=
paddle
::
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
fe4cd502
...
...
@@ -65,44 +65,51 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
...
...
paddle/fluid/platform/init.cc
浏览文件 @
fe4cd502
...
...
@@ -116,21 +116,49 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
#endif
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
512f
))
{
#ifndef __AVX
512F
__
LOG
(
WARNING
)
<<
"AVX
512F
is available, Please re-compile on local machine"
;
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
#ifndef __AVX__
LOG
(
WARNING
)
<<
"AVX is available, Please re-compile on local machine"
;
#endif
}
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx2
))
{
#ifndef __AVX2__
LOG
(
WARNING
)
<<
"AVX2 is available, Please re-compile on local machine"
;
// Throw some informations when CPU instructions mismatch.
#define AVX_GUIDE(compiletime, runtime) \
LOG(FATAL) \
<< "This version is compiled on higher instruction(" #compiletime \
") system, you may encounter illegal instruction error running on" \
" your local CPU machine. Please reinstall the " #runtime \
" version or compile from source code."
#ifdef __AVX512F__
if
(
!
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx512f
))
{
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx2
))
{
AVX_GUIDE
(
AVX512
,
AVX2
);
}
else
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
AVX_GUIDE
(
AVX512
,
AVX
);
}
else
{
AVX_GUIDE
(
AVX512
,
NonAVX
);
}
}
#endif
#ifdef __AVX2__
if
(
!
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx2
))
{
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
AVX_GUIDE
(
AVX2
,
AVX
);
}
else
{
AVX_GUIDE
(
AVX2
,
NonAVX
);
}
}
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
#ifndef __AVX__
LOG
(
WARNING
)
<<
"AVX is available, Please re-compile on local machine"
;
#endif
#ifdef __AVX__
if
(
!
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
AVX_GUIDE
(
AVX
,
NonAVX
);
}
#endif
#undef AVX_GUIDE
}
void
InitGLOG
(
const
std
::
string
&
prog_name
)
{
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
fe4cd502
...
...
@@ -187,6 +187,29 @@ class MKLDNNHandler {
return
mem_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
user_memory_p
,
const
std
::
shared_ptr
<
mkldnn
::
memory
>&
target_memory_p
,
const
std
::
string
&
suffix
,
std
::
vector
<
mkldnn
::
primitive
>&
pipeline
)
{
// NOLINT
auto
local_key
=
key_
+
suffix
;
auto
key_reorder_p
=
key_
+
suffix
+
"reorder_p"
;
auto
stored_reorder_p
=
std
::
static_pointer_cast
<
mkldnn
::
reorder
>
(
dev_ctx_
.
GetBlob
(
key_reorder_p
));
if
(
stored_reorder_p
)
{
pipeline
.
push_back
(
*
stored_reorder_p
);
}
else
{
auto
reorder_p
=
std
::
make_shared
<
mkldnn
::
reorder
>
(
*
user_memory_p
,
*
target_memory_p
);
dev_ctx_
.
SetBlob
(
key_reorder_p
,
reorder_p
);
pipeline
.
push_back
(
*
reorder_p
);
}
return
target_memory_p
;
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemory
(
mkldnn
::
memory
::
primitive_desc
&
mpd
,
// NOLINT
mkldnn
::
memory
::
primitive_desc
&
user_mpd
,
// NOLINT
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
fe4cd502
...
...
@@ -821,6 +821,13 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_data_balance_
=
b
;
})
// FIXME(chengudo): enable_data_balance seems not important
.
def_property
(
"enable_sequential_execution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_sequential_execution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_sequential_execution_
=
b
;
})
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
fe4cd502
...
...
@@ -147,7 +147,6 @@ function cmake_gen() {
-DWITH_SWIG_PY=
${
WITH_SWIG_PY
:-
ON
}
-DCUDNN_ROOT=/usr/
-DWITH_TESTING=
${
WITH_TESTING
:-
ON
}
-DWITH_FAST_BUNDLE_TEST=ON
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
-DWITH_FLUID_ONLY=
${
WITH_FLUID_ONLY
:-
OFF
}
...
...
@@ -180,7 +179,6 @@ EOF
-DWITH_PYTHON
=
${
WITH_PYTHON
:-
ON
}
\
-DCUDNN_ROOT
=
/usr/
\
-DWITH_TESTING
=
${
WITH_TESTING
:-
ON
}
\
-DWITH_FAST_BUNDLE_TEST
=
ON
\
-DCMAKE_MODULE_PATH
=
/opt/rocm/hip/cmake
\
-DWITH_FLUID_ONLY
=
${
WITH_FLUID_ONLY
:-
OFF
}
\
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
ON
\
...
...
python/paddle/fluid/io.py
浏览文件 @
fe4cd502
...
...
@@ -884,12 +884,13 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
load_prog
=
Program
()
load_block
=
load_prog
.
global_block
()
need_delete_vars
=
[]
for
var_tuple
in
slice_vars_and_attrs
:
orig_var
=
var_tuple
[
0
]
start
=
var_tuple
[
1
]
slice_var
=
var_tuple
[
2
]
end
=
start
+
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
end
=
start
+
slice_var
.
shape
[
0
]
clone_orig_var
=
load_block
.
create_var
(
name
=
orig_var
.
name
,
...
...
@@ -917,5 +918,8 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
attrs
=
{
'axes'
:
[
0
],
'starts'
:
[
start
],
'ends'
:
[
end
]})
need_delete_vars
.
append
(
clone_orig_var
)
load_block
.
append_op
(
type
=
'delete_var'
,
inputs
=
{
'X'
:
need_delete_vars
},
)
executor
.
run
(
load_prog
)
python/paddle/fluid/layers/control_flow.py
浏览文件 @
fe4cd502
...
...
@@ -1586,8 +1586,7 @@ class DynamicRNN(object):
self
.
lod_rank_table
=
None
self
.
max_seq_len
=
None
self
.
step_idx
=
None
self
.
zero_idx
=
fill_constant
(
shape
=
[
1
],
value
=
0
,
dtype
=
'int64'
,
force_cpu
=
True
)
self
.
zero_idx
=
None
self
.
mem_dict
=
dict
()
self
.
output_array
=
[]
self
.
outputs
=
[]
...
...
@@ -1792,6 +1791,7 @@ class DynamicRNN(object):
"""
self
.
_assert_in_rnn_block_
(
'memory'
)
self
.
_init_zero_idx_
()
if
init
is
not
None
:
if
not
isinstance
(
init
,
Variable
):
raise
TypeError
(
...
...
@@ -1905,6 +1905,22 @@ class DynamicRNN(object):
array_write
(
x
=
each
,
i
=
self
.
step_idx
,
array
=
outside_array
)
self
.
output_array
.
append
(
outside_array
)
def
_init_zero_idx_
(
self
):
if
self
.
zero_idx
is
None
:
parent_block
=
self
.
_parent_block_
()
self
.
zero_idx
=
parent_block
.
create_var
(
name
=
unique_name
.
generate
(
'zero_idx'
),
dtype
=
'int64'
)
parent_block
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
self
.
zero_idx
]},
attrs
=
{
'shape'
:
[
1
],
'dtype'
:
self
.
zero_idx
.
dtype
,
'value'
:
float
(
0
),
'force_cpu'
:
True
})
def
_parent_block_
(
self
):
prog
=
self
.
helper
.
main_program
parent_idx
=
prog
.
current_block
().
parent_idx
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
fe4cd502
...
...
@@ -315,6 +315,7 @@ def _copy_reader_var_(block, var):
new_var
=
block
.
create_var
(
name
=
var
.
name
,
type
=
core
.
VarDesc
.
VarType
.
READER
)
new_var
.
desc
.
set_shapes
(
var
.
desc
.
shapes
())
new_var
.
desc
.
set_dtypes
(
var
.
desc
.
dtypes
())
new_var
.
desc
.
set_lod_levels
(
var
.
desc
.
lod_levels
())
new_var
.
persistable
=
True
return
new_var
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
fe4cd502
...
...
@@ -154,9 +154,11 @@ __all__ = [
'mul'
,
'sigmoid_cross_entropy_with_logits'
,
'maxout'
,
'affine_grid'
,
'sequence_reverse'
,
'affine_channel'
,
'hash'
,
'grid_sampler'
,
'log_loss'
,
'add_position_encoding'
,
]
...
...
@@ -710,8 +712,18 @@ def dynamic_gru(input,
The first part are weights of the update gate and reset gate with
shape :math:`(D
\\
times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D
\\
times D)`.
bias_attr(ParamAttr): The parameter attribute for learnable the
hidden-hidden bias.
If it is set to None or one attribute of ParamAttr, dynamic_gru will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, dynamic_gru will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
is_reverse(bool): Whether to compute reversed GRU, default
:attr:`False`.
gate_activation(str): The activation for update gate and reset gate.
...
...
@@ -810,10 +822,29 @@ def gru_unit(input,
Args:
input (Variable): The fc transformed input value of current step.
hidden (Variable): The hidden value of
lstm
unit from previous step.
hidden (Variable): The hidden value of
gru
unit from previous step.
size (integer): The input dimension value.
param_attr (ParamAttr): The weight parameters for gru unit. Default: None
bias_attr (ParamAttr): The bias parameters for gru unit. Default: None
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:
- The shape of the weight matrix is :math:`(T
\\
times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D
\\
times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D
\\
times D)`.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, gru_unit will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
activation (string): The activation type for cell (actNode).
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
...
...
@@ -2071,7 +2102,8 @@ def pool2d(input,
global_pooling
=
False
,
use_cudnn
=
True
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
${comment}
...
...
@@ -2085,11 +2117,13 @@ def pool2d(input,
pool_type: ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
global_pooling: ${global_pooling_comment}
use_cudnn: ${use_cudnn_comment}
ceil_mode: ${ceil_mode_comment}
global_pooling
(bool)
: ${global_pooling_comment}
use_cudnn
(bool)
: ${use_cudnn_comment}
ceil_mode
(bool)
: ${ceil_mode_comment}
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: The pooling result.
...
...
@@ -2147,7 +2181,8 @@ def pool2d(input,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
return
pool_out
...
...
@@ -2161,7 +2196,8 @@ def pool3d(input,
global_pooling
=
False
,
use_cudnn
=
True
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
This function adds the operator for pooling in 3-dimensions, using the
pooling configurations mentioned in input parameters.
...
...
@@ -2177,6 +2213,8 @@ def pool3d(input,
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: output of pool3d layer.
...
...
@@ -2215,7 +2253,8 @@ def pool3d(input,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
return
pool_out
...
...
@@ -4443,7 +4482,10 @@ def transpose(x, perm, name=None):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[5, 10, 15], dtype='float32')
# use append_batch_size=False to avoid prepending extra
# batch size in shape
x = fluid.layers.data(name='x', shape=[5, 10, 15],
dtype='float32', append_batch_size=False)
x_transposed = layers.transpose(x, perm=[1, 0, 2])
"""
...
...
@@ -4680,7 +4722,8 @@ def multiplex(inputs, index):
def
softmax_with_cross_entropy
(
logits
,
label
,
soft_label
=
False
,
ignore_index
=-
100
):
ignore_index
=-
100
,
numeric_stable_mode
=
False
):
"""
**Softmax With Cross Entropy Operator.**
...
...
@@ -4714,6 +4757,18 @@ def softmax_with_cross_entropy(logits,
\\
left(
\\
text{logit}_i -
\\
log
\\
left(
\\
sum_{i=0}^{K}
\\
exp(
\\
text{logit}_i)
\\
right)
\\
right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by:
.. math::
max_j =
\\
max_{i=0}^{K}{
\\
text{logit}_i}
log
\\
_max
\\
_sum_j =
\\
log
\\
sum_{i=0}^{K}
\\
exp(logit_i - max_j)
softmax_j =
\\
exp(logit_j - max_j - {log
\\
_max
\\
_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
...
...
@@ -4725,6 +4780,13 @@ def softmax_with_cross_entropy(logits,
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: False
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
...
...
@@ -4747,8 +4809,11 @@ def softmax_with_cross_entropy(logits,
'Label'
:
label
},
outputs
=
{
'Softmax'
:
softmax
,
'Loss'
:
loss
},
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
})
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
,
'numeric_stable_mode'
:
numeric_stable_mode
})
return
loss
...
...
@@ -6108,6 +6173,124 @@ def crop(x, shape=None, offsets=None, name=None):
return
out
def
affine_grid
(
theta
,
out_shape
,
name
=
None
):
"""
It generates a grid of (x,y) coordinates using the parameters of
the affine transformation that correspond to a set of points where
the input feature map should be sampled to produce the transformed
output feature map.
.. code-block:: text
* Case 1:
Given:
theta = [[[x_11, x_12, x_13]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
out_shape = [2, 3, 5, 5]
Step 1:
Generate normalized coordinates according to out_shape.
The values of the normalized coordinates are in the interval between -1 and 1.
The shape of the normalized coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5]
[ 1. 1. 1. 1. 1. ]]
[[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
[ 0.5 -1. 1. ]
[ 1. -1. 1. ]
[-1. -0.5 1. ]
[-0.5 -0.5 1. ]
[ 0. -0.5 1. ]
[ 0.5 -0.5 1. ]
[ 1. -0.5 1. ]
[-1. 0. 1. ]
[-0.5 0. 1. ]
[ 0. 0. 1. ]
[ 0.5 0. 1. ]
[ 1. 0. 1. ]
[-1. 0.5 1. ]
[-0.5 0.5 1. ]
[ 0. 0.5 1. ]
[ 0.5 0.5 1. ]
[ 1. 0.5 1. ]
[-1. 1. 1. ]
[-0.5 1. 1. ]
[ 0. 1. 1. ]
[ 0.5 1. 1. ]
[ 1. 1. 1. ]]
Step3:
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
Args:
theta (Variable): A batch of affine transform parameters with shape [N, 2, 3].
out_shape (Variable | list | tuple): The shape of target output with format [N, C, H, W].
out_shape can be a Variable or a list or tuple.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output with shape [N, H, W, 2].
Raises:
ValueError: If the type of arguments is not supported.
Examples:
.. code-block:: python
theta = fluid.layers.data(name="x", shape=[2, 3], dtype="float32")
out_shape = fluid.layers.data(name="y", shape=[-1], dtype="float32")
data = fluid.layers.affine_grid(theta, out_shape)
# or
data = fluid.layers.affine_grid(theta, [5, 3, 28, 28])
"""
helper
=
LayerHelper
(
'affine_grid'
)
if
not
(
isinstance
(
out_shape
,
list
)
or
isinstance
(
out_shape
,
tuple
)
or
\
isinstance
(
out_shape
,
Variable
)):
raise
ValueError
(
"The out_shape should be a list, tuple or Variable."
)
if
not
isinstance
(
theta
,
Variable
):
raise
ValueError
(
"The theta should be a Variable."
)
out
=
helper
.
create_variable_for_type_inference
(
theta
.
dtype
)
ipts
=
{
'Theta'
:
theta
}
attrs
=
{}
if
isinstance
(
out_shape
,
Variable
):
ipts
[
'OutputShape'
]
=
out_shape
else
:
attrs
[
'output_shape'
]
=
out_shape
helper
.
append_op
(
type
=
'affine_grid'
,
inputs
=
ipts
,
outputs
=
{
'Output'
:
out
},
attrs
=
None
if
len
(
attrs
)
==
0
else
attrs
)
return
out
def
rank_loss
(
label
,
left
,
right
,
name
=
None
):
"""
**Rank loss layer for RankNet**
...
...
@@ -7322,10 +7505,10 @@ def clip(x, min, max, name=None):
helper
=
LayerHelper
(
"clip"
,
**
locals
())
if
name
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
name
=
unique_name
.
generate
(
"."
.
join
([
helper
.
name
,
'tmp'
])
)
out
=
helper
.
create_variable
(
type
=
x
.
type
,
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
"clip"
,
...
...
@@ -7354,10 +7537,10 @@ def clip_by_norm(x, max_norm, name=None):
helper
=
LayerHelper
(
"clip_by_norm"
,
**
locals
())
if
name
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
name
=
unique_name
.
generate
(
"."
.
join
([
helper
.
name
,
'tmp'
])
)
out
=
helper
.
create_variable
(
type
=
x
.
type
,
name
=
name
,
dtype
=
x
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
"clip_by_norm"
,
...
...
@@ -7561,19 +7744,59 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
def
hash
(
input
,
hash_size
,
num_hash
=
1
,
name
=
None
):
"""
hash the input
Args:
input (Variable): The input variable which is a one-hot word.
hash_size (int): The space size for hash algorithm.
Hash the input to an integer whose value is less than the given hash size.
The hash algorithm we used was xxHash - Extremely fast hash algorithm
(https://github.com/Cyan4973/xxHash/tree/v0.6.5)
A simple example as below:
.. code-block:: text
Given:
# shape [2, 2]
input.data = [
[[1], [2]],
[[3], [4]],
]
input.lod = [[0, 2]]
hash_size = 10000
num_hash = 4
Then:
Hash op will take all number in input's 2nd dimension as hash algorithm's
input for each time. Each input will be hashed for 4 times, and get an
array whose length is 4. Each value in the array ranges from 0 to 9999.
# shape [2, 4]
output.data = [
[[9662], [9217], [1129], [8487]],
[[8310], [1327], [1654], [4567]],
]
output.lod = [[0, 2]]
Args:
input (Variable): The input variable which is a one-hot word. The
dimensions of the input variable must be 2.
hash_size (int): The space size for hash algorithm. The output value
will keep in the range:math:`[0, hash_size - 1]`.
num_hash (int): The times of hash, default 1.
name (str, default None): The name of this layer.
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, len(word_dict))
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, num_hash=4, hash_size=1000)
"""
helper
=
LayerHelper
(
'hash'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
...
...
@@ -7587,6 +7810,87 @@ def hash(input, hash_size, num_hash=1, name=None):
return
out
@
templatedoc
()
def
grid_sampler
(
x
,
grid
,
name
=
None
):
"""
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper
=
LayerHelper
(
"grid_sampler"
,
**
locals
())
if
not
isinstance
(
x
,
Variable
):
return
ValueError
(
"The x should be a Variable"
)
if
not
isinstance
(
grid
,
Variable
):
return
ValueError
(
"The grid should be a Variable"
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
ipts
=
{
'X'
:
x
,
'Grid'
:
grid
}
helper
.
append_op
(
type
=
'grid_sampler'
,
inputs
=
ipts
,
outputs
=
{
'Output'
:
out
})
return
out
def
log_loss
(
input
,
label
,
epsilon
=
1e-4
,
name
=
None
):
"""
**Negative Log Loss Layer**
...
...
python/paddle/fluid/tests/unittests/dist_save_load.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
sys
import
signal
import
subprocess
import
argparse
import
time
import
math
import
random
from
multiprocessing
import
Process
from
functools
import
reduce
import
numpy
as
np
import
unittest
import
six
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid
import
io
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
,
RUN_STEP
from
dist_simnet_bow
import
TestDistSimnetBow2x2
,
DATA_URL
,
DATA_MD5
class
TestDistSaveLoad2x2
(
TestDistSimnetBow2x2
):
def
_load_persistable_vars
(
self
,
executor
,
dirname
,
program
):
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
RAW
:
return
False
# @GRAD are named for gradient variables, checkpoint will not save it.
if
"@GRAD"
in
var
.
name
:
return
False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if
".trainer_"
in
var
.
name
:
return
False
# .block is named for distribute train variables, checkpoint will not save it.
if
".block"
in
var
.
name
:
return
False
if
"tmp_"
in
var
.
name
:
return
False
return
var
.
persistable
io
.
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
run_pserver
(
self
,
args
):
self
.
get_model
(
batch_size
=
2
)
# NOTE: pserver should not call memory optimize
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
pserver_prog
=
t
.
get_pserver_program
(
args
.
current_endpoint
)
startup_prog
=
t
.
get_startup_program
(
args
.
current_endpoint
,
pserver_prog
)
need_load
=
bool
(
int
(
os
.
getenv
(
"LOAD"
,
"0"
)))
model_dir
=
os
.
getenv
(
"MODEL_DIR"
,
""
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
if
need_load
and
model_dir
:
self
.
_load_persistable_vars
(
exe
,
model_dir
,
startup_prog
)
exe
.
run
(
pserver_prog
)
def
run_trainer
(
self
,
args
):
test_program
,
avg_cost
,
train_reader
,
test_reader
,
batch_acc
,
predict
=
\
self
.
get_model
(
batch_size
=
2
)
if
args
.
mem_opt
:
fluid
.
memory_optimize
(
fluid
.
default_main_program
(),
skip_grads
=
True
)
if
args
.
is_dist
:
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
trainer_prog
=
t
.
get_trainer_program
()
else
:
trainer_prog
=
fluid
.
default_main_program
()
if
args
.
use_cuda
:
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
startup_exe
=
fluid
.
Executor
(
place
)
startup_exe
.
run
(
fluid
.
default_startup_program
())
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
1
strategy
.
allow_op_delay
=
False
build_stra
=
fluid
.
BuildStrategy
()
if
args
.
use_reduce
:
build_stra
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
else
:
build_stra
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
exe
=
fluid
.
ParallelExecutor
(
args
.
use_cuda
,
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
,
build_strategy
=
build_stra
)
feed_var_list
=
[
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
if
var
.
is_data
]
feeder
=
fluid
.
DataFeeder
(
feed_var_list
,
place
)
reader_generator
=
train_reader
()
def
get_data
():
origin_batch
=
next
(
reader_generator
)
if
args
.
is_dist
and
args
.
use_reader_alloc
:
new_batch
=
[]
for
offset
,
item
in
enumerate
(
origin_batch
):
if
offset
%
2
==
args
.
trainer_id
:
new_batch
.
append
(
item
)
return
new_batch
else
:
return
origin_batch
need_save
=
bool
(
int
(
os
.
getenv
(
"SAVE"
,
"0"
)))
model_dir
=
os
.
getenv
(
"MODEL_DIR"
,
""
)
if
need_save
:
for
_
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
if
need_save
and
model_dir
:
io
.
save_persistables
(
startup_exe
,
model_dir
,
trainer_prog
)
var
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
'__fc_b__'
).
get_tensor
())
print
(
np
.
ravel
(
var
).
tolist
())
if
__name__
==
"__main__"
:
paddle
.
dataset
.
common
.
download
(
DATA_URL
,
'simnet'
,
DATA_MD5
,
"train"
)
runtime_main
(
TestDistSaveLoad2x2
)
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
fe4cd502
...
...
@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce
=
False
,
fuse_elewise_add_act_ops
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
use_fast_executor
=
False
):
use_fast_executor
=
False
,
enable_sequential_execution
=
False
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
...
...
@@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
...
...
python/paddle/fluid/tests/unittests/test_affine_grid_op.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
AffineGrid
(
theta
,
size
):
n
=
size
[
0
]
w
=
size
[
3
]
h
=
size
[
2
]
h_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
h
)[
np
.
newaxis
,
:],
w
,
axis
=
0
).
T
[:,
:,
np
.
newaxis
]
w_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
w
)[
np
.
newaxis
,
:],
h
,
axis
=
0
)[:,
:,
np
.
newaxis
]
grid
=
np
.
concatenate
(
[
w_idx
,
h_idx
,
np
.
ones
([
h
,
w
,
1
])],
axis
=
2
)
# h * w * 3
grid
=
np
.
repeat
(
grid
[
np
.
newaxis
,
:],
size
[
0
],
axis
=
0
)
# n * h * w *3
ret
=
np
.
zeros
([
n
,
h
*
w
,
2
])
theta
=
theta
.
transpose
([
0
,
2
,
1
])
for
i
in
range
(
len
(
theta
)):
ret
[
i
]
=
np
.
dot
(
grid
[
i
].
reshape
([
h
*
w
,
3
]),
theta
[
i
])
# print ret.reshape([h * w, 2]).astype("float32")
return
ret
.
reshape
([
n
,
h
,
w
,
2
]).
astype
(
"float32"
)
class
TestAffineGridOp
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
"affine_grid"
theta
=
np
.
random
.
randint
(
1
,
3
,
self
.
theta_shape
).
astype
(
"float32"
)
theta
=
np
.
ones
(
self
.
theta_shape
).
astype
(
"float32"
)
self
.
inputs
=
{
'Theta'
:
theta
}
self
.
attrs
=
{
"use_cudnn"
:
True
}
if
self
.
dynamic_shape
:
self
.
inputs
[
'OutputShape'
]
=
self
.
output_shape
else
:
self
.
attrs
[
'output_shape'
]
=
self
.
output_shape
self
.
outputs
=
{
'Output'
:
AffineGrid
(
theta
,
self
.
output_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
(
[
'Theta'
],
'Output'
,
no_grad_set
=
[
'OutputShape'
],
max_relative_error
=
0.006
)
def
initTestCase
(
self
):
self
.
theta_shape
=
(
3
,
2
,
3
)
self
.
output_shape
=
np
.
array
([
3
,
2
,
5
,
7
]).
astype
(
"int32"
)
self
.
dynamic_shape
=
False
class
TestAffineGridOpCase1
(
TestAffineGridOp
):
def
initTestCase
(
self
):
self
.
theta_shape
=
(
3
,
2
,
3
)
self
.
output_shape
=
np
.
array
([
3
,
2
,
5
,
7
]).
astype
(
"int32"
)
self
.
dynamic_shape
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
fe4cd502
...
...
@@ -37,10 +37,15 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes."
)
@
staticmethod
def
get_transpiler
(
trainer_id
,
main_program
,
pserver_endpoints
,
trainers
,
sync_mode
):
def
get_transpiler
(
trainer_id
,
main_program
,
pserver_endpoints
,
trainers
,
sync_mode
,
dc_asgd
=
False
):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
enable_dc_asgd
=
dc_asgd
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
trainer_id
=
trainer_id
,
...
...
@@ -55,7 +60,7 @@ class TestDistRunnerBase(object):
# NOTE: pserver should not call memory optimize
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
args
.
trainers
,
args
.
sync_mode
,
args
.
dc_asgd
)
pserver_prog
=
t
.
get_pserver_program
(
args
.
current_endpoint
)
startup_prog
=
t
.
get_startup_program
(
args
.
current_endpoint
,
pserver_prog
)
...
...
@@ -75,8 +80,7 @@ class TestDistRunnerBase(object):
t
=
self
.
get_transpiler
(
args
.
trainer_id
,
fluid
.
default_main_program
(),
args
.
endpoints
,
args
.
trainers
,
args
.
sync_mode
)
args
.
sync_mode
,
args
.
dc_asgd
)
trainer_prog
=
t
.
get_trainer_program
()
else
:
trainer_prog
=
fluid
.
default_main_program
()
...
...
@@ -155,6 +159,7 @@ def runtime_main(test_class):
parser
.
add_argument
(
'--mem_opt'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_cuda'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_reduce'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--dc_asgd'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_reader_alloc'
,
action
=
'store_true'
,
required
=
False
)
parser
.
add_argument
(
'--batch_size'
,
required
=
False
,
type
=
int
,
default
=
2
)
...
...
@@ -200,6 +205,7 @@ class TestDistBase(unittest.TestCase):
self
.
_enforce_place
=
None
self
.
_mem_opt
=
False
self
.
_use_reduce
=
False
self
.
_dc_asgd
=
False
# must use with async mode
self
.
_use_reader_alloc
=
True
self
.
_setup_config
()
self
.
_after_setup_config
()
...
...
python/paddle/fluid/tests/unittests/test_dist_mnist.py
浏览文件 @
fe4cd502
...
...
@@ -53,6 +53,15 @@ class TestDistMnistAsync(TestDistBase):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
200
)
class
TestDistMnistDcAsgd
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_dc_asgd
=
True
def
test_se_resnext
(
self
):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
200
)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
...
...
python/paddle/fluid/tests/unittests/test_dist_save_load.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
shutil
import
unittest
import
tempfile
import
numpy
as
np
from
test_dist_base
import
TestDistBase
,
RUN_STEP
class
TestDistSaveLoadDense2x2
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_enforce_place
=
"CPU"
def
check_with_place
(
self
,
model_file
,
delta
=
1e-3
,
check_error_log
=
False
,
need_envs
=
{}):
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
,
""
),
"PYTHONPATH"
:
os
.
getenv
(
"PYTHONPATH"
,
""
),
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"http_proxy"
:
""
}
required_envs
.
update
(
need_envs
)
if
check_error_log
:
required_envs
[
"GLOG_v"
]
=
"7"
required_envs
[
"GLOG_logtostderr"
]
=
"1"
model_dir
=
tempfile
.
mkdtemp
()
local_env
=
{}
local_env
[
"SAVE"
]
=
"1"
local_env
[
"MODEL_DIR"
]
=
model_dir
local_env
.
update
(
required_envs
)
cluster_env
=
{}
cluster_env
[
"LOAD"
]
=
"1"
cluster_env
[
"MODEL_DIR"
]
=
model_dir
cluster_env
.
update
(
required_envs
)
local_var
=
self
.
_run_local
(
model_file
,
local_env
,
check_error_log
)
tr0_var
,
tr1_var
=
self
.
_run_cluster
(
model_file
,
cluster_env
,
check_error_log
)
shutil
.
rmtree
(
model_dir
)
local_np
=
np
.
array
(
eval
(
local_var
[
0
]))
train0_np
=
np
.
array
(
eval
(
tr0_var
[
0
]))
train1_np
=
np
.
array
(
eval
(
tr1_var
[
0
]))
self
.
assertAlmostEqual
(
local_np
.
all
(),
train0_np
.
all
(),
delta
=
delta
)
self
.
assertAlmostEqual
(
local_np
.
all
(),
train1_np
.
all
(),
delta
=
delta
)
self
.
assertAlmostEqual
(
train0_np
.
all
(),
train1_np
.
all
(),
delta
=
delta
)
@
unittest
.
skip
(
reason
=
"CI fail"
)
def
test_dist
(
self
):
need_envs
=
{
"IS_DISTRIBUTED"
:
'0'
,
"IS_SPARSE"
:
'0'
,
'IS_SELF_CONTAINED_LR'
:
'1'
}
self
.
check_with_place
(
"dist_save_load.py"
,
delta
=
0
,
check_error_log
=
False
,
need_envs
=
need_envs
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
AffineGrid
(
theta
,
size
):
n
=
size
[
0
]
h
=
size
[
2
]
w
=
size
[
3
]
h_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
h
)[
np
.
newaxis
,
:],
w
,
axis
=
0
).
T
[:,
:,
np
.
newaxis
]
w_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
w
)[
np
.
newaxis
,
:],
h
,
axis
=
0
)[:,
:,
np
.
newaxis
]
grid
=
np
.
concatenate
(
[
w_idx
,
h_idx
,
np
.
ones
([
h
,
w
,
1
])],
axis
=
2
)
# h * w * 3
grid
=
np
.
repeat
(
grid
[
np
.
newaxis
,
:],
size
[
0
],
axis
=
0
)
# n * h * w *3
ret
=
np
.
zeros
([
n
,
h
*
w
,
2
])
theta
=
theta
.
transpose
([
0
,
2
,
1
])
for
i
in
range
(
len
(
theta
)):
ret
[
i
]
=
np
.
dot
(
grid
[
i
].
reshape
([
h
*
w
,
3
]),
theta
[
i
])
return
ret
.
reshape
([
n
,
h
,
w
,
2
]).
astype
(
"float32"
)
def
getGridPointValue
(
data
,
x
,
y
):
data_shape
=
data
.
shape
N
=
data_shape
[
0
]
H
=
data_shape
[
2
]
W
=
data_shape
[
3
]
out
=
np
.
zeros
(
data_shape
,
dtype
=
'float'
)
for
i
in
range
(
N
):
for
j
in
range
(
H
):
for
k
in
range
(
W
):
if
y
[
i
,
j
,
k
]
<
0
or
y
[
i
,
j
,
k
]
>
H
-
1
or
x
[
i
,
j
,
k
]
<
0
or
x
[
i
,
j
,
k
]
>
W
-
1
:
out
[
i
,
:,
j
,
k
]
=
0
else
:
out
[
i
,
:,
j
,
k
]
=
data
[
i
,
:,
y
[
i
,
j
,
k
],
x
[
i
,
j
,
k
]]
return
out
def
GridSampler
(
data
,
grid
):
dims
=
data
.
shape
N
=
dims
[
0
]
C
=
dims
[
1
]
H
=
dims
[
2
]
W
=
dims
[
3
]
x
=
grid
[:,
:,
:,
0
]
y
=
grid
[:,
:,
:,
1
]
y_max
=
H
-
1
x_max
=
W
-
1
x
=
0.5
*
((
x
.
astype
(
'float32'
)
+
1.0
)
*
x_max
)
y
=
0.5
*
((
y
.
astype
(
'float32'
)
+
1.0
)
*
y_max
)
x0
=
np
.
floor
(
x
).
astype
(
'int32'
)
x1
=
x0
+
1
y0
=
np
.
floor
(
y
).
astype
(
'int32'
)
y1
=
y0
+
1
wa
=
np
.
tile
(((
x1
-
x
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wb
=
np
.
tile
(((
x1
-
x
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wc
=
np
.
tile
(((
x
-
x0
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wd
=
np
.
tile
(((
x
-
x0
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
va
=
getGridPointValue
(
data
,
x0
,
y0
)
vb
=
getGridPointValue
(
data
,
x0
,
y1
)
vc
=
getGridPointValue
(
data
,
x1
,
y0
)
vd
=
getGridPointValue
(
data
,
x1
,
y1
)
out
=
(
wa
*
va
+
wb
*
vb
+
wc
*
vc
+
wd
*
vd
).
astype
(
'float32'
)
return
out
class
TestGridSamplerOp
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
'grid_sampler'
x
=
np
.
random
.
randint
(
0
,
255
,
self
.
x_shape
).
astype
(
'float32'
)
theta
=
np
.
zeros
(
self
.
theta_shape
).
astype
(
'float32'
)
for
i
in
range
(
self
.
theta_shape
[
0
]):
for
j
in
range
(
2
):
for
k
in
range
(
3
):
theta
[
i
,
j
,
k
]
=
np
.
random
.
rand
(
1
)[
0
]
grid
=
AffineGrid
(
theta
,
self
.
x_shape
)
self
.
inputs
=
{
'X'
:
x
,
'Grid'
:
grid
}
self
.
attrs
=
{
'use_cudnn'
:
True
}
self
.
outputs
=
{
'Output'
:
GridSampler
(
x
,
grid
)}
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-3
)
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Grid'
],
'Output'
,
max_relative_error
=
0.61
)
def
initTestCase
(
self
):
self
.
x_shape
=
(
2
,
5
,
7
,
3
)
self
.
grid_shape
=
(
2
,
7
,
3
,
2
)
self
.
theta_shape
=
(
2
,
2
,
3
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
fe4cd502
...
...
@@ -865,6 +865,31 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_grid_sampler
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
3
,
5
,
7
],
dtype
=
'float32'
)
grid
=
layers
.
data
(
name
=
'grid'
,
shape
=
[
5
,
7
,
2
],
dtype
=
'float32'
)
out
=
layers
.
grid_sampler
(
x
,
grid
)
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_affine_grid
(
self
):
program
=
Program
()
with
program_guard
(
program
):
data
=
layers
.
data
(
name
=
'data'
,
shape
=
[
2
,
3
,
3
],
dtype
=
"float32"
)
out
,
ids
=
layers
.
argsort
(
input
=
data
,
axis
=
1
)
theta
=
layers
.
data
(
name
=
"theta"
,
shape
=
[
2
,
3
],
dtype
=
"float32"
)
out_shape
=
layers
.
data
(
name
=
"out_shape"
,
shape
=
[
-
1
],
dtype
=
"float32"
)
data_0
=
layers
.
affine_grid
(
theta
,
out_shape
)
data_1
=
layers
.
affine_grid
(
theta
,
[
5
,
3
,
28
,
28
])
self
.
assertIsNotNone
(
data_0
)
self
.
assertIsNotNone
(
data_1
)
print
(
str
(
program
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
浏览文件 @
fe4cd502
...
...
@@ -232,6 +232,46 @@ class TestResnet(TestParallelExecutorBase):
for
loss
in
zip
(
all_reduce_last_loss
,
reduce_last_loss
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
if
not
use_cuda
:
return
all_reduce_first_loss_seq
,
all_reduce_last_loss_seq
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
iter
=
iter
,
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
False
,
optimizer
=
optimizer
,
enable_sequential_execution
=
True
)
reduce_first_loss_seq
,
reduce_last_loss_seq
=
self
.
check_network_convergence
(
model
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
iter
=
iter
,
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_reduce
=
True
,
optimizer
=
optimizer
,
enable_sequential_execution
=
True
)
for
loss
in
zip
(
all_reduce_first_loss
,
all_reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
all_reduce_last_loss
,
all_reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
reduce_first_loss
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
reduce_last_loss
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
for
loss
in
zip
(
all_reduce_first_loss_seq
,
reduce_first_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
1e-6
)
for
loss
in
zip
(
all_reduce_last_loss_seq
,
reduce_last_loss_seq
):
self
.
assertAlmostEquals
(
loss
[
0
],
loss
[
1
],
delta
=
delta2
)
def
_check_resnet_convergence
(
self
,
model
,
use_cuda
=
True
,
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
浏览文件 @
fe4cd502
...
...
@@ -173,6 +173,8 @@ class TestTransformer(TestParallelExecutorBase):
def
test_main
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
enable_sequential_execution
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
,
iter
=
5
)
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
fe4cd502
...
...
@@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
...
...
@@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x,
c_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
r_start
:
r_end
,
c_start
:
c_end
]
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
(
(
r_end
-
r_start
)
*
(
c_end
-
c_start
))
field_size
=
((
r_end
-
r_start
)
*
(
c_end
-
c_start
))
if
exclusive
\
else
(
ksize
[
0
]
*
ksize
[
1
])
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
field_size
return
out
...
...
@@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool2D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
output
=
self
.
pool2D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
...
...
@@ -106,7 +110,9 @@ class TestPool2d_Op(OpTest):
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
self
.
outputs
=
{
'Out'
:
output
}
...
...
@@ -150,6 +156,9 @@ class TestPool2d_Op(OpTest):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool2d_Op
):
def
init_test_case
(
self
):
...
...
@@ -322,5 +331,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
fe4cd502
...
...
@@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
...
...
@@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x,
strides
,
paddings
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
...
...
@@ -85,8 +87,10 @@ def avg_pool3D_forward_naive(x,
w_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
d_start
:
d_end
,
h_start
:
h_end
,
w_start
:
w_end
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
(
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
))
field_size
=
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
)
\
if
exclusive
else
ksize
[
0
]
*
ksize
[
1
]
*
ksize
[
2
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
field_size
return
out
...
...
@@ -100,13 +104,14 @@ class TestPool3d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool3D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
output
=
self
.
pool3D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
...
...
@@ -117,7 +122,9 @@ class TestPool3d_Op(OpTest):
'global_pooling'
:
self
.
global_pool
,
'use_cudnn'
:
self
.
use_cudnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
self
.
outputs
=
{
'Out'
:
output
}
...
...
@@ -161,6 +168,9 @@ class TestPool3d_Op(OpTest):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool3d_Op
):
def
init_test_case
(
self
):
...
...
@@ -333,5 +343,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_py_reader_lod_level_share.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
unittest
class
TestLoDLevelShare
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_double_buffer
=
False
def
test_lod_level_share
(
self
):
reader
=
fluid
.
layers
.
py_reader
(
capacity
=
16
,
shapes
=
([
-
1
,
256
],
[
-
1
,
512
],
[
-
1
,
100
]),
dtypes
=
(
'float32'
,
'int64'
,
'double'
),
lod_levels
=
(
1
,
2
,
0
),
use_double_buffer
=
self
.
use_double_buffer
)
x
,
y
,
z
=
fluid
.
layers
.
read_file
(
reader
)
self
.
assertEqual
(
x
.
lod_level
,
1
)
self
.
assertEqual
(
y
.
lod_level
,
2
)
self
.
assertEqual
(
z
.
lod_level
,
0
)
class
TestLoDLevelShare2
(
TestLoDLevelShare
):
def
setUp
(
self
):
self
.
use_double_buffer
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_py_reader_pin_memory.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
numpy
as
np
from
threading
import
Thread
def
user_reader
(
inputs
):
def
_reader
():
for
d
in
inputs
:
yield
d
return
_reader
def
batch_feeder
(
batch_reader
,
pin_memory
=
False
,
img_dtype
=
"float32"
):
def
_feeder
():
for
batch_data
in
batch_reader
():
sample_batch
=
[]
label_batch
=
[]
for
sample
,
label
in
batch_data
:
sample_batch
.
append
(
sample
)
label_batch
.
append
([
label
])
tensor
=
core
.
LoDTensor
()
label
=
core
.
LoDTensor
()
place
=
core
.
CUDAPinnedPlace
()
if
pin_memory
else
core
.
CPUPlace
()
tensor
.
set
(
np
.
array
(
sample_batch
,
dtype
=
img_dtype
),
place
)
label
.
set
(
np
.
array
(
label_batch
,
dtype
=
"int64"
),
place
)
yield
[
tensor
,
label
]
return
_feeder
class
TestPyReader
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
capacity
=
10
self
.
shapes
=
[(
-
1
,
3
,
2
,
1
),
(
-
1
,
1
)]
self
.
lod_levels
=
[
0
,
0
]
self
.
dtypes
=
[
'float32'
,
'int64'
]
def
test_pin_memory_pyreader
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
core
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
executor
=
fluid
.
Executor
(
place
)
data_file
=
fluid
.
layers
.
py_reader
(
capacity
=
self
.
capacity
,
dtypes
=
self
.
dtypes
,
lod_levels
=
self
.
lod_levels
,
shapes
=
self
.
shapes
)
# feed_queue = data_file.queue
read_out_data
=
fluid
.
layers
.
read_file
(
data_file
)
self
.
inputs
=
[]
for
_
in
range
(
10
):
sample
=
np
.
random
.
uniform
(
low
=
0
,
high
=
1
,
size
=
[
3
,
2
,
1
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
low
=
0
,
high
=
10
,
size
=
[
1
]).
astype
(
"int64"
)
self
.
inputs
.
append
((
sample
,
label
))
self
.
input_tensors
=
[]
for
d
,
l
in
batch_feeder
(
paddle
.
batch
(
user_reader
(
self
.
inputs
),
batch_size
=
2
),
pin_memory
=
True
if
fluid
.
core
.
is_compiled_with_cuda
()
else
False
)():
ta
=
fluid
.
LoDTensorArray
()
ta
.
append
(
d
)
ta
.
append
(
l
)
self
.
input_tensors
.
append
(
ta
)
self
.
batched_inputs
=
[]
for
batch
in
paddle
.
batch
(
user_reader
(
self
.
inputs
),
batch_size
=
2
)():
feed_d
=
[]
feed_l
=
[]
for
d
,
l
in
batch
:
feed_d
.
append
(
d
)
feed_l
.
append
([
l
])
self
.
batched_inputs
.
append
([
feed_d
,
feed_l
])
data_file
.
decorate_tensor_provider
(
batch_feeder
(
paddle
.
batch
(
user_reader
(
self
.
inputs
),
batch_size
=
2
),
pin_memory
=
True
if
fluid
.
core
.
is_compiled_with_cuda
()
else
False
))
executor
.
run
(
fluid
.
default_startup_program
())
self
.
outputs
=
[]
data_file
.
start
()
for
_
in
self
.
input_tensors
:
self
.
outputs
.
append
(
executor
.
run
(
fetch_list
=
list
(
read_out_data
)))
data_file
.
reset
()
self
.
validate
()
def
validate
(
self
):
self
.
assertEqual
(
len
(
self
.
batched_inputs
),
len
(
self
.
outputs
))
for
in_data_list
,
out_data_list
in
zip
(
self
.
batched_inputs
,
self
.
outputs
):
self
.
assertEqual
(
len
(
in_data_list
),
len
(
out_data_list
))
in_data_list_np
=
[
np
.
array
(
in_lod_tensor
)
for
in_lod_tensor
in
in_data_list
]
for
in_data
,
out_data
in
zip
(
in_data_list_np
,
out_data_list
):
self
.
assertTrue
((
in_data
==
out_data
).
all
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_ref_by_trainer_id_op.py
0 → 100644
浏览文件 @
fe4cd502
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestRefByTrainerIdOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"ref_by_trainer_id"
param_baks
=
[(
"x%d"
%
x
,
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
))
for
x
in
range
(
10
)]
self
.
inputs
=
{
'X'
:
param_baks
,
'TrainerId'
:
np
.
array
([
8
]).
astype
(
"int64"
)
}
self
.
outputs
=
{
'Out'
:
param_baks
[
8
][
1
]}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
fe4cd502
...
...
@@ -26,7 +26,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -46,6 +50,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -54,6 +59,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOpNoCudnn
(
TestSoftmaxWithCrossEntropyOp
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
class
TestSoftmaxWithCrossEntropyOp2
(
OpTest
):
"""
Test softmax with cross entropy operator with soft labels.
...
...
@@ -93,7 +103,11 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
Test softmax with cross entropy operator with ignore_index.
"""
def
initParams
(
self
):
self
.
numeric_stable_mode
=
False
def
setUp
(
self
):
self
.
initParams
()
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
...
...
@@ -114,7 +128,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
,
"numeric_stable_mode"
:
self
.
numeric_stable_mode
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -123,5 +140,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOp3NoCudnn
(
TestSoftmaxWithCrossEntropyOp3
):
def
initParams
(
self
):
self
.
numeric_stable_mode
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
fe4cd502
...
...
@@ -38,7 +38,7 @@ import six
import
logging
from
.ps_dispatcher
import
RoundRobin
,
HashName
,
PSDispatcher
from
..
import
core
,
framework
from
..
import
core
,
framework
,
unique_name
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
Block
,
\
Parameter
,
grad_var_name
...
...
@@ -138,6 +138,7 @@ class DistributeTranspilerConfig(object):
slice_var_up
=
True
split_method
=
None
min_block_size
=
8192
enable_dc_asgd
=
False
# supported modes: pserver, nccl2
mode
=
"pserver"
print_log
=
False
...
...
@@ -252,6 +253,8 @@ class DistributeTranspiler(object):
n workers, the id may range from 0 ~ n-1
program (Program|None): program to transpile,
default is fluid.default_main_program().
startup_program (Program|None): startup_program to transpile,
default is fluid.default_startup_program().
pservers (str): comma separated ip:port string for the pserver
list.
trainers (int|str): in pserver mode this is the number of
...
...
@@ -383,6 +386,8 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
send_barrier_out
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"sync_mode"
:
self
.
sync_mode
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -426,6 +431,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
param_varname
,
recv_op_role_var_name
],
...
...
@@ -440,6 +446,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
all_recv_outputs
},
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -651,6 +658,24 @@ in a single call.")
endpoint
,
op
):
opt_op_on_pserver
.
append
(
op
)
# step 3.3
# prepare if dc asgd is enabled
if
self
.
config
.
enable_dc_asgd
==
True
:
assert
(
self
.
sync_mode
==
False
)
self
.
param_bak_list
=
[]
# add param_bak for each trainer
for
p
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
# each parameter should have w_bak for each trainer id
for
i
in
range
(
self
.
trainer_num
):
param_bak_name
=
"%s.trainer_%d_bak"
%
(
p
.
name
,
i
)
tmpvar
=
pserver_program
.
global_block
().
create_var
(
# NOTE: this var name format is used in `request_get_handler`
name
=
param_bak_name
,
type
=
p
.
type
,
shape
=
p
.
shape
,
dtype
=
p
.
dtype
)
self
.
param_bak_list
.
append
((
p
,
tmpvar
))
# step 3.4
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
...
...
@@ -741,7 +766,7 @@ in a single call.")
grad_to_block_id
,
merged_var
,
lr_ops
)
# dedup grad to ids list
# dedup grad to ids list
grad_to_block_id
=
list
(
set
(
grad_to_block_id
))
# append global ops
if
global_ops
:
...
...
@@ -787,6 +812,8 @@ in a single call.")
if
self
.
has_distributed_lookup_table
:
attrs
[
'checkpint_block_id'
]
=
checkpoint_block_id
if
self
.
config
.
enable_dc_asgd
:
attrs
[
'dc_asgd'
]
=
True
if
len
(
prefetch_var_name_to_block_id
)
>
0
:
attrs
[
...
...
@@ -903,6 +930,15 @@ to transpile() call.")
inputs
=
new_inputs
,
outputs
=
new_outputs
,
attrs
=
op
.
all_attrs
())
if
self
.
config
.
enable_dc_asgd
:
for
p
,
p_bak
in
self
.
param_bak_list
:
startup_param_var
=
s_prog
.
global_block
().
vars
[
p
.
name
]
startup_tmpvar
=
s_prog
.
global_block
().
vars
[
p_bak
.
name
]
# copy init random value to param_bak
s_prog
.
global_block
().
append_op
(
type
=
"assign"
,
inputs
=
{
"X"
:
startup_param_var
},
outputs
=
{
"Out"
:
startup_tmpvar
})
# add slice vars
s_prog
.
_slice_vars_and_attrs
=
self
.
_get_slice_vars_and_attrs
(
endpoint
)
...
...
@@ -920,11 +956,11 @@ to transpile() call.")
block_idx
=
int
(
block_name
.
split
(
block_suffix
)[
1
])
orig_var
=
self
.
origin_program
.
global_block
().
vars
[
orig_var_name
]
skip_
numel
=
0
skip_
dim0
=
0
slice_vars
=
self
.
param_var_mapping
[
orig_var_name
]
for
slice_var
in
slice_vars
[:
block_idx
]:
skip_
numel
+=
reduce
(
lambda
x
,
y
:
x
*
y
,
slice_var
.
shape
)
slice_vars_and_attrs
.
append
([
orig_var
,
skip_
numel
,
param
])
skip_
dim0
+=
slice_var
.
shape
[
0
]
slice_vars_and_attrs
.
append
([
orig_var
,
skip_
dim0
,
param
])
return
slice_vars_and_attrs
...
...
@@ -1175,6 +1211,7 @@ to transpile() call.")
attrs
=
{
"sync_mode"
:
not
self
.
sync_mode
,
"epmap"
:
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
self
.
grad_name_to_param_name
[
table_grad_name
],
...
...
@@ -1531,6 +1568,69 @@ to transpile() call.")
attrs
=
{
"scale"
:
1.0
/
float
(
self
.
trainer_num
)})
return
merged_var
def
_append_dc_asgd_ops
(
self
,
block
,
param_var
,
grad_var
):
# NOTE: can not use grammar candy here, should put ops in specific block
local_param_bak
=
block
.
create_var
(
name
=
"%s.local_bak"
%
param_var
.
name
,
shape
=
param_var
.
shape
,
type
=
param_var
.
type
,
dtype
=
param_var
.
dtype
,
persistable
=
False
)
# trainer_id_var is block local
trainer_id_var
=
block
.
create_var
(
name
=
"@TRAINER_ID@"
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
shape
=
[
1
],
persistable
=
False
)
# ref_inputs = [x[1] for x in self.param_bak_list]
ref_inputs
=
[]
for
p
,
p_bak
in
self
.
param_bak_list
:
if
p
.
name
==
param_var
.
name
:
print
(
"#### ref inputs: "
,
param_var
.
name
,
p_bak
.
name
)
ref_inputs
.
append
(
p_bak
)
block
.
append_op
(
type
=
"ref_by_trainer_id"
,
inputs
=
{
"X"
:
ref_inputs
,
"TrainerId"
:
trainer_id_var
},
outputs
=
{
"Out"
:
local_param_bak
})
def
__create_temp_var__
():
return
block
.
create_var
(
name
=
unique_name
.
generate
(
"tmp_dc_output"
),
shape
=
param_var
.
shape
,
type
=
param_var
.
type
,
dtype
=
param_var
.
dtype
,
persistable
=
False
)
o1
=
__create_temp_var__
()
block
.
append_op
(
type
=
"elementwise_sub"
,
inputs
=
{
"X"
:
param_var
,
"Y"
:
local_param_bak
},
outputs
=
{
"Out"
:
o1
})
o2
=
__create_temp_var__
()
block
.
append_op
(
type
=
"elementwise_mul"
,
inputs
=
{
"X"
:
o1
,
"Y"
:
grad_var
},
outputs
=
{
"Out"
:
o2
})
o3
=
__create_temp_var__
()
block
.
append_op
(
type
=
"elementwise_mul"
,
inputs
=
{
"X"
:
o2
,
"Y"
:
grad_var
},
outputs
=
{
"Out"
:
o3
})
# TODO(typhoonzero): append scale
o4
=
__create_temp_var__
()
block
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
grad_var
,
"Y"
:
o3
},
outputs
=
{
"Out"
:
o4
})
return
o4
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
):
program
=
optimize_block
.
program
...
...
@@ -1546,9 +1646,16 @@ to transpile() call.")
break
return
param_block
if
self
.
config
.
enable_dc_asgd
:
param_var
=
_get_param_block
(
opt_op
)
dc
=
self
.
_append_dc_asgd_ops
(
optimize_block
,
param_var
,
merged_var
)
for
key
in
opt_op
.
input_names
:
if
key
==
"Grad"
:
new_inputs
[
key
]
=
merged_var
if
self
.
config
.
enable_dc_asgd
:
new_inputs
[
key
]
=
dc
else
:
new_inputs
[
key
]
=
merged_var
elif
key
==
"Param"
:
param_block
=
_get_param_block
(
opt_op
)
if
not
param_block
:
...
...
python/paddle/fluid/transpiler/inference_transpiler.py
浏览文件 @
fe4cd502
...
...
@@ -61,6 +61,9 @@ class InferenceTranspiler(object):
raise
TypeError
(
"scope should be as Scope type or None"
)
use_mkldnn
=
bool
(
os
.
getenv
(
"FLAGS_use_mkldnn"
,
False
))
if
use_mkldnn
:
self
.
_depthwise_conv_mkldnn
(
program
)
self
.
_fuse_batch_norm
(
program
,
place
,
scope
)
if
use_mkldnn
:
self
.
_fuse_conv_bias_mkldnn
(
program
)
...
...
@@ -70,6 +73,31 @@ class InferenceTranspiler(object):
program
)
# ResNet residual block merging
self
.
_fuse_bn_relu_mkldnn
(
program
)
def
_depthwise_conv_mkldnn
(
self
,
program
):
'''
Transpile the program by replacing depthwise_conv2d to conv2d for MKLDNN program.
The result is:
- before:
- any_other_op->depthwise_conv->any_other_op
- after:
- any_other_op->conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self
.
block
=
program
.
block
(
0
)
i
=
0
while
i
<
len
(
self
.
block
.
ops
):
current_op
=
self
.
block
.
ops
[
i
]
if
current_op
.
type
==
'depthwise_conv2d'
:
current_op
.
desc
.
set_type
(
"conv2d"
)
i
=
i
+
1
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program
=
program
.
clone
()
def
_fuse_conv_eltwise_mkldnn
(
self
,
program
):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
...
...
python/setup.py.in
浏览文件 @
fe4cd502
...
...
@@ -27,7 +27,7 @@ def _get_version_detail(idx):
if re.match('@TAG_VERSION_REGEX@', '@PADDLE_VERSION@'):
version_details = '@PADDLE_VERSION@'.split('.')
if len(version_details)
=
= 3:
if len(version_details)
>
= 3:
return version_details[idx]
return 0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录