未验证 提交 0f1fde51 编写于 作者: L Leo Chen 提交者: GitHub

fix the modification of set_expected_place (#31177)

* revert the modification of set_expected_place

* set device before op run

* add ut
上级 dc8dfba3
...@@ -72,6 +72,13 @@ TEST(test_tracer, test_trace_op) { ...@@ -72,6 +72,13 @@ TEST(test_tracer, test_trace_op) {
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
#ifndef PADDLE_WITH_XPU
ASSERT_THROW(tracer.TraceOp("mul", ins, outs, mul_attr_map,
platform::XPUPlace(0), true);
, platform::EnforceNotMet);
#endif
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>(); const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) { for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0); ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
...@@ -311,10 +318,6 @@ TEST(test_tracer, test_expected_place) { ...@@ -311,10 +318,6 @@ TEST(test_tracer, test_expected_place) {
platform::CUDAPlace gpu_place(0); platform::CUDAPlace gpu_place(0);
tracer.SetExpectedPlace(gpu_place); tracer.SetExpectedPlace(gpu_place);
ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true); ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true);
// assert throw
platform::XPUPlace xpu_place(0);
ASSERT_THROW(tracer.SetExpectedPlace(xpu_place), platform::EnforceNotMet);
#endif #endif
} }
{ {
...@@ -323,10 +326,6 @@ TEST(test_tracer, test_expected_place) { ...@@ -323,10 +326,6 @@ TEST(test_tracer, test_expected_place) {
platform::XPUPlace xpu_place(0); platform::XPUPlace xpu_place(0);
tracer.SetExpectedPlace(xpu_place); tracer.SetExpectedPlace(xpu_place);
ASSERT_EQ(platform::is_xpu_place(tracer.ExpectedPlace()), true); ASSERT_EQ(platform::is_xpu_place(tracer.ExpectedPlace()), true);
// assert throw
platform::CUDAPlace cuda_place(0);
ASSERT_THROW(tracer.SetExpectedPlace(cuda_place), platform::EnforceNotMet);
#endif #endif
} }
} }
......
...@@ -162,6 +162,23 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -162,6 +162,23 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} }
try { try {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
platform::SetXPUDeviceId(
BOOST_GET_CONST(platform::XPUPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
}
OpBase::Run(*op, new_ins, outs, attrs, place); OpBase::Run(*op, new_ins, outs, attrs, place);
} catch (platform::EnforceNotMet& exception) { } catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(type, &exception); framework::AppendErrorOpHint(type, &exception);
...@@ -199,22 +216,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -199,22 +216,6 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} }
void Tracer::SetExpectedPlace(platform::Place place) { void Tracer::SetExpectedPlace(platform::Place place) {
// NOTE(wangxi): set device id before launch device kernel
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::SetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
platform::SetXPUDeviceId(BOOST_GET_CONST(platform::XPUPlace, place).device);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
}
expected_place_ = place; expected_place_ = place;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册