Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3bd54ed7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3bd54ed7
编写于
12月 17, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into multithread-sparse-adam
上级
fac87022
3628d894
变更
198
展开全部
隐藏空白更改
内联
并排
Showing
198 changed file
with
3582 addition
and
1107 deletion
+3582
-1107
README.md
README.md
+81
-0
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+3
-1
cmake/external/brpc.cmake
cmake/external/brpc.cmake
+12
-8
cmake/external/gtest.cmake
cmake/external/gtest.cmake
+7
-3
cmake/external/leveldb.cmake
cmake/external/leveldb.cmake
+2
-2
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+6
-3
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+3
-3
paddle/fluid/framework/data_layout_transform.h
paddle/fluid/framework/data_layout_transform.h
+8
-8
paddle/fluid/framework/data_type.cc
paddle/fluid/framework/data_type.cc
+7
-17
paddle/fluid/framework/data_type.h
paddle/fluid/framework/data_type.h
+45
-32
paddle/fluid/framework/data_type_test.cc
paddle/fluid/framework/data_type_test.cc
+4
-4
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+9
-2
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/fuse_vars_op_handle.h
paddle/fluid/framework/details/fuse_vars_op_handle.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+7
-7
paddle/fluid/framework/dlpack_tensor.cc
paddle/fluid/framework/dlpack_tensor.cc
+17
-20
paddle/fluid/framework/dlpack_tensor_test.cc
paddle/fluid/framework/dlpack_tensor_test.cc
+4
-16
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+3
-3
paddle/fluid/framework/executor_thread_worker.cc
paddle/fluid/framework/executor_thread_worker.cc
+13
-33
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc
+106
-0
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc
...fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc
+105
-0
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h
.../fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h
+33
-0
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
.../fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
+104
-0
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h
...e/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h
+33
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+112
-1
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+45
-0
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+3
-3
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+5
-9
paddle/fluid/framework/op_kernel_type_test.cc
paddle/fluid/framework/op_kernel_type_test.cc
+2
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+7
-7
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+2
-2
paddle/fluid/framework/tensor.cc
paddle/fluid/framework/tensor.cc
+2
-2
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+5
-5
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+5
-7
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+4
-6
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-2
paddle/fluid/inference/api/analysis_predictor_tester.cc
paddle/fluid/inference/api/analysis_predictor_tester.cc
+6
-1
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+2
-2
paddle/fluid/inference/api/api_impl_tester.cc
paddle/fluid/inference/api/api_impl_tester.cc
+2
-2
paddle/fluid/inference/api/paddle_pass_builder.h
paddle/fluid/inference/api/paddle_pass_builder.h
+4
-1
paddle/fluid/inference/io.cc
paddle/fluid/inference/io.cc
+1
-1
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+2
-2
paddle/fluid/inference/tests/api/trt_models_tester.cc
paddle/fluid/inference/tests/api/trt_models_tester.cc
+24
-1
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+4
-4
paddle/fluid/operators/arg_max_op.cc
paddle/fluid/operators/arg_max_op.cc
+0
-1
paddle/fluid/operators/arg_max_op.cu
paddle/fluid/operators/arg_max_op.cu
+0
-2
paddle/fluid/operators/arg_min_op.cc
paddle/fluid/operators/arg_min_op.cc
+0
-1
paddle/fluid/operators/arg_min_op.cu
paddle/fluid/operators/arg_min_op.cu
+0
-2
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+2
-2
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+2
-3
paddle/fluid/operators/average_accumulates_op.cc
paddle/fluid/operators/average_accumulates_op.cc
+2
-3
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+7
-13
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+1
-1
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+1
-2
paddle/fluid/operators/bpr_loss_op.cc
paddle/fluid/operators/bpr_loss_op.cc
+4
-6
paddle/fluid/operators/controlflow/CMakeLists.txt
paddle/fluid/operators/controlflow/CMakeLists.txt
+1
-1
paddle/fluid/operators/controlflow/conditional_block_op.cc
paddle/fluid/operators/controlflow/conditional_block_op.cc
+6
-7
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+1
-1
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+8
-8
paddle/fluid/operators/conv_transpose_op.cc
paddle/fluid/operators/conv_transpose_op.cc
+4
-6
paddle/fluid/operators/crf_decoding_op.cc
paddle/fluid/operators/crf_decoding_op.cc
+2
-3
paddle/fluid/operators/crop_op.cc
paddle/fluid/operators/crop_op.cc
+3
-6
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+4
-6
paddle/fluid/operators/ctc_align_op.cc
paddle/fluid/operators/ctc_align_op.cc
+2
-3
paddle/fluid/operators/cudnn_lstm_op.cu.cc
paddle/fluid/operators/cudnn_lstm_op.cu.cc
+2
-0
paddle/fluid/operators/detection/anchor_generator_op.cc
paddle/fluid/operators/detection/anchor_generator_op.cc
+1
-2
paddle/fluid/operators/detection/bipartite_match_op.cc
paddle/fluid/operators/detection/bipartite_match_op.cc
+2
-3
paddle/fluid/operators/detection/density_prior_box_op.cc
paddle/fluid/operators/detection/density_prior_box_op.cc
+1
-2
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+2
-3
paddle/fluid/operators/detection/mine_hard_examples_op.cc
paddle/fluid/operators/detection/mine_hard_examples_op.cc
+1
-2
paddle/fluid/operators/detection/multiclass_nms_op.cc
paddle/fluid/operators/detection/multiclass_nms_op.cc
+1
-2
paddle/fluid/operators/detection/prior_box_op.cc
paddle/fluid/operators/detection/prior_box_op.cc
+1
-2
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
...fluid/operators/detection/roi_perspective_transform_op.cc
+4
-6
paddle/fluid/operators/detection/rpn_target_assign_op.cc
paddle/fluid/operators/detection/rpn_target_assign_op.cc
+1
-2
paddle/fluid/operators/detection/target_assign_op.cc
paddle/fluid/operators/detection/target_assign_op.cc
+2
-3
paddle/fluid/operators/detection_map_op.cc
paddle/fluid/operators/detection_map_op.cc
+1
-2
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+19
-12
paddle/fluid/operators/distributed/brpc_client.cc
paddle/fluid/operators/distributed/brpc_client.cc
+313
-58
paddle/fluid/operators/distributed/brpc_client.h
paddle/fluid/operators/distributed/brpc_client.h
+82
-17
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
+84
-0
paddle/fluid/operators/distributed/brpc_rdma_pool.h
paddle/fluid/operators/distributed/brpc_rdma_pool.h
+56
-0
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
+196
-0
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
+49
-0
paddle/fluid/operators/distributed/brpc_serde_test.cc
paddle/fluid/operators/distributed/brpc_serde_test.cc
+175
-0
paddle/fluid/operators/distributed/brpc_server.cc
paddle/fluid/operators/distributed/brpc_server.cc
+235
-29
paddle/fluid/operators/distributed/brpc_variable_response.cc
paddle/fluid/operators/distributed/brpc_variable_response.cc
+73
-0
paddle/fluid/operators/distributed/brpc_variable_response.h
paddle/fluid/operators/distributed/brpc_variable_response.h
+67
-0
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+1
-2
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+1
-9
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+4
-0
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+3
-5
paddle/fluid/operators/distributed/sendrecvop_utils.h
paddle/fluid/operators/distributed/sendrecvop_utils.h
+14
-6
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+7
-8
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+4
-3
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
+1
-3
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
...e/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
+1
-3
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+2
-0
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+2
-2
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+4
-6
paddle/fluid/operators/fc_op.cc
paddle/fluid/operators/fc_op.cc
+4
-6
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+2
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+2
-2
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
+4
-6
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
+1
-2
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
...le/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
...le/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
+2
-3
paddle/fluid/operators/gather_op.cc
paddle/fluid/operators/gather_op.cc
+4
-6
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+6
-6
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+1
-2
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+4
-6
paddle/fluid/operators/interpolate_op.cc
paddle/fluid/operators/interpolate_op.cc
+4
-4
paddle/fluid/operators/is_empty_op.cc
paddle/fluid/operators/is_empty_op.cc
+1
-2
paddle/fluid/operators/isfinite_op.cc
paddle/fluid/operators/isfinite_op.cc
+2
-3
paddle/fluid/operators/layer_norm_op.cc
paddle/fluid/operators/layer_norm_op.cc
+1
-2
paddle/fluid/operators/linear_chain_crf_op.cc
paddle/fluid/operators/linear_chain_crf_op.cc
+3
-6
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+1
-1
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+1
-1
paddle/fluid/operators/lod_reset_op.cc
paddle/fluid/operators/lod_reset_op.cc
+4
-6
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+1
-1
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+1
-2
paddle/fluid/operators/lrn_op.cc
paddle/fluid/operators/lrn_op.cc
+2
-3
paddle/fluid/operators/lstm_op.cc
paddle/fluid/operators/lstm_op.cc
+2
-4
paddle/fluid/operators/lstmp_op.cc
paddle/fluid/operators/lstmp_op.cc
+2
-4
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+2
-4
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+1
-1
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+153
-62
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+268
-147
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+22
-10
paddle/fluid/operators/mean_iou_op.cc
paddle/fluid/operators/mean_iou_op.cc
+2
-3
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+1
-3
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+1
-3
paddle/fluid/operators/metrics/accuracy_op.cc
paddle/fluid/operators/metrics/accuracy_op.cc
+2
-3
paddle/fluid/operators/metrics/auc_op.cc
paddle/fluid/operators/metrics/auc_op.cc
+2
-3
paddle/fluid/operators/metrics/precision_recall_op.cc
paddle/fluid/operators/metrics/precision_recall_op.cc
+2
-3
paddle/fluid/operators/multiplex_op.cc
paddle/fluid/operators/multiplex_op.cc
+4
-6
paddle/fluid/operators/nce_op.cc
paddle/fluid/operators/nce_op.cc
+4
-6
paddle/fluid/operators/optimizers/adadelta_op.cc
paddle/fluid/operators/optimizers/adadelta_op.cc
+2
-3
paddle/fluid/operators/optimizers/adagrad_op.cc
paddle/fluid/operators/optimizers/adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/adam_op.cc
paddle/fluid/operators/optimizers/adam_op.cc
+1
-2
paddle/fluid/operators/optimizers/adamax_op.cc
paddle/fluid/operators/optimizers/adamax_op.cc
+2
-3
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/ftrl_op.cc
paddle/fluid/operators/optimizers/ftrl_op.cc
+1
-2
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/proximal_gd_op.cc
paddle/fluid/operators/optimizers/proximal_gd_op.cc
+2
-3
paddle/fluid/operators/pad2d_op.cc
paddle/fluid/operators/pad2d_op.cc
+4
-4
paddle/fluid/operators/pad_constant_like_op.cc
paddle/fluid/operators/pad_constant_like_op.cc
+4
-6
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+65
-7
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+10
-6
paddle/fluid/operators/pool_with_index_op.cc
paddle/fluid/operators/pool_with_index_op.cc
+39
-9
paddle/fluid/operators/pool_with_index_op.h
paddle/fluid/operators/pool_with_index_op.h
+8
-4
paddle/fluid/operators/positive_negative_pair_op.cc
paddle/fluid/operators/positive_negative_pair_op.cc
+2
-3
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+4
-6
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+1
-1
paddle/fluid/operators/psroi_pool_op.cc
paddle/fluid/operators/psroi_pool_op.cc
+4
-6
paddle/fluid/operators/random_crop_op.cc
paddle/fluid/operators/random_crop_op.cc
+2
-3
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+2
-2
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+1
-1
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+5
-9
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+1
-1
paddle/fluid/operators/roi_align_op.cc
paddle/fluid/operators/roi_align_op.cc
+4
-6
paddle/fluid/operators/roi_pool_op.cc
paddle/fluid/operators/roi_pool_op.cc
+4
-6
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+1
-1
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+1
-1
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
+2
-3
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
+2
-2
paddle/fluid/operators/similarity_focus_op.cc
paddle/fluid/operators/similarity_focus_op.cc
+2
-3
paddle/fluid/operators/slice_op.cc
paddle/fluid/operators/slice_op.cc
+2
-3
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-4
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+3
-5
paddle/fluid/operators/spp_op.h
paddle/fluid/operators/spp_op.h
+3
-3
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+6
-7
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+1
-4
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+3
-6
paddle/fluid/operators/unpool_op.cc
paddle/fluid/operators/unpool_op.cc
+4
-6
paddle/fluid/operators/warpctc_op.cc
paddle/fluid/operators/warpctc_op.cc
+4
-6
paddle/fluid/operators/yolov3_loss_op.cc
paddle/fluid/operators/yolov3_loss_op.cc
+4
-6
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-0
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+6
-5
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+10
-1
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+1
-1
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+12
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+200
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+7
-3
python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt
python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt
+6
-0
python/paddle/fluid/tests/unittests/ngraph/__init__.py
python/paddle/fluid/tests/unittests/ngraph/__init__.py
+13
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+23
-0
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+65
-26
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+86
-35
python/paddle/fluid/tests/unittests/test_pool_max_op.py
python/paddle/fluid/tests/unittests/test_pool_max_op.py
+77
-18
未找到文件。
README.md
浏览文件 @
3bd54ed7
...
@@ -19,6 +19,15 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
...
@@ -19,6 +19,15 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our
[
release announcement
](
https://github.com/PaddlePaddle/Paddle/releases
)
to track the latest feature of PaddlePaddle.
Please refer to our
[
release announcement
](
https://github.com/PaddlePaddle/Paddle/releases
)
to track the latest feature of PaddlePaddle.
欢迎来到 PaddlePaddle GitHub
PaddlePaddle (PArallel Distributed Deep LEarning) 是一个简单易用、高效灵活、可扩展的深度学习平台,最初由百度科学家和工程师共同开发,目的是将深度学习技术应用到百度的众多产品中。
我们的愿景是让每个人都能通过PaddlePaddle接触深度学习
跟进PaddlePaddle最新特性请参考我们的
[
版本说明
](
https://github.com/PaddlePaddle/Paddle/releases
)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### Install Latest Stable Release:
### Install Latest Stable Release:
```
```
...
@@ -34,6 +43,23 @@ pip install paddlepaddle-gpu==1.2.0.post85
...
@@ -34,6 +43,23 @@ pip install paddlepaddle-gpu==1.2.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/
# For installation on other platform, refer to http://paddlepaddle.org/
```
```
### PaddlePaddle最新版本: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2)
### 安装最新稳定版本:
```
# Linux CPU
pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==1.2.0.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==1.2.0.post85
# 其他平台上的安装指引请参考 http://paddlepaddle.org/
```
## Features
## Features
-
**Flexibility**
-
**Flexibility**
...
@@ -74,10 +100,38 @@ pip install paddlepaddle-gpu==1.2.0.post85
...
@@ -74,10 +100,38 @@ pip install paddlepaddle-gpu==1.2.0.post85
Baidu and it has achieved a significant impact. We hope you can also explore
Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product.
the capability of PaddlePaddle to make an impact on your product.
## 特点
-
**灵活性**
PaddlePaddle支持丰富的神经网络架构和优化算法。易于配置复杂模型,例如带有注意力机制或复杂记忆连接的神经网络机器翻译模型。
-
**高效性**
为了高效使用异步计算资源,PaddlePaddle对框架的不同层进行优化,包括计算、存储、架构和通信。下面是一些样例:
- 通过SSE/AVX 内置函数、BLAS库(例如MKL、OpenBLAS、cuBLAS)或定制的CPU/GPU内核优化数学操作。
- 通过MKL-DNN库优化CNN网络
- 高度优化循环网络,无需执行 `padding` 操作即可处理 **变长** 序列
- 针对高维稀疏数据模型,优化了局部和分布式训练。
-
**稳定性**
有了 PaddlePaddle,使得利用各种CPU/GPU和机器来加速训练变得简单。PaddlePaddle 通过优化通信可以实现巨大吞吐量和快速执行。
-
**连接产品**
另外,PaddlePaddle 的设计也易于部署。在百度,PaddlePaddle 已经部署到含有巨大用户量的产品和服务上,包括广告点击率(CTR)预测、大规模图像分类、光学字符识别(OCR)、搜索排序,计算机病毒检测、推荐系统等等。PaddlePaddle广泛应用于百度产品中,产生了非常重要的影响。我们希望您也能探索 PaddlePaddle 的能力,为您的产品创造新的影响力和效果。
## Installation
## Installation
It is recommended to read
[
this doc
](
http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html
)
on our website.
It is recommended to read
[
this doc
](
http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html
)
on our website.
## 安装
推荐阅读官网上的
[
安装说明
](
http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html
)
## Documentation
## Documentation
We provide
[
English
](
http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html
)
and
We provide
[
English
](
http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html
)
and
...
@@ -99,10 +153,37 @@ We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarte
...
@@ -99,10 +153,37 @@ We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarte
We appreciate your contributions!
We appreciate your contributions!
## 文档
我们提供
[
英文
](
http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html
)
和
[
中文
](
http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html
)
文档
-
[
深度学习101
](
https://github.com/PaddlePaddle/book
)
或许您想从这个在线交互式书籍开始,可以在Jupyter Notebook中运行
-
[
分布式训练
](
http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html
)
可以在MPI集群上运行分布式训练任务
-
[
Python API
](
http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html
)
新的API支持代码更少更简洁的程序
-
[
贡献方式
](
http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html
)
欢迎您的贡献!
## Ask Questions
## Ask Questions
You are welcome to submit questions and bug reports as
[
Github Issues
](
https://github.com/PaddlePaddle/Paddle/issues
)
.
You are welcome to submit questions and bug reports as
[
Github Issues
](
https://github.com/PaddlePaddle/Paddle/issues
)
.
## 答疑
欢迎您将问题和bug报告以
[
Github Issues
](
https://github.com/PaddlePaddle/Paddle/issues
)
的形式提交
## Copyright and License
## Copyright and License
PaddlePaddle is provided under the
[
Apache-2.0 license
](
LICENSE
)
.
PaddlePaddle is provided under the
[
Apache-2.0 license
](
LICENSE
)
.
## 版权和许可证
PaddlePaddle由
[
Apache-2.0 license
](
LICENSE
)
提供
benchmark/fluid/fluid_benchmark.py
浏览文件 @
3bd54ed7
...
@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
...
@@ -81,9 +81,11 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
# the role, should be either PSERVER or TRAINER
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
config
=
distribute_transpiler
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
slice_var_up
=
not
args
.
no_split_var
config
.
slice_var_up
=
not
args
.
no_split_var
config
.
min_block_size
=
1048576
t
=
distribute_transpiler
.
DistributeTranspiler
(
config
=
config
)
t
=
distribute_transpiler
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
t
.
transpile
(
trainer_id
,
trainer_id
,
# NOTE: *MUST* use train_prog, for we are using with guard to
# NOTE: *MUST* use train_prog, for we are using with guard to
...
...
cmake/external/brpc.cmake
浏览文件 @
3bd54ed7
...
@@ -14,14 +14,16 @@
...
@@ -14,14 +14,16 @@
INCLUDE
(
ExternalProject
)
INCLUDE
(
ExternalProject
)
find_library
(
SSL_LIBRARY NAMES ssl
)
find_package
(
OpenSSL REQUIRED
)
message
(
STATUS
"ssl:"
${
OPENSSL_SSL_LIBRARY
}
)
message
(
STATUS
"crypto:"
${
OPENSSL_CRYPTO_LIBRARY
}
)
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
ssl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
SSL_LIBRARY
}
)
SET_PROPERTY
(
TARGET ssl PROPERTY IMPORTED_LOCATION
${
OPENSSL_
SSL_LIBRARY
}
)
find_library
(
CRYPTO_LIBRARY NAMES crypto
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
crypto SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
CRYPTO_LIBRARY
}
)
SET_PROPERTY
(
TARGET crypto PROPERTY IMPORTED_LOCATION
${
OPENSSL_CRYPTO_LIBRARY
}
)
SET
(
BRPC_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/brpc
)
SET
(
BRPC_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/brpc
)
SET
(
BRPC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/brpc
)
SET
(
BRPC_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/brpc
)
...
@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
...
@@ -31,14 +33,15 @@ SET(BRPC_LIBRARIES "${BRPC_INSTALL_DIR}/lib/libbrpc.a" CACHE FILEPATH "brpc libr
INCLUDE_DIRECTORIES
(
${
BRPC_INCLUDE_DIR
}
)
INCLUDE_DIRECTORIES
(
${
BRPC_INCLUDE_DIR
}
)
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
# Reference https://stackoverflow.com/questions/45414507/pass-a-list-of-prefix-paths-to-externalproject-add-in-cmake-args
set
(
prefix_path
"
${
THIRD_PARTY_PATH
}
/install/gflags|
${
THIRD_PARTY_PATH
}
/install/leveldb|
${
THIRD_PARTY_PATH
}
/install/snappy|
${
THIRD_PARTY_PATH
}
/install/gtest|
${
THIRD_PARTY_PATH
}
/install/protobuf|
${
THIRD_PARTY_PATH
}
/install/zlib"
)
set
(
prefix_path
"
${
THIRD_PARTY_PATH
}
/install/gflags|
${
THIRD_PARTY_PATH
}
/install/leveldb|
${
THIRD_PARTY_PATH
}
/install/snappy|
${
THIRD_PARTY_PATH
}
/install/gtest|
${
THIRD_PARTY_PATH
}
/install/protobuf|
${
THIRD_PARTY_PATH
}
/install/zlib
|
${
THIRD_PARTY_PATH
}
/install/glog
"
)
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
# If minimal .a is need, you can set WITH_DEBUG_SYMBOLS=OFF
ExternalProject_Add
(
ExternalProject_Add
(
extern_brpc
extern_brpc
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
# TODO(gongwb): change to de newst repo when they changed.
GIT_REPOSITORY
"https://github.com/gongweibao/brpc"
GIT_REPOSITORY
"https://github.com/gongweibao/brpc"
GIT_TAG
"
7dc04defad1fd4173aae170c3fcbde131b65155a
"
GIT_TAG
"
e9b67ec1b7458f2af5fae76451afe1e27e01b4b4
"
PREFIX
${
BRPC_SOURCES_DIR
}
PREFIX
${
BRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
@@ -50,7 +53,7 @@ ExternalProject_Add(
...
@@ -50,7 +53,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_PREFIX_PATH=
${
prefix_path
}
-DCMAKE_PREFIX_PATH=
${
prefix_path
}
-D
BRPC_
WITH_GLOG=ON
-DWITH_GLOG=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DIOBUF_WITH_HUGE_BLOCK=ON
-DBRPC_WITH_RDMA=
${
WITH_BRPC_RDMA
}
-DBRPC_WITH_RDMA=
${
WITH_BRPC_RDMA
}
${
EXTERNAL_OPTIONAL_ARGS
}
${
EXTERNAL_OPTIONAL_ARGS
}
...
@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
...
@@ -65,5 +68,6 @@ ADD_LIBRARY(brpc STATIC IMPORTED GLOBAL)
SET_PROPERTY
(
TARGET brpc PROPERTY IMPORTED_LOCATION
${
BRPC_LIBRARIES
}
)
SET_PROPERTY
(
TARGET brpc PROPERTY IMPORTED_LOCATION
${
BRPC_LIBRARIES
}
)
ADD_DEPENDENCIES
(
brpc extern_brpc
)
ADD_DEPENDENCIES
(
brpc extern_brpc
)
add_definitions
(
-DBRPC_WITH_GLOG
)
LIST
(
APPEND external_project_dependencies brpc
)
LIST
(
APPEND external_project_dependencies brpc
)
cmake/external/gtest.cmake
浏览文件 @
3bd54ed7
...
@@ -12,8 +12,12 @@
...
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
IF
(
WITH_TESTING
)
#FIXME:(gongwb) Move brpc's gtest dependency.
ENABLE_TESTING
()
IF
(
WITH_TESTING
OR
(
WITH_DISTRIBUTE AND NOT WITH_GRPC
))
IF
(
WITH_TESTING
)
ENABLE_TESTING
()
ENDIF
(
WITH_TESTING
)
INCLUDE
(
ExternalProject
)
INCLUDE
(
ExternalProject
)
SET
(
GTEST_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/gtest
)
SET
(
GTEST_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/gtest
)
...
@@ -76,4 +80,4 @@ IF(WITH_TESTING)
...
@@ -76,4 +80,4 @@ IF(WITH_TESTING)
ADD_DEPENDENCIES
(
gtest_main extern_gtest
)
ADD_DEPENDENCIES
(
gtest_main extern_gtest
)
LIST
(
APPEND external_project_dependencies gtest gtest_main
)
LIST
(
APPEND external_project_dependencies gtest gtest_main
)
ENDIF
(
WITH_TESTING
)
ENDIF
(
WITH_TESTING
OR
(
WITH_DISTRIBUTE AND NOT WITH_GRPC
)
)
cmake/external/leveldb.cmake
浏览文件 @
3bd54ed7
...
@@ -24,8 +24,8 @@ ExternalProject_Add(
...
@@ -24,8 +24,8 @@ ExternalProject_Add(
extern_leveldb
extern_leveldb
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
LEVELDB_SOURCES_DIR
}
PREFIX
${
LEVELDB_SOURCES_DIR
}
URL
"https://github.com/google/leveldb/archive/v1.18.tar.gz
"
GIT_REPOSITORY
"https://github.com/google/leveldb
"
URL_MD5
"73770de34a2a5ab34498d2e05b2b7fa0"
GIT_TAG v1.18
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND CXXFLAGS=-fPIC make -j
${
NUM_OF_PROCESSOR
}
libleveldb.a
BUILD_COMMAND CXXFLAGS=-fPIC make -j
${
NUM_OF_PROCESSOR
}
libleveldb.a
INSTALL_COMMAND mkdir -p
${
LEVELDB_INSTALL_DIR
}
/lib/
INSTALL_COMMAND mkdir -p
${
LEVELDB_INSTALL_DIR
}
/lib/
...
...
paddle/fluid/API.spec
浏览文件 @
3bd54ed7
...
@@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
...
@@ -77,6 +77,8 @@ paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name']
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
3bd54ed7
...
@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
...
@@ -169,9 +169,12 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library
(
naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
cc_library
(
naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
else
()
if
(
WITH_NGRAPH
)
if
(
WITH_NGRAPH
)
if
(
NOT WIN32
)
if
(
NOT WIN32
)
...
...
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
3bd54ed7
...
@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
...
@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out
->
mutable_data
(
expected_kernel_type
.
place_
,
in
.
type
());
out
->
mutable_data
(
expected_kernel_type
.
place_
,
in
.
type
());
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
in
.
type
()
),
in
.
type
(
),
CastDataLayout
(
pool
.
Get
(
expected_kernel_type
.
place_
),
axis
,
in
,
out
));
CastDataLayout
(
pool
.
Get
(
expected_kernel_type
.
place_
),
axis
,
in
,
out
));
out
->
set_layout
(
expected_kernel_type
.
data_layout_
);
out
->
set_layout
(
expected_kernel_type
.
data_layout_
);
...
@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
...
@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case
mkldnn
::
memory
::
data_type
::
f32
:
case
mkldnn
::
memory
::
data_type
::
f32
:
return
platform
::
to_void_cast
(
tensor
.
data
<
float
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
float
>
());
case
mkldnn
::
memory
::
data_type
::
s8
:
case
mkldnn
::
memory
::
data_type
::
s8
:
return
platform
::
to_void_cast
(
tensor
.
data
<
char
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
int8_t
>
());
case
mkldnn
::
memory
::
data_type
::
u8
:
case
mkldnn
::
memory
::
data_type
::
u8
:
return
platform
::
to_void_cast
(
tensor
.
data
<
unsigned
char
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
unsigned
char
>
());
case
mkldnn
::
memory
::
data_type
::
s16
:
case
mkldnn
::
memory
::
data_type
::
s16
:
...
@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
...
@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory
::
data_type
in_type
=
ToMKLDNNDataType
(
in
.
type
());
memory
::
data_type
in_type
=
ToMKLDNNDataType
(
in
.
type
());
PADDLE_ENFORCE
(
in_type
!=
memory
::
data_type
::
data_undef
,
PADDLE_ENFORCE
(
in_type
!=
memory
::
data_type
::
data_undef
,
"Input tensor type is not supported:
"
,
in
.
type
().
nam
e
());
"Input tensor type is not supported:
%s"
,
in
.
typ
e
());
memory
::
data_type
out_type
=
in_type
;
memory
::
data_type
out_type
=
in_type
;
auto
in_format
=
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
auto
in_format
=
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
...
...
paddle/fluid/framework/data_layout_transform.h
浏览文件 @
3bd54ed7
...
@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
...
@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
}
}
}
}
inline
MKLDNNDataType
ToMKLDNNDataType
(
const
std
::
type_index
type
)
{
inline
MKLDNNDataType
ToMKLDNNDataType
(
proto
::
VarType
::
Type
type
)
{
static
const
std
::
map
<
std
::
type_index
,
MKLDNNDataType
>
dict
{
static
std
::
unordered_map
<
int
,
MKLDNNDataType
>
dict
{
{
std
::
type_index
(
typeid
(
float
)),
MKLDNNDataType
::
f32
},
// NOLINT
{
DataTypeTrait
<
float
>::
DataType
,
MKLDNNDataType
::
f32
},
{
std
::
type_index
(
typeid
(
char
)),
MKLDNNDataType
::
s8
},
// NOLINT
{
DataTypeTrait
<
int8_t
>::
DataType
,
MKLDNNDataType
::
s8
},
{
std
::
type_index
(
typeid
(
unsigned
char
))
,
MKLDNNDataType
::
u8
},
{
DataTypeTrait
<
uint8_t
>::
DataType
,
MKLDNNDataType
::
u8
},
{
std
::
type_index
(
typeid
(
int16_t
))
,
MKLDNNDataType
::
s16
},
{
DataTypeTrait
<
int16_t
>::
DataType
,
MKLDNNDataType
::
s16
},
{
std
::
type_index
(
typeid
(
int32_t
))
,
MKLDNNDataType
::
s32
}};
{
DataTypeTrait
<
int32_t
>::
DataType
,
MKLDNNDataType
::
s32
}};
auto
iter
=
dict
.
find
(
type
);
auto
iter
=
dict
.
find
(
static_cast
<
int
>
(
type
)
);
if
(
iter
!=
dict
.
end
())
return
iter
->
second
;
if
(
iter
!=
dict
.
end
())
return
iter
->
second
;
return
MKLDNNDataType
::
data_undef
;
return
MKLDNNDataType
::
data_undef
;
}
}
...
...
paddle/fluid/framework/data_type.cc
浏览文件 @
3bd54ed7
...
@@ -26,7 +26,7 @@ struct DataTypeMap {
...
@@ -26,7 +26,7 @@ struct DataTypeMap {
std
::
unordered_map
<
std
::
type_index
,
proto
::
VarType
::
Type
>
cpp_to_proto_
;
std
::
unordered_map
<
std
::
type_index
,
proto
::
VarType
::
Type
>
cpp_to_proto_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
std
::
type_index
,
size_t
>
cpp
_to_size_
;
std
::
unordered_map
<
int
,
size_t
>
proto
_to_size_
;
};
};
static
DataTypeMap
*
InitDataTypeMap
();
static
DataTypeMap
*
InitDataTypeMap
();
...
@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
...
@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
typeid
(
T
));
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
typeid
(
T
));
map
->
cpp_to_proto_
.
emplace
(
typeid
(
T
),
proto_type
);
map
->
cpp_to_proto_
.
emplace
(
typeid
(
T
),
proto_type
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
cpp_to_size_
.
emplace
(
typeid
(
T
),
sizeof
(
T
));
map
->
proto_to_size_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
sizeof
(
T
));
}
}
static
DataTypeMap
*
InitDataTypeMap
()
{
static
DataTypeMap
*
InitDataTypeMap
()
{
...
@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
...
@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
_ForEachDataType_
(
RegType
);
RegType
(
float16
,
proto
::
VarType
::
FP16
);
RegType
(
float
,
proto
::
VarType
::
FP32
);
RegType
(
double
,
proto
::
VarType
::
FP64
);
RegType
(
int
,
proto
::
VarType
::
INT32
);
RegType
(
int64_t
,
proto
::
VarType
::
INT64
);
RegType
(
bool
,
proto
::
VarType
::
BOOL
);
RegType
(
size_t
,
proto
::
VarType
::
SIZE_T
);
RegType
(
int16_t
,
proto
::
VarType
::
INT16
);
RegType
(
uint8_t
,
proto
::
VarType
::
UINT8
);
RegType
(
int8_t
,
proto
::
VarType
::
INT8
);
#undef RegType
#undef RegType
return
retv
;
return
retv
;
...
@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
...
@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast
<
int
>
(
type
));
static_cast
<
int
>
(
type
));
}
}
size_t
SizeOfType
(
std
::
type_index
type
)
{
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_size_
.
find
(
type
);
auto
it
=
gDataTypeMap
().
proto_to_size_
.
find
(
static_cast
<
int
>
(
type
)
);
if
(
it
!=
gDataTypeMap
().
cpp
_to_size_
.
end
())
{
if
(
it
!=
gDataTypeMap
().
proto
_to_size_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
PADDLE_THROW
(
"Not support %s as tensor type"
,
type
.
name
(
));
PADDLE_THROW
(
"Not support %s as tensor type"
,
DataTypeToString
(
type
));
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_type.h
浏览文件 @
3bd54ed7
...
@@ -22,46 +22,59 @@ limitations under the License. */
...
@@ -22,46 +22,59 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
typename
T
>
struct
DataTypeTrait
{};
// Stub handle for void
template
<
>
struct
DataTypeTrait
<
void
>
{
constexpr
static
auto
DataType
=
proto
::
VarType
::
RAW
;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_
(
DefineDataTypeTrait
);
#undef DefineDataTypeTrait
extern
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
);
extern
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
);
extern
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
);
extern
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
);
template
<
typename
Visitor
>
template
<
typename
Visitor
>
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
switch
(
type
)
{
#define VisitDataTypeCallback(cpp_type, proto_type) \
case
proto
::
VarType
::
FP16
:
do { \
visitor
.
template
apply
<
platform
::
float16
>();
if (type == proto_type) { \
break
;
visitor.template apply<cpp_type>(); \
case
proto
::
VarType
::
FP32
:
return; \
visitor
.
template
apply
<
float
>();
} \
break
;
} while (0)
case
proto
::
VarType
::
FP64
:
visitor
.
template
apply
<
double
>();
_ForEachDataType_
(
VisitDataTypeCallback
);
break
;
#undef VisitDataTypeCallback
case
proto
::
VarType
::
INT32
:
PADDLE_THROW
(
"Not supported %d"
,
type
);
visitor
.
template
apply
<
int
>();
break
;
case
proto
::
VarType
::
INT64
:
visitor
.
template
apply
<
int64_t
>();
break
;
case
proto
::
VarType
::
BOOL
:
visitor
.
template
apply
<
bool
>();
break
;
case
proto
::
VarType
::
UINT8
:
visitor
.
template
apply
<
uint8_t
>();
break
;
case
proto
::
VarType
::
INT16
:
visitor
.
template
apply
<
int16_t
>();
break
;
case
proto
::
VarType
::
INT8
:
visitor
.
template
apply
<
int8_t
>();
break
;
default:
PADDLE_THROW
(
"Not supported %d"
,
type
);
}
}
}
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
extern
size_t
SizeOfType
(
std
::
type_index
type
);
extern
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
proto
::
VarType
::
Type
&
type
)
{
const
proto
::
VarType
::
Type
&
type
)
{
out
<<
DataTypeToString
(
type
);
out
<<
DataTypeToString
(
type
);
...
...
paddle/fluid/framework/data_type_test.cc
浏览文件 @
3bd54ed7
...
@@ -26,15 +26,15 @@ TEST(DataType, float16) {
...
@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor
tensor
;
Tensor
tensor
;
CPUPlace
cpu
;
CPUPlace
cpu
;
tensor
.
mutable_data
(
cpu
,
f
::
ToTypeIndex
(
dtype
)
);
tensor
.
mutable_data
(
cpu
,
dtype
);
// test fp16 tensor
// test fp16 tensor
EXPECT_EQ
(
tensor
.
type
(),
std
::
type_index
(
typeid
(
float16
)));
EXPECT_EQ
(
tensor
.
type
(),
f
::
ToDataType
(
typeid
(
float16
)));
// test fp16 size
// test fp16 size
EXPECT_EQ
(
f
::
SizeOfType
(
f
::
ToTypeIndex
(
dtype
)
),
2u
);
EXPECT_EQ
(
f
::
SizeOfType
(
dtype
),
2u
);
// test debug info
// test debug info
std
::
string
type
=
"float16"
;
std
::
string
type
=
"
::paddle::platform::
float16"
;
EXPECT_STREQ
(
f
::
DataTypeToString
(
dtype
).
c_str
(),
type
.
c_str
());
EXPECT_STREQ
(
f
::
DataTypeToString
(
dtype
).
c_str
(),
type
.
c_str
());
}
}
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
3bd54ed7
...
@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
...
@@ -12,12 +12,19 @@ cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_DISTRIBUTE
)
if
(
NOT WITH_GRPC
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
endif
()
endif
()
if
(
WITH_GPU
)
if
(
WITH_GPU
)
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
)
dynload_cuda variable_visitor
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_
g
rpc
)
ddim dynload_cuda selected_rows_functor sendrecvop_rpc
)
else
()
else
()
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
nv_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor
)
ddim dynload_cuda selected_rows_functor
)
...
@@ -30,7 +37,7 @@ else()
...
@@ -30,7 +37,7 @@ else()
variable_visitor
)
variable_visitor
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_
g
rpc
)
ddim selected_rows_functor sendrecvop_rpc
)
else
()
else
()
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
cc_library
(
reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor
)
ddim selected_rows_functor
)
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
3bd54ed7
...
@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
...
@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
// Reduce All Tensor to trg in CPU
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
scope
=
auto
&
scope
=
...
...
paddle/fluid/framework/details/fuse_vars_op_handle.h
浏览文件 @
3bd54ed7
...
@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
...
@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle
(
ir
::
Node
*
node
,
Scope
*
local_scope
,
FuseVarsOpHandle
(
ir
::
Node
*
node
,
Scope
*
local_scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
&
inputs_numel
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
&
inputs_numel
,
const
std
::
type_index
&
var_type
)
const
proto
::
VarType
::
Type
var_type
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
local_scope_
(
local_scope
),
local_scope_
(
local_scope
),
place_
(
place
),
place_
(
place
),
...
@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
...
@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope
*
local_scope_
;
Scope
*
local_scope_
;
const
platform
::
Place
place_
;
const
platform
::
Place
place_
;
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
inputs_numel_
;
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
inputs_numel_
;
const
std
::
type_index
type_
;
const
proto
::
VarType
::
Type
type_
;
int64_t
total_numel_
;
int64_t
total_numel_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
3bd54ed7
...
@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
...
@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
}
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if
(
framework
::
IsType
<
const
float
>
(
in_selected_rows
[
0
]
->
value
().
type
()))
{
if
(
in_selected_rows
[
0
]
->
value
().
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
float
>
(
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
float
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
if
(
framework
::
IsType
<
const
double
>
(
}
else
if
(
in_selected_rows
[
0
]
->
value
().
type
()
==
in_selected_rows
[
0
]
->
value
().
type
())
)
{
framework
::
proto
::
VarType
::
FP64
)
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
double
>
(
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
double
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
{
}
else
{
PADDLE_ENFORCE
(
false
,
PADDLE_THROW
(
"only support double or float when gather SelectedRows"
);
"only support double or float when gahter SelectedRows"
);
}
}
#endif
#endif
});
});
...
@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if
(
!
FLAGS_cpu_deterministic
)
{
if
(
!
FLAGS_cpu_deterministic
)
{
ReduceLoDTensor
func
(
lod_tensors
,
ReduceLoDTensor
func
(
lod_tensors
,
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
}
else
{
}
else
{
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...
@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->
FindVar
(
out_var_handle
->
name_
)
->
FindVar
(
out_var_handle
->
name_
)
->
GetMutable
<
framework
::
LoDTensor
>
();
->
GetMutable
<
framework
::
LoDTensor
>
();
ReduceLoDTensor
func
(
lod_tensors
,
&
reduce_sum_trg
);
ReduceLoDTensor
func
(
lod_tensors
,
&
reduce_sum_trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
auto
trg
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
trg
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
reduce_sum_trg
.
data
<
void
>
()
!=
trg
->
data
<
void
>
())
{
if
(
reduce_sum_trg
.
data
<
void
>
()
!=
trg
->
data
<
void
>
())
{
...
...
paddle/fluid/framework/dlpack_tensor.cc
浏览文件 @
3bd54ed7
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
...
@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return
dtype
;
return
dtype
;
}
}
static
DLDataType
GetDLDataTypeFromTypeIndex
(
const
std
::
type_index
&
type
)
{
static
std
::
unordered_map
<
int
,
::
DLDataType
>
CreateDLDataTypeMap
()
{
#define REG_DL_DATA_TYPE(type) \
static
std
::
unordered_map
<
int
,
::
DLDataType
>
result
;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static
const
std
::
unordered_map
<
std
::
type_index
,
::
DLDataType
>
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map
({
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE
(
platform
::
float16
),
// NOLINT
REG_DL_DATA_TYPE
(
float
),
// NOLINT
_ForEachDataType_
(
REG_DL_DATA_TYPE
);
REG_DL_DATA_TYPE
(
double
),
// NOLINT
#undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE
(
int
),
// NOLINT
return
result
;
REG_DL_DATA_TYPE
(
int64_t
),
// NOLINT
}
REG_DL_DATA_TYPE
(
bool
),
// NOLINT
REG_DL_DATA_TYPE
(
size_t
),
// NOLINT
static
DLDataType
GetDLDataTypeFromTypeIndex
(
proto
::
VarType
::
Type
type
)
{
REG_DL_DATA_TYPE
(
int16_t
),
// NOLINT
static
auto
type_to_dtype_map
=
CreateDLDataTypeMap
();
REG_DL_DATA_TYPE
(
uint8_t
),
// NOLINT
REG_DL_DATA_TYPE
(
int8_t
)
// NOLINT
});
static
auto
type_to_dtype_map_end_it
=
type_to_dtype_map
.
end
();
static
auto
type_to_dtype_map_end_it
=
type_to_dtype_map
.
end
();
auto
it
=
type_to_dtype_map
.
find
(
type
);
auto
it
=
type_to_dtype_map
.
find
(
static_cast
<
int
>
(
type
)
);
PADDLE_ENFORCE
(
it
!=
type_to_dtype_map_end_it
,
"Unsupported data type %
s
"
,
PADDLE_ENFORCE
(
it
!=
type_to_dtype_map_end_it
,
"Unsupported data type %
d
"
,
type
.
name
()
);
type
);
return
it
->
second
;
return
it
->
second
;
#undef REG_DL_DATA_TYPE
#undef REG_DL_DATA_TYPE
}
}
...
...
paddle/fluid/framework/dlpack_tensor_test.cc
浏览文件 @
3bd54ed7
...
@@ -91,23 +91,11 @@ void TestMainLoop() {
...
@@ -91,23 +91,11 @@ void TestMainLoop() {
}
}
}
}
}
}
TEST
(
dlpack
,
test_all
)
{
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \
_ForEachDataType_
(
TestCallback
);
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
}
using
float16
=
platform
::
float16
;
PADDLE_DLPACK_TEST
(
float16
);
PADDLE_DLPACK_TEST
(
float
);
PADDLE_DLPACK_TEST
(
double
);
PADDLE_DLPACK_TEST
(
int
);
PADDLE_DLPACK_TEST
(
int64_t
);
PADDLE_DLPACK_TEST
(
bool
);
PADDLE_DLPACK_TEST
(
size_t
);
PADDLE_DLPACK_TEST
(
int16_t
);
PADDLE_DLPACK_TEST
(
uint8_t
);
PADDLE_DLPACK_TEST
(
int8_t
);
#undef PADDLE_DLPACK_TEST
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/executor.cc
浏览文件 @
3bd54ed7
...
@@ -157,9 +157,9 @@ void Executor::Close() {
...
@@ -157,9 +157,9 @@ void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
// except 0.
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
auto
client
=
::
paddle
::
operators
::
distributed
::
GRPCClient
>
(
0
)
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
->
SendComplete
();
client
->
SendComplete
();
#endif
#endif
}
}
...
...
paddle/fluid/framework/executor_thread_worker.cc
浏览文件 @
3bd54ed7
...
@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
...
@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std
::
cout
<<
sstream
.
str
()
<<
std
::
endl
;
std
::
cout
<<
sstream
.
str
()
<<
std
::
endl
;
}
}
void
print_fetch_var
(
Scope
*
scope
,
std
::
string
var_name
)
{
static
void
print_fetch_var
(
Scope
*
scope
,
const
std
::
string
&
var_name
)
{
const
LoDTensor
&
tensor
=
scope
->
FindVar
(
var_name
)
->
Get
<
LoDTensor
>
();
auto
&
tensor
=
scope
->
FindVar
(
var_name
)
->
Get
<
LoDTensor
>
();
if
(
std
::
type_index
(
tensor
.
type
())
==
#define PrintLoDTensorCallback(cpp_type, proto_type) \
std
::
type_index
(
typeid
(
platform
::
float16
)))
{
do { \
print_lod_tensor
<
platform
::
float16
>
(
var_name
,
tensor
);
if (tensor.type() == proto_type) { \
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
float
)))
{
print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor
<
float
>
(
var_name
,
tensor
);
return; \
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
} \
std
::
type_index
(
typeid
(
double
)))
{
} while (0)
print_lod_tensor
<
double
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
_ForEachDataType_
(
PrintLoDTensorCallback
);
print_lod_tensor
<
int
>
(
var_name
,
tensor
);
VLOG
(
1
)
<<
"print_fetch_var: unrecognized data type:"
<<
tensor
.
type
();
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int64_t
)))
{
print_lod_tensor
<
int64_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
bool
)))
{
print_lod_tensor
<
bool
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
uint8_t
)))
{
print_lod_tensor
<
uint8_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int16_t
)))
{
print_lod_tensor
<
int16_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int8_t
)))
{
print_lod_tensor
<
int8_t
>
(
var_name
,
tensor
);
}
else
{
VLOG
(
1
)
<<
"print_fetch_var: unrecognized data type:"
<<
tensor
.
type
().
name
();
}
return
;
}
}
void
ExecutorThreadWorker
::
TrainFiles
()
{
void
ExecutorThreadWorker
::
TrainFiles
()
{
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
3bd54ed7
...
@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base)
...
@@ -42,6 +42,8 @@ pass_library(multi_batch_merge_pass base)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
conv_bn_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
pass_library
(
seqconv_eltadd_relu_fuse_pass inference
)
pass_library
(
is_test_pass base
)
pass_library
(
is_test_pass base
)
pass_library
(
conv_elementwise_add_act_fuse_pass inference
)
pass_library
(
conv_elementwise_add2_act_fuse_pass inference
)
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
mkldnn_placement_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
pass_library
(
depthwise_conv_mkldnn_pass base
)
...
...
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse.cc
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework
::
proto
::
OpDesc
PrepareOpDesc
(
const
framework
::
proto
::
OpDesc
&
base_desc
,
const
std
::
string
&
bias
,
const
std
::
string
&
bias1
,
const
std
::
string
&
activation
,
const
std
::
string
&
output
)
{
auto
proto
=
base_desc
;
framework
::
OpDesc
desc
(
proto
,
nullptr
);
desc
.
SetInput
(
"Bias"
,
{
bias
});
desc
.
SetInput
(
"ResidualData"
,
{
bias1
});
desc
.
SetAttr
(
"activation"
,
activation
);
desc
.
SetOutput
(
"Output"
,
{
output
});
desc
.
SetAttr
(
"is_test"
,
true
);
desc
.
SetAttr
(
"use_cudnn"
,
false
);
return
*
desc
.
Proto
();
}
std
::
unique_ptr
<
ir
::
Graph
>
ConvElementwiseAddActFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
pattern_name
=
"conv_elementwise_add_act_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
AsInput
()
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
patterns
::
ConvElementwiseaddAct
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
auto
base_op_desc
=
*
conv_op
->
Op
()
->
Proto
();
std
::
string
bias_name
=
elementwise_add_in_y
->
Name
();
std
::
string
bias1_name
=
elementwise_add_in_y_1
->
Name
();
std
::
string
act_op_type
=
act_op
->
Op
()
->
Type
();
std
::
string
act_op_out
=
act_out
->
Name
();
auto
new_op_proto
=
PrepareOpDesc
(
base_op_desc
,
bias_name
,
bias1_name
,
act_op_type
,
act_op_out
);
framework
::
OpDesc
new_op_desc
(
new_op_proto
,
nullptr
);
// Create a new node for the fused op.
auto
new_conv_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
// Link inputs and outputs.
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
conv_in_node
=
subgraph
.
at
(
x
);
IR_NODE_LINK_TO
(
conv_in_node
,
new_conv_op
);
// Input
IR_NODE_LINK_TO
(
conv_filter
,
new_conv_op
);
// Filter
IR_NODE_LINK_TO
(
elementwise_add_in_y
,
new_conv_op
);
// Bias
IR_NODE_LINK_TO
(
elementwise_add_in_y_1
,
new_conv_op
);
// ResidualData
IR_NODE_LINK_TO
(
new_conv_op
,
act_out
);
// Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv_op
,
elementwise_add_op
,
elementwise_add_op_1
,
elementwise_add_out
});
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add2_act_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAdd2ActFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(elementwise_add_op_1); \
GET_IR_NODE(elementwise_add_in_y_1); \
GET_IR_NODE(elementwise_add_out_1); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework
::
proto
::
OpDesc
PrepareOpDesc
(
const
framework
::
proto
::
OpDesc
&
base_desc
,
const
std
::
string
&
bias
,
const
std
::
string
&
bias1
,
const
std
::
string
&
activation
,
const
std
::
string
&
output
)
{
auto
proto
=
base_desc
;
framework
::
OpDesc
desc
(
proto
,
nullptr
);
desc
.
SetInput
(
"Bias"
,
{
bias
});
desc
.
SetInput
(
"ResidualData"
,
{
bias1
});
desc
.
SetAttr
(
"activation"
,
activation
);
desc
.
SetOutput
(
"Output"
,
{
output
});
desc
.
SetAttr
(
"is_test"
,
true
);
return
*
desc
.
Proto
();
}
std
::
unique_ptr
<
ir
::
Graph
>
ConvElementwiseAdd2ActFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
pattern_name
=
"conv_elementwise_add_act_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
AsInput
()
->
assert_is_op_input
(
"conv2d"
,
"Input"
);
patterns
::
ConvElementwiseadd2Act
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
auto
base_op_desc
=
*
conv_op
->
Op
()
->
Proto
();
std
::
string
bias_name
=
elementwise_add_in_y
->
Name
();
std
::
string
bias1_name
=
elementwise_add_in_y_1
->
Name
();
std
::
string
act_op_type
=
act_op
->
Op
()
->
Type
();
std
::
string
act_op_out
=
act_out
->
Name
();
auto
new_op_proto
=
PrepareOpDesc
(
base_op_desc
,
bias_name
,
bias1_name
,
act_op_type
,
act_op_out
);
framework
::
OpDesc
new_op_desc
(
new_op_proto
,
nullptr
);
// Create a new node for the fused op.
graph
->
CreateOpNode
(
&
new_op_desc
);
// Link inputs and outputs.
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
conv_in_node
=
subgraph
.
at
(
x
);
IR_NODE_LINK_TO
(
conv_in_node
,
conv_op
);
// Input
IR_NODE_LINK_TO
(
conv_filter
,
conv_op
);
// Filter
IR_NODE_LINK_TO
(
conv_op
,
conv_out
);
// Output
IR_NODE_LINK_TO
(
elementwise_add_in_y
,
conv_op
);
// Bias
IR_NODE_LINK_TO
(
elementwise_add_in_y_1
,
conv_op
);
// Bias
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv_op
,
elementwise_add_op
,
elementwise_add_op_1
,
elementwise_add_out
});
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add2_act_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAdd2ActFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ConvElementwiseAdd2ActFusePass
:
public
FusePassBase
{
public:
virtual
~
ConvElementwiseAdd2ActFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(conv_op); \
GET_IR_NODE(conv_out); \
GET_IR_NODE(conv_filter); \
GET_IR_NODE(elementwise_add_op); \
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out); \
GET_IR_NODE(act_op); \
GET_IR_NODE(act_out);
// Inherient the basic infomation from `base_desc`, and modify some fields.
framework
::
proto
::
OpDesc
PrepareOpDesc
(
const
framework
::
proto
::
OpDesc
&
base_desc
,
const
std
::
string
&
bias
,
const
std
::
string
&
activation
,
const
std
::
string
&
output
)
{
auto
proto
=
base_desc
;
framework
::
OpDesc
desc
(
proto
,
nullptr
);
desc
.
SetType
(
"conv2d_fusion"
);
desc
.
SetInput
(
"Bias"
,
{
bias
});
desc
.
SetInput
(
"ResidualData"
,
{});
desc
.
SetAttr
(
"activation"
,
activation
);
desc
.
SetOutput
(
"Output"
,
{
output
});
desc
.
SetAttr
(
"is_test"
,
true
);
desc
.
SetAttr
(
"use_cudnn"
,
false
);
desc
.
Flush
();
return
*
desc
.
Proto
();
}
std
::
unique_ptr
<
ir
::
Graph
>
ConvElementwiseAddActFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
pattern_name
=
"conv_elementwise_add_act_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
.
get
());
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
assert_is_op_input
(
"conv2d"
,
"Input"
)
->
AsInput
();
patterns
::
ConvElementwiseaddAct
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
auto
base_op_desc
=
*
conv_op
->
Op
()
->
Proto
();
std
::
string
bias_name
=
elementwise_add_in_y
->
Name
();
std
::
string
act_op_type
=
act_op
->
Op
()
->
Type
();
std
::
string
act_op_out
=
act_out
->
Name
();
auto
new_op_proto
=
PrepareOpDesc
(
base_op_desc
,
bias_name
,
act_op_type
,
act_op_out
);
framework
::
OpDesc
new_op_desc
(
new_op_proto
,
nullptr
);
// Create a new node for the fused op.
auto
*
new_conv_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
// Link inputs and outputs.
PADDLE_ENFORCE
(
subgraph
.
count
(
x
));
auto
*
conv_in_node
=
subgraph
.
at
(
x
);
IR_NODE_LINK_TO
(
conv_in_node
,
new_conv_op
);
// Input
IR_NODE_LINK_TO
(
conv_filter
,
new_conv_op
);
// Filter
IR_NODE_LINK_TO
(
elementwise_add_in_y
,
new_conv_op
);
// Bias
IR_NODE_LINK_TO
(
new_conv_op
,
act_out
);
// Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
.
get
(),
{
conv_op
,
conv_out
,
elementwise_add_op
,
elementwise_add_out
,
act_op
});
};
gpd
(
graph
.
get
(),
handler
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
conv_elementwise_add_act_fuse_pass
,
paddle
::
framework
::
ir
::
ConvElementwiseAddActFusePass
);
paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
ConvElementwiseAddActFusePass
:
public
FusePassBase
{
public:
virtual
~
ConvElementwiseAddActFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
3bd54ed7
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
...
@@ -25,6 +26,7 @@
...
@@ -25,6 +26,7 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
...
@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
...
@@ -104,7 +106,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
for
(
const
auto
&
pdnode
:
pattern_
.
nodes
())
{
for
(
const
auto
&
pdnode
:
pattern_
.
nodes
())
{
if
(
pdnode
->
Tell
(
&
node
))
{
if
(
pdnode
->
Tell
(
&
node
))
{
VLOG
(
4
)
<<
"
pdnode "
<<
pdnode
->
name
()
<<
" marked"
;
VLOG
(
4
)
<<
"
Node "
<<
node
.
Name
()
<<
" marked as "
<<
pdnode
->
name
()
;
pdnodes2nodes_
[
pdnode
.
get
()].
insert
(
&
node
);
pdnodes2nodes_
[
pdnode
.
get
()].
insert
(
&
node
);
}
}
}
}
...
@@ -1099,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
...
@@ -1099,6 +1101,115 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return
out_var
;
return
out_var
;
}
}
std
::
unordered_set
<
std
::
string
>
conv_act_set
({
"identity"
,
"sigmoid"
,
"relu"
,
"relu6"
,
"relux"
,
"tanh"
,
"band_pass"
});
PDNode
*
patterns
::
ConvElementwiseaddAct
::
operator
()(
PDNode
*
conv_in
)
{
conv_in
->
AsInput
();
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
conv_out
=
pattern
->
NewNode
(
conv_out_repr
())
->
assert_is_op_output
(
"conv2d"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
conv_filter
=
pattern
->
NewNode
(
conv_filter_repr
())
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
elementwise_add_op
=
pattern
->
NewNode
(
elementwise_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
elementwise_add_in_y
=
pattern
->
NewNode
(
elementwise_add_in_y_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
auto
elementwise_add_out
=
pattern
->
NewNode
(
elementwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
auto
act_op
=
pattern
->
NewNode
(
act_op_repr
())
->
assert_is_op
()
->
assert_more
([
&
](
Node
*
node
)
{
auto
op_type
=
node
->
Name
();
return
conv_act_set
.
count
(
op_type
);
});
auto
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_var
()
// is activation op's output.
->
assert_more
([
&
](
Node
*
node
)
{
for
(
auto
*
in_op
:
node
->
inputs
)
{
if
(
conv_act_set
.
count
(
in_op
->
Name
()))
{
return
true
;
}
}
return
false
;
})
->
AsOutput
();
conv_op
->
LinksFrom
({
conv_in
,
conv_filter
});
conv_out
->
LinksFrom
({
conv_op
});
elementwise_add_op
->
LinksFrom
({
conv_out
,
elementwise_add_in_y
})
.
LinksTo
({
elementwise_add_out
});
act_op
->
LinksFrom
({
elementwise_add_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
ConvElementwiseadd2Act
::
operator
()(
PDNode
*
conv_in
)
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
auto
conv_filter
=
pattern
->
NewNode
(
conv_filter_repr
())
->
assert_is_op_input
(
"conv2d"
,
"Filter"
)
->
AsInput
();
auto
conv_out
=
pattern
->
NewNode
(
conv_out_repr
())
->
assert_is_op_output
(
"conv2d"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
elementwise_add_op
=
pattern
->
NewNode
(
elementwise_add_op_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
elementwise_add_in_y
=
pattern
->
NewNode
(
elementwise_add_in_y_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
auto
elementwise_add_out
=
pattern
->
NewNode
(
elementwise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsIntermediate
();
auto
elementwise_add_op_1
=
pattern
->
NewNode
(
elementwise_add_op_1_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
elementwise_add_in_y_1
=
pattern
->
NewNode
(
elementwise_add_in_y_1_repr
())
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
auto
elementwise_add_out_1
=
pattern
->
NewNode
(
elementwise_add_out_1_repr
())
->
assert_is_op_output
(
"elementwise_add"
)
->
AsIntermediate
();
auto
act_op
=
pattern
->
NewNode
(
act_op_repr
())
->
assert_is_op
()
->
assert_more
([
&
](
Node
*
node
)
{
auto
op_type
=
node
->
Name
();
return
conv_act_set
.
count
(
op_type
);
});
auto
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_var
()
// is activation op's output.
->
assert_more
([
&
](
Node
*
node
)
{
for
(
auto
*
in_op
:
node
->
inputs
)
{
if
(
conv_act_set
.
count
(
in_op
->
Name
()))
{
return
true
;
}
}
return
false
;
})
->
AsOutput
();
conv_op
->
LinksFrom
({
conv_in
,
conv_filter
}).
LinksTo
({
conv_out
});
elementwise_add_op
->
LinksFrom
({
conv_out
,
elementwise_add_in_y
})
.
LinksTo
({
elementwise_add_out
});
elementwise_add_op_1
->
LinksFrom
(
{
elementwise_add_out
,
elementwise_add_in_y_1
});
act_op
->
LinksFrom
({
elementwise_add_out_1
}).
LinksTo
({
act_out
});
return
act_out
;
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
3bd54ed7
...
@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase {
...
@@ -671,6 +671,51 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE
(
elementwise_add_y
);
PATTERN_DECL_NODE
(
elementwise_add_y
);
PATTERN_DECL_NODE
(
elementwise_add_out
);
PATTERN_DECL_NODE
(
elementwise_add_out
);
};
};
// Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct
ConvElementwiseaddAct
:
public
PatternBase
{
ConvElementwiseaddAct
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_elementwiseadd_act"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_in
);
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
conv_filter
);
PATTERN_DECL_NODE
(
elementwise_add_op
);
PATTERN_DECL_NODE
(
elementwise_add_in_y
);
// input
PATTERN_DECL_NODE
(
elementwise_add_out
);
PATTERN_DECL_NODE
(
act_op
);
PATTERN_DECL_NODE
(
act_out
);
};
// Conv + ElementwiseAdd + ElementwiseAdd + Activation
struct
ConvElementwiseadd2Act
:
public
PatternBase
{
ConvElementwiseadd2Act
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_elementwiseadd2_elementwiseadd_act"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_in
);
PATTERN_DECL_NODE
(
conv_op
);
PATTERN_DECL_NODE
(
conv_filter
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
elementwise_add_op
);
PATTERN_DECL_NODE
(
elementwise_add_in_y
);
// input
PATTERN_DECL_NODE
(
elementwise_add_out
);
PATTERN_DECL_NODE
(
elementwise_add_op_1
);
PATTERN_DECL_NODE
(
elementwise_add_in_y_1
);
// input
PATTERN_DECL_NODE
(
elementwise_add_out_1
);
PATTERN_DECL_NODE
(
act_op
);
PATTERN_DECL_NODE
(
act_out
);
};
}
// namespace patterns
}
// namespace patterns
// Link two ir::Nodes from each other.
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
3bd54ed7
...
@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
...
@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
// only print first ten elements
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
IsType
<
float
>
(
t
.
type
())
)
{
if
(
t
.
type
()
==
proto
::
VarType
::
FP32
)
{
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
}
else
if
(
IsType
<
int64_t
>
(
t
.
type
())
)
{
}
else
if
(
t
.
type
()
==
proto
::
VarType
::
INT64
)
{
os
<<
t
.
data
<
int64_t
>
()[
i
]
<<
" "
;
os
<<
t
.
data
<
int64_t
>
()[
i
]
<<
" "
;
}
else
{
}
else
{
PADDLE_THROW
(
"LoDTensor data type not in [float, int64_t]"
);
PADDLE_THROW
(
"LoDTensor data type not in [float, int64_t]"
);
...
@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
...
@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE
(
!
lod_tensors
.
empty
());
PADDLE_ENFORCE
(
!
lod_tensors
.
empty
());
framework
::
DDim
new_dim
=
lod_tensors
[
0
]
->
dims
();
framework
::
DDim
new_dim
=
lod_tensors
[
0
]
->
dims
();
std
::
type_index
new_type
=
lod_tensors
[
0
]
->
type
();
auto
new_type
=
lod_tensors
[
0
]
->
type
();
framework
::
DataLayout
new_layout
=
lod_tensors
[
0
]
->
layout
();
framework
::
DataLayout
new_layout
=
lod_tensors
[
0
]
->
layout
();
LoD
new_lod
=
lod_tensors
[
0
]
->
lod
();
LoD
new_lod
=
lod_tensors
[
0
]
->
lod
();
for
(
size_t
i
=
1
;
i
<
lod_tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
lod_tensors
.
size
();
++
i
)
{
...
...
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
3bd54ed7
...
@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
...
@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
PADDLE_ENFORCE
(
sp
==
Ddim2Shape
(
tensor_pd
->
dims
()),
PADDLE_ENFORCE
(
sp
==
Ddim2Shape
(
tensor_pd
->
dims
()),
"Ensure ngraph tensor layout align with paddle tensor"
);
"Ensure ngraph tensor layout align with paddle tensor"
);
if
(
tensor_pd
->
type
().
hash_code
()
==
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
FP32
)
{
typeid
(
float
).
hash_code
())
{
// NOLINT
const
float
*
arr
=
tensor_pd
->
data
<
float
>
();
const
float
*
arr
=
tensor_pd
->
data
<
float
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f32
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f32
,
sp
,
const_cast
<
float
*>
(
arr
));
const_cast
<
float
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
INT32
)
{
typeid
(
int
).
hash_code
())
{
// NOLINT
const
int
*
arr
=
tensor_pd
->
data
<
int
>
();
const
int
*
arr
=
tensor_pd
->
data
<
int
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i32
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i32
,
sp
,
const_cast
<
int
*>
(
arr
));
const_cast
<
int
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
()
.
hash_code
()
==
typeid
(
int64_t
).
hash_code
()
)
{
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
INT64
)
{
const
int64_t
*
arr
=
tensor_pd
->
data
<
int64_t
>
();
const
int64_t
*
arr
=
tensor_pd
->
data
<
int64_t
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i64
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i64
,
sp
,
const_cast
<
int64_t
*>
(
arr
));
const_cast
<
int64_t
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
FP64
)
{
typeid
(
double
).
hash_code
())
{
// NOLINT
const
double
*
arr
=
tensor_pd
->
data
<
double
>
();
const
double
*
arr
=
tensor_pd
->
data
<
double
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f64
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f64
,
sp
,
const_cast
<
double
*>
(
arr
));
const_cast
<
double
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
BOOL
)
{
typeid
(
bool
).
hash_code
())
{
// NOLINT
const
bool
*
arr
=
tensor_pd
->
data
<
bool
>
();
const
bool
*
arr
=
tensor_pd
->
data
<
bool
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
boolean
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
boolean
,
sp
,
const_cast
<
bool
*>
(
arr
));
const_cast
<
bool
*>
(
arr
));
...
...
paddle/fluid/framework/op_kernel_type_test.cc
浏览文件 @
3bd54ed7
...
@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
...
@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType
op_kernel_type2
(
DataType
::
FP16
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
OpKernelType
op_kernel_type2
(
DataType
::
FP16
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
LibraryType
::
kCUDNN
);
LibraryType
::
kCUDNN
);
ASSERT_EQ
(
paddle
::
framework
::
KernelTypeToString
(
op_kernel_type2
),
ASSERT_EQ
(
paddle
::
framework
::
KernelTypeToString
(
op_kernel_type2
),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_"
"data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"
);
"type[CUDNN]"
);
}
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
3bd54ed7
...
@@ -45,10 +45,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
...
@@ -45,10 +45,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
)
{
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
)
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
ToDataType
(
var
->
Get
<
framework
::
LoDTensor
>
().
type
()
);
return
var
->
Get
<
framework
::
LoDTensor
>
().
type
(
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
ToDataType
(
return
var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
();
var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
());
}
else
{
}
else
{
PADDLE_THROW
(
"Var should be LoDTensor or SelectedRows"
);
PADDLE_THROW
(
"Var should be LoDTensor or SelectedRows"
);
}
}
...
@@ -95,13 +94,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
...
@@ -95,13 +94,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
return
""
;
return
""
;
}
}
return
DataTypeToString
(
ToDataType
(
tensor
.
type
()
));
return
DataTypeToString
(
tensor
.
type
(
));
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
auto
tensor
=
var
->
Get
<
SelectedRows
>
().
value
();
auto
tensor
=
var
->
Get
<
SelectedRows
>
().
value
();
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
return
"uninited"
;
return
"uninited"
;
}
else
{
}
else
{
return
DataTypeToString
(
ToDataType
(
tensor
.
type
()
));
return
DataTypeToString
(
tensor
.
type
(
));
}
}
}
else
{
}
else
{
return
""
;
return
""
;
...
@@ -688,7 +687,8 @@ static void CheckTensorNANOrInf(const std::string& name,
...
@@ -688,7 +687,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if
(
tensor
.
memory_size
()
==
0
)
{
if
(
tensor
.
memory_size
()
==
0
)
{
return
;
return
;
}
}
if
(
!
IsType
<
float
>
(
tensor
.
type
())
&&
!
IsType
<
double
>
(
tensor
.
type
()))
{
if
(
tensor
.
type
()
!=
proto
::
VarType
::
FP32
&&
tensor
.
type
()
!=
proto
::
VarType
::
FP64
)
{
return
;
return
;
}
}
PADDLE_ENFORCE
(
!
framework
::
TensorContainsInf
(
tensor
),
PADDLE_ENFORCE
(
!
framework
::
TensorContainsInf
(
tensor
),
...
@@ -883,7 +883,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -883,7 +883,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
if
(
t
!=
nullptr
)
{
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s is not initialized: %s"
,
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s is not initialized: %s"
,
ipt_name
,
DebugString
());
ipt_name
,
DebugString
());
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()
));
int
tmp
=
static_cast
<
int
>
(
t
->
type
(
));
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)"
,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)"
,
...
...
paddle/fluid/framework/selected_rows.cc
浏览文件 @
3bd54ed7
...
@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
...
@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if
(
index
<
0
)
{
if
(
index
<
0
)
{
VLOG
(
5
)
<<
"id "
<<
id
<<
" not in the table, return 0"
;
VLOG
(
5
)
<<
"id "
<<
id
<<
" not in the table, return 0"
;
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()
),
value_
->
type
(
),
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
}
else
{
}
else
{
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()
),
value_
->
type
(
),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
index
*
value_width
,
value_width
));
}
}
...
...
paddle/fluid/framework/tensor.cc
浏览文件 @
3bd54ed7
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
extern
size_t
SizeOfType
(
std
::
type_index
type
);
extern
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
);
void
Tensor
::
check_memory_size
()
const
{
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
...
@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
...
@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
}
}
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
proto
::
VarType
::
Type
type
,
memory
::
Allocator
::
Attr
attr
,
memory
::
Allocator
::
Attr
attr
,
size_t
requested_size
)
{
size_t
requested_size
)
{
type_
=
type
;
type_
=
type
;
...
...
paddle/fluid/framework/tensor.h
浏览文件 @
3bd54ed7
...
@@ -19,9 +19,9 @@ limitations under the License. */
...
@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <typeindex>
#include <typeindex>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -67,7 +67,7 @@ class Tensor {
...
@@ -67,7 +67,7 @@ class Tensor {
friend
struct
EigenVector
;
friend
struct
EigenVector
;
public:
public:
Tensor
()
:
type_
(
typeid
(
float
)
),
offset_
(
0
)
{}
Tensor
()
:
type_
(
proto
::
VarType
::
FP32
),
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
template
<
typename
T
>
...
@@ -88,7 +88,7 @@ class Tensor {
...
@@ -88,7 +88,7 @@ class Tensor {
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
size_t
requested_size
=
0
);
size_t
requested_size
=
0
);
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
void
*
mutable_data
(
platform
::
Place
place
,
proto
::
VarType
::
Type
type
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
size_t
requested_size
=
0
);
size_t
requested_size
=
0
);
...
@@ -138,7 +138,7 @@ class Tensor {
...
@@ -138,7 +138,7 @@ class Tensor {
return
holder_
->
place
();
return
holder_
->
place
();
}
}
std
::
type_index
type
()
const
{
proto
::
VarType
::
Type
type
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor not initialized yet when Tensor::type() is called."
);
holder_
,
"Tensor not initialized yet when Tensor::type() is called."
);
return
type_
;
return
type_
;
...
@@ -165,7 +165,7 @@ class Tensor {
...
@@ -165,7 +165,7 @@ class Tensor {
private:
private:
/*! holds the memory block if allocated. */
/*! holds the memory block if allocated. */
std
::
shared_ptr
<
memory
::
Allocation
>
holder_
;
std
::
shared_ptr
<
memory
::
Allocation
>
holder_
;
std
::
type_index
type_
;
proto
::
VarType
::
Type
type_
;
/**
/**
* @brief points to elements dimensions.
* @brief points to elements dimensions.
*
*
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
3bd54ed7
...
@@ -24,9 +24,8 @@ template <typename T>
...
@@ -24,9 +24,8 @@ template <typename T>
inline
const
T
*
Tensor
::
data
()
const
{
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
();
check_memory_size
();
bool
valid
=
bool
valid
=
std
::
is_same
<
T
,
void
>::
value
||
type_
==
std
::
type_index
(
typeid
(
T
));
std
::
is_same
<
T
,
void
>::
value
||
type_
==
DataTypeTrait
<
T
>::
DataType
;
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %d"
,
type_
);
type_
.
name
());
return
reinterpret_cast
<
const
T
*>
(
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
...
@@ -38,9 +37,8 @@ template <typename T>
...
@@ -38,9 +37,8 @@ template <typename T>
inline
T
*
Tensor
::
data
()
{
inline
T
*
Tensor
::
data
()
{
check_memory_size
();
check_memory_size
();
bool
valid
=
bool
valid
=
std
::
is_same
<
T
,
void
>::
value
||
type_
==
std
::
type_index
(
typeid
(
T
));
std
::
is_same
<
T
,
void
>::
value
||
type_
==
DataTypeTrait
<
T
>::
DataType
;
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
type_
);
type_
.
name
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
offset_
);
}
}
...
@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
...
@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t
requested_size
)
{
size_t
requested_size
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)
,
attr
,
requested_size
));
mutable_data
(
place
,
DataTypeTrait
<
T
>::
DataType
,
attr
,
requested_size
));
}
}
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
3bd54ed7
...
@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
...
@@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template
<
typename
Predicate
,
typename
DevCtx
>
template
<
typename
Predicate
,
typename
DevCtx
>
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
VisitDataType
(
ToDataType
(
tensor
.
type
()
),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
VisitDataType
(
tensor
.
type
(
),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
predicate
,
tensor
,
ctx
,
out
));
predicate
,
tensor
,
ctx
,
out
));
}
}
template
<
typename
Predicate
>
template
<
typename
Predicate
>
...
@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
...
@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size
// int32_t size
// void* protobuf message
// void* protobuf message
proto
::
VarType
::
TensorDesc
desc
;
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
ToDataType
(
tensor
.
type
()
));
desc
.
set_data_type
(
tensor
.
type
(
));
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
*
pb_dims
=
desc
.
mutable_dims
();
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
...
@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
...
@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
void
*
buf
;
void
*
buf
;
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
ctx
=
platform
::
CPUDeviceContext
();
size_t
size
=
size_t
size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
desc
.
data_type
());
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
ToTypeIndex
(
desc
.
data_type
()));
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
Tensor
cpu_tensor
;
Tensor
cpu_tensor
;
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
3bd54ed7
...
@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto
type
=
fetch
.
type
();
auto
type
=
fetch
.
type
();
auto
output
=
&
(
outputs
->
at
(
i
));
auto
output
=
&
(
outputs
->
at
(
i
));
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
if
(
type
==
typeid
(
float
)
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
GetFetchOne
<
float
>
(
fetch
,
output
);
GetFetchOne
<
float
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
FLOAT32
;
output
->
dtype
=
PaddleDType
::
FLOAT32
;
}
else
if
(
type
==
typeid
(
int64_t
)
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
INT64
;
output
->
dtype
=
PaddleDType
::
INT64
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/api/analysis_predictor_tester.cc
浏览文件 @
3bd54ed7
...
@@ -55,7 +55,12 @@ TEST(AnalysisPredictor, analysis_off) {
...
@@ -55,7 +55,12 @@ TEST(AnalysisPredictor, analysis_off) {
}
}
TEST
(
AnalysisPredictor
,
analysis_on
)
{
TEST
(
AnalysisPredictor
,
analysis_on
)
{
AnalysisConfig
config
(
false
);
#ifdef PADDLE_WITH_CUDA
AnalysisConfig
config
(
true
);
config
.
fraction_of_gpu_memory
=
0.15
;
#else
AnalysisConfig
config
;
#endif
config
.
model_dir
=
FLAGS_dirname
;
config
.
model_dir
=
FLAGS_dirname
;
config
.
enable_ir_optim
=
true
;
config
.
enable_ir_optim
=
true
;
...
...
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
3bd54ed7
...
@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto
type
=
fetch
.
type
();
auto
type
=
fetch
.
type
();
auto
output
=
&
(
outputs
->
at
(
i
));
auto
output
=
&
(
outputs
->
at
(
i
));
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
if
(
type
==
typeid
(
float
)
)
{
if
(
type
==
framework
::
DataTypeTrait
<
float
>::
DataType
)
{
GetFetchOne
<
float
>
(
fetch
,
output
);
GetFetchOne
<
float
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
FLOAT32
;
output
->
dtype
=
PaddleDType
::
FLOAT32
;
}
else
if
(
type
==
typeid
(
int64_t
)
)
{
}
else
if
(
type
==
framework
::
DataTypeTrait
<
int64_t
>::
DataType
)
{
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
INT64
;
output
->
dtype
=
PaddleDType
::
INT64
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/api/api_impl_tester.cc
浏览文件 @
3bd54ed7
...
@@ -36,10 +36,10 @@ namespace paddle {
...
@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor
LodTensorToPaddleTensor
(
framework
::
LoDTensor
*
t
)
{
PaddleTensor
LodTensorToPaddleTensor
(
framework
::
LoDTensor
*
t
)
{
PaddleTensor
pt
;
PaddleTensor
pt
;
if
(
t
->
type
()
==
typeid
(
int64_t
)
)
{
if
(
t
->
type
()
==
framework
::
proto
::
VarType
::
INT64
)
{
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
int64_t
));
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
int64_t
));
pt
.
dtype
=
PaddleDType
::
INT64
;
pt
.
dtype
=
PaddleDType
::
INT64
;
}
else
if
(
t
->
type
()
==
typeid
(
float
)
)
{
}
else
if
(
t
->
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
float
));
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
float
));
pt
.
dtype
=
PaddleDType
::
FLOAT32
;
pt
.
dtype
=
PaddleDType
::
FLOAT32
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/api/paddle_pass_builder.h
浏览文件 @
3bd54ed7
...
@@ -118,7 +118,10 @@ class GpuPassStrategy : public PassStrategy {
...
@@ -118,7 +118,10 @@ class GpuPassStrategy : public PassStrategy {
public:
public:
GpuPassStrategy
()
:
PassStrategy
({})
{
GpuPassStrategy
()
:
PassStrategy
({})
{
passes_
.
assign
({
passes_
.
assign
({
"infer_clean_graph_pass"
,
"conv_bn_fuse_pass"
,
"infer_clean_graph_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_elementwise_add_act_fuse_pass"
,
//
"conv_elementwise_add2_act_fuse_pass"
,
//
});
});
}
}
...
...
paddle/fluid/inference/io.cc
浏览文件 @
3bd54ed7
...
@@ -79,7 +79,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
...
@@ -79,7 +79,7 @@ void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
for
(
auto
*
var
:
global_block
.
AllVars
())
{
for
(
auto
*
var
:
global_block
.
AllVars
())
{
if
(
IsPersistable
(
var
))
{
if
(
IsPersistable
(
var
))
{
VLOG
(
3
)
<<
"persistable variable's name: "
<<
var
->
Name
();
VLOG
(
4
)
<<
"persistable variable's name: "
<<
var
->
Name
();
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
new_var
->
SetShape
(
var
->
GetShape
());
new_var
->
SetShape
(
var
->
GetShape
());
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
3bd54ed7
...
@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
...
@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
}
}
for
(
size_t
i
=
0
;
i
<
a_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
a_size
;
i
++
)
{
if
(
a
.
type
()
==
typeid
(
float
)
)
{
if
(
a
.
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
const
auto
*
a_data
=
a
.
data
<
float
>
();
const
auto
*
a_data
=
a
.
data
<
float
>
();
const
auto
*
b_data
=
b
.
data
<
float
>
();
const
auto
*
b_data
=
b
.
data
<
float
>
();
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
...
@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
...
@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data
[
i
]);
b_data
[
i
]);
return
false
;
return
false
;
}
}
}
else
if
(
a
.
type
()
==
typeid
(
int64_t
)
)
{
}
else
if
(
a
.
type
()
==
framework
::
proto
::
VarType
::
INT64
)
{
const
auto
*
a_data
=
a
.
data
<
int64_t
>
();
const
auto
*
a_data
=
a
.
data
<
int64_t
>
();
const
auto
*
b_data
=
b
.
data
<
int64_t
>
();
const
auto
*
b_data
=
b
.
data
<
int64_t
>
();
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
...
...
paddle/fluid/inference/tests/api/trt_models_tester.cc
浏览文件 @
3bd54ed7
...
@@ -78,6 +78,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
...
@@ -78,6 +78,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
if
(
use_analysis
||
use_tensorrt
)
{
if
(
use_analysis
||
use_tensorrt
)
{
contrib
::
AnalysisConfig
config
(
true
);
contrib
::
AnalysisConfig
config
(
true
);
config
.
pass_builder
()
->
TurnOnDebug
();
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
config
,
model_dir
,
true
,
use_tensorrt
,
SetConfig
<
contrib
::
AnalysisConfig
>
(
&
config
,
model_dir
,
true
,
use_tensorrt
,
FLAGS_batch_size
);
FLAGS_batch_size
);
TestPrediction
(
reinterpret_cast
<
PaddlePredictor
::
Config
*>
(
&
config
),
TestPrediction
(
reinterpret_cast
<
PaddlePredictor
::
Config
*>
(
&
config
),
...
@@ -141,9 +142,31 @@ TEST(TensorRT_resnext50, profile) {
...
@@ -141,9 +142,31 @@ TEST(TensorRT_resnext50, profile) {
profile
(
model_dir
,
/* use_analysis */
true
,
FLAGS_use_tensorrt
);
profile
(
model_dir
,
/* use_analysis */
true
,
FLAGS_use_tensorrt
);
}
}
TEST
(
resnext50
,
compare_analysis_native
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/resnext50"
;
compare
(
model_dir
,
false
/*use tensorrt*/
);
}
TEST
(
TensorRT_mobilenet
,
analysis
)
{
TEST
(
TensorRT_mobilenet
,
analysis
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/"
+
"mobilenet"
;
std
::
string
model_dir
=
FLAGS_infer_model
+
"/"
+
"mobilenet"
;
compare
(
model_dir
,
/* use_tensorrt */
false
);
compare
(
model_dir
,
false
/* use_tensorrt */
);
}
TEST
(
AnalysisPredictor
,
use_gpu
)
{
std
::
string
model_dir
=
FLAGS_infer_model
+
"/"
+
"mobilenet"
;
AnalysisConfig
config
(
true
);
config
.
model_dir
=
model_dir
;
config
.
fraction_of_gpu_memory
=
0.15
;
config
.
pass_builder
()
->
TurnOnDebug
();
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
inputs_all
;
auto
predictor
=
CreatePaddlePredictor
(
config
);
SetFakeImageInput
(
&
inputs_all
,
model_dir
,
false
,
"__model__"
,
""
);
std
::
vector
<
PaddleTensor
>
outputs
;
for
(
auto
&
input
:
inputs_all
)
{
ASSERT_TRUE
(
predictor
->
Run
(
input
,
&
outputs
));
}
}
}
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/operators/affine_grid_op.cc
浏览文件 @
3bd54ed7
...
@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
...
@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library
=
framework
::
LibraryType
::
kCUDNN
;
library
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
auto
data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
()
);
auto
data_type
=
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
(
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library
);
framework
::
DataLayout
::
kAnyLayout
,
library
);
}
}
...
@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
...
@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
()
),
ctx
.
GetPlace
(
),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
}
};
};
...
...
paddle/fluid/operators/arg_max_op.cc
浏览文件 @
3bd54ed7
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_max_op.cu
浏览文件 @
3bd54ed7
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_min_op.cc
浏览文件 @
3bd54ed7
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_min_op.cu
浏览文件 @
3bd54ed7
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
3bd54ed7
...
@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
...
@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl
<
DeviceContext
>
functor
;
ArrayToLoDFunctorImpl
<
DeviceContext
>
functor
;
functor
.
dev_ctx_
=
dev_ctx
;
functor
.
dev_ctx_
=
dev_ctx
;
functor
.
prev_functor_
=
this
;
functor
.
prev_functor_
=
this
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
out
->
type
()
),
functor
);
framework
::
VisitDataType
(
out
->
type
(
),
functor
);
}
}
};
};
...
@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
...
@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
!
x
.
empty
(),
"There's no element in the input array."
);
PADDLE_ENFORCE
(
!
x
.
empty
(),
"There's no element in the input array."
);
int
rank
=
x
[
0
].
dims
().
size
();
int
rank
=
x
[
0
].
dims
().
size
();
platform
::
Place
place
=
x
[
0
].
place
();
platform
::
Place
place
=
x
[
0
].
place
();
std
::
type_index
data_type
=
x
[
0
].
type
();
auto
data_type
=
x
[
0
].
type
();
int64_t
batch_size
=
x
[
0
].
dims
()[
0
];
int64_t
batch_size
=
x
[
0
].
dims
()[
0
];
framework
::
DDim
ins_dims
=
rank
>
1
framework
::
DDim
ins_dims
=
rank
>
1
?
framework
::
slice_ddim
(
x
[
0
].
dims
(),
1
,
rank
)
?
framework
::
slice_ddim
(
x
[
0
].
dims
(),
1
,
rank
)
...
...
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
3bd54ed7
...
@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
AttentionLSTMOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
AttentionLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
void
AttentionLSTMOpMaker
::
Make
()
{
void
AttentionLSTMOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/average_accumulates_op.cc
浏览文件 @
3bd54ed7
...
@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
...
@@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
3bd54ed7
...
@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
// By default, the type of the scale, bias, mean,
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
// or double (For double input tensor).
...
@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP64
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP64
)
{
bn_param_type
=
framework
::
proto
::
VarType
::
FP64
;
bn_param_type
=
framework
::
proto
::
VarType
::
FP64
;
}
}
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
()),
"Scale input should be of float type"
);
"Scale input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
()),
"Bias input should be of float type"
);
"Bias input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Mean"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Mean"
)
->
type
()),
"Mean input should be of float type"
);
"Mean input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
(),
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
"Variance input should be of float type"
);
"Variance input should be of float type"
);
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...
@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
...
@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout
,
library
);
layout
,
library
);
}
}
};
};
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
3bd54ed7
...
@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
scores
->
at
(
0
).
type
()
),
scores
->
at
(
0
).
type
(
),
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
,
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
,
beam_size
,
end_id
));
beam_size
,
end_id
));
}
}
...
...
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
3bd54ed7
...
@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
...
@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
return
kt
;
return
kt
;
}
}
...
...
paddle/fluid/operators/bpr_loss_op.cc
浏览文件 @
3bd54ed7
...
@@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel {
...
@@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
...
@@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/controlflow/CMakeLists.txt
浏览文件 @
3bd54ed7
include
(
operators
)
include
(
operators
)
register_operators
()
register_operators
(
DEPS naive_executor
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
paddle/fluid/operators/controlflow/conditional_block_op.cc
浏览文件 @
3bd54ed7
...
@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
...
@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if
(
!
(
ips
.
size
()
==
1UL
&&
ips
[
0
]
->
IsInitialized
()))
{
if
(
!
(
ips
.
size
()
==
1UL
&&
ips
[
0
]
->
IsInitialized
()))
{
PADDLE_THROW
(
"should have one initialized input as condition"
);
PADDLE_THROW
(
"should have one initialized input as condition"
);
}
}
if
(
!
(
framework
::
IsType
<
bool
>
(
ips
[
0
]
->
type
())
&&
// NOLINT
ips
[
0
]
->
numel
()
==
1
))
{
PADDLE_ENFORCE
(
ips
[
0
]
->
type
()
==
framework
::
proto
::
VarType
::
BOOL
&&
PADDLE_THROW
(
ips
[
0
]
->
numel
()
==
1
,
"condition input's data type should be bool, "
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d"
,
"numel should be 1, actual numel is %d"
,
ips
[
0
]
->
numel
());
ips
[
0
]
->
numel
());
}
bool
res
=
false
;
bool
res
=
false
;
if
(
platform
::
is_gpu_place
(
ips
[
0
]
->
place
()))
{
if
(
platform
::
is_gpu_place
(
ips
[
0
]
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
3bd54ed7
...
@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase {
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
attrs
[
"dtype"
]
=
framework
::
ToDataType
(
inside_tensor
.
type
()
);
attrs
[
"dtype"
]
=
inside_tensor
.
type
(
);
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"value"
]
=
0.0
f
;
attrs
[
"value"
]
=
0.0
f
;
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
3bd54ed7
...
@@ -44,7 +44,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -44,7 +44,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
std
::
vector
<
int
>
dilations
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"dilations"
);
PADDLE_ENFORCE
(
in_dims
.
size
()
==
4
||
in_dims
.
size
()
==
5
,
PADDLE_ENFORCE
(
in_dims
.
size
()
==
4
||
in_dims
.
size
()
==
5
,
"Conv intput should be 4-D or 5-D tensor."
);
"Conv intput should be 4-D or 5-D tensor, get %u"
,
in_dims
.
size
());
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
filter_dims
.
size
(),
in_dims
.
size
(),
filter_dims
.
size
(),
"Conv input dimension and filter dimension should be the same."
);
"Conv input dimension and filter dimension should be the same."
);
...
@@ -95,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -95,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
}
#endif
#endif
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
());
auto
filter_data_type
=
ctx
.
Input
<
Tensor
>
(
"Filter"
)
->
type
();
auto
filter_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Filter"
)
->
type
());
PADDLE_ENFORCE_EQ
(
input_data_type
,
filter_data_type
,
PADDLE_ENFORCE_EQ
(
input_data_type
,
filter_data_type
,
"input and filter data type should be consistent"
);
"input and filter data type should be consistent"
);
...
@@ -382,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
...
@@ -382,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
()
,
ctx
.
GetPlace
(),
layout_
,
library_
,
layout_
,
library_
,
customized_type_value
);
customized_type_value
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/conv_transpose_op.cc
浏览文件 @
3bd54ed7
...
@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
...
@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
void
Conv2DTransposeOpMaker
::
Make
()
{
void
Conv2DTransposeOpMaker
::
Make
()
{
...
@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
...
@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/crf_decoding_op.cc
浏览文件 @
3bd54ed7
...
@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
...
@@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/crop_op.cc
浏览文件 @
3bd54ed7
...
@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
...
@@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
...
@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
3bd54ed7
...
@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/ctc_align_op.cc
浏览文件 @
3bd54ed7
...
@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
...
@@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/cudnn_lstm_op.cu.cc
浏览文件 @
3bd54ed7
...
@@ -300,9 +300,11 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
...
@@ -300,9 +300,11 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
}
}
CudnnRNNCache
*
cudnn_rnn_cache
=
nullptr
;
CudnnRNNCache
*
cudnn_rnn_cache
=
nullptr
;
if
(
cache_var
->
IsInitialized
())
{
if
(
cache_var
->
IsInitialized
())
{
// const_cast is usually bad.
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
->
GetMutable
<
CudnnRNNCache
>
();
->
GetMutable
<
CudnnRNNCache
>
();
}
else
{
}
else
{
// const_cast is usually bad.
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
->
GetMutable
<
CudnnRNNCache
>
();
->
GetMutable
<
CudnnRNNCache
>
();
std
::
random_device
rnd
;
std
::
random_device
rnd
;
...
...
paddle/fluid/operators/detection/anchor_generator_op.cc
浏览文件 @
3bd54ed7
...
@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
...
@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/bipartite_match_op.cc
浏览文件 @
3bd54ed7
...
@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
...
@@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
LoDTensor
>
(
"DistMat"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"DistMat"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/density_prior_box_op.cc
浏览文件 @
3bd54ed7
...
@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
...
@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
3bd54ed7
...
@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
...
@@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Anchors"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Anchors"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/mine_hard_examples_op.cc
浏览文件 @
3bd54ed7
...
@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
...
@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"ClsLoss"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"ClsLoss"
)
->
type
(),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/multiclass_nms_op.cc
浏览文件 @
3bd54ed7
...
@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
...
@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Scores"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Scores"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/prior_box_op.cc
浏览文件 @
3bd54ed7
...
@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
...
@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
浏览文件 @
3bd54ed7
...
@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
...
@@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
...
@@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/rpn_target_assign_op.cc
浏览文件 @
3bd54ed7
...
@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
...
@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Anchor"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Anchor"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/target_assign_op.cc
浏览文件 @
3bd54ed7
...
@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
...
@@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection_map_op.cc
浏览文件 @
3bd54ed7
...
@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
...
@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"DetectRes"
)
->
type
(),
ctx
.
Input
<
framework
::
Tensor
>
(
"DetectRes"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
3bd54ed7
...
@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
...
@@ -12,7 +12,7 @@ configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @O
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_
g
rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_rpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc collective_client.cc collective_server.cc
PROTO send_recv.proto
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory
)
DEPS lod_tensor selected_rows_functor memory
)
...
@@ -20,36 +20,43 @@ if(WITH_GRPC)
...
@@ -20,36 +20,43 @@ if(WITH_GRPC)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_
g
rpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_rpc scope profiler math_function SERIAL
)
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor
selected_rows_functor scope math_function SERIAL
)
selected_rows_functor scope math_function SERIAL
)
endif
()
endif
()
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_
g
rpc memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
else
()
else
()
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
set_source_files_properties
(
brpc_server.cc parameter_prefetch.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc collective_server.cc collective_server_test.cc
collective_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_library
(
sendrecvop_
b
rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_library
(
sendrecvop_rpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
collective_client.cc collective_server.cc
PROTO send_recv.proto
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
DEPS lod_tensor selected_rows memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_
b
rpc memory
)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
set
(
brpc_test_depends sendrecvop_rpc brpc ssl crypto protobuf leveldb gflags glog executor
proto_desc lookup_sparse_table_op snappystream snappy zlib
)
cc_test
(
b
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
DEPS
${
brpc_test_depends
}
SERIAL
)
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
DEPS
${
brpc_test_depends
}
SERIAL
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS
${
brpc_test_depends
}
selected_rows_functor scope math_function SERIAL
)
endif
()
endif
()
endif
()
paddle/fluid/operators/distributed/brpc_client.cc
浏览文件 @
3bd54ed7
...
@@ -14,135 +14,316 @@
...
@@ -14,135 +14,316 @@
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/operators/distributed/brpc_client.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
DEFINE_int32
(
brpc_channel_num
,
24
,
"Number of channels to send requests connected to one server"
);
DEFINE_int32
(
timeout_ms
,
30000
,
"RPC timeout in milliseconds"
);
DEFINE_int32
(
timeout_ms
,
30000
,
"RPC timeout in milliseconds"
);
DEFINE_int32
(
max_retry
,
3
,
"Max retries(not including the first RPC)"
);
DEFINE_int32
(
max_retry
,
3
,
"Max retries(not including the first RPC)"
);
BRPCClient
::~
BRPCClient
()
{
Wait
();
}
BRPCClient
::~
BRPCClient
()
{
Wait
();
}
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
,
sendrecv
::
VoidMessage
*
response
)
{
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VoidMessage
>
response_guard
(
response
);
std
::
unique_ptr
<
sendrecv
::
VoidMessage
>
response_guard
(
response
);
// this channel can be used by other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
LOG
(
FATAL
)
<<
"Fail to send SendVar: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
var_h
->
Finish
(
false
);
cls
->
DecreaseReqCount
();
return
;
return
;
}
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
var_h
->
Finish
(
true
);
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
cls
->
DecreaseReqCount
();
VLOG
(
4
)
<<
"HandleSendResponse from: "
<<
cntl
->
remote_side
()
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
VLOG
(
4
)
<<
"Finish HandleSendResponse"
;
}
}
bool
BRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"SendRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
framework
::
AsyncIO
(
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
[
var_name_val
,
p_ctx
,
ep_val
,
p_scope
,
time_out
,
ch_ptr
,
this
]
{
sendrecv
::
VariableMessage
request
;
auto
ch_ctx
=
ch_ptr
->
Pop
();
distributed
::
SerializeToIOBuf
(
var_name_val
,
var
,
*
p_ctx
,
&
request
,
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
&
cntl
->
request_attachment
(),
""
,
false
,
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
trainer_id_
);
cntl
->
set_timeout_ms
(
time_out
);
google
::
protobuf
::
Closure
*
done
=
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
);
&
HandleSendResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
sendrecv
::
VariableMessage
request
;
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
request
,
response
,
done
);
});
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
request
,
response
,
done
);
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
req_count_
++
;
return
true
;
return
var_h
;
}
}
void
HandleFetchBarrierResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
// this channel can be used other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
LOG
(
FATAL
)
<<
"Fail to get HandleFetchBarrierResponse: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
var_h
->
Finish
(
false
);
cls
->
DecreaseReqCount
();
return
;
}
var_h
->
Finish
(
true
);
cls
->
DecreaseReqCount
();
VLOG
(
4
)
<<
"HandleFetchBarrierResponse from: "
<<
cntl
->
remote_side
()
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
VLOG
(
4
)
<<
"Finish HandleFetchBarrierResponse"
;
}
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
)
{
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
)
{
// std::unique_ptr makes sure cntl/response will be deleted before returning.
// std::unique_ptr makes sure cntl/response will be deleted before returning.
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
brpc
::
Controller
>
cntl_guard
(
cntl
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
std
::
unique_ptr
<
sendrecv
::
VariableMessage
>
response_guard
(
response
);
// this channel can be used other now.
ch_ptr
->
Push
(
ch_ctx
);
if
(
cntl
->
Failed
())
{
if
(
cntl
->
Failed
())
{
LOG
(
WARNING
)
<<
"Fail to send EchoRequest, "
<<
cntl
->
ErrorText
();
LOG
(
FATAL
)
<<
"Fail to GetVar: "
<<
var_h
->
name
()
<<
", error text: "
<<
cntl
->
ErrorText
();
cls
->
DecreaseReqCount
();
var_h
->
Finish
(
false
);
return
;
return
;
}
}
LOG
(
INFO
)
<<
"Received response from "
<<
cntl
->
remote_side
()
<<
" latency="
<<
cntl
->
latency_us
()
<<
"us"
;
// framework::Variable* outvar = nullptr;
VLOG
(
4
)
<<
"HandleGetResponse from: "
<<
cntl
->
remote_side
()
// DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
<<
", varname: "
<<
var_h
->
name
()
<<
", latency: "
<<
cntl
->
latency_us
()
<<
"us"
;
framework
::
Variable
*
outvar
=
nullptr
;
int
trainer_id
;
distributed
::
DeserializeFromIOBuf
(
*
response
,
cntl
->
response_attachment
(),
*
var_h
->
ctx
(),
var_h
->
scope
(),
&
outvar
,
&
trainer_id
);
VLOG
(
4
)
<<
"Finish HandleGetResponse"
;
cls
->
DecreaseReqCount
();
var_h
->
Finish
(
true
);
}
}
bool
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
const
std
::
string
&
var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"GetRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
framework
::
AsyncIO
(
sendrecv
::
VariableMessage
req
;
[
var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
time_out
,
ch
,
this
]
{});
req
.
set_varname
(
var_name_val
);
req
.
set_trainer_id
(
trainer_id_
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleGetResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
if
(
method_name
==
"GetMonomerVariable"
)
{
ch_ctx
->
stub
->
GetMonomerVariable
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
GetVariable
(
cntl
,
&
req
,
response
,
done
);
}
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
req_count_
++
;
return
true
;
return
var_h
;
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetMonomerVariable"
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"GetMonomerBarrier"
,
var_name
,
time_out
);
}
}
bool
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetVariable"
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"PrefetchRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
sendrecv
::
VariableMessage
req
;
distributed
::
SerializeToIOBuf
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
&
cntl
->
request_attachment
(),
out_var_name_val
,
false
,
0
,
table_name_val
);
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleGetResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
ch_ctx
->
stub
->
PrefetchVariable
(
cntl
,
&
req
,
response
,
done
);
time_out
,
ch
,
this
]
{});
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
});
req_count_
++
;
req_count_
++
;
return
true
;
return
var_h
;
}
}
void
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
int64_t
time_out
)
{
req_count_
++
;
return
AsyncSendMessage
(
ep
,
"BatchBarrierRPC"
,
BATCH_BARRIER_MESSAGE
,
time_out
);
}
}
void
BRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
VarHandlePtr
BRPCClient
::
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
int64_t
time_out
)
{
auto
ch_ptr
=
GetChannel
(
ep
);
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VariableMessage
*
response
=
new
sendrecv
::
VariableMessage
();
cntl
->
set_timeout_ms
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
const
std
::
string
method
=
"FetchBarrierRPC"
;
// var handle
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
platform
::
RecordRPCEvent
record_event
(
method
,
nullptr
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleFetchBarrierResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
ch_ctx
->
stub
->
GetVariable
(
cntl
,
&
req
,
response
,
done
);
req_count_
++
;
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
return
var_h
;
}
}
void
BRPCClient
::
Wait
()
{
bool
BRPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
VLOG
(
9
)
<<
"begin to brpcclient wait"
;
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
VLOG
(
9
)
<<
"end to brpcclient wait"
;
return
true
;
}
}
ChannelQueuePtr
BRPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
ChannelQueuePtr
BRPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
VLOG
(
4
)
<<
"begin to GetChannel:"
<<
ep
;
{
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
auto
it
=
channels_
.
find
(
ep
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
if
(
it
!=
channels_
.
end
())
{
VLOG
(
4
)
<<
"end to GetChannel:"
<<
ep
;
return
it
->
second
;
return
it
->
second
;
}
}
}
}
...
@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
...
@@ -150,12 +331,20 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
ChannelQueuePtr
q
(
new
framework
::
BlockingQueue
<
ChannelContextPtr
>
());
ChannelQueuePtr
q
(
new
framework
::
BlockingQueue
<
ChannelContextPtr
>
());
brpc
::
ChannelOptions
options
;
brpc
::
ChannelOptions
options
;
#ifdef PADDLE_WITH_BRPC_RDMA
options
.
use_rdma
=
true
;
#endif
options
.
protocol
=
"baidu_std"
;
options
.
protocol
=
"baidu_std"
;
options
.
connection_type
=
"pooled"
;
// don't use pooled type. the server can't afford that.
options
.
connect_timeout_ms
=
100
;
options
.
connection_type
=
"single"
;
options
.
connect_timeout_ms
=
1000
;
options
.
timeout_ms
=
FLAGS_timeout_ms
/*milliseconds*/
;
options
.
timeout_ms
=
FLAGS_timeout_ms
/*milliseconds*/
;
options
.
max_retry
=
FLAGS_max_retry
;
options
.
max_retry
=
FLAGS_max_retry
;
for
(
int
i
=
0
;
i
<
FLAGS_brpc_channel_num
;
++
i
)
{
VLOG
(
1
)
<<
"create "
<<
brpc_channel_num_per_server_
<<
" brpc channels to pserver:"
<<
ep
;
for
(
int
i
=
0
;
i
<
brpc_channel_num_per_server_
;
++
i
)
{
std
::
shared_ptr
<
ChannelContext
>
c
(
new
ChannelContext
());
std
::
shared_ptr
<
ChannelContext
>
c
(
new
ChannelContext
());
if
(
c
->
channel
.
Init
(
ep
.
c_str
(),
&
options
)
!=
0
)
{
if
(
c
->
channel
.
Init
(
ep
.
c_str
(),
&
options
)
!=
0
)
{
LOG
(
FATAL
)
<<
"Fail to initialize channel"
;
LOG
(
FATAL
)
<<
"Fail to initialize channel"
;
...
@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
...
@@ -172,9 +361,75 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
channels_
[
ep
]
=
q
;
channels_
[
ep
]
=
q
;
}
}
VLOG
(
4
)
<<
"end to GetChannel:"
<<
ep
;
return
q
;
return
q
;
}
}
VarHandlePtr
BRPCClient
::
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"SendCompleteRPC"
,
COMPLETE_MESSAGE
,
time_out
);
}
void
BRPCClient
::
SendComplete
()
{
for
(
auto
&
kv
:
channels_
)
{
AsyncSendComplete
(
kv
.
first
);
}
}
VarHandlePtr
BRPCClient
::
AsyncSendVarMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
sendrecv
::
VariableMessage
&
req
,
int64_t
time_out
)
{
auto
ch_ptr
=
GetChannel
(
ep
);
auto
ch_ctx
=
ch_ptr
->
Pop
();
brpc
::
Controller
*
cntl
=
new
brpc
::
Controller
();
sendrecv
::
VoidMessage
*
response
=
new
sendrecv
::
VoidMessage
();
cntl
->
set_timeout_ms
(
time_out
);
platform
::
RecordRPCEvent
record_event
(
method_name
,
nullptr
);
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method_name
,
req
.
varname
(),
nullptr
,
nullptr
));
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
if
(
method_name
==
"CheckPointNotifyRPC"
)
{
ch_ctx
->
stub
->
CheckpointNotify
(
cntl
,
&
req
,
response
,
done
);
}
else
if
(
method_name
==
"GetMonomerBarrier"
)
{
ch_ctx
->
stub
->
GetMonomerBarrier
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
req
,
response
,
done
);
}
req_count_
++
;
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
var_h
->
Wait
();
}
return
var_h
;
}
VarHandlePtr
BRPCClient
::
AsyncSendMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
std
::
string
&
message
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
message
);
return
AsyncSendVarMessage
(
ep
,
method_name
,
req
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
)
{
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
CHECKPOINT_SAVE_MESSAGE
);
req
.
set_out_varname
(
dir
);
return
AsyncSendVarMessage
(
ep
,
"CheckPointNotifyRPC"
,
req
,
time_out
);
}
}
// namespace distributed
}
// namespace distributed
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_client.h
浏览文件 @
3bd54ed7
...
@@ -31,6 +31,8 @@ limitations under the License. */
...
@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...
@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
...
@@ -53,33 +55,94 @@ class BRPCClient : public RPCClient {
BRPCClient
()
{}
BRPCClient
()
{}
virtual
~
BRPCClient
();
virtual
~
BRPCClient
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
VarHandlePtr
AsyncSendVar
(
const
std
::
string
&
ep
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
platform
::
DeviceContext
&
ctx
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
VarHandlePtr
AsyncGetVar
(
const
std
::
string
&
ep
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
platform
::
DeviceContext
&
ctx
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
VarHandlePtr
AsyncGetMonomerBarrier
(
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
VarHandlePtr
AsyncGetMonomerVariable
(
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
Wait
()
override
;
VarHandlePtr
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
bool
Wait
()
override
;
void
SendComplete
()
override
;
private:
private:
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
void
Proceed
();
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
VarHandlePtr
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
VarHandlePtr
AsyncSendMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
std
::
string
&
message
,
int64_t
time_out
);
VarHandlePtr
AsyncSendVarMessage
(
const
std
::
string
&
ep
,
const
std
::
string
&
method_name
,
const
sendrecv
::
VariableMessage
&
req
,
int64_t
time_out
);
friend
void
HandleSendResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VoidMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
friend
void
HandleGetResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
friend
void
HandleFetchBarrierResponse
(
brpc
::
Controller
*
cntl
,
sendrecv
::
VariableMessage
*
response
,
VarHandlePtr
var_h
,
ChannelQueuePtr
ch_ptr
,
ChannelContextPtr
ch_ctx
,
BRPCClient
*
cls
);
void
DecreaseReqCount
()
{
if
(
--
req_count_
<=
0
)
{
sync_cond_
.
notify_all
();
}
}
private:
private:
std
::
unordered_map
<
std
::
string
,
ChannelQueuePtr
>
channels_
;
std
::
unordered_map
<
std
::
string
,
ChannelQueuePtr
>
channels_
;
...
@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
...
@@ -88,6 +151,8 @@ class BRPCClient : public RPCClient {
std
::
condition_variable
sync_cond_
;
std
::
condition_variable
sync_cond_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
atomic
<
int64_t
>
req_count_
{
0
};
static
constexpr
int
brpc_channel_num_per_server_
=
4
;
// mutex for GetChannel thread safety
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
std
::
mutex
chan_mutex_
;
DISABLE_COPY_AND_ASSIGN
(
BRPCClient
);
DISABLE_COPY_AND_ASSIGN
(
BRPCClient
);
...
...
paddle/fluid/operators/distributed/brpc_rdma_pool.cc
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_BRPC_RDMA
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "brpc/channel.h"
#include "brpc/rdma/rdma_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
RdmaMemPool
&
RdmaMemPool
::
Instance
()
{
static
RdmaMemPool
*
g_rdma_mem_pool
=
new
RdmaMemPool
();
return
*
g_rdma_mem_pool
;
}
void
*
RdmaMemPool
::
Find
(
const
std
::
string
&
varname
,
int64_t
size
)
{
pthread_rwlock_rdlock
(
&
access_
);
auto
it
=
pool_
.
find
(
varname
);
if
(
it
==
pool_
.
end
())
{
pthread_rwlock_unlock
(
&
access_
);
return
nullptr
;
}
auto
info
=
it
->
second
;
if
(
info
.
data_size
!=
size
)
{
pthread_rwlock_unlock
(
&
access_
);
PADDLE_ENFORCE
(
false
,
"var:%s size:%ld != %ld"
,
varname
,
size
,
info
.
data_size
);
return
nullptr
;
}
pthread_rwlock_unlock
(
&
access_
);
return
info
.
data
;
}
void
RdmaMemPool
::
Register
(
const
std
::
string
&
varname
,
void
*
data
,
int64_t
data_size
)
{
void
*
old
=
Find
(
varname
,
data_size
);
if
(
old
!=
nullptr
)
{
if
(
data
!=
old
)
{
PADDLE_ENFORCE
(
false
,
"var:%s data:%ld != %ld"
,
varname
,
data
,
old
);
}
VLOG
(
7
)
<<
"Find on rdma:"
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
;
return
;
}
VarInfo
info
;
info
.
data
=
data
;
info
.
data_size
=
data_size
;
pthread_rwlock_wrlock
(
&
access_
);
pool_
[
varname
]
=
info
;
pthread_rwlock_unlock
(
&
access_
);
if
(
brpc
::
rdma
::
RegisterMemoryForRdma
(
data
,
data_size
))
{
LOG
(
FATAL
)
<<
"register "
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
<<
" error"
;
}
VLOG
(
4
)
<<
"register on rdma:"
<<
varname
<<
" data:"
<<
data
<<
" data_size:"
<<
data_size
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/distributed/brpc_rdma_pool.h
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_BRPC_RDMA
#include <pthread.h> // NOLINT
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
/*
* This class is used to avoid duplicated registion of brpc::rdma.
*/
class
RdmaMemPool
{
public:
static
RdmaMemPool
&
Instance
();
RdmaMemPool
()
:
access_
(
PTHREAD_RWLOCK_INITIALIZER
)
{}
virtual
~
RdmaMemPool
()
{
pthread_rwlock_destroy
(
&
access_
);
}
void
Register
(
const
std
::
string
&
varname
,
void
*
data
,
int64_t
size
);
void
*
Find
(
const
std
::
string
&
varname
,
int64_t
size
);
private:
struct
VarInfo
{
void
*
data
;
int64_t
data_size
;
VarInfo
()
:
data
(
nullptr
),
data_size
(
0
)
{}
};
private:
std
::
unordered_map
<
std
::
string
,
VarInfo
>
pool_
;
pthread_rwlock_t
access_
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
#endif
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.cc
0 → 100644
浏览文件 @
3bd54ed7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
IOBufWriter
{
public:
static
void
Append
(
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
)
{
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
iobuf
->
append
(
v
,
vlen
);
}
static
void
AppendTCPZeroCopy
(
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
VLOG
(
7
)
<<
"AppendTCPZeroCopy "
<<
" k:"
<<
k
<<
" data:"
<<
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
))
<<
" data_size:"
<<
vlen
<<
" in_cuda_pinned:"
<<
in_cuda_pinned
;
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
// FIXME(gongwb): use append_zerocopy
/*
if (in_cuda_pinned) {
iobuf->append_zerocopy(v, vlen, IOBufWriter::FreeMemory);
} else {
iobuf->append_zerocopy(v, vlen, nullptr);
}
*/
iobuf
->
append
(
v
,
vlen
);
destroy
(
user_data
);
}
#ifdef PADDLE_WITH_BRPC_RDMA
static
void
AppendRdmaZeroCopy
(
const
std
::
string
varname
,
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
VLOG
(
7
)
<<
"AppendRdmaZeroCopy varname:"
<<
varname
<<
" k:"
<<
k
<<
" data:"
<<
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
))
<<
" data_size:"
<<
vlen
<<
" in_cuda_pinned:"
<<
in_cuda_pinned
;
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
k
),
4
);
iobuf
->
append
(
reinterpret_cast
<
char
*>
(
&
vlen
),
8
);
RdmaMemPool
::
Instance
().
Register
(
varname
,
static_cast
<
void
*>
(
const_cast
<
char
*>
(
v
)),
vlen
);
// FIXME(gongwb): use append_zerocopy
// iobuf->append_zerocopy(v, vlen, nullptr);
iobuf
->
append
(
v
,
vlen
);
destroy
(
user_data
);
return
;
}
#endif
static
void
AppendZeroCopy
(
const
std
::
string
varname
,
butil
::
IOBuf
*
iobuf
,
int
k
,
const
char
*
v
,
int64_t
vlen
,
bool
in_cuda_pinned
,
void
(
*
destroy
)(
void
*
),
void
*
user_data
)
{
#ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter
::
AppendRdmaZeroCopy
(
varname
,
iobuf
,
k
,
v
,
vlen
,
in_cuda_pinned
,
destroy
,
user_data
);
#else
IOBufWriter
::
AppendTCPZeroCopy
(
iobuf
,
k
,
v
,
vlen
,
in_cuda_pinned
,
destroy
,
user_data
);
#endif
}
};
void
SerializeToIOBuf
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
,
const
std
::
string
&
out_varname
,
bool
var_is_not_stable
,
int
trainer_id
,
const
std
::
string
&
table_name
)
{
std
::
unique_ptr
<
TensorPayload
>
payload
;
request
->
set_varname
(
name
);
request
->
set_trainer_id
(
trainer_id
);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
->
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
->
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_varname
.
empty
())
{
request
->
set_out_varname
(
out_varname
);
}
if
(
!
table_name
.
empty
())
{
request
->
set_table_name
(
table_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
->
set_type
(
::
sendrecv
::
LOD_TENSOR
);
payload
.
reset
(
new
TensorPayload
(
GetTensorPayload
(
var
,
ctx
,
request
)));
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
request
->
set_type
(
::
sendrecv
::
SELECTED_ROWS
);
payload
.
reset
(
new
TensorPayload
(
GetSelectedRowsPayload
(
var
,
ctx
,
request
)));
#ifdef PADDLE_WITH_CUDA
}
else
if
(
var
->
IsType
<
ncclUniqueId
>
())
{
request
->
set_type
(
::
sendrecv
::
NCCL_ID
);
const
ncclUniqueId
&
uid
=
var
->
Get
<
ncclUniqueId
>
();
// TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter
::
Append
(
iobuf
,
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
uid
.
internal
,
NCCL_UNIQUE_ID_BYTES
);
return
;
#endif
}
else
{
PADDLE_THROW
(
"Serialize does not support type: %s"
,
typeid
(
var
->
Type
()).
name
());
}
PADDLE_ENFORCE_NOT_NULL
(
payload
);
// FIXME(gongwb): it seems that can use zero copy.
if
(
var_is_not_stable
)
{
IOBufWriter
::
Append
(
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
());
}
else
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
IOBufWriter
::
AppendZeroCopy
(
name
,
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
(),
true
,
SerializeDestroyCallback
,
static_cast
<
void
*>
(
payload
.
get
()));
payload
.
release
();
#endif
}
else
{
IOBufWriter
::
AppendZeroCopy
(
name
,
iobuf
,
::
sendrecv
::
VariableMessage
::
kSerializedFieldNumber
,
static_cast
<
const
char
*>
(
payload
->
ptr
()),
payload
->
memory_size
(),
false
,
SerializeDestroyCallback
,
static_cast
<
void
*>
(
payload
.
get
()));
payload
.
release
();
}
}
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
IOBufWriter
::
Append
(
iobuf
,
::
sendrecv
::
VariableMessage
::
kRowsFieldNumber
,
reinterpret_cast
<
const
char
*>
(
slr
->
rows
().
data
()),
static_cast
<
int64_t
>
(
rows_memory_size
));
}
}
void
DeserializeFromIOBuf
(
const
::
sendrecv
::
VariableMessage
&
meta
,
const
butil
::
IOBuf
&
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
)
{
operators
::
distributed
::
BRPCVariableResponse
resp
(
scope
,
&
ctx
);
PADDLE_ENFORCE
(
resp
.
Parse
(
iobuf
,
meta
)
==
0
,
"parse iobuf to tensor error!"
);
*
var
=
resp
.
GetVar
();
*
trainer_id
=
resp
.
GetTrainerId
();
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h
0 → 100644
浏览文件 @
3bd54ed7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
SerializeToIOBuf
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
,
butil
::
IOBuf
*
iobuf
,
const
std
::
string
&
out_varname
,
bool
var_is_not_stable
,
const
int
trainer_id
=
0
,
const
std
::
string
&
table_name
=
std
::
string
());
void
DeserializeFromIOBuf
(
const
VarMsg
&
meta
,
const
butil
::
IOBuf
&
iobuf
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
*
scope
,
framework
::
Variable
**
var
,
int
*
trainer_id
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_serde_test.cc
0 → 100644
浏览文件 @
3bd54ed7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "brpc/channel.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace
framework
=
paddle
::
framework
;
namespace
platform
=
paddle
::
platform
;
namespace
operators
=
paddle
::
operators
;
namespace
math
=
paddle
::
operators
::
math
;
namespace
memory
=
paddle
::
memory
;
void
RunSerdeTestSelectedRows
(
platform
::
Place
place
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
butil
::
IOBuf
iobuf
;
sendrecv
::
VariableMessage
msg
;
int
tensor_numel
=
564
*
128
;
// serialize var to IOBuf
{
framework
::
Variable
var
;
auto
*
slr
=
var
.
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
set_height
(
1000
);
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
tensor
->
Resize
(
framework
::
make_ddim
({
564
,
128
}));
tensor
->
mutable_data
<
float
>
(
place
);
math
::
set_constant
(
ctx
,
tensor
,
32.7
);
for
(
int
i
=
0
;
i
<
564
;
++
i
)
rows
->
push_back
(
i
);
operators
::
distributed
::
SerializeToIOBuf
(
"myvar"
,
&
var
,
ctx
,
&
msg
,
&
iobuf
,
""
,
false
);
}
// desrialize
{
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
BRPCVariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
iobuf
,
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
*
slr2
=
var2
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
tensor2
=
slr2
->
mutable_value
();
auto
*
rows2
=
slr2
->
mutable_rows
();
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
*
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
->
data
<
float
>
());
}
const
int64_t
*
rows_data2
=
rows2
->
data
();
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
{
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
32.7
);
}
for
(
size_t
i
=
0
;
i
<
rows2
->
size
();
++
i
)
{
EXPECT_EQ
(
rows_data2
[
i
],
static_cast
<
int64_t
>
(
i
));
}
EXPECT_EQ
(
slr2
->
height
(),
1000
);
}
}
void
RunTestLodTensor
(
platform
::
Place
place
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
// serialize var to ByteBuffer
butil
::
IOBuf
iobuf
;
sendrecv
::
VariableMessage
msg
;
int
tensor_numel
=
512
*
8
*
4
*
2
;
{
framework
::
Variable
var
;
auto
*
tensor
=
var
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
512
,
8
,
4
,
2
}));
framework
::
LoD
lod
;
lod
.
push_back
(
framework
::
Vector
<
size_t
>
({
1
,
3
,
8
}));
tensor
->
set_lod
(
lod
);
tensor
->
mutable_data
<
float
>
(
place
);
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
operators
::
distributed
::
SerializeToIOBuf
(
"myvar"
,
&
var
,
ctx
,
&
msg
,
&
iobuf
,
""
,
false
);
}
// check sendrecv::VariableMessage meta data
{
EXPECT_EQ
(
msg
.
varname
(),
"myvar"
);
EXPECT_EQ
(
msg
.
type
(),
0
);
EXPECT_EQ
(
msg
.
dims
()[
0
],
512
);
EXPECT_EQ
(
msg
.
dims
()[
1
],
8
);
EXPECT_EQ
(
msg
.
dims
()[
2
],
4
);
EXPECT_EQ
(
msg
.
dims
()[
3
],
2
);
EXPECT_EQ
(
msg
.
lod_level
(),
1
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
0
),
1
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
1
),
3
);
EXPECT_EQ
(
msg
.
lod
(
0
).
lod_data
(
2
),
8
);
}
// deserialize
{
framework
::
Scope
scope
;
scope
.
Var
(
"myvar"
);
operators
::
distributed
::
BRPCVariableResponse
resp
(
&
scope
,
&
ctx
);
EXPECT_EQ
(
resp
.
Parse
(
iobuf
,
msg
),
0
);
framework
::
Variable
*
var2
=
resp
.
GetVar
();
auto
tensor2
=
var2
->
Get
<
framework
::
LoDTensor
>
();
float
*
tensor_data2
=
nullptr
;
framework
::
Tensor
tmp_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
CPUPlace
cpu
;
framework
::
TensorCopy
(
tensor2
,
cpu
,
&
tmp_tensor
);
tensor_data2
=
tmp_tensor
.
data
<
float
>
();
}
else
{
tensor_data2
=
const_cast
<
float
*>
(
tensor2
.
data
<
float
>
());
}
for
(
int
i
=
0
;
i
<
tensor_numel
;
++
i
)
EXPECT_FLOAT_EQ
(
tensor_data2
[
i
],
31.9
);
}
}
TEST
(
LodTensor
,
Run
)
{
platform
::
CPUPlace
place
;
RunTestLodTensor
(
place
);
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPlace
gpu
(
0
);
RunTestLodTensor
(
gpu
);
#endif
}
TEST
(
SelectedRows
,
Run
)
{
platform
::
CPUPlace
place
;
RunSerdeTestSelectedRows
(
place
);
#ifdef PADDLE_WITH_CUDA
platform
::
CUDAPlace
gpu
;
RunSerdeTestSelectedRows
(
gpu
);
#endif
}
paddle/fluid/operators/distributed/brpc_server.cc
浏览文件 @
3bd54ed7
...
@@ -13,84 +13,287 @@
...
@@ -13,84 +13,287 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/operators/distributed/brpc_server.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/brpc_sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
namespace
sendrecv
{
namespace
sendrecv
{
typedef
std
::
unordered_map
<
std
::
string
,
namespace
distributed
=
paddle
::
operators
::
distributed
;
paddle
::
operators
::
distributed
::
RequestHandler
*>
typedef
std
::
unordered_map
<
std
::
string
,
distributed
::
RequestHandler
*>
HandlerMap
;
HandlerMap
;
class
BRPCServiceImpl
:
public
SendRecvService
{
class
BRPCServiceImpl
:
public
SendRecvService
{
public:
public:
explicit
BRPCServiceImpl
(
const
HandlerMap
&
rpc_call_map
)
explicit
BRPCServiceImpl
(
const
HandlerMap
&
rpc_call_map
,
:
request_send_h_
(
nullptr
),
distributed
::
RPCServer
*
rpc_server
)
request_get_h_
(
nullptr
),
:
rpc_server_
(
rpc_server
)
{
request_prefetch_h_
(
nullptr
)
{
VLOG
(
3
)
<<
"BRPCServiceImpl size: "
<<
rpc_call_map
.
size
();
auto
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestSend
);
auto
it
=
rpc_call_map
.
find
(
distributed
::
kRequestSend
);
if
(
it
!=
rpc_call_map
.
end
())
{
if
(
it
!=
rpc_call_map
.
end
())
{
request_send_h_
=
it
->
second
;
request_send_h_
=
it
->
second
;
send_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestSend
)));
}
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestSend
);
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGet
);
if
(
it
!=
rpc_call_map
.
end
())
{
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_h_
=
it
->
second
;
request_get_h_
=
it
->
second
;
get_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestGet
)));
}
}
it
=
rpc_call_map
.
find
(
paddle
::
operators
::
distributed
::
kRequestPrefetch
);
it
=
rpc_call_map
.
find
(
distributed
::
kRequestPrefetch
);
if
(
it
!=
rpc_call_map
.
end
())
{
if
(
it
!=
rpc_call_map
.
end
())
{
request_prefetch_h_
=
it
->
second
;
request_prefetch_h_
=
it
->
second
;
prefetch_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestPrefetch
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestCheckpoint
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_checkpoint_h_
=
it
->
second
;
checkpoint_notify_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestPrefetch
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGetMonomerVariable
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_monomer_handler_h_
=
it
->
second
;
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGetMonomerBarrier
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_get_monomer_barrier_handler_h_
=
it
->
second
;
}
}
}
}
virtual
~
BRPCServiceImpl
()
{}
virtual
~
BRPCServiceImpl
()
{}
void
SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
void
SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
google
::
protobuf
::
Closure
*
done
)
override
{
send_threads_
->
Run
(
[
=
]
{
_SendVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_SendVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_send_h_
!=
nullptr
,
PADDLE_ENFORCE
(
request_send_h_
!=
nullptr
,
"RequestSend handler should be registed first!"
);
"RequestSend handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
paddle
::
framework
::
Scope
*
local_scope
=
request_send_h_
->
scope
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
std
::
string
varname
=
request
->
varname
();
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestSend var_name:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
if
(
!
request_send_h_
->
sync_mode
())
{
distributed
::
BRPCVariableResponse
resp
(
request_send_h_
->
scope
(),
local_scope
=
&
request_send_h_
->
scope
()
->
NewScope
();
request_send_h_
->
dev_ctx
(),
invar
=
local_scope
->
Var
(
varname
);
!
request_send_h_
->
sync_mode
());
}
else
{
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
invar
=
local_scope
->
FindVar
(
varname
);
"parse iobuf to tensor error!"
);
}
request_send_h_
->
Handle
(
varname
,
local_scope
,
invar
,
&
outvar
);
auto
scope
=
resp
.
GetMutableLocalScope
();
auto
invar
=
resp
.
GetVar
();
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
if
(
!
request_send_h_
->
sync_mode
())
{
request_send_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_send_h_
->
scope
()
->
DeleteScope
(
local_scope
);
}
}
}
void
GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
void
GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
google
::
protobuf
::
Closure
*
done
)
override
{
get_threads_
->
Run
(
[
=
]
{
_GetVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_get_h_
!=
nullptr
,
PADDLE_ENFORCE
(
request_get_h_
!=
nullptr
,
"RequestGet handler should be registed first!"
);
"RequestGet handler should be registed first!"
);
}
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestGet varname:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
auto
scope
=
request_get_h_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
varname
,
outvar
,
*
request_get_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
void
PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
google
::
protobuf
::
Closure
*
done
)
override
{
prefetch_threads_
->
Run
(
[
=
]
{
_PrefetchVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_prefetch_h_
!=
nullptr
,
PADDLE_ENFORCE
(
request_prefetch_h_
!=
nullptr
,
"kRequestPrefetch handler should be registed first!"
);
"kRequestPrefetch handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
// prefetch process...
std
::
string
in_var_name
=
request
->
varname
();
std
::
string
out_var_name
=
request
->
out_varname
();
VLOG
(
3
)
<<
"RequestPrefetch, in_var_name: "
<<
in_var_name
<<
", out_var_name: "
<<
out_var_name
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
distributed
::
BRPCVariableResponse
resp
(
request_prefetch_h_
->
scope
(),
request_prefetch_h_
->
dev_ctx
(),
true
);
PADDLE_ENFORCE
(
resp
.
Parse
(
cntl
->
request_attachment
(),
*
request
)
==
0
,
"parse iobuf to tensor error!"
);
auto
scope
=
resp
.
GetMutableLocalScope
();
auto
invar
=
scope
->
FindVar
(
in_var_name
);
std
::
string
table_name
=
request
->
table_name
();
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
scope
->
Var
(
out_var_name
);
request_prefetch_h_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_var_name
,
table_name
);
distributed
::
SerializeToIOBuf
(
out_var_name
,
outvar
,
*
request_prefetch_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
true
);
}
void
CheckpointNotify
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
checkpoint_notify_threads_
->
Run
(
[
=
]
{
_CheckpointNotify
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_CheckpointNotify
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_checkpoint_h_
!=
nullptr
,
"kRequestCheckpointNotify handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
distributed
::
BRPCVariableResponse
resp
(
request_checkpoint_h_
->
scope
(),
request_checkpoint_h_
->
dev_ctx
());
auto
scope
=
resp
.
GetMutableLocalScope
();
std
::
string
checkpoint_notify
=
request
->
varname
();
std
::
string
checkpoint_dir
=
request
->
out_varname
();
int
trainer_id
=
request
->
trainer_id
();
VLOG
(
4
)
<<
"RequestCheckpointNotify notify: "
<<
checkpoint_notify
<<
", dir: "
<<
checkpoint_dir
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
request_checkpoint_h_
->
Handle
(
checkpoint_notify
,
scope
,
nullptr
,
nullptr
,
trainer_id
,
checkpoint_dir
);
}
void
GetMonomerVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_get_monomer_handler_h_
!=
nullptr
,
"kRequestGetMonomerVariable handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
// proc request.
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"GetMonomerVariable "
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
rpc_server_
->
WaitVarCond
(
varname
);
distributed
::
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
auto
scope
=
h
.
scope_
;
auto
invar
=
scope
->
FindVar
(
varname
);
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_monomer_handler_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request
->
trainer_id
());
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
varname
,
outvar
,
*
h
.
dev_ctx_
,
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
GetMonomerBarrier
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VoidMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
PADDLE_ENFORCE
(
request_get_monomer_barrier_handler_h_
!=
nullptr
,
"RequestGetMonomerBarrier handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
VLOG
(
3
)
<<
"RequestGetMonomerBarrier var_name:"
<<
varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
rpc_server_
->
WaitVarCond
(
varname
);
distributed
::
MonomerHandle
h
=
rpc_server_
->
GetMonomer
(
varname
);
paddle
::
framework
::
Scope
*
scope
=
nullptr
;
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_monomer_barrier_handler_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
request
->
trainer_id
());
}
}
private:
private:
paddle
::
operators
::
distributed
::
RequestHandler
*
request_send_h_
;
distributed
::
RequestHandler
*
request_send_h_
{
nullptr
};
paddle
::
operators
::
distributed
::
RequestHandler
*
request_get_h_
;
distributed
::
RequestHandler
*
request_get_h_
{
nullptr
};
paddle
::
operators
::
distributed
::
RequestHandler
*
request_prefetch_h_
;
distributed
::
RequestHandler
*
request_prefetch_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_checkpoint_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_monomer_handler_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_monomer_barrier_handler_h_
{
nullptr
};
distributed
::
RPCServer
*
rpc_server_
{
nullptr
};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
send_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
get_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
prefetch_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
checkpoint_notify_threads_
;
};
};
}
// namespace sendrecv
}
// namespace sendrecv
...
@@ -100,7 +303,7 @@ namespace distributed {
...
@@ -100,7 +303,7 @@ namespace distributed {
void
AsyncBRPCServer
::
StartServer
()
{
void
AsyncBRPCServer
::
StartServer
()
{
// Instance of your service.
// Instance of your service.
sendrecv
::
BRPCServiceImpl
service_impl
(
rpc_call_map_
);
sendrecv
::
BRPCServiceImpl
service_impl
(
rpc_call_map_
,
this
);
// Add the service into server. Notice the second parameter, because the
// Add the service into server. Notice the second parameter, because the
// service is put on stack, we don't want server to delete it, otherwise
// service is put on stack, we don't want server to delete it, otherwise
...
@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
...
@@ -111,6 +314,9 @@ void AsyncBRPCServer::StartServer() {
}
}
brpc
::
ServerOptions
options
;
brpc
::
ServerOptions
options
;
#ifdef PADDLE_WITH_BRPC_RDMA
options
.
use_rdma
=
true
;
#endif
options
.
idle_timeout_sec
=
idle_timeout_s_
;
options
.
idle_timeout_sec
=
idle_timeout_s_
;
options
.
max_concurrency
=
max_concurrency_
;
options
.
max_concurrency
=
max_concurrency_
;
if
(
server_
.
Start
(
bind_address_
.
c_str
(),
&
options
)
!=
0
)
{
if
(
server_
.
Start
(
bind_address_
.
c_str
(),
&
options
)
!=
0
)
{
...
...
paddle/fluid/operators/distributed/brpc_variable_response.cc
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "paddle/fluid/operators/distributed/brpc_variable_response.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
namespace
pb
=
::
google
::
protobuf
;
using
vr
=
::
sendrecv
::
VariableMessage
;
int
BRPCVariableResponse
::
Parse
(
Source
*
source
)
{
pb
::
io
::
ZeroCopyInputStream
*
input_stream
=
source
->
contents
();
pb
::
io
::
CodedInputStream
input
(
input_stream
);
input
.
SetTotalBytesLimit
(
INT_MAX
,
INT_MAX
);
while
(
1
)
{
unsigned
int
tag
=
0
;
if
(
!
input
.
ReadLittleEndian32
(
&
tag
))
{
break
;
}
uint64_t
num_bytes
=
0
;
if
(
!
input
.
ReadLittleEndian64
(
&
num_bytes
))
{
break
;
}
int
field
=
static_cast
<
int
>
(
tag
);
int
ret
=
field
==
0
?
-
1
:
field
;
switch
(
field
)
{
case
vr
::
kSerializedFieldNumber
:
{
if
(
!
ProcSerializedField
(
field
,
&
input
,
num_bytes
))
{
return
ret
;
}
break
;
}
case
vr
::
kRowsFieldNumber
:
{
PADDLE_ENFORCE
((
meta_
.
type
()
==
sendrecv
::
SELECTED_ROWS
||
meta_
.
type
()
==
sendrecv
::
LOD_TENSOR
)
&&
meta_
.
varname
()
!=
""
,
"meta info should be got first!"
);
if
(
!
CopySelectRowsData
(
&
input
,
*
dev_ctx_
,
num_bytes
))
{
return
ret
;
}
break
;
}
default:
{
PADDLE_ENFORCE
(
false
,
"not surpported %u fieldnumber"
,
field
);
return
ret
;
}
}
}
return
0
;
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc_variable_response.h
0 → 100644
浏览文件 @
3bd54ed7
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "brpc/channel.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
BRPCSourceWrapper
:
public
Source
{
public:
explicit
BRPCSourceWrapper
(
const
butil
::
IOBuf
&
iobuf
)
:
source_
(
iobuf
)
{}
::
google
::
protobuf
::
io
::
ZeroCopyInputStream
*
contents
()
override
{
return
&
source_
;
}
private:
butil
::
IOBufAsZeroCopyInputStream
source_
;
};
class
BRPCVariableResponse
:
public
VariableResponse
{
public:
BRPCVariableResponse
(
const
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
bool
create_scope
=
false
)
:
VariableResponse
(
scope
,
dev_ctx
,
create_scope
)
{}
virtual
~
BRPCVariableResponse
()
{}
// parse attachment from iobuf
int
Parse
(
Source
*
source
)
override
;
int
Parse
(
const
butil
::
IOBuf
&
iobuf
,
const
sendrecv
::
VariableMessage
&
meta
)
{
BRPCSourceWrapper
wrapper
(
iobuf
);
return
VariableResponse
::
Parse
(
&
wrapper
,
meta
);
}
};
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
3bd54ed7
...
@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
...
@@ -293,8 +293,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
const
std
::
string
method
=
"SendMonomerFetchBarrierRPC"
;
VarHandlePtr
h
(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name
,
nullptr
,
nullptr
));
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
s
->
Prepare
(
h
,
time_out
);
s
->
Prepare
(
h
,
time_out
);
VLOG
(
30
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
30
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
3bd54ed7
...
@@ -32,13 +32,6 @@ namespace paddle {
...
@@ -32,13 +32,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
static
void
SerializeDestroyCallback
(
void
*
payload
)
{
if
(
payload
!=
nullptr
)
{
auto
*
shared_payload
=
reinterpret_cast
<
TensorPayload
*>
(
payload
);
delete
shared_payload
;
}
}
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
...
@@ -122,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -122,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
sizeof
(
int64_t
);
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
3bd54ed7
...
@@ -75,6 +75,10 @@ class RPCServer {
...
@@ -75,6 +75,10 @@ class RPCServer {
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
int
thread_num
=
5
);
int
GetThreadNum
(
const
std
::
string
&
rpc_name
)
{
return
rpc_thread_num_
[
rpc_name
];
}
// Wait util all the clients have reached the barrier for one
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// RequestHandler if you want to run the server/client in a
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
3bd54ed7
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/brpc_rdma_pool.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
...
@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
...
@@ -45,7 +46,6 @@ static TensorPayload GetCommunicationAllocationFromTensor(
memory
::
Copy
(
cuda_pinned
,
result
->
ptr
(),
memory
::
Copy
(
cuda_pinned
,
result
->
ptr
(),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
boost
::
get
<
platform
::
CUDAPlace
>
(
tensor
.
place
()),
tensor
.
data
<
void
>
(),
copy_size
,
gpu_dev_ctx
.
stream
());
tensor
.
data
<
void
>
(),
copy_size
,
gpu_dev_ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
return
TensorPayload
(
result
);
return
TensorPayload
(
result
);
#else
#else
...
@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
...
@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// FIXME(wuyi): data types in send_recv.proto is copied from
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
// framework.proto
request
->
set_data_type
(
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
tensor
.
type
()));
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
tensor
.
type
())));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
request
->
add_dims
(
dim
);
request
->
add_dims
(
dim
);
}
}
...
@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
...
@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
)
{
VarMsg
*
request
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
request
->
set_data_type
(
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
slr
->
value
().
type
()));
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
slr
->
value
().
type
())));
request
->
set_lod_level
(
0
);
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
request
->
set_slr_height
(
slr
->
height
());
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.h
浏览文件 @
3bd54ed7
...
@@ -50,6 +50,13 @@ class TensorPayload final {
...
@@ -50,6 +50,13 @@ class TensorPayload final {
size_t
memory_size_
;
size_t
memory_size_
;
};
};
inline
void
SerializeDestroyCallback
(
void
*
payload
)
{
if
(
payload
!=
nullptr
)
{
auto
*
shared_payload
=
reinterpret_cast
<
TensorPayload
*>
(
payload
);
delete
shared_payload
;
}
}
TensorPayload
GetTensorPayload
(
framework
::
Variable
*
var
,
TensorPayload
GetTensorPayload
(
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
);
VarMsg
*
request
);
...
@@ -58,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
...
@@ -58,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
);
VarMsg
*
request
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
inline
framework
::
proto
::
VarType
::
Type
ToVarType
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
sendrecv
::
VariableMessage
::
FP32
:
case
sendrecv
::
VariableMessage
::
FP32
:
return
typeid
(
float
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
FP32
;
// NOLINT
case
sendrecv
::
VariableMessage
::
FP64
:
case
sendrecv
::
VariableMessage
::
FP64
:
return
typeid
(
double
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
FP64
;
// NOLINT
case
sendrecv
::
VariableMessage
::
INT32
:
case
sendrecv
::
VariableMessage
::
INT32
:
return
typeid
(
int
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
INT32
;
// NOLINT
case
sendrecv
::
VariableMessage
::
INT64
:
case
sendrecv
::
VariableMessage
::
INT64
:
return
typeid
(
int64_t
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
INT64
;
// NOLINT
case
sendrecv
::
VariableMessage
::
BOOL
:
case
sendrecv
::
VariableMessage
::
BOOL
:
return
typeid
(
bool
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
BOOL
;
// NOLINT
default:
default:
PADDLE_THROW
(
"Not support type %d"
,
type
);
PADDLE_THROW
(
"Not support type %d"
,
type
);
}
}
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
3bd54ed7
...
@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
...
@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor
->
set_lod
(
lod
);
tensor
->
set_lod
(
lod
);
void
*
tensor_data
=
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
To
TypeIndex
(
meta_
.
data_type
()));
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
To
VarType
(
meta_
.
data_type
()));
VLOG
(
6
)
<<
"Tensor.memory_size = "
<<
tensor
->
memory_size
()
VLOG
(
6
)
<<
"Tensor.memory_size = "
<<
tensor
->
memory_size
()
<<
", Buffer Size = "
<<
length
;
<<
", Buffer Size = "
<<
length
;
...
@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
...
@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr
->
set_height
(
meta_
.
slr_height
());
slr
->
set_height
(
meta_
.
slr_height
());
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
tensor
->
Resize
(
dims
);
tensor
->
Resize
(
dims
);
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
tensor
->
numel
()),
PADDLE_ENFORCE_EQ
(
length
/
framework
::
SizeOfType
(
static_cast
<
size_t
>
(
tensor
->
numel
()),
paddle
::
operators
::
distributed
::
ToTypeIndex
(
length
/
framework
::
SizeOfType
(
paddle
::
operators
::
distributed
::
ToVarType
(
meta_
.
data_type
())));
meta_
.
data_type
())));
void
*
tensor_data
=
tensor
->
mutable_data
(
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
paddle
::
operators
::
distributed
::
To
TypeIndex
(
meta_
.
data_type
()));
paddle
::
operators
::
distributed
::
To
VarType
(
meta_
.
data_type
()));
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
return
false
;
return
false
;
...
@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
...
@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
resize
(
length
/
slr
->
mutable_rows
()
->
resize
(
length
/
sizeof
(
int64_t
));
// int64
framework
::
SizeOfType
(
typeid
(
int64_t
)));
// int64
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
// copy rows CPU data, GPU data will be copied lazily.
// copy rows CPU data, GPU data will be copied lazily.
...
...
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
3bd54ed7
...
@@ -2,9 +2,9 @@ include(operators)
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_
g
rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
else
()
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_
b
rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
3bd54ed7
...
@@ -26,10 +26,11 @@ limitations under the License. */
...
@@ -26,10 +26,11 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
rpc_send_thread_num
,
5
,
"number of threads for rpc send"
);
DEFINE_int32
(
rpc_send_thread_num
,
12
,
"number of threads for rpc send"
);
DEFINE_int32
(
rpc_get_thread_num
,
5
,
"number of threads for rpc get"
);
DEFINE_int32
(
rpc_get_thread_num
,
12
,
"number of threads for rpc get"
);
DEFINE_int32
(
rpc_prefetch_thread_num
,
5
,
"number of threads for rpc prefetch"
);
DEFINE_int32
(
rpc_prefetch_thread_num
,
12
,
"number of threads for rpc prefetch"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
浏览文件 @
3bd54ed7
...
@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
...
@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
).
front
()
->
type
(),
ctx
.
GetPlace
());
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
).
front
()
->
type
()),
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
浏览文件 @
3bd54ed7
...
@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
...
@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
)[
0
]
->
type
(),
ctx
.
GetPlace
());
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
3bd54ed7
...
@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
...
@@ -58,7 +58,9 @@ class SendOp : public framework::OperatorBase {
}
}
if
(
sync_send
)
{
if
(
sync_send
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
VLOG
(
7
)
<<
"before sync_send "
<<
ins
[
i
]
<<
"from "
<<
epmap
[
i
];
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
VLOG
(
7
)
<<
"after sync_send "
<<
ins
[
i
]
<<
"from "
<<
epmap
[
i
];
}
}
}
}
}
}
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fc_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fill_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/gather_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/interpolate_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/is_empty_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/isfinite_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/layer_norm_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/linear_chain_crf_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/load_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lod_reset_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lookup_sparse_table_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lrn_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lstm_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/lstmp_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/math_function.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/math_function.cu
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/pooling.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/pooling.cu
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/pooling.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/mean_iou_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/mean_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/metrics/accuracy_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/metrics/auc_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/metrics/precision_recall_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/multiplex_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/nce_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/adadelta_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/adagrad_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/adam_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/adamax_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/ftrl_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/optimizers/proximal_gd_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pad2d_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pad_constant_like_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pool_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pool_op.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pool_with_index_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/pool_with_index_op.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/positive_negative_pair_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/prelu_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/print_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/psroi_pool_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/random_crop_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/reshape_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/roi_align_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/roi_pool_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/save_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/scatter_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/similarity_focus_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/slice_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/softmax_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/spp_op.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/sum_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/transpose_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/unpool_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/warpctc_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/operators/yolov3_loss_op.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/platform/device_context.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/platform/nccl_helper.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/pybind/pybind.cc
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/fluid/pybind/tensor_py.h
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
paddle/scripts/paddle_build.sh
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/__init__.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/layers/nn.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt
0 → 100644
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/ngraph/__init__.py
0 → 100644
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_pool_max_op.py
浏览文件 @
3bd54ed7
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录