未验证 提交 9f3d9381 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix SetDeviceId in eager_final_state_api from python_c_gen.py (#42025) (#42067)

上级 b3d608e2
......@@ -100,6 +100,9 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
{}
tstate = PyEval_SaveThread();
// Set Device ID
{}
auto out = {}({});
......@@ -118,6 +121,19 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
"""
FUNCTION_SET_DEVICE_TEMPLATE = \
"""
{}
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
VLOG(1) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device;
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
"""
FUNCTION_NAME_TEMPLATE = \
"{}{}{}"
......@@ -293,14 +309,23 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
"false")
parse_attributes_str = ""
expected_place_str = "auto place = egr::Controller::Instance().GetExpectedPlace();\n"
# Generate Python-C Attributes Parsing Logic
for name, atype, _, pos in orig_forward_attrs_list:
parsing_function_name = FindParsingFunctionFromAttributeType(atype)
# Used input argument place if specified from Python frontend.
if len(expected_place_str
) != 0 and parsing_function_name == "CastPyArg2Place":
expected_place_str = ""
assert name == "place", "Only support 'place' as template argument name in FUNCTION_SET_DEVICE_TEMPLATE."
parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format(
name, pos, atype, name, parsing_function_name, name,
forward_api_name, pos)
set_device_str = FUNCTION_SET_DEVICE_TEMPLATE.format(expected_place_str)
# Generate Dygraph Function Call Logic
num_args = len(forward_inputs_position_map.keys()) + len(
orig_forward_attrs_list)
......@@ -326,8 +351,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, fwd_function_name,
dygraph_function_call_str, return_str)
get_eager_tensor_str, parse_attributes_str, set_device_str,
fwd_function_name, dygraph_function_call_str, return_str)
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
......@@ -361,8 +386,9 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, inplaced_fwd_function_name,
dygraph_function_call_str, return_str)
parse_attributes_str, set_device_str,
inplaced_fwd_function_name, dygraph_function_call_str,
return_str)
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册