Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
281ea2f4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
281ea2f4
编写于
4月 14, 2023
作者:
U
umiswing
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dcu]: Add rocsparse_spmm for dcu. (#52200)
上级
5fbcf37d
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
772 addition
and
19 deletion
+772
-19
paddle/fluid/platform/dynload/rocsparse.cc
paddle/fluid/platform/dynload/rocsparse.cc
+37
-0
paddle/fluid/platform/dynload/rocsparse.h
paddle/fluid/platform/dynload/rocsparse.h
+75
-0
paddle/phi/backends/dynload/CMakeLists.txt
paddle/phi/backends/dynload/CMakeLists.txt
+8
-1
paddle/phi/backends/dynload/dynamic_loader.cc
paddle/phi/backends/dynload/dynamic_loader.cc
+2
-0
paddle/phi/backends/dynload/rocsparse.cc
paddle/phi/backends/dynload/rocsparse.cc
+37
-0
paddle/phi/backends/dynload/rocsparse.h
paddle/phi/backends/dynload/rocsparse.h
+86
-0
paddle/phi/backends/gpu/gpu_resources.cc
paddle/phi/backends/gpu/gpu_resources.cc
+12
-0
paddle/phi/kernels/funcs/sparse/sparse_blas.h
paddle/phi/kernels/funcs/sparse/sparse_blas.h
+3
-0
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h
+405
-0
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
+37
-2
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
+13
-1
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
+57
-15
未找到文件。
paddle/fluid/platform/dynload/rocsparse.cc
0 → 100644
浏览文件 @
281ea2f4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/dynload/rocsparse.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
#define DEFINE_WRAP(__name) DynLoad__##__name __name
#ifdef ROCSPARSE_ROUTIN_EACH
ROCSPARSE_ROUTINE_EACH
(
DEFINE_WRAP
);
#endif
#ifdef ROCSPARSE_ROUTINE_EACH_R2
ROCSPARSE_ROUTINE_EACH_R2
(
DEFINE_WRAP
);
#endif
#ifdef ROCSPARSE_ROUTINE_EACH_R3
ROCSPARSE_ROUTINE_EACH_R3
(
DEFINE_WRAP
);
#endif
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/rocsparse.h
0 → 100644
浏览文件 @
281ea2f4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hip/hip_runtime.h>
#include <rocsparse.h>
#include <mutex> // NOLINT
#include <type_traits>
#include "paddle/phi/backends/dynload/rocsparse.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
/**
* The following macro definition can generate structs
* (for each function) to dynamic load rocsparse routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#define PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP(__name) \
using DynLoad__##__name = phi::dynload::DynLoad__##__name; \
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_HIP)
#define ROCSPARSE_ROUTINE_EACH(__macro) \
__macro(rocsparse_create_handle); \
__macro(rocsparse_destroy_handle); \
__macro(rocsparse_set_stream); \
__macro(rocsparse_csr2coo);
ROCSPARSE_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#if HIP_VERSION >= 402
#define ROCSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(rocsparse_create_coo_descr); \
__macro(rocsparse_create_csr_descr); \
__macro(rocsparse_destroy_spmat_descr); \
__macro(rocsparse_create_dnmat_descr); \
__macro(rocsparse_destroy_dnmat_descr); \
__macro(rocsparse_spmm);
ROCSPARSE_ROUTINE_EACH_R2
(
PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#endif
#if HIP_VERSION >= 403
#define ROCSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(rocsparse_sddmm_buffer_size); \
__macro(rocsparse_sddmm_preprocess); \
__macro(rocsparse_sddmm);
ROCSPARSE_ROUTINE_EACH_R3
(
PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#endif
#endif // PADDLE_WITH_HIP
#undef PLATFORM_DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/phi/backends/dynload/CMakeLists.txt
浏览文件 @
281ea2f4
...
...
@@ -20,7 +20,14 @@ if(NOT WITH_NV_JETSON)
endif
()
if
(
WITH_ROCM
)
list
(
APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc hipfft.cc
)
list
(
APPEND
HIP_SRCS
rocblas.cc
miopen.cc
hiprand.cc
hipfft.cc
rocsparse.cc
)
endif
()
# There is no macOS version of NCCL.
...
...
paddle/phi/backends/dynload/dynamic_loader.cc
浏览文件 @
281ea2f4
...
...
@@ -427,6 +427,8 @@ void* GetCusparseDsoHandle() {
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
win_cusparse_lib
,
true
,
{
cuda_lib_path
});
#elif defined(PADDLE_WITH_HIP)
return
GetDsoHandleFromSearchPath
(
FLAGS_rocm_dir
,
"librocsparse.so"
);
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcusparse.so"
);
#endif
...
...
paddle/phi/backends/dynload/rocsparse.cc
0 → 100644
浏览文件 @
281ea2f4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/dynload/rocsparse.h"
namespace
phi
{
namespace
dynload
{
std
::
once_flag
rocsparse_dso_flag
;
void
*
rocsparse_dso_handle
=
nullptr
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
#ifdef ROCSPARSE_ROUTINE_EACH
ROCSPARSE_ROUTINE_EACH
(
DEFINE_WRAP
)
#endif
#ifdef ROCSPARSE_ROUTINE_EACH_R2
ROCSPARSE_ROUTINE_EACH_R2
(
DEFINE_WRAP
);
#endif
#ifdef ROCSPARSE_ROUTINE_EACH_R3
ROCSPARSE_ROUTINE_EACH_R3
(
DEFINE_WRAP
);
#endif
}
// namespace dynload
}
// namespace phi
paddle/phi/backends/dynload/rocsparse.h
0 → 100644
浏览文件 @
281ea2f4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hip/hip_runtime.h>
#include <rocsparse.h>
#include <mutex> // NOLINT
#include <type_traits>
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
namespace
phi
{
namespace
dynload
{
extern
std
::
once_flag
rocsparse_dso_flag
;
extern
void
*
rocsparse_dso_handle
;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load rocsparse routine
* via operator overloading.
*
* note: default dynamic linked libs
*/
#define DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
rocsparse_status operator()(Args... args) { \
using rocsparse_func = decltype(&::__name); \
std::call_once(rocsparse_dso_flag, []() { \
rocsparse_dso_handle = phi::dynload::GetCusparseDsoHandle(); \
}); \
static void *p_##__name = dlsym(rocsparse_dso_handle, #__name); \
return reinterpret_cast<rocsparse_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#if defined(PADDLE_WITH_HIP)
#define ROCSPARSE_ROUTINE_EACH(__macro) \
__macro(rocsparse_create_handle); \
__macro(rocsparse_destroy_handle); \
__macro(rocsparse_set_stream); \
__macro(rocsparse_csr2coo);
ROCSPARSE_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#if HIP_VERSION >= 402
#define ROCSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(rocsparse_create_coo_descr); \
__macro(rocsparse_create_csr_descr); \
__macro(rocsparse_destroy_spmat_descr); \
__macro(rocsparse_create_dnmat_descr); \
__macro(rocsparse_destroy_dnmat_descr); \
__macro(rocsparse_spmm);
ROCSPARSE_ROUTINE_EACH_R2
(
DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#endif
#if HIP_VERSION >= 403
#define ROCSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(rocsparse_sddmm_buffer_size); \
__macro(rocsparse_sddmm_preprocess); \
__macro(rocsparse_sddmm);
ROCSPARSE_ROUTINE_EACH_R3
(
DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
)
#endif
#endif // PADDLE_WITH_HIP
#undef DECLARE_DYNAMIC_LOAD_ROCSPARSE_WRAP
}
// namespace dynload
}
// namespace phi
paddle/phi/backends/gpu/gpu_resources.cc
浏览文件 @
281ea2f4
...
...
@@ -33,6 +33,10 @@
#endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/rocsparse.h"
#endif
#include "glog/logging.h"
#include "unsupported/Eigen/CXX11/Tensor"
...
...
@@ -295,6 +299,9 @@ void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) {
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseCreate
(
handle
));
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseSetStream
(
*
handle
,
stream
));
#endif
#elif defined(PADDLE_WITH_HIP)
phi
::
dynload
::
rocsparse_create_handle
(
handle
);
phi
::
dynload
::
rocsparse_set_stream
(
*
handle
,
stream
);
#endif
}
...
...
@@ -306,6 +313,11 @@ void DestroySparseHandle(sparseHandle_t handle) {
handle
=
nullptr
;
}
#endif
#elif defined(PADDLE_WITH_HIP)
if
(
handle
!=
nullptr
)
{
phi
::
dynload
::
rocsparse_destroy_handle
(
handle
);
handle
=
nullptr
;
}
#endif
}
...
...
paddle/phi/kernels/funcs/sparse/sparse_blas.h
浏览文件 @
281ea2f4
...
...
@@ -97,3 +97,6 @@ inline SparseBlasT<DeviceContext, T> GetSparseBlas(
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000
#include "paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h"
#endif
#if defined(PADDLE_WITH_HIP) && HIP_VERSION >= 402
#include "paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h"
#endif
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.hip.h
0 → 100644
浏览文件 @
281ea2f4
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/dynload/rocsparse.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/visit_type.h"
namespace
phi
{
namespace
funcs
{
namespace
sparse
{
template
<
typename
IntT
>
rocsparse_indextype
GetGpuIndexType
()
{
if
(
std
::
is_same
<
IntT
,
int32_t
>::
value
)
{
return
rocsparse_indextype_i32
;
}
else
if
(
std
::
is_same
<
IntT
,
int64_t
>::
value
)
{
return
rocsparse_indextype_i64
;
}
}
template
<
typename
T
>
rocsparse_datatype
GetGpuDataType
()
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
rocsparse_datatype_f32_r
;
}
else
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
rocsparse_datatype_f64_r
;
}
}
inline
rocsparse_operation
GetTransposeOperation
(
const
bool
trans
)
{
if
(
trans
)
{
return
rocsparse_operation_transpose
;
}
else
{
return
rocsparse_operation_none
;
}
}
template
<
typename
TensorType
>
inline
rocsparse_spmm_alg
GetSpMMAlgorithm
(
const
TensorType
&
x
)
{
return
rocsparse_spmm_alg_default
;
}
/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/
template
<
typename
T
,
typename
IntT
>
inline
void
CreateCsrDescriptor
(
const
phi
::
SparseCsrTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
,
rocsparse_spmat_descr
*
descriptor
)
{
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
PADDLE_ENFORCE_GE
(
x_ndims
,
2
,
phi
::
errors
::
InvalidArgument
(
"the dim size of SparseCsrTensor must be "
"greater than or eaqual to 2."
));
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
x_ndims
-
2
;
i
++
)
{
batch_size
*=
xdim_vec
[
i
];
}
PADDLE_ENFORCE_EQ
(
x
.
non_zero_crows
().
numel
(),
batch_size
*
(
M
+
1
),
phi
::
errors
::
PreconditionNotMet
(
"the length of SparseCsrTensor crows is not right."
));
const
IntT
*
crows_data
=
x
.
non_zero_crows
().
data
<
IntT
>
();
const
IntT
*
cols_data
=
x
.
non_zero_cols
().
data
<
IntT
>
();
const
T
*
values_data
=
x
.
non_zero_elements
().
data
<
T
>
();
int64_t
batch_nnz
=
x
.
nnz
()
/
batch_size
;
rocsparse_indextype
itype
=
GetGpuIndexType
<
int64_t
>
();
rocsparse_indextype
jtype
=
GetGpuIndexType
<
int64_t
>
();
rocsparse_datatype
ttype
=
GetGpuDataType
<
T
>
();
dev_ctx
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_create_csr_descr
(
descriptor
,
M
,
N
,
batch_nnz
,
const_cast
<
IntT
*>
(
crows_data
),
const_cast
<
IntT
*>
(
cols_data
),
const_cast
<
T
*>
(
values_data
),
itype
,
jtype
,
rocsparse_index_base_zero
,
ttype
);
});
if
(
batch_size
>
1
)
{
// TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Batch Sparse matmul use 'rocsparse_coo_set_strided_batch', which is "
"supported from ROCM 5.2.0"
));
}
}
template
<
typename
T
,
typename
IntT
>
inline
void
CreateCooDescriptor
(
const
phi
::
SparseCooTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
,
rocsparse_spmat_descr
*
descriptor
)
{
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
PADDLE_ENFORCE_GE
(
x_ndims
,
2
,
phi
::
errors
::
InvalidArgument
(
"the dim size of SparseCooTensor must be "
"greater than or eaqual to 2."
));
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
x_ndims
-
2
;
i
++
)
{
batch_size
*=
xdim_vec
[
i
];
}
int64_t
nnz
=
x
.
nnz
();
const
IntT
*
indices_data
=
x
.
non_zero_indices
().
data
<
IntT
>
();
const
T
*
values_data
=
x
.
non_zero_elements
().
data
<
T
>
();
auto
rows_data
=
indices_data
+
(
x_ndims
-
2
)
*
nnz
;
auto
cols_data
=
indices_data
+
(
x_ndims
-
1
)
*
nnz
;
int64_t
batch_nnz
=
nnz
/
batch_size
;
rocsparse_indextype
itype
=
GetGpuIndexType
<
int64_t
>
();
rocsparse_datatype
ttype
=
GetGpuDataType
<
T
>
();
dev_ctx
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_create_coo_descr
(
descriptor
,
M
,
N
,
batch_nnz
,
const_cast
<
IntT
*>
(
rows_data
),
const_cast
<
IntT
*>
(
cols_data
),
const_cast
<
T
*>
(
values_data
),
itype
,
rocsparse_index_base_zero
,
ttype
);
});
if
(
batch_size
>
1
)
{
// TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Batch Sparse matmul use 'rocsparse_coo_set_strided_batch', which is "
"supported from ROCM 5.2.0"
));
}
}
template
<
typename
T
>
class
RocSparseSpMatDescriptor
{
public:
explicit
RocSparseSpMatDescriptor
(
const
phi
::
SparseCsrTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"Csr RocSparseSpMatDescriptor"
,
([
&
]
{
CreateCsrDescriptor
<
T
,
data_t
>
(
x
,
dev_ctx_
,
&
descriptor_
);
}));
VLOG
(
6
)
<<
"Create csr rocsparse_spmat_descr "
<<
&
descriptor_
;
}
explicit
RocSparseSpMatDescriptor
(
const
phi
::
SparseCooTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"Coo RocSparseSpMatDescriptor"
,
([
&
]
{
CreateCooDescriptor
<
T
,
data_t
>
(
x
,
dev_ctx_
,
&
descriptor_
);
}));
VLOG
(
6
)
<<
"Create coo rocsparse_spmat_descr "
<<
&
descriptor_
;
}
~
RocSparseSpMatDescriptor
()
{
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_destroy_spmat_descr
(
descriptor_
);
});
VLOG
(
6
)
<<
"Destroy roscparse_spmat_descr "
<<
&
descriptor_
;
}
const
rocsparse_spmat_descr
&
descriptor
()
const
{
return
descriptor_
;
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
rocsparse_spmat_descr
descriptor_
;
};
/************* DENSE MATRIX DESCRIPTOR ************/
template
<
typename
T
>
class
RocSparseDnMatDescriptor
{
public:
explicit
RocSparseDnMatDescriptor
(
const
phi
::
DenseTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
PADDLE_ENFORCE_GE
(
x_ndims
,
2
,
phi
::
errors
::
InvalidArgument
(
"the dim size of DenseTensor must be "
"greater than or eaqual to 2."
));
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
x_ndims
-
2
;
i
++
)
{
batch_size
*=
xdim_vec
[
i
];
}
const
T
*
x_data
=
x
.
data
<
T
>
();
rocsparse_datatype
ttype
=
GetGpuDataType
<
T
>
();
dev_ctx
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_create_dnmat_descr
(
&
descriptor_
,
M
,
N
,
N
,
const_cast
<
T
*>
(
x_data
),
ttype
,
rocsparse_order_row
);
});
PADDLE_ENFORCE_EQ
(
x
.
numel
(),
batch_size
*
M
*
N
,
phi
::
errors
::
InvalidArgument
(
"The number of elements in DenseTensor "
"must equals to batch_size * M * N."
));
if
(
batch_size
>
1
)
{
// TODO(umiswing): Add batch sparse matmul support for ROCM after 5.2.0
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Batch Sparse matmul use 'rocsparse_dnmat_set_strided_batch', which "
"is supported from ROCM 5.2.0"
));
}
VLOG
(
6
)
<<
"Create cusparseDnMatDescr_t "
<<
&
descriptor_
;
}
~
RocSparseDnMatDescriptor
()
{
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_destroy_dnmat_descr
(
descriptor_
);
});
VLOG
(
6
)
<<
"Destroy rocsparse_dnmat_descr "
<<
&
descriptor_
;
}
const
rocsparse_dnmat_descr
&
descriptor
()
const
{
return
descriptor_
;
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
rocsparse_dnmat_descr
descriptor_
;
};
/************* SPARSE*DENSE->DENSE MATMUL ************/
template
<
>
template
<
typename
T
,
typename
TensorType
>
void
SparseBlas
<
phi
::
GPUContext
>::
SPMM
(
bool
transa
,
bool
transb
,
T
alpha
,
const
TensorType
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
phi
::
DenseTensor
*
mat_out
)
const
{
auto
a_descriptor
=
RocSparseSpMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
b_descriptor
=
RocSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
out_descriptor
=
RocSparseDnMatDescriptor
<
T
>
(
*
mat_out
,
dev_ctx_
);
rocsparse_datatype
ttype
=
GetGpuDataType
<
T
>
();
size_t
buffer_size
=
0
;
// Query SpMM buffer
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_spmm
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
ttype
,
GetSpMMAlgorithm
(
mat_a
),
rocsparse_spmm_stage_buffer_size
,
&
buffer_size
,
nullptr
);
});
// Allocate buffer
phi
::
Allocator
::
AllocationPtr
tmp_buffer
=
phi
::
memory_utils
::
Alloc
(
dev_ctx_
.
GetPlace
(),
buffer_size
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx_
.
stream
())));
void
*
tmp_buffer_ptr
=
tmp_buffer
->
ptr
();
// Preprocess data
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_spmm
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
ttype
,
GetSpMMAlgorithm
(
mat_a
),
rocsparse_spmm_stage_preprocess
,
&
buffer_size
,
tmp_buffer_ptr
);
});
// Performs the actual SpMM computation
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_spmm
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
ttype
,
GetSpMMAlgorithm
(
mat_a
),
rocsparse_spmm_stage_compute
,
&
buffer_size
,
tmp_buffer_ptr
);
});
}
/************* DENSE*DENSE->SPARSE MATMUL ************/
#if HIP_VERSION >= 403
template
<
>
template
<
typename
T
,
typename
TensorType
>
void
SparseBlas
<
phi
::
GPUContext
>::
SDDMM
(
bool
transa
,
bool
transb
,
T
alpha
,
const
phi
::
DenseTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
TensorType
*
mat_out
)
const
{
auto
a_descriptor
=
RocSparseDnMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
b_descriptor
=
RocSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
out_descriptor
=
RocSparseSpMatDescriptor
<
T
>
(
*
mat_out
,
dev_ctx_
);
rocsparse_datatype
gpu_type
=
GetGpuDataType
<
T
>
();
size_t
buffer_size
=
0
;
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_sddmm_buffer_size
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
gpu_type
,
rocsparse_sddmm_alg_default
,
&
buffer_size
);
});
phi
::
Allocator
::
AllocationPtr
tmp_buffer
=
phi
::
memory_utils
::
Alloc
(
dev_ctx_
.
GetPlace
(),
buffer_size
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx_
.
stream
())));
void
*
tmp_buffer_ptr
=
tmp_buffer
->
ptr
();
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_sddmm_preprocess
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
gpu_type
,
rocsparse_sddmm_alg_default
,
tmp_buffer_ptr
);
});
dev_ctx_
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_sddmm
(
handle
,
GetTransposeOperation
(
transa
),
GetTransposeOperation
(
transb
),
&
alpha
,
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
out_descriptor
.
descriptor
(),
gpu_type
,
rocsparse_sddmm_alg_default
,
tmp_buffer_ptr
);
});
}
#endif
}
// namespace sparse
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
浏览文件 @
281ea2f4
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
...
...
@@ -35,7 +36,7 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx,
const
DenseTensor
&
dout
,
SparseCooTensor
*
dx
,
DenseTensor
*
dy
)
{
#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11030
|| HIP_VERSION >= 403
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
// dx{SparseCoo} = dout{Dense} * y'{Dense}
...
...
@@ -44,8 +45,13 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx,
// which will increase some expenses.
EmptyLikeCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
dx
);
SparseCsrTensor
dx_csr
=
CooToCsr
<
T
,
Context
>
(
dev_ctx
,
*
dx
);
#ifdef PADDLE_WITH_HIP
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
dx_csr
.
mutable_non_zero_elements
(),
static_cast
<
T
>
(
0.0
f
));
#endif
sparse_blas
.
SDDMM
(
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
&
dx_csr
);
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
dx_csr
,
dx
);
}
...
...
@@ -56,13 +62,29 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx,
meta_dy
.
set_dtype
(
y
.
dtype
());
dev_ctx
.
template
Alloc
<
T
>(
dy
);
#ifdef PADDLE_WITH_HIP
SparseCsrTensor
x_csr
=
CooToCsr
<
T
,
Context
>
(
dev_ctx
,
x
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
dy
,
static_cast
<
T
>
(
0.0
f
));
sparse_blas
.
SPMM
(
true
,
false
,
static_cast
<
T
>
(
1
),
x_csr
,
dout
,
static_cast
<
T
>
(
0
),
dy
);
#elif defined(PADDLE_WITH_CUDA)
sparse_blas
.
SPMM
(
true
,
false
,
static_cast
<
T
>
(
1
),
x
,
dout
,
static_cast
<
T
>
(
0
),
dy
);
#endif
}
#else
#ifdef PADDLE_WITH_CUDA
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"backward of 'sparse.matmul' use cusparseSDDMM, which is supported from "
"CUDA 11.3"
));
#elif defined(PADDLE_WITH_HIP)
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"backward of 'sparse.matmul' use "
"rocsparse_sddmm with transpose, which is "
"supported from "
"ROCM 4.3.0"
));
#endif
#endif
}
...
...
@@ -73,7 +95,7 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx,
const
DenseTensor
&
dout
,
SparseCsrTensor
*
dx
,
DenseTensor
*
dy
)
{
#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11030
|| HIP_VERSION >= 403
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
// dx{SparseCsr} = dout{Dense} * y'{Dense}
...
...
@@ -94,13 +116,26 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
T
>(
dy
);
#ifdef PADDLE_WITH_HIP
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
dy
,
static_cast
<
T
>
(
0.0
f
));
#endif
sparse_blas
.
SPMM
(
true
,
false
,
static_cast
<
T
>
(
1
),
x
,
dout
,
static_cast
<
T
>
(
0
),
dy
);
}
#else
#ifdef PADDLE_WITH_CUDA
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"backward of 'sparse.matmul' use cusparseSDDMM, which is supported from "
"CUDA 11.3"
));
#elif defined(PADDLE_WITH_HIP)
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"backward of 'sparse.matmul' use "
"rocsparse_sddmm with transpose, which is "
"supported from "
"ROCM 4.3.0"
));
#endif
#endif
}
...
...
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
浏览文件 @
281ea2f4
...
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
...
...
@@ -36,7 +37,7 @@ void MatmulKernelImpl(const Context& dev_ctx,
const
TensorType
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
)
{
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000
|| HIP_VERSION >= 402
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
std
::
vector
<
int64_t
>
ydim_vec
=
phi
::
vectorize
(
y
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
...
...
@@ -80,13 +81,24 @@ void MatmulKernelImpl(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
T
>(
out
);
#ifdef PADDLE_WITH_HIP
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0.0
f
));
#endif
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
sparse_blas
.
SPMM
(
false
,
false
,
static_cast
<
T
>
(
1
),
x
,
y
,
static_cast
<
T
>
(
0
),
out
);
#else
#ifdef PADDLE_WITH_CUDA
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"forward of 'sparse.matmul' use cusparseSpMM, "
"which is supported from CUDA 11.0"
));
#elif defined(PADDLE_WITH_HIP)
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"forward of 'sparse.matmul' use rocsparse_spmm, "
"which is supported from ROCM 4.2.0"
));
#endif
#endif
}
...
...
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
浏览文件 @
281ea2f4
...
...
@@ -17,12 +17,16 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/remove.h>
#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/rocsparse.h"
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
...
...
@@ -214,55 +218,88 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
SparseCooTensor
*
out
)
{
const
DDim
&
x_dims
=
x
.
dims
();
const
int64_t
non_zero_num
=
x
.
cols
().
numel
();
// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int')
// now
#ifdef PADDLE_WITH_HIP
const
auto
&
csr_crows
=
Cast
<
IntT
>
(
dev_ctx
,
x
.
crows
(),
DataType
::
INT32
);
const
auto
&
csr_cols
=
Cast
<
IntT
>
(
dev_ctx
,
x
.
cols
(),
DataType
::
INT32
);
const
int
*
csr_crows_data
=
csr_crows
.
template
data
<
int
>();
const
int
*
csr_cols_data
=
csr_cols
.
template
data
<
int
>();
#else
const
auto
&
csr_crows
=
x
.
crows
();
const
auto
&
csr_cols
=
x
.
cols
();
const
auto
&
csr_values
=
x
.
values
();
const
IntT
*
csr_crows_data
=
csr_crows
.
data
<
IntT
>
();
const
IntT
*
csr_cols_data
=
csr_cols
.
data
<
IntT
>
();
#endif
const
auto
&
csr_values
=
x
.
values
();
const
T
*
csr_values_data
=
csr_values
.
data
<
T
>
();
int64_t
sparse_dim
=
2
;
if
(
x_dims
.
size
()
==
3
)
{
sparse_dim
=
3
;
}
int
batchs
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
batch
e
s
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
#ifdef PADDLE_WITH_HIP
DenseTensor
indices
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
int
*
coo_indices
=
indices
.
data
<
int
>
();
int
*
coo_rows_data
=
coo_indices
;
int
*
coo_cols_data
=
coo_rows_data
+
non_zero_num
;
#else
DenseTensor
indices
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
GPUContext
>
(
dev_ctx
,
csr_values
);
DenseTensor
offsets
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
});
DenseTensor
offsets
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batches
});
IntT
*
coo_indices
=
indices
.
data
<
IntT
>
();
IntT
*
batch_ptr
=
x_dims
.
size
()
==
2
?
nullptr
:
coo_indices
;
IntT
*
coo_rows_data
=
x_dims
.
size
()
==
2
?
coo_indices
:
batch_ptr
+
non_zero_num
;
IntT
*
coo_cols_data
=
coo_rows_data
+
non_zero_num
;
IntT
*
offsets_ptr
=
batchs
==
1
?
nullptr
:
offsets
.
data
<
IntT
>
();
IntT
*
offsets_ptr
=
batches
==
1
?
nullptr
:
offsets
.
data
<
IntT
>
();
#endif
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
GPUContext
>
(
dev_ctx
,
csr_values
);
T
*
coo_values_data
=
values
.
data
<
T
>
();
if
(
batchs
>
1
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
batchs
,
1
);
GetBatchSizes
<
IntT
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
>>>
(
csr_crows_data
,
rows
,
batchs
,
offsets_ptr
);
if
(
batches
>
1
)
{
#ifdef PADDLE_WITH_HIP
thrust
::
exclusive_scan
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"'rocsparse_csr2coo' only supports batches "
"with a value of 1 currently."
));
#else
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
batches
,
1
);
GetBatchSizes
<
IntT
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
>>>
(
csr_crows_data
,
rows
,
batches
,
offsets_ptr
);
thrust
::
exclusive_scan
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
offsets_ptr
,
offsets_ptr
+
batchs
,
offsets_ptr
+
batch
e
s
,
offsets_ptr
);
#endif
}
#ifdef PADDLE_WITH_HIP
dev_ctx
.
CusparseCall
([
&
](
rocsparse_handle
handle
)
{
phi
::
dynload
::
rocsparse_csr2coo
(
handle
,
csr_crows_data
,
non_zero_num
,
rows
,
coo_rows_data
,
rocsparse_index_base_zero
);
});
#else
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rows
,
1
);
config
.
block_per_grid
.
y
=
batchs
;
config
.
block_per_grid
.
y
=
batch
e
s
;
ConvertCsrCrowsToCooRows
<
IntT
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
.
x
>>>
(
csr_crows_data
,
offsets_ptr
,
coo_rows_data
,
batch_ptr
,
rows
);
#endif
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
coo_cols_data
,
csr_cols_data
,
#ifdef PADDLE_WITH_HIP
sizeof
(
int
)
*
non_zero_num
,
#else
sizeof
(
IntT
)
*
non_zero_num
,
#endif
gpuMemcpyDeviceToDevice
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
coo_values_data
,
...
...
@@ -271,6 +308,11 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
gpuMemcpyDeviceToDevice
,
dev_ctx
.
stream
());
#ifdef PADDLE_WITH_HIP
if
(
std
::
is_same
<
IntT
,
int64_t
>::
value
)
indices
=
Cast
<
int
>
(
dev_ctx
,
indices
,
DataType
::
INT64
);
#endif
out
->
SetMember
(
indices
,
values
,
x_dims
,
true
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录