未验证 提交 93ea1297 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] enable the new standalone executor by default (#41179)

* enable new executor by default

* enable stream safe allocator

* test=document_fix;test=coverage

* do not use scope in op kernel

* fit empty program for new executor

* fix communication depend

* fix test_sync_batch_norm

* skip unsupported place

* refine datatransfer

* fit for dirtributed program

* fix dependencpy

* fix some ut
上级 feaa9798
......@@ -319,6 +319,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
}
bool transfered = false;
DataTranferHelper data_transfer_helper(place, var_scope);
for (auto& var_name_item : *ins_map_temp) {
bool should_skip_input =
......@@ -334,6 +335,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>()) {
tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
} else if (var->IsType<LoDTensorArray>()) {
if (var->Get<LoDTensorArray>().size() == 0) {
continue;
}
tensor_in =
static_cast<const Tensor*>(&(var->Get<LoDTensorArray>()[0]));
} else {
......@@ -389,6 +393,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
if (is_transferred) {
transfered = true;
// update RuntimeContext.inputs and original op_func_node inputs
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
......@@ -426,11 +431,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
}
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
// with instruction. (hot fix, it is not good design here)
op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, new_outs, op_base->Attrs()));
if (transfered) {
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
// with instruction. (hot fix, it is not good design here)
op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, new_outs, op_base->Attrs()));
}
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
}
......
......@@ -300,8 +300,16 @@ void InterpreterCore::Convert(
gc_event_.emplace_back(vec_instruction_[i].DeviceContext().GetPlace(),
platform::GenerateDeviceEventFlag());
}
bool inplaced = false;
for (auto inst : vec_instruction_) {
if (inst.OpBase()->Type() == "share_buffer" ||
inst.OpBase()->Type() == "share_data") {
VLOG(4) << "Already inplaced, skip inplace now.";
inplaced = true;
}
}
if (FLAGS_new_executor_use_inplace) {
if (FLAGS_new_executor_use_inplace && !inplaced) {
BuildInplace();
}
......@@ -565,12 +573,11 @@ void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops,
std::vector<std::atomic<size_t>>* atomic_deps,
std::vector<std::atomic<size_t>>* atomic_var_ref) {
VLOG(4) << "atomic 1:" << atomic_deps;
auto& next_instr = instr.NextInstructions();
auto IsReady = [atomic_deps](size_t next_id) {
VLOG(4) << "atomic:" << atomic_deps << " " << &(*atomic_deps)[next_id]
<< " " << next_id;
VLOG(4) << "atomic:" << atomic_deps << " op_id: " << next_id
<< ", remain deps: " << (*atomic_deps)[next_id];
return (*atomic_deps)[next_id].fetch_sub(1, std::memory_order_relaxed) == 1;
};
......
......@@ -428,19 +428,19 @@ void build_op_func_list(const platform::Place& place,
op_func_node.dev_ctx_ = dev_ctx;
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
// see OperatorWithKernel::RunImpl in operator.cc for why
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);
auto run_phi_kernel = false;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(
op_with_kernel->Type())) {
......@@ -476,7 +476,6 @@ void build_op_func_list(const platform::Place& place,
op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx,
&pt_kernel_context);
op_func_node.pt_kernel_ = op_with_kernel->PhiKernel();
(*op_func_node.pt_kernel_)(&pt_kernel_context);
} else {
auto kernels_iter = all_op_kernels.find(op->Type());
......@@ -711,6 +710,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
const std::set<std::string> random_op_set = {
"bernoulli", "poisson", "multinomial", "gaussian_random",
"uniform_random", "randint", "randperm", "exponential"};
int dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
......@@ -721,6 +721,147 @@ std::map<int, std::list<int>> build_op_downstream_map(
}
}
// add dependency for communication op
const std::string communication_op_prefix = "c_";
dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type().find(
communication_op_prefix) != std::string::npos) {
if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx);
}
dependence_op_idx = op_idx;
}
}
// TODO(zhiqiu): there still some cases not handled
// add dependency for c_sync_comm_stream
// in program, we can add only one c_sync_comm_stream to sync all
// communication ops.
// c_allreduce_sum(a)
// c_allreduce_sum(b)
// c_allreduce_sum(c)
// c_sync_comm_stream(a)
const std::string kSyncComm = "c_sync_comm_stream";
dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type() == kSyncComm) {
dependence_op_idx = op_idx;
} else {
if (dependence_op_idx != -1) {
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
op2dependences[op_idx].insert(dependence_op_idx);
}
}
}
// add dependency for coalesce_tensor
const std::string kCoalesceTensor = "coalesce_tensor";
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type() == kCoalesceTensor) {
VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx;
auto fused_out = vec_instruction[op_idx].Outputs().at("FusedOutput")[0];
auto outputs = vec_instruction[op_idx].Outputs().at("Output");
auto is_read = [](const Instruction& inst, int var_id) -> bool {
for (auto pair : inst.Inputs()) {
for (auto item : pair.second) {
if (item == var_id) {
return true;
}
}
}
return false;
};
auto is_write = [](const Instruction& inst, int var_id) -> bool {
for (auto pair : inst.Outputs()) {
for (auto item : pair.second) {
if (item == var_id) {
return true;
}
}
}
return false;
};
// find first op that reads fused_out
auto first_read_fused_out_op = -1;
for (auto j = op_idx + 1; j < vec_instruction.size(); ++j) {
if (is_read(vec_instruction[j], fused_out)) {
first_read_fused_out_op = j;
break;
}
}
if (UNLIKELY(first_read_fused_out_op == -1)) {
VLOG(4) << "No op read FusedOutput";
continue;
}
// find ops that write 'outputs' between (op_index,
// first_read_fused_out_op)
// add depend: them->first_read_fused_out_op
for (auto j = op_idx + 1;
j < static_cast<size_t>(first_read_fused_out_op); ++j) {
for (auto var_id : outputs) {
if (is_write(vec_instruction[j], var_id)) {
op2dependences[first_read_fused_out_op].insert(j);
VLOG(4) << j << " -> " << first_read_fused_out_op;
VLOG(4)
<< "Add depend from " << vec_instruction[j].OpBase()->Type()
<< " to "
<< vec_instruction[first_read_fused_out_op].OpBase()->Type();
}
}
}
// find first op read 'outputs' between (first_read_fused_out_op, end)
// add depned: first_read_fused_out_op -> first op that reads 'outputs'
// special case for consecutive communication ops, for example,
// FusedOutput = c_sync_calc_stream(FusedOutput)
// FusedOutput= c_allreduce_sum(FusedOutput)
// FusedOutput = c_sync_comm_stream(FusedOutput)
// we should take the last one to add depned instead of
// 'first_read_fused_out_op'
size_t target = first_read_fused_out_op;
for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
++j) {
if (j == target + 1 &&
vec_instruction[target].OpBase()->Type().find(
communication_op_prefix) != std::string::npos &&
vec_instruction[j].OpBase()->Type().find(communication_op_prefix) !=
std::string::npos) {
VLOG(4) << "Found consecutive communication ops, "
<< vec_instruction[target].OpBase()->Type() << " -> "
<< vec_instruction[j].OpBase()->Type();
target = j;
continue;
}
for (auto var_id : outputs) {
if (is_read(vec_instruction[j], var_id)) {
op2dependences[j].insert(target);
VLOG(4) << target << " -> " << j;
VLOG(4) << "Add depend from "
<< vec_instruction[target].OpBase()->Type() << " to "
<< vec_instruction[j].OpBase()->Type();
}
}
}
}
}
for (auto pair : op2dependences) {
VLOG(10) << pair.first << " Depends on " << pair.second.size();
std::ostringstream oss;
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
VLOG(10) << oss.str();
}
return std::move(get_downstream_map(op2dependences));
}
......
......@@ -85,7 +85,7 @@ PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false,
// NOTE(Ruibiao): This FLAGS is just to be compatibled with
// the old single-stream CUDA allocator. It will be removed
// after StreamSafeCudaAllocator has been fully tested.
PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, false,
PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, true,
"Enable StreamSafeCUDAAllocator");
PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, false,
......
......@@ -394,19 +394,10 @@ def _is_enable_standalone_executor():
Whether to use experimental executor `StandaloneExecutor`.
"""
flag = False
# NOTE(zhiqiu): enable STANDALONE_EXECUTOR on windows platform by default
# It should be enabled on all platform in the future.
import platform
sysstr = platform.system().lower()
if sysstr == 'windows':
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', 1)
else:
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', None)
env_val = os.environ.get('FLAGS_USE_STANDALONE_EXECUTOR', '1')
if env_val in [1, '1', True, 'True', 'true']:
flag = True
warnings.warn("STANDALONE_EXECUTOR is enabled.")
return flag
......@@ -1386,6 +1377,10 @@ class Executor(object):
program = pruned_program
def _can_use_interpreter_core(program, place):
if core.is_compiled_with_npu() or core.is_compiled_with_xpu(
) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu():
return False
compiled = isinstance(program, compiler.CompiledProgram)
# NOTE(zhiqiu): do not support compiled program now
if compiled:
......@@ -1396,6 +1391,8 @@ class Executor(object):
# else:
# return False
else:
if isinstance(program._graph, compiler.CompiledProgram):
return False
assert isinstance(program, Program)
return True
......
......@@ -951,7 +951,7 @@ endif()
if (WITH_DISTRIBUTE AND NOT APPLE)
if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 160)
set_tests_properties(test_dist_mnist_gradient_merge PROPERTIES TIMEOUT 360)
endif()
endif()
......@@ -1033,7 +1033,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES
set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_argsort_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120)
......@@ -1072,7 +1072,7 @@ set_tests_properties(test_space_to_depth_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_dyn_rnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_sgd_op PROPERTIES TIMEOUT 250)
set_tests_properties(test_parallel_executor_seresnext_base_gpu PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 120 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
set_tests_properties(test_matrix_nms_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_generator_dataloader PROPERTIES TIMEOUT 120)
set_tests_properties(test_partial_concat_op PROPERTIES TIMEOUT 120)
......
......@@ -48,6 +48,7 @@ class TestCollectiveReduceScatter(TestCollectiveRunnerBase):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks)
toutdata = fluid.layers.collective._c_sync_comm_stream(toutdata, 0)
return toutdata
......
......@@ -32,7 +32,7 @@ def prepare_python_path_and_return_module(path):
assert filename.endswith(py_suffix), filename
env_name = 'PYTHONPATH'
python_path = env_name
python_path = os.environ.get(env_name, '')
if python_path:
paths = [p for p in python_path.split(":") if p]
if dirname not in paths:
......@@ -41,6 +41,7 @@ def prepare_python_path_and_return_module(path):
else:
python_path = path
os.environ[env_name] = python_path
print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1)
return filename[:-len(py_suffix)]
......
......@@ -91,7 +91,7 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30)
set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60 ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120)
......
......@@ -27,6 +27,7 @@ paddle.enable_static()
class TestSliceOpDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
self.config()
......
......@@ -30,6 +30,7 @@ from paddle.fluid import compiler
from paddle.fluid import Program, program_guard
from op_test import OpTest, _set_use_system_allocator
from decorator_helper import prog_scope
_set_use_system_allocator(True)
......@@ -105,6 +106,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
sgd_opt.backward(out)
return main, startup, [out, conv, bn]
@prog_scope()
def _compare(self, place, layout, only_forward):
"""Compare results."""
seed = 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册