提交 4d256991 编写于 作者: H Haihao Shen

Merge from paddle public dev and fix a bug for weight channel 0 in bias

...@@ -50,11 +50,7 @@ if(NOT WITH_PROFILER) ...@@ -50,11 +50,7 @@ if(NOT WITH_PROFILER)
endif(NOT WITH_PROFILER) endif(NOT WITH_PROFILER)
if(NOT CMAKE_CROSSCOMPILING) if(NOT CMAKE_CROSSCOMPILING)
if(WITH_AVX AND AVX512F_FOUND) if(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX512F_FLAG})
elseif(WITH_AVX AND AVX2_FOUND)
set(SIMD_FLAG ${AVX2_FLAG})
elseif(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG}) set(SIMD_FLAG ${AVX_FLAG})
elseif(SSE3_FOUND) elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
......
...@@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
...@@ -54,7 +54,7 @@ ExternalProject_Add( ...@@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" GIT_TAG "21fb5f2af1dd14e132af4f1b79160977ee487818"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -89,7 +89,9 @@ CHECK_CXX_SOURCE_RUNS(" ...@@ -89,7 +89,9 @@ CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h> #include <immintrin.h>
int main() int main()
{ {
__m512i a = _mm512_undefined_epi32(); __m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4,
13, -5, 6, -7, 9, 2, -6, 3);
__m512i result = _mm512_abs_epi32 (a);
return 0; return 0;
}" AVX512F_FOUND) }" AVX512F_FOUND)
......
...@@ -118,9 +118,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', ...@@ -118,9 +118,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon',
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0)) paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None)) paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,)) paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR')) paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',)) paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
...@@ -174,9 +175,11 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -174,9 +175,11 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
...@@ -189,6 +192,7 @@ paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, k ...@@ -189,6 +192,7 @@ paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, k
paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True)) paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True))
paddle.fluid.layers.create_py_reader_by_data ArgSpec(args=['capacity', 'feed_list', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, True))
paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...@@ -198,6 +202,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'], ...@@ -198,6 +202,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'],
paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)) paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)) paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None))
paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.tensor_array_to_tensor ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,))
......
...@@ -18,8 +18,8 @@ namespace framework { ...@@ -18,8 +18,8 @@ namespace framework {
void TransDataDevice(const Tensor &in, const platform::Place &dst_place, void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
Tensor *out) { Tensor *out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(30) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(), in.place().which(), dst_place.which(),
......
...@@ -49,10 +49,10 @@ class TestOpWithKernel : public OperatorWithKernel { ...@@ -49,10 +49,10 @@ class TestOpWithKernel : public OperatorWithKernel {
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) { if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel"; VLOG(30) << "force use gpu kernel";
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else { } else {
VLOG(3) << "use default kernel"; VLOG(30) << "use default kernel";
return OpKernelType(proto::VarType::FP32, return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place()); ctx.Input<Tensor>("input")->place());
} }
...@@ -148,7 +148,7 @@ TEST(Operator, CPUtoGPU) { ...@@ -148,7 +148,7 @@ TEST(Operator, CPUtoGPU) {
// get output // get output
auto* output2 = scope.Var("OUT2"); auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place); gpu_op->Run(scope, cuda_place);
VLOG(3) << "after gpu_op run"; VLOG(30) << "after gpu_op run";
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>(); // auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
......
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ ...@@ -30,7 +31,9 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() endif()
...@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap ...@@ -40,12 +43,13 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass) if (WITH_GPU)
else() list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
endif() endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
......
...@@ -60,7 +60,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -60,7 +60,7 @@ void BroadcastOpHandle::BroadcastOneVar(
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (UNLIKELY(!in_tensor.IsInitialized())) { if (UNLIKELY(!in_tensor.IsInitialized())) {
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!"; VLOG(30) << "in var " << in_var_handle.name_ << "not inited, return!";
return; return;
} }
......
...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle { ...@@ -37,8 +37,9 @@ struct TestBroadcastOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_; std::vector<p::Place> place_list_;
bool use_gpu_; bool use_gpu_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle { ...@@ -90,6 +91,7 @@ struct TestBroadcastOpHandle {
} }
void InitBroadcastOp(size_t input_scope_idx) { void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle { ...@@ -101,39 +103,39 @@ struct TestBroadcastOpHandle {
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation); ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not support."); PADDLE_THROW("CUDA is not support.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get())); place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
new BroadcastOpHandle(n.get(), local_scopes_, place_list_)); place_list_);
#endif #endif
} }
std::unique_ptr<ir::Node> v = nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
place_list_[input_scope_idx]); "input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v2 = nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v2.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp(); dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle { ...@@ -141,20 +143,20 @@ struct TestBroadcastOpHandle {
if (!use_gpu_) { if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
} }
std::unique_ptr<ir::Node> v3 = nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", place_list_[j]); new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
// add dummy var // add dummy var
std::unique_ptr<ir::Node> v4 = nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(v4.get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle = DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp(); out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle); op_handle_->AddOutput(out_dummy_var_handle);
} }
......
...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -69,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
} }
private: private:
...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass); ...@@ -136,3 +140,4 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -73,6 +73,8 @@ struct BuildStrategy { ...@@ -73,6 +73,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.
......
...@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -29,9 +29,15 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); };
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
......
...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
bool is_lock_and_record_event_free_{false};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cstddef> // for size_t
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,6 +27,7 @@ struct ExecutionStrategy { ...@@ -26,6 +27,7 @@ struct ExecutionStrategy {
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100}; size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false};
}; };
} // namespace details } // namespace details
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_ + pool_(strategy.num_threads_ +
1), // add one more thread for generate op_deps 1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops"); for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
bootstrap_ops_.emplace_back(op.get()); bootstrap_ops_.emplace_back(op);
} }
} }
...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -110,7 +109,10 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
complete_q->Pop(); complete_q->Pop();
} }
} }
exception_.ReThrow(); if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops);
exception_.ReThrow();
}
} }
num_complete += num_comp; num_complete += num_comp;
} }
...@@ -128,7 +130,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -128,7 +130,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
size_t complete = 0; size_t complete = 0;
while (op_to_run != nullptr) { while (op_to_run != nullptr) {
try { try {
op_to_run->Run(strategy_.use_cuda_); if (LIKELY(!strategy_.dry_run_)) {
op_to_run->Run(strategy_.use_cuda_);
}
++complete; ++complete;
} catch (...) { } catch (...) {
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
......
...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, ...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {}
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
...@@ -22,8 +22,10 @@ namespace details { ...@@ -22,8 +22,10 @@ namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_; std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) { void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var // initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) { for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -39,41 +41,41 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
} }
// create op handle node // create op handle node
std::unique_ptr<ir::Node> n = nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation); ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) { if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
PADDLE_THROW("CUDA is not supported."); PADDLE_THROW("CUDA is not supported.");
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
op_handle_.reset(new FusedBroadcastOpHandle( op_handle_ = new FusedBroadcastOpHandle(
n.get(), local_scopes_, place_list_, nccl_ctxs_.get())); nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else #else
op_handle_.reset( op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
new FusedBroadcastOpHandle(n.get(), local_scopes_, place_list_)); local_scopes_, place_list_);
#endif #endif
} }
for (size_t i = 0; i < input_scope_idxes.size(); ++i) { for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle // add input var handle
std::unique_ptr<ir::Node> in_node = nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle = VarHandle* in_var_handle =
new VarHandle(in_node.get(), 1, input_scope_idxes[i], "in_var" + i, new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
place_list_[input_scope_idxes[i]]); "in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add output var handle // add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
std::unique_ptr<ir::Node> out_node = nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable); ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = VarHandle* out_var_handle = new VarHandle(
new VarHandle(out_node.get(), 2, j, "out_var" + i, place_list_[j]); nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
......
...@@ -31,9 +31,10 @@ struct TestGatherOpHandle { ...@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
...@@ -70,7 +71,7 @@ struct TestGatherOpHandle { ...@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -82,44 +83,45 @@ struct TestGatherOpHandle { ...@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle =
"out", gpu_list_[input_scope_idx]); new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(100) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair);
} }
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var);
} }
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); pending_ops.insert({op, op->NoDupInputSize()});
} }
} }
...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass, REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars);
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -34,7 +34,14 @@ ...@@ -34,7 +34,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace { namespace {
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops";
void PolishGraphToSupportDataHazards(ir::Graph *graph) { void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, ...@@ -92,7 +99,7 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
} }
var_holder.emplace_back(var); var_holder.emplace_back(var);
} else { } else {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
} }
return var; return var;
} }
...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -154,7 +161,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node, ir::Node *node,
size_t place_id) const { size_t place_id) const {
auto p = places_[place_id]; auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -303,7 +310,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// related op in the place 0 // related op in the place 0
...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -317,11 +323,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node); int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") { if (node->Op()->Type() == "recv") {
...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -337,7 +345,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr( } else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) { static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name); bcast_var_name_set[op_dev_id].emplace(origin_param_name);
...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -356,12 +364,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(result, node); int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(n->Name(), op_dev_id);
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -392,14 +399,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -392,14 +399,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) { for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(100) << "Bcast " << g_name << " for parameter " << p_name;
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(g_name, cur_device_id);
.emplace(g_name, cur_device_id);
if (!is_dist_train) { if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
} }
...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -458,7 +464,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result)); result.Erase<GraphOps>(kGraphOps);
return graph; return graph;
} }
...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, ...@@ -498,7 +504,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(src_dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( ...@@ -535,7 +541,7 @@ void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) { for (auto &p_name : bcast_varnames[dev_id]) {
auto *in = auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get(); result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back();
op_handle->AddInput(in); op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id]; auto &p = places_[out_dev_id];
...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -571,7 +577,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, ...@@ -579,7 +585,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
auto var = auto var =
new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable),
...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -600,14 +606,14 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
SetCommunicationContext(op_handle, p); SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) { for (const std::string &d_name : datas) {
auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][d_name];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get()); op_handle->AddInput(vars.back());
auto var = new VarHandle( auto var = new VarHandle(
result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), result->CreateEmptyNode(d_name, ir::Node::Type::kVariable),
vars.size(), i, d_name, p); vars.size(), i, d_name, p);
...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -617,8 +623,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
} }
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node) const { const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -631,16 +638,22 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ...@@ -631,16 +638,22 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname) const { const ir::Graph &graph, const std::string &varname,
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice); const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device.find(varname.substr(0, pos));
}
}
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
...@@ -690,7 +703,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -690,7 +703,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_)); local_scopes_, places_));
#endif #endif
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -698,7 +711,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -698,7 +711,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
auto &vars = result->Get<GraphVars>(kGraphVars)[i][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[i][og];
PADDLE_ENFORCE(!vars.empty()); PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad);
} }
auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og]; auto &vars = result->Get<GraphVars>(kGraphVars)[dst_dev_id][og];
auto var = auto var =
...@@ -709,8 +722,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -709,8 +722,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -725,23 +739,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -725,23 +739,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") { node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
...@@ -759,14 +772,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -759,14 +772,14 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
} }
void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
for (ir::Node *input : node->inputs) { for (ir::Node *input : node->inputs) {
VarHandle *var = nullptr; VarHandle *var = nullptr;
for (int place_offset = 0; place_offset < num_places; ++place_offset) { for (int place_offset = 0; place_offset < num_places; ++place_offset) {
auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset]; auto &var_holders = result->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[input->Name()]; auto &var_holder = var_holders[input->Name()];
if (!var_holder.empty()) { if (!var_holder.empty()) {
var = var_holder.rbegin()->get(); var = *var_holder.rbegin();
op_handle->AddInput(var); op_handle->AddInput(var);
} }
} }
...@@ -774,12 +787,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -774,12 +787,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -794,14 +809,12 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -794,14 +809,12 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U); PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U);
op_dev_id = GetAppropriateDeviceID({send_param_grad[1]}); op_dev_id = GetAppropriateDeviceID({send_param_grad[1]});
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(100) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(send_param_grad[1], op_dev_id);
.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -811,16 +824,16 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -811,16 +824,16 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id =
VLOG(10) << "recv param " << recv_param_grad[0] GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
<< " get grad place: " << recv_param_grad[1] VLOG(100) << "recv param " << recv_param_grad[0]
<< " place: " << op_dev_id; << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id;
} else { } else {
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier, fetch_barrier will run on place 0; // send_barrier, fetch_barrier will run on place 0;
...@@ -839,7 +852,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -839,7 +852,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
// send_barrier, recv, fetch_barrier's inputs are deps var, get them from // send_barrier, recv, fetch_barrier's inputs are deps var, get them from
// all places // all places
auto p = places_[op_dev_id]; auto p = places_[op_dev_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back().get(); auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -847,7 +860,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -847,7 +860,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name()); outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1);
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
......
...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -35,23 +35,14 @@ namespace details { ...@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a // The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector< typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) {
preceding_ops_[op];
pending_ops_[op];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op);
pending_ops_[op].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
......
...@@ -30,8 +30,8 @@ struct TestReduceOpHandle { ...@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_; Scope g_scope_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_; std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase *op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_; std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device)); platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
} }
...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() { ~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_)); PADDLE_ENFORCE(cudaEventDestroy(event_));
} }
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,6 +44,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { ...@@ -43,6 +44,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr; return nullptr;
} }
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount); auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
...@@ -54,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -54,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>> std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map; compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op; std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr || if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace())) !platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op; return var_names_in_op;
...@@ -104,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -104,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}; };
auto update_ref_cnts_from_non_compute_op = [&]( auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) { if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) { for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base); auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
...@@ -124,8 +140,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -124,8 +140,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if (next_compute_op != nullptr) { if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) { if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name); compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator " VLOG(50) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name(); << next_compute_op->Name();
} else { } else {
// Create new reference_count_op_handle // Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode( ir::Node *ref_cnt_node = graph->CreateEmptyNode(
...@@ -133,40 +149,30 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -133,40 +149,30 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) { AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
} }
} }
} }
} }
}; };
auto &all_ops = graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue; if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(), in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end()); out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace()); auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node = ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) { AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); compute_ref_cnt_map[compute_op] = ref_cnt_handle;
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -174,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -174,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs()); update_ref_cnts_from_non_compute_op(op, op->Outputs());
} }
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops; std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) { for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() if (ir::IsControlDepVar(*in->Node())) {
if (ir::IsControlDepVar(*in->Node())) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p)); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); this->RunAndRecordEvent([this] {
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
// lock. place_);
op_->Run(*tmp_scope, place_); });
} }
std::string RPCOpHandle::Name() const { return name_; } std::string RPCOpHandle::Name() const { return name_; }
......
...@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() {
->stream(); ->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream); platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op"; VLOG(100) << place_ << "RUN Scale loss grad op";
}); });
#endif #endif
} }
......
...@@ -94,8 +94,8 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -94,8 +94,8 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
op_node_list[i - 1]->outputs.push_back(dep_var); op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]); dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]); dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name() VLOG(100) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name(); << " and " << op_node_list[i]->Name();
} }
return graph; return graph;
} }
......
...@@ -19,14 +19,16 @@ namespace framework { ...@@ -19,14 +19,16 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node()); graph->RemoveNode(op->Node());
} }
fetch_ops->clear(); fetch_ops->clear();
......
...@@ -38,8 +38,7 @@ class SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,25 +52,25 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, ready_vars.get(), version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, ready_vars.get(), var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
InsertPendingOp(&pending_ops, op.get()); InsertPendingOp(&pending_ops, op);
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -109,6 +110,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_.get(), &fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} else { } else {
continue; continue;
...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -140,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -151,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -208,14 +210,16 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -208,14 +210,16 @@ void ThreadedSSAGraphExecutor::RunOp(
details::OpHandleBase *op) { details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(100)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(100) << op << " " << op->Name() << " : " << op->DebugString();
} }
op->Run(strategy_.use_cuda_); if (LIKELY(!strategy_.dry_run_)) {
VLOG(10) << op << " " << op->Name() << " Done "; op->Run(strategy_.use_cuda_);
}
VLOG(100) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(100) << op << " " << op->Name() << "Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
......
...@@ -48,7 +48,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -48,7 +48,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
// Use topological sort algorithm // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() final = default;
private: private:
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars, BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps( void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
const std::vector<std::string> &fetch_tensors, std::vector<FetchOpHandle *> *fetch_ops,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_set<VarHandleBase *> *pending_vars,
std::unordered_set<VarHandleBase *> *pending_vars, BlockingQueue<VarHandleBase *> *ready_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data); FeedFetchList *fetch_data);
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -20,6 +20,8 @@ namespace details { ...@@ -20,6 +20,8 @@ namespace details {
VarHandleBase::~VarHandleBase() {} VarHandleBase::~VarHandleBase() {}
VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << name_ << ":" << place_;
...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const { ...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
} }
std::string DummyVarHandle::DebugString() const { return node_->Name(); } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
DummyVarHandle::~DummyVarHandle() {
VLOG(4) << "deleting dummy var handle " << DebugString();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,10 @@ class OpHandleBase; ...@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
...@@ -94,6 +97,8 @@ struct VarHandleBase { ...@@ -94,6 +97,8 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(ir::Node* node, size_t version, size_t scope_index, VarHandle(ir::Node* node, size_t version, size_t scope_index,
...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase { ...@@ -121,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~DummyVarHandle();
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
...@@ -43,7 +43,7 @@ ExecutorPrepareContext::ExecutorPrepareContext( ...@@ -43,7 +43,7 @@ ExecutorPrepareContext::ExecutorPrepareContext(
} }
ExecutorPrepareContext::~ExecutorPrepareContext() { ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(50) << "destroy ExecutorPrepareContext";
} }
template <typename RefCntMap> template <typename RefCntMap>
...@@ -60,7 +60,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, ...@@ -60,7 +60,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
if ((it->second)-- == 1) { if ((it->second)-- == 1) {
auto* var = scope.FindVar(name); auto* var = scope.FindVar(name);
if (var != nullptr) { if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'"; VLOG(100) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>()); erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
...@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {} ...@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>(0)
->SendComplete(); ->SendComplete();
#endif #endif
} }
...@@ -139,21 +141,21 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -139,21 +141,21 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
if (var->Persistable()) { if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name()); auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = scope->Var(var->Name()); auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
} }
} else { } else {
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name()); auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is " VLOG(30) << "Create variable " << var->Name() << ", which pointer is "
<< ptr; << ptr;
} }
} }
} }
...@@ -284,7 +286,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -284,7 +286,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
int i = 0; int i = 0;
for (auto& feed_target : (*feed_targets)) { for (auto& feed_target : (*feed_targets)) {
std::string var_name = feed_target.first; std::string var_name = feed_target.first;
VLOG(3) << "feed target's name: " << var_name; VLOG(30) << "feed target's name: " << var_name;
// prepend feed op // prepend feed op
auto* op = global_block->PrependOp(); auto* op = global_block->PrependOp();
...@@ -307,7 +309,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -307,7 +309,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
int i = 0; int i = 0;
for (auto& fetch_target : (*fetch_targets)) { for (auto& fetch_target : (*fetch_targets)) {
std::string var_name = fetch_target.first; std::string var_name = fetch_target.first;
VLOG(3) << "fetch target's name: " << var_name; VLOG(30) << "fetch target's name: " << var_name;
// append fetch op // append fetch op
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -396,8 +398,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -396,8 +398,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(20) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
} }
} }
...@@ -422,10 +424,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -422,10 +424,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(20) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: " VLOG(20) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------"; VLOG(20) << "-------------------------------------------------------";
} }
} }
...@@ -469,7 +471,7 @@ void Executor::RunPreparedContext( ...@@ -469,7 +471,7 @@ void Executor::RunPreparedContext(
void Executor::EnableMKLDNN(const ProgramDesc& program) { void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True"; VLOG(30) << "use_mkldnn=True";
for (size_t bid = 0; bid < program.Size(); ++bid) { for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid); auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) { for (auto* op : block->AllOps()) {
......
...@@ -25,7 +25,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, ...@@ -25,7 +25,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
const std::string& var_name, size_t index) { const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will // If var_name Variable is not found in GlobalScope, a new variable will
// be created. // be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; VLOG(30) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name); Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>()); auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
if (index >= feed_inputs.size()) { if (index >= feed_inputs.size()) {
...@@ -47,8 +47,8 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, ...@@ -47,8 +47,8 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
typeid(FeedFetchList).name()); typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>(); auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index]; auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index VLOG(30) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims(); << " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size()); PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return tensor; return tensor;
} }
......
...@@ -56,6 +56,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -56,6 +56,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
......
...@@ -147,19 +147,19 @@ void PrepareParameters(Graph* graph, const Param& param) { ...@@ -147,19 +147,19 @@ void PrepareParameters(Graph* graph, const Param& param) {
scope->Var(param.LSTMX)->GetMutable<LoDTensor>(); scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>(); scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \ #define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \ VLOG(40) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \ VLOG(40) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \ VLOG(40) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \ auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \ auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>(); auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W(forget); GATE_W(forget);
...@@ -208,7 +208,7 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, ...@@ -208,7 +208,7 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
int D = W_forget_w0.dims()[0]; int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0]; int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D})); out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims(); VLOG(30) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace()); float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors( std::array<const float*, 4> tensors(
......
...@@ -57,7 +57,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -57,7 +57,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvBias fuse"; VLOG(40) << "handle ConvBias fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_bias_pattern); // Filter conv_bias_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
...@@ -74,7 +74,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -74,7 +74,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(3) << "do not perform conv+bias fuse"; VLOG(30) << "do not perform conv+bias fuse";
return; return;
} }
......
...@@ -121,7 +121,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -121,7 +121,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
int found_conv_bn_count = 0; int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvBN fuse"; VLOG(40) << "handle ConvBN fuse";
// conv, batch_norm, // conv, batch_norm,
// conv_weight, conv_out, // conv_weight, conv_out,
...@@ -133,7 +133,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( ...@@ -133,7 +133,7 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
// check if fuse can be done and if MKL-DNN should be used // check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm); FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
if (fuse_option == DO_NOT_FUSE) { if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+bn fuse"; VLOG(30) << "do not perform conv+bn fuse";
return; return;
} }
...@@ -241,7 +241,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( ...@@ -241,7 +241,7 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
int found_conv_bn_count = 0; int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvBN fuse"; VLOG(40) << "handle ConvBN fuse";
// conv, batch_norm, // conv, batch_norm,
// conv_weight, conv_out, // conv_weight, conv_out,
......
...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
int found_conv_relu_count = 0; int found_conv_relu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvReLU fuse"; VLOG(40) << "handle ConvReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_relu_pattern); // Filter conv_relu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
...@@ -48,7 +48,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -48,7 +48,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
FuseOptions fuse_option = FindFuseOption(*conv, *relu); FuseOptions fuse_option = FindFuseOption(*conv, *relu);
if (fuse_option == DO_NOT_FUSE) { if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+relu fuse"; VLOG(30) << "do not perform conv+relu fuse";
return; return;
} }
......
...@@ -39,7 +39,7 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl( ...@@ -39,7 +39,7 @@ std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
int found_depthwise_conv_mkldnn_count = 0; int found_depthwise_conv_mkldnn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(3) << "handle DepthwiseConvMKLDNN fuse"; VLOG(30) << "handle DepthwiseConvMKLDNN fuse";
GET_NODE(depthwise_conv, (*pattern)); GET_NODE(depthwise_conv, (*pattern));
depthwise_conv->Op()->SetType("conv2d"); depthwise_conv->Op()->SetType("conv2d");
found_depthwise_conv_mkldnn_count++; found_depthwise_conv_mkldnn_count++;
......
...@@ -39,7 +39,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -39,7 +39,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(40) << "handle FC fuse";
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
......
...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -61,7 +61,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse"; VLOG(40) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
elewise_add_act_pattern); elewise_add_act_pattern);
...@@ -77,10 +77,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct( ...@@ -77,10 +77,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddAct(
Node *elewise_add_act_node = CreateFuseElewiseAddActNode( Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n); g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n);
VLOG(4) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> " VLOG(40) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> "
<< ele_add->Name() << " -> " << ele_out_n << "\n" << ele_add->Name() << " -> " << ele_out_n << "\n"
<< "\t " << ele_out_n << " -> " << act->Name() << " -> " << "\t " << ele_out_n << " -> " << act->Name() << " -> "
<< act_out_n; << act_out_n;
ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node); ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node);
found_elewise_add_act_count++; found_elewise_add_act_count++;
...@@ -113,7 +113,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -113,7 +113,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse"; VLOG(40) << "handle FuseElewiseAddAct fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out,
...@@ -129,9 +129,9 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -129,9 +129,9 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseActElewiseAdd(
Node *elewise_add_act_node = CreateFuseElewiseAddActNode( Node *elewise_add_act_node = CreateFuseElewiseAddActNode(
g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n); g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n);
VLOG(4) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n VLOG(40) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n
<< "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> " << "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> "
<< ele_add->Name() << " -> " << elewise_add_out_n; << ele_add->Name() << " -> " << elewise_add_out_n;
ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node); ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node);
found_elewise_add_act_count++; found_elewise_add_act_count++;
...@@ -165,7 +165,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -165,7 +165,7 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "handle FuseElewiseAddActGrad1 fuse"; VLOG(40) << "handle FuseElewiseAddActGrad1 fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out, GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
...@@ -208,10 +208,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -208,10 +208,10 @@ std::unique_ptr<ir::Graph> FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
auto fused_node = g->CreateOpNode(&desc); auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> " VLOG(40) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t " << act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< d_itermediate_out_n << " and " << act_out_n << " -> " << d_itermediate_out_n << " and " << act_out_n << " -> "
<< ele_add_grad->Name() << " -> " << d_itermediate_out_n; << ele_add_grad->Name() << " -> " << d_itermediate_out_n;
ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node); ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node);
found_elewise_add_act_count++; found_elewise_add_act_count++;
......
...@@ -26,59 +26,58 @@ namespace ir { ...@@ -26,59 +26,58 @@ namespace ir {
namespace { namespace {
void CheckProgram(const ProgramDesc &program) { void CheckProgram(const ProgramDesc &program) {
std::map<int, bool> visit;
#define _INT(role) static_cast<int>(role) #define _INT(role) static_cast<int>(role)
for (size_t i = 0; i < program.Size(); ++i) { std::map<int, bool> visit;
for (OpDesc *op : program.Block(i).AllOps()) { for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added. // For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue; if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = boost::get<int>( int role_id =
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true; visit[role_id] = true;
switch (role_id) { switch (role_id) {
case _INT(OpRole::kForward): case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR) LOG(ERROR)
<< "Cannot add backward operator before forward operator %s." << "Cannot add backward operator before forward operator %s."
<< op->Type(); << op->Type();
} }
break; break;
case _INT(OpRole::kBackward): case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss): case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE( PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s before optimize operator.", "Cannot add backward operator %s after optimize operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss): case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) | PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(), _INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before " "Cannot add backward|loss operator before "
"forward|loss operator %s.", "forward|loss operator %s.",
op->Type()); op->Type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.", "Cannot add forward|loss operator %s after optimize operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kOptimize): case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched): case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(), PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.", "Optimize operators %s must follow backward operator.",
op->Type()); op->Type());
break; break;
case _INT(OpRole::kLRSched): case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist): case _INT(OpRole::kDist):
case _INT(OpRole::kRPC): case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified): case _INT(OpRole::kNotSpecified):
break; break;
default: default:
LOG(FATAL) << "Unknown operator role. Don't add new role because " LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing."; "you don't know what you are doing.";
}
} }
} }
#undef _INT #undef _INT
} }
} // namespace } // namespace
...@@ -93,7 +92,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -93,7 +92,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) { const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(30) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA // var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
...@@ -161,7 +160,7 @@ void Graph::ResolveHazard( ...@@ -161,7 +160,7 @@ void Graph::ResolveHazard(
auto it_old = versions.rbegin(); auto it_old = versions.rbegin();
++it_old; ++it_old;
for (; it_old != versions.rend(); it_new = it_old, ++it_old) { for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
VLOG(3) << "deal with var: " << (*it_new)->Name(); VLOG(30) << "deal with var: " << (*it_new)->Name();
ir::Node *write_op = ir::Node *write_op =
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs; const auto &read_ops = (*it_old)->outputs;
......
...@@ -89,7 +89,7 @@ class Graph { ...@@ -89,7 +89,7 @@ class Graph {
attr_name); attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(30) << "deleting " << attr_name;
delete attr; delete attr;
}; };
} }
...@@ -102,6 +102,15 @@ class Graph { ...@@ -102,6 +102,15 @@ class Graph {
attr_dels_[attr_name] = []() {}; attr_dels_[attr_name] = []() {};
} }
template <typename AttrType>
void Erase(const std::string &attr_name) {
PADDLE_ENFORCE(attrs_.count(attr_name) != 0, "%s not set in the graph",
attr_name);
attr_dels_[attr_name]();
attrs_.erase(attr_name);
attr_dels_.erase(attr_name);
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; } const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
......
...@@ -33,8 +33,9 @@ void SortHelper( ...@@ -33,8 +33,9 @@ void SortHelper(
} }
} }
VLOG(3) << "topology sort insert: " << node->Name() VLOG(30) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size(); << reinterpret_cast<void *>(node) << " input "
<< node->inputs.size();
ret->push_back(node); ret->push_back(node);
} }
...@@ -103,9 +104,9 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -103,9 +104,9 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n); adj_list[n].insert(adj_n);
} }
} }
...@@ -163,10 +164,10 @@ size_t GraphNum(const Graph &graph) { ...@@ -163,10 +164,10 @@ size_t GraphNum(const Graph &graph) {
graph_nodes.emplace_back(g_nodes); graph_nodes.emplace_back(g_nodes);
} }
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(100)) {
VLOG(10) << "graph_num: " << graph_nodes.size(); VLOG(100) << "graph_num: " << graph_nodes.size();
for (auto &g_n : graph_nodes) { for (auto &g_n : graph_nodes) {
VLOG(10) << "graph_nodes: " << g_n.size(); VLOG(100) << "graph_nodes: " << g_n.size();
if (g_n.size() < 10) { if (g_n.size() < 10) {
std::stringstream out; std::stringstream out;
for (auto &node : g_n) { for (auto &node : g_n) {
...@@ -180,7 +181,7 @@ size_t GraphNum(const Graph &graph) { ...@@ -180,7 +181,7 @@ size_t GraphNum(const Graph &graph) {
} }
out << "]"; out << "]";
} }
VLOG(10) << out.str(); VLOG(100) << out.str();
} }
} }
} }
......
...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph); ...@@ -37,6 +37,15 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
template <typename T>
std::vector<T *> FilterByNodeWrapper(const Graph &graph) {
std::vector<T *> ret;
for (ir::Node *n : graph.Nodes()) {
if (n->IsWrappedBy<T>()) ret.push_back(&n->Wrapper<T>());
}
return ret;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <array> #include <array>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -91,19 +92,19 @@ void GraphPatternDetector::operator()(Graph *graph, ...@@ -91,19 +92,19 @@ void GraphPatternDetector::operator()(Graph *graph,
PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size()); PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size());
int id = 0; int id = 0;
for (auto &g : subgraphs) { for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph"; VLOG(30) << "optimizing #" << id++ << " subgraph";
handler(g, graph); handler(g, graph);
} }
} }
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
VLOG(3) << "mark pdnodes in graph"; VLOG(30) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto &node : GraphTraits::DFS(graph)) { for (auto &node : GraphTraits::DFS(graph)) {
for (const auto &pdnode : pattern_.nodes()) { for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked"; VLOG(40) << "pdnode " << pdnode->name() << " marked";
pdnodes2nodes_[pdnode.get()].insert(&node); pdnodes2nodes_[pdnode.get()].insert(&node);
} }
} }
...@@ -111,7 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { ...@@ -111,7 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
// Check to early stop if some PDNode can't find matched Node. // Check to early stop if some PDNode can't find matched Node.
for (auto &pdnode : pattern_.nodes()) { for (auto &pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) { if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; VLOG(40) << pdnode->name() << " can't find matched Node, early stop";
// return false; // return false;
} }
} }
...@@ -120,7 +121,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { ...@@ -120,7 +121,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
GetMarkedNodes(const_cast<Graph *>(&graph)).insert(n); GetMarkedNodes(const_cast<Graph *>(&graph)).insert(n);
} }
} }
VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; VLOG(30) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
...@@ -213,7 +214,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -213,7 +214,7 @@ GraphPatternDetector::DetectPatterns() {
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
// in edges of PDNodes. // in edges of PDNodes.
for (const auto &edge : pattern_.edges()) { for (const auto &edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); VLOG(40) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here. // TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles. // Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected. // Detect two Nodes that can match these two roles and they are connected.
...@@ -224,7 +225,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -224,7 +225,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target // source -> target
for (Node *source : pdnodes2nodes_[edge.first]) { for (Node *source : pdnodes2nodes_[edge.first]) {
for (Node *target : pdnodes2nodes_[edge.second]) { for (Node *target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id(); VLOG(80) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies. // TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) { for (const auto &group : pre_groups) {
HitGroup new_group = group; HitGroup new_group = group;
...@@ -240,12 +241,13 @@ GraphPatternDetector::DetectPatterns() { ...@@ -240,12 +241,13 @@ GraphPatternDetector::DetectPatterns() {
} }
} }
} }
VLOG(3) << "step " << step << " get records: " << cur_groups.size(); VLOG(30) << "step " << step << " get records: " << cur_groups.size();
for (auto &group : cur_groups) { for (auto &group : cur_groups) {
for (auto &item : group.roles) { for (auto &item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name(); VLOG(40) << "node " << item.second->id() << " as "
<< item.first->name();
} }
VLOG(4) << "========================================================="; VLOG(40) << "=========================================================";
} }
} }
......
...@@ -41,7 +41,7 @@ std::string FormatName(const Node* node) { ...@@ -41,7 +41,7 @@ std::string FormatName(const Node* node) {
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path; VLOG(30) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
......
...@@ -20,7 +20,7 @@ namespace ir { ...@@ -20,7 +20,7 @@ namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy."; VLOG(30) << "Aplies MKL-DNN placement strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true); n->Op()->SetAttr("use_mkldnn", true);
......
...@@ -62,7 +62,7 @@ VarDesc UpdateGradVarDesc( ...@@ -62,7 +62,7 @@ VarDesc UpdateGradVarDesc(
string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat); string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat);
VarDesc repeated_var = CopyVarDesc(var_desc); VarDesc repeated_var = CopyVarDesc(var_desc);
repeated_var.SetName(new_gname); repeated_var.SetName(new_gname);
VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat; VLOG(30) << "update " << var_desc->Name() << " to repeat " << repeat;
return repeated_var; return repeated_var;
} }
return *var_desc; return *var_desc;
...@@ -78,7 +78,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -78,7 +78,7 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
std::vector<ir::Node*> nodes = TopologySortOperations(*graph); std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
auto origin_nodes = graph->ReleaseNodes(); auto origin_nodes = graph->ReleaseNodes();
VLOG(3) << "origin nodes count: " << origin_nodes.size(); VLOG(30) << "origin nodes count: " << origin_nodes.size();
ir::Graph& result = *graph; ir::Graph& result = *graph;
// 1. record op nodes of different roles // 1. record op nodes of different roles
...@@ -137,8 +137,8 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl( ...@@ -137,8 +137,8 @@ std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
"%s.repeat.%d", repeated_op.Input("Variance")[0], i); "%s.repeat.%d", repeated_op.Input("Variance")[0], i);
bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]); bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]);
bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]); bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]);
VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to " VLOG(30) << "renaming " << repeated_op.Input("Mean")[0] << " to "
<< new_mean_name; << new_mean_name;
repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name); repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name);
repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name); repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name);
repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0], repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0],
......
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <typeindex>
#include <typeinfo>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -24,9 +27,33 @@ namespace paddle { ...@@ -24,9 +27,33 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Node should normally created by Graph::CreateXXXNode(). // Node should only created by Graph::CreateXXXNode().
// 1. Every Node should be part of a graph. No dangling Node exists.
// 2. Node only contains members necessary for building graph structure.
// It doesn't contain other unrelated members, such as device, etc.
//
// Sometimes, for specific usages, Node needs to have additional members,
// such as device_placement, version in order to be executed. It is suggested
// to use composition pattern.
//
// class RunnableOp {
// RunnableOp(ir::Node* n) : n_(n) { n_.WrappedBy(this); }
//
// int any_thing_;
// }
//
// RunnableOp is owned by the ir::Node that composes it. In other words.
// ir::Node will be responsible for deleting RunnableOp, say, when ir::Node
// is deleted from the graph.
class Node { class Node {
public: public:
virtual ~Node() {
if (!wrapper_.empty()) {
VLOG(4) << "ir::Node deleting a wrapper node " << Name();
wrapper_deleter_();
}
}
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
...@@ -44,6 +71,29 @@ class Node { ...@@ -44,6 +71,29 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
// Set the `wrapper` that wraps the Node. `wrapper` is owned by Node.
template <typename T>
void WrappedBy(T* wrapper) {
if (!wrapper_.empty()) {
wrapper_deleter_();
}
wrapper_ = wrapper;
wrapper_deleter_ = [wrapper]() { delete wrapper; };
wrapper_type_ = std::type_index(typeid(T));
}
// Return a reference to the `wrapper`.
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
}
// Test if the Node is wrapped by type T.
template <typename T>
bool IsWrappedBy() {
return std::type_index(typeid(T)) == wrapper_type_;
}
// Please don't use this API! // Please don't use this API!
int id() const { return id_; } int id() const { return id_; }
...@@ -95,6 +145,11 @@ class Node { ...@@ -95,6 +145,11 @@ class Node {
static int count_; static int count_;
// Please don't use this API or make this public. // Please don't use this API or make this public.
static void ResetId() { count_ = 0; } static void ResetId() { count_ = 0; }
boost::any wrapper_;
std::function<void(void)> wrapper_deleter_;
std::type_index wrapper_type_ = std::type_index(typeid(void));
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
}; };
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class RunnableOp {
public:
RunnableOp(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
class RunnableOp2 {
public:
RunnableOp2(Node* node, bool* alive) : node_(node), alive_(alive) {
node_->WrappedBy(this);
}
virtual ~RunnableOp2() { *alive_ = false; }
private:
Node* node_;
bool* alive_;
};
TEST(NodeTest, Basic) {
bool alive1 = true;
bool alive2 = true;
std::unique_ptr<Node> n1(CreateNodeForTest("n1", Node::Type::kVariable));
std::unique_ptr<Node> n2(CreateNodeForTest("n2", Node::Type::kVariable));
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp2>());
new RunnableOp(n1.get(), &alive1);
new RunnableOp2(n2.get(), &alive2);
EXPECT_TRUE(n1->IsWrappedBy<RunnableOp>());
EXPECT_FALSE(n1->IsWrappedBy<RunnableOp2>());
EXPECT_FALSE(n2->IsWrappedBy<RunnableOp>());
EXPECT_TRUE(n2->IsWrappedBy<RunnableOp2>());
EXPECT_TRUE(alive1);
EXPECT_TRUE(alive2);
n1.reset(nullptr);
n2.reset(nullptr);
EXPECT_FALSE(alive1);
EXPECT_FALSE(alive2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -76,7 +76,7 @@ class Pass { ...@@ -76,7 +76,7 @@ class Pass {
attr_name); attr_name);
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() { attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name; VLOG(30) << "deleting " << attr_name;
delete attr; delete attr;
}; };
} }
......
...@@ -12,10 +12,13 @@ ...@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h" #include <set>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
...@@ -159,10 +162,7 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { ...@@ -159,10 +162,7 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"}); std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"});
PDNode* act = pattern->NewNode( PDNode* act = pattern->NewNode(
[=](Node* x) { [=](Node* x) { return x && x->IsOp() && acts.count(x->Op()->Type()); },
return x && x->IsOp() && acts.count(x->Op()->Type());
},
"act"); "act");
PDNode* fc_out = pattern->NewNode( PDNode* fc_out = pattern->NewNode(
...@@ -196,7 +196,7 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -196,7 +196,7 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "get one concat pattern"; VLOG(40) << "get one concat pattern";
// fc // fc
GET_NODE(fc_w, detector.pattern()); GET_NODE(fc_w, detector.pattern());
GET_NODE(fc_bias, detector.pattern()); GET_NODE(fc_bias, detector.pattern());
......
...@@ -60,7 +60,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { ...@@ -60,7 +60,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle SeqConv EltAdd Relu fuse"; VLOG(40) << "handle SeqConv EltAdd Relu fuse";
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
......
...@@ -31,7 +31,7 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) { ...@@ -31,7 +31,7 @@ void LoDRankTable::Reset(const LoD& lod, size_t level) {
TableItem item; TableItem item;
item.index = i; item.index = i;
item.length = vec[i + 1] - vec[i]; item.length = vec[i + 1] - vec[i];
VLOG(10) << "Add item to rank table " << item.index << " " << item.length; VLOG(100) << "Add item to rank table " << item.index << " " << item.length;
items_.emplace_back(item); items_.emplace_back(item);
} }
// NOTE(yuyang18): // NOTE(yuyang18):
......
...@@ -51,7 +51,7 @@ TEST(mixed_vector, InitWithCount) { ...@@ -51,7 +51,7 @@ TEST(mixed_vector, InitWithCount) {
TEST(mixed_vector, ForEach) { TEST(mixed_vector, ForEach) {
vec<int> tmp; vec<int> tmp;
for (auto& v : tmp) { for (auto& v : tmp) {
VLOG(3) << v; VLOG(30) << v;
} }
} }
......
...@@ -71,7 +71,7 @@ void NaiveExecutor::Prepare(Scope *parent_scope, ...@@ -71,7 +71,7 @@ void NaiveExecutor::Prepare(Scope *parent_scope,
void NaiveExecutor::Run() { void NaiveExecutor::Run() {
for (auto &op : ops_) { for (auto &op : ops_) {
VLOG(4) << "run " << op->Type(); VLOG(40) << "run " << op->Type();
op->Run(*scope_, place_); op->Run(*scope_, place_);
} }
} }
...@@ -95,21 +95,21 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, Scope *scope, ...@@ -95,21 +95,21 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, Scope *scope,
if (var->Persistable()) { if (var->Persistable()) {
auto *ptr = const_cast<Scope *>(ancestor_scope)->Var(var->Name()); auto *ptr = const_cast<Scope *>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { // Create temporary variables in local scope. } else { // Create temporary variables in local scope.
auto *ptr = scope->Var(var->Name()); auto *ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
} }
} else { } else {
for (auto &var : global_block.AllVars()) { for (auto &var : global_block.AllVars()) {
auto *ptr = scope->Var(var->Name()); auto *ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is " VLOG(30) << "Create variable " << var->Name() << ", which pointer is "
<< ptr; << ptr;
} }
} }
} }
......
...@@ -82,7 +82,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -82,7 +82,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR) { if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor"; VLOG(30) << "input " << in << " is not LodTensor";
return; return;
} }
out_var->SetLoDLevel(in_var->GetLoDLevel()); out_var->SetLoDLevel(in_var->GetLoDLevel());
...@@ -241,32 +241,32 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -241,32 +241,32 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
const proto::OpProto::Attr &attr = GetProtoAttr(name); const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) { switch (attr.type()) {
case proto::AttrType::BOOLEANS: { case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS"; << " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>(); this->attrs_[name] = std::vector<bool>();
break; break;
} }
case proto::AttrType::INTS: { case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS"; << " from INTS to INTS";
this->attrs_[name] = std::vector<int>(); this->attrs_[name] = std::vector<int>();
break; break;
} }
case proto::AttrType::FLOATS: { case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS"; << " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>(); this->attrs_[name] = std::vector<float>();
break; break;
} }
case proto::AttrType::STRINGS: { case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS"; << " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>(); this->attrs_[name] = std::vector<std::string>();
break; break;
} }
case proto::AttrType::BLOCKS: { case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name VLOG(110) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS"; << " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>()); this->SetBlocksAttr(name, std::vector<BlockDesc *>());
return; return;
} }
...@@ -499,13 +499,13 @@ void OpDesc::CheckAttrs() { ...@@ -499,13 +499,13 @@ void OpDesc::CheckAttrs() {
} }
void OpDesc::InferShape(const BlockDesc &block) const { void OpDesc::InferShape(const BlockDesc &block) const {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(30) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs(); InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
PADDLE_ENFORCE(static_cast<bool>(infer_shape), PADDLE_ENFORCE(static_cast<bool>(infer_shape),
"%s's infer_shape has not been registered", this->Type()); "%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block); CompileTimeInferShapeContext ctx(*this, block);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(100)) {
std::ostringstream sout; std::ostringstream sout;
auto inames = this->InputArgumentNames(); auto inames = this->InputArgumentNames();
sout << " From ["; sout << " From [";
...@@ -516,7 +516,7 @@ void OpDesc::InferShape(const BlockDesc &block) const { ...@@ -516,7 +516,7 @@ void OpDesc::InferShape(const BlockDesc &block) const {
std::copy(onames.begin(), onames.end(), std::copy(onames.begin(), onames.end(),
std::ostream_iterator<std::string>(sout, ", ")); std::ostream_iterator<std::string>(sout, ", "));
sout << "]"; sout << "]";
VLOG(10) << sout.str(); VLOG(100) << sout.str();
} }
infer_shape(&ctx); infer_shape(&ctx);
} }
...@@ -607,7 +607,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { ...@@ -607,7 +607,7 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto shape = var->GetShape(); auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) { } catch (...) {
VLOG(5) << "GetDim of variable " << name << " error"; VLOG(50) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} }
return res; return res;
...@@ -624,7 +624,7 @@ std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims( ...@@ -624,7 +624,7 @@ std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s)); res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
} }
} catch (...) { } catch (...) {
VLOG(5) << "GetRepeatedDim of variable " << name << " error."; VLOG(50) << "GetRepeatedDim of variable " << name << " error.";
std::rethrow_exception(std::current_exception()); std::rethrow_exception(std::current_exception());
} }
return res; return res;
......
...@@ -46,9 +46,9 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -46,9 +46,9 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) { const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be" VLOG(10) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDesc& op_desc) " "used in unit tests. Use CreateOp(const OpDesc& op_desc) "
"instead."; "instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
......
...@@ -140,7 +140,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -140,7 +140,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope); VLOG(40) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place); PADDLE_THROW("Cannot run operator on place %s", place);
...@@ -160,7 +160,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -160,7 +160,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
} else { } else {
RunImpl(scope, place); RunImpl(scope, place);
} }
VLOG(3) << place << " " << DebugStringEx(&scope); VLOG(30) << place << " " << DebugStringEx(&scope);
} }
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
...@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -259,6 +259,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
if (row_size >= 0) { if (row_size >= 0) {
ss << "[row_size=" << row_size << "]"; ss << "[row_size=" << row_size << "]";
} }
std::string dtype = GetDtype(*scope, output.second[i]);
ss << ":" << dtype;
ss << "[" << GetDims(*scope, var_name, true) << "]"; ss << "[" << GetDims(*scope, var_name, true) << "]";
ss << "(" << GetLoD(*scope, var_name) << ")"; ss << "(" << GetLoD(*scope, var_name) << ")";
} }
...@@ -358,7 +360,7 @@ static bool VarIsTensor(const Variable& var) { ...@@ -358,7 +360,7 @@ static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
const Tensor* GetTensorFromVar(const Variable& var) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>())); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
...@@ -369,7 +371,7 @@ const Tensor* GetTensorFromVar(const Variable& var) { ...@@ -369,7 +371,7 @@ const Tensor* GetTensorFromVar(const Variable& var) {
} }
} }
static Tensor* GetMutableTensorFromVar(Variable* var) { Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
...@@ -414,8 +416,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -414,8 +416,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); return Input<LoDTensor>(name);
return var == nullptr ? nullptr : GetTensorFromVar(*var);
} }
template <> template <>
...@@ -425,17 +426,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -425,17 +426,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(*var); if (var == nullptr) return nullptr;
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return &(var->Get<LoDTensor>());
}); });
return res; return res;
} }
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name); return Output<LoDTensor>(name);
return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
} }
template <> template <>
...@@ -445,10 +450,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -445,10 +450,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std::vector<Tensor*> res; std::vector<Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr if (var == nullptr) return nullptr;
: GetMutableTensorFromVar(var); PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return var->GetMutable<LoDTensor>();
}); });
return res; return res;
} }
...@@ -708,14 +717,14 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -708,14 +717,14 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
auto expected_kernel_key = auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(30) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() && if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) { expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one"; VLOG(30) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain; expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout; expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
...@@ -767,12 +776,14 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -767,12 +776,14 @@ void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars, const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
for (auto& var_name : inplace_vars) { for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; VLOG(30) << "share inplace var " + var_name +
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); " back to it's original scope";
auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
auto* var = transfer_scope.FindVar(var_name); auto* var = transfer_scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr", PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
var_name); var_name);
auto* transformed_tensor = GetTensorFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
} }
} }
...@@ -789,7 +800,7 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -789,7 +800,7 @@ Scope* OperatorWithKernel::TryTransferData(
continue; continue;
} }
auto* tensor_in = GetTensorFromVar(*var); auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
} }
...@@ -807,8 +818,8 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -807,8 +818,8 @@ Scope* OperatorWithKernel::TryTransferData(
transfered_inplace_vars->emplace_back(var_name); transfered_inplace_vars->emplace_back(var_name);
} }
VLOG(3) << "Transform Variable " << var_name << " from " VLOG(30) << "Transform Variable " << var_name << " from "
<< kernel_type_for_var << " to " << expected_kernel_key; << kernel_type_for_var << " to " << expected_kernel_key;
if (new_scope == nullptr) { if (new_scope == nullptr) {
new_scope = &scope.NewScope(); new_scope = &scope.NewScope();
......
...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD"; ...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";
// define some kernel priority // define some kernel priority
/* Define multiple kernel type fallback order*/ /* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
...@@ -224,7 +228,7 @@ class ExecutionContext { ...@@ -224,7 +228,7 @@ class ExecutionContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>(); return var == nullptr ? nullptr : &var->Get<T>();
}); });
...@@ -237,7 +241,7 @@ class ExecutionContext { ...@@ -237,7 +241,7 @@ class ExecutionContext {
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>(); return var == nullptr ? nullptr : var->GetMutable<T>();
}); });
......
...@@ -38,9 +38,20 @@ class ParallelExecutorPrivate { ...@@ -38,9 +38,20 @@ class ParallelExecutorPrivate {
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
: places_(places) {} : places_(places) {}
~ParallelExecutorPrivate() {
if (own_local_scope_) {
for (size_t i = 1; i < local_scopes_.size(); ++i) {
// Skip the first scope, since it is the global scope.
Scope *local_scope = local_scopes_[i];
if (global_scope_->HasKid(local_scope)) {
global_scope_->DeleteScope(local_scope);
}
}
}
}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -188,7 +199,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -188,7 +199,7 @@ void ParallelExecutor::BCastParamsToDevices(
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
if (!main_tensor.IsInitialized()) { if (!main_tensor.IsInitialized()) {
VLOG(3) << "one in var not inited, return!"; VLOG(30) << "one in var not inited, return!";
continue; continue;
} }
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
...@@ -306,16 +317,6 @@ ParallelExecutor::~ParallelExecutor() { ...@@ -306,16 +317,6 @@ ParallelExecutor::~ParallelExecutor() {
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
Scope *local_scope = member_->local_scopes_[i];
if (member_->global_scope_->HasKid(local_scope)) {
member_->global_scope_->DeleteScope(local_scope);
}
}
}
// member_ must be destructed before gcs_ since the destructor of // member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside. // ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset(); member_.reset();
......
...@@ -149,7 +149,7 @@ Variable* Scope::VarInternal(const std::string& name) { ...@@ -149,7 +149,7 @@ Variable* Scope::VarInternal(const std::string& name) {
v = new Variable(); v = new Variable();
vars_[name].reset(v); vars_[name].reset(v);
VLOG(3) << "Create variable " << name; VLOG(30) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
......
...@@ -176,7 +176,7 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -176,7 +176,7 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
if (ids.numel() == 0) { if (ids.numel() == 0) {
VLOG(3) << "keys is empty, please check data!"; VLOG(30) << "keys is empty, please check data!";
} else { } else {
int64_t value_width = value_->numel() / value_->dims()[0]; int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
......
...@@ -22,8 +22,8 @@ namespace framework { ...@@ -22,8 +22,8 @@ namespace framework {
void TensorCopy(const Tensor& src, const platform::Place& dst_place, void TensorCopy(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx, Tensor* dst) { const platform::DeviceContext& ctx, Tensor* dst) {
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " VLOG(30) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place; << dst_place;
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
...@@ -37,8 +37,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -37,8 +37,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(30) << "Skip copy the same data async from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
...@@ -77,8 +77,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -77,8 +77,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
if (platform::is_same_place(src_place, dst_place)) { if (platform::is_same_place(src_place, dst_place)) {
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(30) << "Skip copy the same data async from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
...@@ -114,8 +114,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -114,8 +114,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
void TensorCopySync(const Tensor& src, const platform::Place& dst_place, void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) { Tensor* dst) {
VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place() VLOG(30) << "TensorCopySync " << src.dims() << " from " << src.place()
<< " to " << dst_place; << " to " << dst_place;
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout()); dst->set_layout(src.layout());
...@@ -125,8 +125,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -125,8 +125,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data from " << src_place << " to " VLOG(30) << "Skip copy the same data from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
...@@ -146,8 +146,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -146,8 +146,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
} else if (platform::is_gpu_place(src_place) && } else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) { if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) {
VLOG(3) << "Skip copy the same data from " << src_place << " to " VLOG(30) << "Skip copy the same data from " << src_place << " to "
<< dst_place; << dst_place;
return; return;
} }
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place); auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
......
...@@ -39,7 +39,7 @@ void ThreadPool::Init() { ...@@ -39,7 +39,7 @@ void ThreadPool::Init() {
int num_threads = std::thread::hardware_concurrency(); int num_threads = std::thread::hardware_concurrency();
if (FLAGS_dist_threadpool_size > 0) { if (FLAGS_dist_threadpool_size > 0) {
num_threads = FLAGS_dist_threadpool_size; num_threads = FLAGS_dist_threadpool_size;
VLOG(1) << "set dist_threadpool_size to " << num_threads; VLOG(10) << "set dist_threadpool_size to " << num_threads;
} }
PADDLE_ENFORCE_GT(num_threads, 0); PADDLE_ENFORCE_GT(num_threads, 0);
threadpool_.reset(new ThreadPool(num_threads)); threadpool_.reset(new ThreadPool(num_threads));
...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) { ...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_); std::unique_lock<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all();
} }
scheduled_.notify_all();
for (auto& t : threads_) { for (auto& t : threads_) {
t->join(); t->join();
...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() { ...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); Task task;
scheduled_.wait( {
lock, [this] { return !this->tasks_.empty() || !this->running_; }); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(
lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) { if (!running_ && tasks_.empty()) {
return; return;
} }
if (tasks_.empty()) {
PADDLE_THROW("This thread has no task to Run");
}
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
lock.unlock(); }
// run the task // run the task
task(); task();
......
...@@ -58,7 +58,7 @@ class ThreadPool { ...@@ -58,7 +58,7 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
template <typename Callback> template <typename Callback>
std::future<void> Run(Callback fn) { std::future<void> Run(Callback fn) {
...@@ -69,7 +69,6 @@ class ThreadPool { ...@@ -69,7 +69,6 @@ class ThreadPool {
template <typename Callback> template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException( std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) { Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> { Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try { try {
fn(); fn();
...@@ -84,7 +83,13 @@ class ThreadPool { ...@@ -84,7 +83,13 @@ class ThreadPool {
return nullptr; return nullptr;
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); {
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW("enqueue on stopped ThreadPool");
}
tasks_.push(std::move(task));
}
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
......
...@@ -61,10 +61,10 @@ size_t VarDesc::GetTensorDescNum() const { ...@@ -61,10 +61,10 @@ size_t VarDesc::GetTensorDescNum() const {
void VarDesc::SetShapes( void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) { const std::vector<std::vector<int64_t>> &multiple_dims) {
if (multiple_dims.size() != GetTensorDescNum()) { if (multiple_dims.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size() VLOG(30) << "WARNING: The number of given shapes(" << multiple_dims.size()
<< ") doesn't match the existing tensor number(" << ") doesn't match the existing tensor number("
<< GetTensorDescNum() << GetTensorDescNum()
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size()); SetTensorDescNum(multiple_dims.size());
} }
std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
...@@ -94,11 +94,11 @@ void VarDesc::SetDataType(proto::VarType::Type data_type) { ...@@ -94,11 +94,11 @@ void VarDesc::SetDataType(proto::VarType::Type data_type) {
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type) { const std::vector<proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) { if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types(" VLOG(30) << "WARNING: The number of given data types("
<< multiple_data_type.size() << multiple_data_type.size()
<< ") doesn't match the existing tensor number(" << ") doesn't match the existing tensor number("
<< GetTensorDescNum() << GetTensorDescNum()
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size()); SetTensorDescNum(multiple_data_type.size());
} }
std::vector<proto::VarType::TensorDesc *> tensor_descs = std::vector<proto::VarType::TensorDesc *> tensor_descs =
...@@ -139,11 +139,11 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { ...@@ -139,11 +139,11 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) { void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
if (multiple_lod_level.size() != GetTensorDescNum()) { if (multiple_lod_level.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given lod_levels(" VLOG(30) << "WARNING: The number of given lod_levels("
<< multiple_lod_level.size() << multiple_lod_level.size()
<< ") doesn't match the existing tensor number(" << ") doesn't match the existing tensor number("
<< GetTensorDescNum() << GetTensorDescNum()
<< "). The Reader is going to be reinitialized."; << "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size()); SetTensorDescNum(multiple_lod_level.size());
} }
switch (desc_.type().type()) { switch (desc_.type().type()) {
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
namespace paddle { namespace paddle {
...@@ -24,5 +27,27 @@ class VarTypeInference { ...@@ -24,5 +27,27 @@ class VarTypeInference {
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
}; };
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
}
}
protected:
virtual std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const = 0;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
if(WITH_TESTING) if(WITH_TESTING)
include(test.cmake) # some generic cmake funtion for inference include(tests/test.cmake) # some generic cmake funtion for inference
endif() endif()
# analysis and tensorrt must be added before creating static library, # analysis and tensorrt must be added before creating static library,
# otherwise, there would be undefined reference to them in static library. # otherwise, there would be undefined reference to them in static library.
......
...@@ -60,7 +60,7 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -60,7 +60,7 @@ class DfgPassManagerImpl final : public DfgPassManager {
private: private:
void AddPass(const std::string& name, AnalysisPass* pass) { void AddPass(const std::string& name, AnalysisPass* pass) {
VLOG(3) << "Adding pass " << name; VLOG(30) << "Adding pass " << name;
Register(name, pass); Register(name, pass);
AddGraphvizDebugerPass(pass); AddGraphvizDebugerPass(pass);
} }
...@@ -101,22 +101,25 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } ...@@ -101,22 +101,25 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes; std::vector<std::string> passes;
passes.push_back("graph_viz_pass"); // add graphviz for debug.
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (use_mkldnn_) { if (use_mkldnn_) {
VLOG(3) << "Adding MKL-DNN placement pass"; VLOG(30) << "Adding MKL-DNN placement pass";
passes.push_back("mkldnn_placement_pass"); passes.push_back("mkldnn_placement_pass");
} }
#endif #endif
// infer_clean_graph_pass should be the first default pass // infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass. // after mkldnn_placement_pass.
passes.push_back("infer_clean_graph_pass"); passes.push_back("infer_clean_graph_pass");
passes.push_back("graph_viz_pass"); // add graphviz for debug.
for (auto& pass : ir_passes_) { for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) { // skip mkldnn pass when use_mkldnn_ = false;
bool skip_pass = (!use_mkldnn_) && pass.find("mkldnn") != std::string::npos;
if (!disabled_ir_passes_.count(pass) && !skip_pass) {
passes.push_back(pass); passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug. passes.push_back("graph_viz_pass"); // add graphviz for debug.
} }
} }
passes.push_back("graph_viz_pass");
argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes)); argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes));
for (auto& x : data_) { for (auto& x : data_) {
......
...@@ -68,8 +68,8 @@ struct Argument { ...@@ -68,8 +68,8 @@ struct Argument {
key); key);
attrs_[key] = data; attrs_[key] = data;
attr_deleters_[key] = [data, key]() { attr_deleters_[key] = [data, key]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; VLOG(30) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
VLOG(3) << "argument delete attr: " << key; VLOG(30) << "argument delete attr: " << key;
delete data; delete data;
}; };
} }
......
...@@ -132,7 +132,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) { ...@@ -132,7 +132,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) {
Node *x{nullptr}; Node *x{nullptr};
if (ir_node->IsOp()) { if (ir_node->IsOp()) {
PADDLE_ENFORCE(ir_node->Op()); PADDLE_ENFORCE(ir_node->Op());
VLOG(4) << "get op " << ir_node << " " << ir_node->Name(); VLOG(40) << "get op " << ir_node << " " << ir_node->Name();
x = nodes.Create(Node::Type::kFunction); x = nodes.Create(Node::Type::kFunction);
x->attr("ir_node").Pointer() = ir_node; x->attr("ir_node").Pointer() = ir_node;
PADDLE_ENFORCE(ir_node->Op()->Proto()); PADDLE_ENFORCE(ir_node->Op()->Proto());
...@@ -141,7 +141,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) { ...@@ -141,7 +141,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) {
} else if (ir_node->IsVar()) { } else if (ir_node->IsVar()) {
// Not create a Node for IR ControlDepVar, considering Inference currently // Not create a Node for IR ControlDepVar, considering Inference currently
// just used in single thread scenerio. // just used in single thread scenerio.
VLOG(4) << "get var " << ir_node->Name(); VLOG(40) << "get var " << ir_node->Name();
x = nodes.Create(Node::Type::kValue); x = nodes.Create(Node::Type::kValue);
x->attr("ir_node").Pointer() = ir_node; x->attr("ir_node").Pointer() = ir_node;
x->SetName(ir_node->Name()); x->SetName(ir_node->Name());
...@@ -151,9 +151,9 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) { ...@@ -151,9 +151,9 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) {
} }
ir_node_map.emplace(ir_node, x); ir_node_map.emplace(ir_node, x);
} }
VLOG(4) << "finish creating Nodes"; VLOG(40) << "finish creating Nodes";
VLOG(4) << "to create edge"; VLOG(40) << "to create edge";
// Create links // Create links
for (auto *ir_node : graph.Nodes()) { for (auto *ir_node : graph.Nodes()) {
auto it = ir_node_map.find(ir_node); auto it = ir_node_map.find(ir_node);
...@@ -175,7 +175,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) { ...@@ -175,7 +175,7 @@ void DataFlowGraph::Build(const framework::ir::Graph &graph) {
"Can't deduce any inputs from the graph, Is the graph empty?"); "Can't deduce any inputs from the graph, Is the graph empty?");
ir_graph = &graph; ir_graph = &graph;
VLOG(3) << "finished build from IR"; VLOG(30) << "finished build from IR";
} }
void DataFlowGraph::Clean() { void DataFlowGraph::Clean() {
......
...@@ -239,9 +239,10 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { ...@@ -239,9 +239,10 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
framework::BlockDesc block_desc(nullptr, &proto); framework::BlockDesc block_desc(nullptr, &proto);
block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
VLOG(4) << "origin variable size: " VLOG(40) << "origin variable size: "
<< argument_->origin_program_desc->blocks(0).vars().size(); << argument_->origin_program_desc->blocks(0).vars().size();
VLOG(4) << "transformed variable size: " << block_desc.Proto()->vars().size(); VLOG(40) << "transformed variable size: "
<< block_desc.Proto()->vars().size();
// copy ops. // copy ops.
for (auto *node : block_node->subgraph) { for (auto *node : block_node->subgraph) {
......
...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { ...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png"; auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message; std::string message;
VLOG(3) << "draw to " << png_path; VLOG(30) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message); ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
} }
......
...@@ -29,7 +29,7 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir, ...@@ -29,7 +29,7 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope); argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope);
// Load parameters. // Load parameters.
VLOG(3) << "Loading parameters from " << model_dir; VLOG(30) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr), LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr),
model_dir, prog_file, param_file); model_dir, prog_file, param_file);
} }
......
...@@ -35,21 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) { ...@@ -35,21 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) {
std::stringstream ss; std::stringstream ss;
// NOTE these commands only works on linux. // NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path; ss << "mkdir -p " << *argument_->model_output_store_path;
VLOG(3) << "run command: " << ss.str(); VLOG(30) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str(""); ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*" ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path; << " " << *argument_->model_output_store_path;
VLOG(3) << "run command: " << ss.str(); VLOG(30) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program // Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc, PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call " "program desc is not transformed, should call "
"DataFlowGraphToFluidPass first."); "DataFlowGraphToFluidPass first.");
VLOG(3) << "store analyzed program to " VLOG(30) << "store analyzed program to "
<< *argument_->model_output_store_path; << *argument_->model_output_store_path;
const std::string program_output_path = const std::string program_output_path =
*argument_->model_output_store_path + "/__model__"; *argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary); std::ofstream file(program_output_path, std::ios::binary);
......
...@@ -23,7 +23,7 @@ namespace analysis { ...@@ -23,7 +23,7 @@ namespace analysis {
bool PassManager::Initialize(Argument* argument) { bool PassManager::Initialize(Argument* argument) {
argument_ = argument; argument_ = argument;
for (auto& pass : data_) { for (auto& pass : data_) {
VLOG(3) << "Initializing pass [" << pass->repr() << "]"; VLOG(30) << "Initializing pass [" << pass->repr() << "]";
if (!pass->Initialize(argument)) { if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]"; LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false; return false;
...@@ -34,7 +34,7 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -34,7 +34,7 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
VLOG(3) << "Total " << data_.size() << " Analysys passes"; VLOG(30) << "Total " << data_.size() << " Analysys passes";
for (auto& pass : data_) { for (auto& pass : data_) {
string::PrettyLogEndl(string::Style::H1(), "* Running Analysis pass [%s]", string::PrettyLogEndl(string::Style::H1(), "* Running Analysis pass [%s]",
pass->repr()); pass->repr());
......
...@@ -232,7 +232,7 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -232,7 +232,7 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
BriefNode *brief_node = itr.second; BriefNode *brief_node = itr.second;
if (!brief_node->node->attr(kMarkerAttrName).Bool()) { if (!brief_node->node->attr(kMarkerAttrName).Bool()) {
VLOG(4) << brief_node->node->id() << " node not a trt candicate."; VLOG(40) << brief_node->node->id() << " node not a trt candicate.";
continue; continue;
} }
......
...@@ -25,9 +25,9 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( ...@@ -25,9 +25,9 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)(); SubGraphFuse(graph, node_inside_subgraph_teller_, argument_)();
VLOG(4) << "debug info " VLOG(40) << "debug info "
<< graph->HumanReadableInfo(false /*show_values*/, << graph->HumanReadableInfo(false /*show_values*/,
true /*show_functions*/); true /*show_functions*/);
} }
} // namespace analysis } // namespace analysis
......
...@@ -37,8 +37,8 @@ if(WITH_TESTING) ...@@ -37,8 +37,8 @@ if(WITH_TESTING)
ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book) ARGS --word2vec_dirname=${WORD2VEC_MODEL_DIR} --book_dirname=${PYTHON_TESTS_DIR}/book)
set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification) set_tests_properties(test_api_impl PROPERTIES DEPENDS test_image_classification)
endif() endif()
cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps}
ARGS --dirname=${PYTHON_TESTS_DIR}/book) ARGS --dirname=${WORD2VEC_MODEL_DIR})
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
cc_library(paddle_inference_tensorrt_subgraph_engine cc_library(paddle_inference_tensorrt_subgraph_engine
......
...@@ -38,7 +38,7 @@ using contrib::AnalysisConfig; ...@@ -38,7 +38,7 @@ using contrib::AnalysisConfig;
bool AnalysisPredictor::Init( bool AnalysisPredictor::Init(
const std::shared_ptr<framework::Scope> &parent_scope, const std::shared_ptr<framework::Scope> &parent_scope,
const std::shared_ptr<framework::ProgramDesc> &program) { const std::shared_ptr<framework::ProgramDesc> &program) {
VLOG(3) << "Predictor::init()"; VLOG(30) << "Predictor::init()";
#if !defined(_WIN32) #if !defined(_WIN32)
if (FLAGS_profile) { if (FLAGS_profile) {
LOG(WARNING) << "Profiler is actived, might affect the performance"; LOG(WARNING) << "Profiler is actived, might affect the performance";
...@@ -89,7 +89,7 @@ bool AnalysisPredictor::Init( ...@@ -89,7 +89,7 @@ bool AnalysisPredictor::Init(
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, std::vector<PaddleTensor> *output_data,
int batch_size) { int batch_size) {
VLOG(3) << "Predictor::predict"; VLOG(30) << "Predictor::predict";
inference::Timer timer; inference::Timer timer;
timer.tic(); timer.tic();
// set feed variable // set feed variable
...@@ -109,7 +109,7 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -109,7 +109,7 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
LOG(ERROR) << "fail to get fetches"; LOG(ERROR) << "fail to get fetches";
return false; return false;
} }
VLOG(3) << "predict cost: " << timer.toc() << "ms"; VLOG(30) << "predict cost: " << timer.toc() << "ms";
// Fix TensorArray reuse not cleaned bug. // Fix TensorArray reuse not cleaned bug.
tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get());
...@@ -119,7 +119,7 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -119,7 +119,7 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed"; VLOG(30) << "Predictor::set_feed";
if (inputs.size() != feeds_.size()) { if (inputs.size() != feeds_.size()) {
LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get " LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
<< inputs.size(); << inputs.size();
...@@ -184,7 +184,7 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -184,7 +184,7 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::get_fetch"; VLOG(30) << "Predictor::get_fetch";
outputs->resize(fetchs_.size()); outputs->resize(fetchs_.size());
for (size_t i = 0; i < fetchs_.size(); ++i) { for (size_t i = 0; i < fetchs_.size(); ++i) {
int idx = boost::get<int>(fetchs_[i]->GetAttr("col")); int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
...@@ -246,7 +246,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -246,7 +246,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
} }
CHECK(argument_.transformed_program_desc); CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor"; VLOG(50) << "to prepare executor";
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(*argument_.transformed_program_desc)); new framework::ProgramDesc(*argument_.transformed_program_desc));
if (argument_.Has(framework::ir::kParamScopeAttr)) { if (argument_.Has(framework::ir::kParamScopeAttr)) {
...@@ -260,7 +260,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -260,7 +260,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) { AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
VLOG(3) << "create AnalysisConfig"; VLOG(30) << "create AnalysisConfig";
if (config.use_gpu) { if (config.use_gpu) {
// 1. GPU memeroy // 1. GPU memeroy
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
...@@ -274,7 +274,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -274,7 +274,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
std::string flag = "--fraction_of_gpu_memory_to_use=" + std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory); std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag); flags.push_back(flag);
VLOG(3) << "set flag: " << flag; VLOG(30) << "set flag: " << flag;
framework::InitGflags(flags); framework::InitGflags(flags);
} }
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
......
...@@ -24,7 +24,7 @@ using contrib::AnalysisConfig; ...@@ -24,7 +24,7 @@ using contrib::AnalysisConfig;
TEST(AnalysisPredictor, ZeroCopy) { TEST(AnalysisPredictor, ZeroCopy) {
AnalysisConfig config; AnalysisConfig config;
config.model_dir = FLAGS_dirname + "/word2vec.inference.model"; config.model_dir = FLAGS_dirname;
config.use_feed_fetch_ops = false; config.use_feed_fetch_ops = false;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config); auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle_inference_api.h"
namespace paddle { namespace paddle {
......
...@@ -157,7 +157,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -157,7 +157,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
LOG(ERROR) << "fail to get fetches"; LOG(ERROR) << "fail to get fetches";
return false; return false;
} }
VLOG(3) << "predict cost: " << timer.toc() << "ms"; VLOG(30) << "predict cost: " << timer.toc() << "ms";
// Fix TensorArray reuse not cleaned bug. // Fix TensorArray reuse not cleaned bug.
tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get());
......
...@@ -34,7 +34,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -34,7 +34,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) { bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
VLOG(3) << "Predictor::init()"; VLOG(30) << "Predictor::init()";
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
} else { } else {
...@@ -70,7 +70,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -70,7 +70,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
OptimizeInferenceProgram(); OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables"; VLOG(50) << "to create variables";
executor_->CreateVariables(*inference_program_, executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0); sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
...@@ -114,9 +114,9 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -114,9 +114,9 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
new ProgramDesc(*inference_program_->Proto())); new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument); Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc); CHECK(argument.transformed_program_desc);
VLOG(5) << "transformed program:\n" VLOG(50) << "transformed program:\n"
<< argument.transformed_program_desc->SerializeAsString(); << argument.transformed_program_desc->SerializeAsString();
VLOG(5) << "to prepare executor"; VLOG(50) << "to prepare executor";
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc)); new framework::ProgramDesc(*argument.transformed_program_desc));
} }
...@@ -129,7 +129,7 @@ template <> ...@@ -129,7 +129,7 @@ template <>
std::unique_ptr<PaddlePredictor> std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<MixedRTConfig, PaddleEngineKind::kAutoMixedTensorRT>( CreatePaddlePredictor<MixedRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
const MixedRTConfig& config) { const MixedRTConfig& config) {
VLOG(3) << "create TensorRTSubgraphPredictor"; VLOG(30) << "create TensorRTSubgraphPredictor";
if (config.use_gpu) { if (config.use_gpu) {
// 1. GPU memeroy // 1. GPU memeroy
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
...@@ -143,7 +143,7 @@ CreatePaddlePredictor<MixedRTConfig, PaddleEngineKind::kAutoMixedTensorRT>( ...@@ -143,7 +143,7 @@ CreatePaddlePredictor<MixedRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
std::string flag = "--fraction_of_gpu_memory_to_use=" + std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory); std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag); flags.push_back(flag);
VLOG(3) << "set flag: " << flag; VLOG(30) << "set flag: " << flag;
framework::InitGflags(flags); framework::InitGflags(flags);
} }
} }
......
...@@ -45,7 +45,7 @@ void Main() { ...@@ -45,7 +45,7 @@ void Main() {
config.fraction_of_gpu_memory = 0.1; // set by yourself config.fraction_of_gpu_memory = 0.1; // set by yourself
predictor = CreatePaddlePredictor<paddle::contrib::MixedRTConfig>(config); predictor = CreatePaddlePredictor<paddle::contrib::MixedRTConfig>(config);
VLOG(3) << "begin to process data"; VLOG(30) << "begin to process data";
// Just a single batch of data. // Just a single batch of data.
std::string line; std::string line;
std::ifstream file(FLAGS_data); std::ifstream file(FLAGS_data);
...@@ -60,13 +60,13 @@ void Main() { ...@@ -60,13 +60,13 @@ void Main() {
PaddleBuf(record.data.data(), record.data.size() * sizeof(float)); PaddleBuf(record.data.data(), record.data.size() * sizeof(float));
input.dtype = PaddleDType::FLOAT32; input.dtype = PaddleDType::FLOAT32;
VLOG(3) << "run executor"; VLOG(30) << "run executor";
std::vector<PaddleTensor> output; std::vector<PaddleTensor> output;
predictor->Run({input}, &output, 1); predictor->Run({input}, &output, 1);
VLOG(3) << "output.size " << output.size(); VLOG(30) << "output.size " << output.size();
auto& tensor = output.front(); auto& tensor = output.front();
VLOG(3) << "output: " << SummaryTensor(tensor); VLOG(30) << "output: " << SummaryTensor(tensor);
// compare with reference result // compare with reference result
CheckOutput(FLAGS_refer, tensor); CheckOutput(FLAGS_refer, tensor);
......
...@@ -47,7 +47,7 @@ static void split(const std::string& str, char sep, ...@@ -47,7 +47,7 @@ static void split(const std::string& str, char sep,
} }
Record ProcessALine(const std::string& line) { Record ProcessALine(const std::string& line) {
VLOG(3) << "process a line"; VLOG(30) << "process a line";
std::vector<std::string> columns; std::vector<std::string> columns;
split(line, '\t', &columns); split(line, '\t', &columns);
CHECK_EQ(columns.size(), 2UL) CHECK_EQ(columns.size(), 2UL)
...@@ -65,8 +65,8 @@ Record ProcessALine(const std::string& line) { ...@@ -65,8 +65,8 @@ Record ProcessALine(const std::string& line) {
for (auto& s : shape_strs) { for (auto& s : shape_strs) {
record.shape.push_back(std::stoi(s)); record.shape.push_back(std::stoi(s));
} }
VLOG(3) << "data size " << record.data.size(); VLOG(30) << "data size " << record.data.size();
VLOG(3) << "data shape size " << record.shape.size(); VLOG(30) << "data shape size " << record.shape.size();
return record; return record;
} }
...@@ -78,8 +78,8 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) { ...@@ -78,8 +78,8 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
file.close(); file.close();
size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); size_t numel = output.data.length() / PaddleDtypeSize(output.dtype);
VLOG(3) << "predictor output numel " << numel; VLOG(30) << "predictor output numel " << numel;
VLOG(3) << "reference output numel " << refer.data.size(); VLOG(30) << "reference output numel " << refer.data.size();
CHECK_EQ(numel, refer.data.size()); CHECK_EQ(numel, refer.data.size());
switch (output.dtype) { switch (output.dtype) {
case PaddleDType::INT64: { case PaddleDType::INT64: {
......
...@@ -49,11 +49,11 @@ void Main(bool use_gpu) { ...@@ -49,11 +49,11 @@ void Main(bool use_gpu) {
config.fraction_of_gpu_memory = 0.1; // set by yourself config.fraction_of_gpu_memory = 0.1; // set by yourself
} }
VLOG(3) << "init predictor"; VLOG(30) << "init predictor";
predictor = CreatePaddlePredictor<NativeConfig>(config); predictor = CreatePaddlePredictor<NativeConfig>(config);
analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config); analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
VLOG(3) << "begin to process data"; VLOG(30) << "begin to process data";
// Just a single batch of data. // Just a single batch of data.
std::string line; std::string line;
std::ifstream file(FLAGS_data); std::ifstream file(FLAGS_data);
...@@ -68,13 +68,13 @@ void Main(bool use_gpu) { ...@@ -68,13 +68,13 @@ void Main(bool use_gpu) {
PaddleBuf(record.data.data(), record.data.size() * sizeof(float)); PaddleBuf(record.data.data(), record.data.size() * sizeof(float));
input.dtype = PaddleDType::FLOAT32; input.dtype = PaddleDType::FLOAT32;
VLOG(3) << "run executor"; VLOG(30) << "run executor";
std::vector<PaddleTensor> output, analysis_output; std::vector<PaddleTensor> output, analysis_output;
predictor->Run({input}, &output, 1); predictor->Run({input}, &output, 1);
VLOG(3) << "output.size " << output.size(); VLOG(30) << "output.size " << output.size();
auto& tensor = output.front(); auto& tensor = output.front();
VLOG(3) << "output: " << SummaryTensor(tensor); VLOG(30) << "output: " << SummaryTensor(tensor);
// compare with reference result // compare with reference result
CheckOutput(FLAGS_refer, tensor); CheckOutput(FLAGS_refer, tensor);
......
...@@ -26,7 +26,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { ...@@ -26,7 +26,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
// parameter. // parameter.
if (var_name == "feed" || var_name == "fetch") continue; if (var_name == "feed" || var_name == "fetch") continue;
if (var->Type() == typeid(framework::LoDTensorArray)) { if (var->Type() == typeid(framework::LoDTensorArray)) {
VLOG(4) << "collect " << var_name; VLOG(40) << "collect " << var_name;
arrays_.push_back(var->GetMutable<framework::LoDTensorArray>()); arrays_.push_back(var->GetMutable<framework::LoDTensorArray>());
} }
} }
...@@ -34,7 +34,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { ...@@ -34,7 +34,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
CollectTensorArrays(kid); CollectTensorArrays(kid);
} }
VLOG(3) << "Collect " << arrays_.size() << " arrays"; VLOG(30) << "Collect " << arrays_.size() << " arrays";
flag_ = false; flag_ = false;
} }
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册