diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 1d0727b80baf7558eb5f391257248a004b1d9f61..d0e5565139c54cdd96a02cf3eeff331b4b4c7762 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -319,6 +319,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } } + bool transfered = false; DataTranferHelper data_transfer_helper(place, var_scope); for (auto& var_name_item : *ins_map_temp) { bool should_skip_input = @@ -334,6 +335,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, if (var->IsType() || var->IsType()) { tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); } else if (var->IsType()) { + if (var->Get().size() == 0) { + continue; + } tensor_in = static_cast(&(var->Get()[0])); } else { @@ -389,6 +393,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } if (is_transferred) { + transfered = true; // update RuntimeContext.inputs and original op_func_node inputs op_func_node->input_index[var_name_item.first][i] = var_scope->VarId(new_var_name); @@ -426,11 +431,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, } } - // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent - // with instruction. (hot fix, it is not good design here) - op_func_node->operator_base_ = - std::shared_ptr(framework::OpRegistry::CreateOp( - op_base->Type(), new_ins, new_outs, op_base->Attrs())); + if (transfered) { + // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent + // with instruction. (hot fix, it is not good design here) + op_func_node->operator_base_ = + std::shared_ptr(framework::OpRegistry::CreateOp( + op_base->Type(), new_ins, new_outs, op_base->Attrs())); + } op_func_node->no_data_transform_index = std::move(no_data_transform_index); } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index cf0b64cbc3a705b46061f1280a1b2483006bbc0d..29aa7b13a270eb348aa8f603959fb52f4eef677a 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -300,8 +300,16 @@ void InterpreterCore::Convert( gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(), platform::GenerateDeviceEventFlag()); } + bool inplaced = false; + for (auto inst : vec_instruction_) { + if (inst.OpBase()->Type() == "share_buffer" || + inst.OpBase()->Type() == "share_data") { + VLOG(4) << "Already inplaced, skip inplace now."; + inplaced = true; + } + } - if (FLAGS_new_executor_use_inplace) { + if (FLAGS_new_executor_use_inplace && !inplaced) { BuildInplace(); } @@ -565,12 +573,11 @@ void InterpreterCore::RunNextInstructions( const Instruction& instr, std::queue* reserved_next_ops, std::vector>* atomic_deps, std::vector>* atomic_var_ref) { - VLOG(4) << "atomic 1:" << atomic_deps; auto& next_instr = instr.NextInstructions(); auto IsReady = [atomic_deps](size_t next_id) { - VLOG(4) << "atomic:" << atomic_deps << " " << &(*atomic_deps)[next_id] - << " " << next_id; + VLOG(4) << "atomic:" << atomic_deps << " op_id: " << next_id + << ", remain deps: " << (*atomic_deps)[next_id]; return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1; }; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 360e0222a516c6220569467f04e62ad3c0d4e41b..a704411f3bb713421dc23903112eacd0de363b57 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -428,19 +428,19 @@ void build_op_func_list(const platform::Place& place, op_func_node.dev_ctx_ = dev_ctx; VLOG(3) << op_with_kernel->Type() << " : expected_kernel_key : " << expected_kernel_key; - auto exec_ctx = - ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); // see OperatorWithKernel::RunImpl in operator.cc for why if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && op->Attr(kAllKernelsMustComputeRuntimeShape))) { InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); // TODO(Aurelius84): In case of control flow ops, they are NOT - // inheritted - // from OperatorWithKernel. + // inheritted from OperatorWithKernel. op_with_kernel->Info().infer_shape_(&infer_shape_ctx); } + auto exec_ctx = + ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); + auto run_phi_kernel = false; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( op_with_kernel->Type())) { @@ -476,7 +476,6 @@ void build_op_func_list(const platform::Place& place, op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx, &pt_kernel_context); op_func_node.pt_kernel_ = op_with_kernel->PhiKernel(); - (*op_func_node.pt_kernel_)(&pt_kernel_context); } else { auto kernels_iter = all_op_kernels.find(op->Type()); @@ -711,6 +710,7 @@ std::map> build_op_downstream_map( const std::set random_op_set = { "bernoulli", "poisson", "multinomial", "gaussian_random", "uniform_random", "randint", "randperm", "exponential"}; + int dependence_op_idx = -1; for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) { @@ -721,6 +721,147 @@ std::map> build_op_downstream_map( } } + // add dependency for communication op + const std::string communication_op_prefix = "c_"; + dependence_op_idx = -1; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type().find( + communication_op_prefix) != std::string::npos) { + if (dependence_op_idx != -1) { + op2dependences[op_idx].insert(dependence_op_idx); + } + dependence_op_idx = op_idx; + } + } + + // TODO(zhiqiu): there still some cases not handled + // add dependency for c_sync_comm_stream + + // in program, we can add only one c_sync_comm_stream to sync all + // communication ops. + // c_allreduce_sum(a) + // c_allreduce_sum(b) + // c_allreduce_sum(c) + // c_sync_comm_stream(a) + const std::string kSyncComm = "c_sync_comm_stream"; + dependence_op_idx = -1; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) { + dependence_op_idx = op_idx; + } else { + if (dependence_op_idx != -1) { + VLOG(4) << "Add depend from " + << vec_instruction[dependence_op_idx].OpBase()->Type() << " to " + << vec_instruction[op_idx].OpBase()->Type(); + op2dependences[op_idx].insert(dependence_op_idx); + } + } + } + + // add dependency for coalesce_tensor + const std::string kCoalesceTensor = "coalesce_tensor"; + for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { + if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) { + VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx; + auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0]; + auto outputs = vec_instruction[op_idx].Outputs().at("Output"); + + auto is_read = [](const Instruction& inst, int var_id) -> bool { + for (auto pair : inst.Inputs()) { + for (auto item : pair.second) { + if (item == var_id) { + return true; + } + } + } + return false; + }; + + auto is_write = [](const Instruction& inst, int var_id) -> bool { + for (auto pair : inst.Outputs()) { + for (auto item : pair.second) { + if (item == var_id) { + return true; + } + } + } + return false; + }; + + // find first op that reads fused_out + auto first_read_fused_out_op = -1; + for (auto j = op_idx + 1; j < vec_instruction.size(); ++j) { + if (is_read(vec_instruction[j], fused_out)) { + first_read_fused_out_op = j; + break; + } + } + + if (UNLIKELY(first_read_fused_out_op == -1)) { + VLOG(4) << "No op read FusedOutput"; + continue; + } + + // find ops that write 'outputs' between (op_index, + // first_read_fused_out_op) + // add depend: them->first_read_fused_out_op + for (auto j = op_idx + 1; + j < static_cast(first_read_fused_out_op); ++j) { + for (auto var_id : outputs) { + if (is_write(vec_instruction[j], var_id)) { + op2dependences[first_read_fused_out_op].insert(j); + VLOG(4) << j << " -> " << first_read_fused_out_op; + VLOG(4) + << "Add depend from " << vec_instruction[j].OpBase()->Type() + << " to " + << vec_instruction[first_read_fused_out_op].OpBase()->Type(); + } + } + } + + // find first op read 'outputs' between (first_read_fused_out_op, end) + // add depned: first_read_fused_out_op -> first op that reads 'outputs' + + // special case for consecutive communication ops, for example, + // FusedOutput = c_sync_calc_stream(FusedOutput) + // FusedOutput= c_allreduce_sum(FusedOutput) + // FusedOutput = c_sync_comm_stream(FusedOutput) + // we should take the last one to add depned instead of + // 'first_read_fused_out_op' + size_t target = first_read_fused_out_op; + for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size(); + ++j) { + if (j == target + 1 && + vec_instruction[target].OpBase()->Type().find( + communication_op_prefix) != std::string::npos && + vec_instruction[j].OpBase()->Type().find(communication_op_prefix) != + std::string::npos) { + VLOG(4) << "Found consecutive communication ops, " + << vec_instruction[target].OpBase()->Type() << " -> " + << vec_instruction[j].OpBase()->Type(); + target = j; + continue; + } + + for (auto var_id : outputs) { + if (is_read(vec_instruction[j], var_id)) { + op2dependences[j].insert(target); + VLOG(4) << target << " -> " << j; + VLOG(4) << "Add depend from " + << vec_instruction[target].OpBase()->Type() << " to " + << vec_instruction[j].OpBase()->Type(); + } + } + } + } + } + for (auto pair : op2dependences) { + VLOG(10) << pair.first << " Depends on " << pair.second.size(); + std::ostringstream oss; + std::copy(pair.second.begin(), pair.second.end(), + std::ostream_iterator(oss, " ")); + VLOG(10) << oss.str(); + } return std::move(get_downstream_map(op2dependences)); } diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index f4dfb76884f17593daefbb4842a8655131a9b8fd..e2730a1b825e9a277b72b9e5a28f503cc645953a 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -85,7 +85,7 @@ PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, // NOTE(Ruibiao): This FLAGS is just to be compatibled with // the old single-stream CUDA allocator. It will be removed // after StreamSafeCudaAllocator has been fully tested. -PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, false, +PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, true, "Enable StreamSafeCUDAAllocator"); PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, false, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index eb833428afa429b35bc99ebbde9bd231ab07b722..935f7b53eba5741e1896725d2d8c20c7eeb1a4e8 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -394,19 +394,10 @@ def _is_enable_standalone_executor(): Whether to use experimental executor `StandaloneExecutor`. """ flag = False - # NOTE(zhiqiu): enable STANDALONE_EXECUTOR on windows platform by default - # It should be enabled on all platform in the future. - - import platform - sysstr = platform.system().lower() - if sysstr == 'windows': - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', 1) - else: - env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None) + env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1') if env_val in [1, '1', True, 'True', 'true']: flag = True - warnings.warn("STANDALONE_EXECUTOR is enabled.") return flag @@ -1386,6 +1377,10 @@ class Executor(object): program = pruned_program def _can_use_interpreter_core(program, place): + if core.is_compiled_with_npu() or core.is_compiled_with_xpu( + ) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu(): + return False + compiled = isinstance(program, compiler.CompiledProgram) # NOTE(zhiqiu): do not support compiled program now if compiled: @@ -1396,6 +1391,8 @@ class Executor(object): # else: # return False else: + if isinstance(program._graph, compiler.CompiledProgram): + return False assert isinstance(program, Program) return True diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ac3c708cc001ec9b4574bc58af3cf6271948fd9d..8b84a9c524adf583b1c02496f8bec09c32b10678 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -951,7 +951,7 @@ endif() if (WITH_DISTRIBUTE AND NOT APPLE) if(WITH_GPU OR WITH_ROCM) set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT 120) - set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 160) + set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 360) endif() endif() @@ -1033,7 +1033,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120) set_tests_properties(test_argsort_op PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120) +set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120) set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120) @@ -1072,7 +1072,7 @@ set_tests_properties(test_space_to_depth_op PROPERTIES TIMEOUT 200) set_tests_properties(test_dyn_rnn PROPERTIES TIMEOUT 120) set_tests_properties(test_sgd_op PROPERTIES TIMEOUT 250) set_tests_properties(test_parallel_executor_seresnext_base_gpu PROPERTIES TIMEOUT 120) -set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120) +set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) set_tests_properties(test_matrix_nms_op PROPERTIES TIMEOUT 120) set_tests_properties(test_generator_dataloader PROPERTIES TIMEOUT 120) set_tests_properties(test_partial_concat_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective_reducescatter.py b/python/paddle/fluid/tests/unittests/collective_reducescatter.py index 8b989c73d4deb69e85b821ef0b2091ef0af7a0c4..00d4a1c4cf6bde4491a5ed86968b3fbbe44b9009 100644 --- a/python/paddle/fluid/tests/unittests/collective_reducescatter.py +++ b/python/paddle/fluid/tests/unittests/collective_reducescatter.py @@ -48,6 +48,7 @@ class TestCollectiveReduceScatter(TestCollectiveRunnerBase): tindata = layers.data( name="tindata", shape=[10, 1000], dtype='float32') toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) + toutdata = fluid.layers.collective._c_sync_comm_stream(toutdata, 0) return toutdata diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py index 488e7c809fc3900710e20bb80ea58385c71e994c..f0ed2cdc049500e53bae28b5d0503e50ac89c60d 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py @@ -32,7 +32,7 @@ def prepare_python_path_and_return_module(path): assert filename.endswith(py_suffix), filename env_name = 'PYTHONPATH' - python_path = env_name + python_path = os.environ.get(env_name, '') if python_path: paths = [p for p in python_path.split(":") if p] if dirname not in paths: @@ -41,6 +41,7 @@ def prepare_python_path_and_return_module(path): else: python_path = path os.environ[env_name] = python_path + print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1) return filename[:-len(py_suffix)] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 5be531258edacbabf98bf6c0dd9405cbdc68561d..808821f06cbae14018ef926347f95e530f03d114 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -91,7 +91,7 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) -set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0) if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 722926b0d77f747b5793f417833f0b99069aa11a..55f87540c1b8a6c0bbf0d10846bdf44357882c7f 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -27,6 +27,7 @@ paddle.enable_static() class TestSliceOpDoubleGradCheck(unittest.TestCase): + @prog_scope() def func(self, place): self.config() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 47a6d2b811552763506e4b213894eead7c992e2d..6bf811be2ad0d8e92c5664e453d5d9e10d5730d0 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -30,6 +30,7 @@ from paddle.fluid import compiler from paddle.fluid import Program, program_guard from op_test import OpTest, _set_use_system_allocator +from decorator_helper import prog_scope _set_use_system_allocator(True) @@ -105,6 +106,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): sgd_opt.backward(out) return main, startup, [out, conv, bn] + @prog_scope() def _compare(self, place, layout, only_forward): """Compare results.""" seed = 10