未验证 提交 32b3469a 编写于 作者: C Chen Weihang 提交者: GitHub

fix tensor stream error in custom op (#44500)

上级 32c97a9d
...@@ -4,17 +4,17 @@ if(WITH_GPU) ...@@ -4,17 +4,17 @@ if(WITH_GPU)
nv_library( nv_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
elseif(WITH_ROCM) elseif(WITH_ROCM)
hip_library( hip_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
else() else()
cc_library( cc_library(
phi_tensor_raw phi_tensor_raw
SRCS tensor.cc SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce) DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
endif() endif()
set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py) set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py)
......
...@@ -21,7 +21,9 @@ limitations under the License. */ ...@@ -21,7 +21,9 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -33,8 +35,6 @@ limitations under the License. */ ...@@ -33,8 +35,6 @@ limitations under the License. */
#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
// clang-format off // clang-format off
namespace paddle { namespace paddle {
...@@ -311,7 +311,10 @@ void Tensor::set_impl(std::shared_ptr<phi::TensorBase> &&impl) { ...@@ -311,7 +311,10 @@ void Tensor::set_impl(std::shared_ptr<phi::TensorBase> &&impl) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const { gpuStream_t Tensor::stream() const {
return platform::stream::get_current_stream(-1)->raw_stream(); int device_id = phi::backends::gpu::GetCurrentDeviceId();
auto* gpu_context = DeviceContextPool::Instance()
.Get<AllocationType::GPU>(GPUPlace(device_id));
return gpu_context->stream();
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册