From 94ffda577c9f2cae2bd2460b9fa849d488d118ea Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 21 Apr 2022 09:50:58 +0800 Subject: [PATCH] [Eager]Fix SetDeviceId in eager_final_state_api from python_c_gen.py (#42025) --- .../final_state_generator/python_c_gen.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index e2bb4104551..7ca5fc833ea 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -100,6 +100,9 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj {} tstate = PyEval_SaveThread(); + + // Set Device ID +{} auto out = {}({}); @@ -118,6 +121,19 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj """ +FUNCTION_SET_DEVICE_TEMPLATE = \ +""" + {} + if (paddle::platform::is_gpu_place(place)) {{ +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::backends::gpu::SetDeviceId(place.device); + VLOG(1) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device; +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU if use CUDAPlace.")); +#endif + }} +""" FUNCTION_NAME_TEMPLATE = \ "{}{}{}" @@ -293,14 +309,23 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): "false") parse_attributes_str = "" + expected_place_str = "auto place = egr::Controller::Instance().GetExpectedPlace();\n" # Generate Python-C Attributes Parsing Logic for name, atype, _, pos in orig_forward_attrs_list: parsing_function_name = FindParsingFunctionFromAttributeType(atype) + # Used input argument place if specified from Python frontend. + if len(expected_place_str + ) != 0 and parsing_function_name == "CastPyArg2Place": + expected_place_str = "" + assert name == "place", "Only support 'place' as template argument name in FUNCTION_SET_DEVICE_TEMPLATE." + parse_attributes_str += PARSE_PYTHON_C_ARGS_TEMPLATE.format( name, pos, atype, name, parsing_function_name, name, forward_api_name, pos) + set_device_str = FUNCTION_SET_DEVICE_TEMPLATE.format(expected_place_str) + # Generate Dygraph Function Call Logic num_args = len(forward_inputs_position_map.keys()) + len( orig_forward_attrs_list) @@ -326,8 +351,8 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): "pythonc_record_event", forward_api_name, "pybind_imperative_func") self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( forward_api_name, pythonc_record_event_str, forward_api_name, - get_eager_tensor_str, parse_attributes_str, fwd_function_name, - dygraph_function_call_str, return_str) + get_eager_tensor_str, parse_attributes_str, set_device_str, + fwd_function_name, dygraph_function_call_str, return_str) # Set prefix of forward_api_name to avoid conflicts prefix = self.namespace.strip("::") @@ -361,8 +386,9 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, get_eager_tensor_str, - parse_attributes_str, inplaced_fwd_function_name, - dygraph_function_call_str, return_str) + parse_attributes_str, set_device_str, + inplaced_fwd_function_name, dygraph_function_call_str, + return_str) # Generate Python-C Function Registration self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format( -- GitLab