diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ae51a53a7197950338ef773d63103fa13bf0a5f5..ba8eccf82042b679f69a32f9d053f05ac8fb9a99 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -107,7 +107,7 @@ struct SelectedRowsAddTensor { PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); auto& in1_value = input1.value(); - framework::Vector in1_rows(input1.rows()); + auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); @@ -206,7 +206,7 @@ struct SelectedRowsAddToTensor { PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); auto& in1_value = input1.value(); - framework::Vector in1_rows(input1.rows()); + auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index e89b27855bdeba3a5189feff94eb063ddfb9b9b8..5fc50aba25d8e69480a17f0f80877b0d03e17276 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -20,7 +20,9 @@ limitations under the License. */ TEST(selected_rows_functor, gpu_add) { paddle::platform::CUDAPlace gpu_place(0); paddle::platform::CPUPlace cpu_place; - paddle::platform::CUDADeviceContext ctx(gpu_place); + paddle::platform::CUDADeviceContext& ctx = + *reinterpret_cast( + paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); paddle::operators::math::SetConstant functor; @@ -132,7 +134,9 @@ TEST(selected_rows_functor, gpu_add) { TEST(selected_rows_functor, gpu_add_to) { paddle::platform::CUDAPlace gpu_place(0); paddle::platform::CPUPlace cpu_place; - paddle::platform::CUDADeviceContext ctx(gpu_place); + paddle::platform::CUDADeviceContext& ctx = + *reinterpret_cast( + paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); paddle::operators::math::SetConstant functor;