未验证 提交 32b3469a 编写于 作者: C Chen Weihang 提交者: GitHub

fix tensor stream error in custom op (#44500)

上级 32c97a9d
......@@ -4,17 +4,17 @@ if(WITH_GPU)
nv_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
elseif(WITH_ROCM)
hip_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
else()
cc_library(
phi_tensor_raw
SRCS tensor.cc
DEPS tensor_base dense_tensor phi_api_utils phi_enforce)
DEPS tensor_base dense_tensor phi_api_utils phi_enforce context_pool)
endif()
set(api_gen_base ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/api_base.py)
......
......@@ -21,7 +21,9 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
......@@ -33,8 +35,6 @@ limitations under the License. */
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
// clang-format off
namespace paddle {
......@@ -311,7 +311,10 @@ void Tensor::set_impl(std::shared_ptr<phi::TensorBase> &&impl) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t Tensor::stream() const {
return platform::stream::get_current_stream(-1)->raw_stream();
int device_id = phi::backends::gpu::GetCurrentDeviceId();
auto* gpu_context = DeviceContextPool::Instance()
.Get<AllocationType::GPU>(GPUPlace(device_id));
return gpu_context->stream();
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册