未验证 提交 b25f25d0 编写于 作者: P pangyoki 提交者: GitHub

fix device_id bug for final_state op in multiprocess testcase (#41407)

* support final_state in multiprocess

* fix no place.device

* set device_id in eager_gen
上级 64f769d4
...@@ -194,6 +194,16 @@ FORWARD_FUNCTION_TEMPLATE = \ ...@@ -194,6 +194,16 @@ FORWARD_FUNCTION_TEMPLATE = \
// Get Input AutoGradMeta // Get Input AutoGradMeta
{} {}
// Set Device Id
auto place = egr::Controller::Instance().GetExpectedPlace();
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
// Forward API Call // Forward API Call
{} {}
// Get Outputs // Get Outputs
...@@ -284,6 +294,7 @@ FORWARD_CC_FILE_TEMPLATE = \ ...@@ -284,6 +294,7 @@ FORWARD_CC_FILE_TEMPLATE = \
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/eager/amp_utils.h" #include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h" #include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
{} {}
{} {}
......
...@@ -270,6 +270,9 @@ def monkey_patch_math_varbase(): ...@@ -270,6 +270,9 @@ def monkey_patch_math_varbase():
# 4. calculation # 4. calculation
axis = -1 axis = -1
if framework._in_eager_mode_ and op_type == 'elementwise_add':
math_op = getattr(_C_ops, 'final_state_add')
else:
math_op = getattr(_C_ops, op_type) math_op = getattr(_C_ops, op_type)
return math_op(self, other_var, 'axis', axis) return math_op(self, other_var, 'axis', axis)
......
...@@ -103,9 +103,7 @@ class TestInplace(unittest.TestCase): ...@@ -103,9 +103,7 @@ class TestInplace(unittest.TestCase):
var_b[1:2] = 3 # var_b is modified inplace before using it var_b[1:2] = 3 # var_b is modified inplace before using it
var_c = paddle.add( var_c = var_b + var_b # Here, the grad op of sum doesn't use the value of var_b
var_b,
var_b) # Here, the grad op of sum doesn't use the value of var_b
loss = var_c.sum() loss = var_c.sum()
var_b[1:2] = 3 # var_b is modified inplace after using it var_b[1:2] = 3 # var_b is modified inplace after using it
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册