Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4a08c781
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4a08c781
编写于
6月 18, 2022
作者:
zhouweiwei2014
提交者:
GitHub
6月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove unuse cuSparse function (#43626)
上级
03517d8a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
34 addition
and
592 deletion
+34
-592
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+0
-11
paddle/fluid/operators/math/cusparse_conversion_api_test.cc
paddle/fluid/operators/math/cusparse_conversion_api_test.cc
+0
-190
paddle/fluid/operators/math/sparse.h
paddle/fluid/operators/math/sparse.h
+0
-113
paddle/fluid/operators/math/sparse_impl.cu.h
paddle/fluid/operators/math/sparse_impl.cu.h
+0
-230
paddle/fluid/platform/dynload/cusparse.h
paddle/fluid/platform/dynload/cusparse.h
+17
-24
paddle/phi/backends/dynload/cusparse.h
paddle/phi/backends/dynload/cusparse.h
+17
-24
未找到文件。
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
4a08c781
...
...
@@ -97,17 +97,6 @@ cc_test(
SRCS concat_test.cc
DEPS concat_and_split
)
if
(
WITH_GPU
AND
(
NOT WITH_ROCM
))
#currenty not yet support ROCM
#the generic conversion APIs of dense and sparse are only supported after cuda11.2
if
((
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_LESS 11.2
))
cc_test
(
cusparse_conversion_api_test
SRCS cusparse_conversion_api_test.cc
DEPS tensor
)
endif
()
endif
()
if
(
WITH_TESTING AND TEST im2col_test
)
set_tests_properties
(
im2col_test PROPERTIES TIMEOUT 120
)
endif
()
paddle/fluid/operators/math/cusparse_conversion_api_test.cc
已删除
100644 → 0
浏览文件 @
03517d8a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/sparse.h"
template
<
typename
T
>
void
TestNNZ
(
const
std
::
vector
<
T
>&
dense_data
,
const
int
correct_nnz
,
const
int
rows
,
const
int
cols
)
{
paddle
::
platform
::
CUDADeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
paddle
::
platform
::
CUDAPlace
());
context
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CUDAPlace
(),
context
->
stream
())
.
get
());
context
->
PartialInitWithAllocator
();
auto
sparse
=
paddle
::
operators
::
math
::
GetSparse
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
*
context
);
paddle
::
framework
::
Tensor
dense
,
nnz_tensor
;
auto
dense_dims
=
phi
::
make_ddim
({
rows
,
cols
});
auto
nnz_dims
=
phi
::
make_ddim
({
dense_dims
[
0
]
+
1
});
dense
.
mutable_data
<
T
>
(
dense_dims
,
paddle
::
platform
::
CUDAPlace
());
paddle
::
framework
::
TensorFromVector
<
T
>
(
dense_data
,
*
context
,
&
dense
);
int32_t
*
nnz_ptr
=
nnz_tensor
.
mutable_data
<
int32_t
>
(
nnz_dims
,
paddle
::
platform
::
CUDAPlace
());
sparse
.
nnz
(
rows
,
cols
,
dense
.
data
<
T
>
(),
nnz_ptr
,
nnz_ptr
+
1
);
std
::
vector
<
int32_t
>
nnz_vec
(
dense_dims
[
0
]
+
1
);
paddle
::
framework
::
TensorToVector
<
int32_t
>
(
nnz_tensor
,
*
context
,
&
nnz_vec
);
delete
context
;
CHECK_EQ
(
correct_nnz
,
nnz_vec
[
0
]);
}
TEST
(
sparse
,
nnz
)
{
std
::
vector
<
float
>
dense_data
=
{
0.0
,
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
3.2
,
0.0
,
0.0
};
TestNNZ
<
float
>
(
dense_data
,
4
,
3
,
3
);
}
TEST
(
sparse
,
nnz_double
)
{
std
::
vector
<
double
>
dense_data
=
{
0.0
,
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
3.2
,
0.0
};
TestNNZ
<
double
>
(
dense_data
,
4
,
4
,
2
);
}
template
<
typename
T
>
void
TestDenseToSparse
(
const
std
::
vector
<
T
>&
correct_dense_data
,
const
std
::
vector
<
int64_t
>&
correct_rows
,
const
std
::
vector
<
int64_t
>&
correct_cols
,
const
std
::
vector
<
T
>&
correct_values
,
const
int
correct_nnz
,
const
int
rows
,
const
int
cols
,
const
std
::
string
&
mode
)
{
paddle
::
platform
::
CUDADeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
paddle
::
platform
::
CUDAPlace
());
context
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CUDAPlace
(),
context
->
stream
())
.
get
());
context
->
PartialInitWithAllocator
();
// get sparse
auto
sparse
=
paddle
::
operators
::
math
::
GetSparse
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
*
context
);
// create tensor and copy vector to tensor
paddle
::
framework
::
Tensor
dense_tensor
,
rows_tensor
,
cols_tensor
,
values_tensor
,
actual_dense_tensor
;
auto
dense_dims
=
phi
::
make_ddim
({
rows
,
cols
});
T
*
dense_data
=
dense_tensor
.
mutable_data
<
T
>
(
dense_dims
,
paddle
::
platform
::
CUDAPlace
());
T
*
actual_dense_data
=
actual_dense_tensor
.
mutable_data
<
T
>
(
dense_dims
,
paddle
::
platform
::
CUDAPlace
());
paddle
::
framework
::
TensorFromVector
<
T
>
(
correct_dense_data
,
*
context
,
&
dense_tensor
);
auto
nnz_dims
=
phi
::
make_ddim
({
correct_nnz
});
auto
crows_dims
=
phi
::
make_ddim
({
rows
+
1
});
int64_t
*
rows_data
=
nullptr
;
if
(
mode
==
"COO"
)
{
rows_data
=
rows_tensor
.
mutable_data
<
int64_t
>
(
nnz_dims
,
paddle
::
platform
::
CUDAPlace
());
}
else
{
rows_data
=
rows_tensor
.
mutable_data
<
int64_t
>
(
crows_dims
,
paddle
::
platform
::
CUDAPlace
());
}
int64_t
*
cols_data
=
cols_tensor
.
mutable_data
<
int64_t
>
(
nnz_dims
,
paddle
::
platform
::
CUDAPlace
());
T
*
values_data
=
values_tensor
.
mutable_data
<
T
>
(
nnz_dims
,
paddle
::
platform
::
CUDAPlace
());
// test dense_to_sparse
if
(
mode
==
"COO"
)
{
sparse
.
DenseToSparseCoo
(
rows
,
cols
,
dense_data
,
rows_data
,
cols_data
,
values_data
);
}
else
{
sparse
.
DenseToSparseCsr
(
rows
,
cols
,
dense_data
,
rows_data
,
cols_data
,
values_data
);
}
std
::
vector
<
int64_t
>
actual_rows
(
correct_nnz
),
actual_crows
(
rows
+
1
),
actual_cols
(
correct_nnz
);
std
::
vector
<
T
>
actual_values
(
correct_nnz
),
actual_dense_vec
(
rows
*
cols
);
if
(
mode
==
"COO"
)
{
paddle
::
framework
::
TensorToVector
<
int64_t
>
(
rows_tensor
,
*
context
,
&
actual_rows
);
}
else
{
paddle
::
framework
::
TensorToVector
<
int64_t
>
(
rows_tensor
,
*
context
,
&
actual_crows
);
}
paddle
::
framework
::
TensorToVector
<
int64_t
>
(
cols_tensor
,
*
context
,
&
actual_cols
);
paddle
::
framework
::
TensorToVector
<
T
>
(
values_tensor
,
*
context
,
&
actual_values
);
for
(
int
i
=
0
;
i
<
correct_nnz
;
i
++
)
{
if
(
mode
==
"COO"
)
{
CHECK_EQ
(
correct_rows
[
i
],
actual_rows
[
i
]);
}
CHECK_EQ
(
correct_cols
[
i
],
actual_cols
[
i
]);
CHECK_EQ
(
correct_values
[
i
],
actual_values
[
i
]);
}
if
(
mode
==
"CSR"
)
{
for
(
int
i
=
0
;
i
<
rows
+
1
;
i
++
)
{
CHECK_EQ
(
correct_rows
[
i
],
actual_crows
[
i
]);
}
}
// test sparse_to_dense
if
(
mode
==
"COO"
)
{
sparse
.
SparseCooToDense
(
rows
,
cols
,
correct_nnz
,
rows_data
,
cols_data
,
values_data
,
actual_dense_data
);
}
else
{
sparse
.
SparseCsrToDense
(
rows
,
cols
,
correct_nnz
,
rows_data
,
cols_data
,
values_data
,
actual_dense_data
);
}
paddle
::
framework
::
TensorToVector
<
T
>
(
actual_dense_tensor
,
*
context
,
&
actual_dense_vec
);
for
(
uint64_t
i
=
0
;
i
<
correct_dense_data
.
size
();
i
++
)
{
CHECK_EQ
(
correct_dense_data
[
i
],
actual_dense_vec
[
i
]);
}
delete
context
;
}
TEST
(
sparse
,
dense_to_sparse
)
{
std
::
vector
<
float
>
dense_data
=
{
0.0
,
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
3.2
,
0.0
,
0.0
};
std
::
vector
<
float
>
values
=
{
1.0
,
2.0
,
3.0
,
3.2
};
std
::
vector
<
int64_t
>
rows
=
{
0
,
1
,
1
,
2
};
std
::
vector
<
int64_t
>
crows
=
{
0
,
1
,
3
,
4
};
std
::
vector
<
int64_t
>
cols
=
{
1
,
0
,
2
,
0
};
TestDenseToSparse
<
float
>
(
dense_data
,
rows
,
cols
,
values
,
4
,
3
,
3
,
"COO"
);
TestDenseToSparse
<
float
>
(
dense_data
,
crows
,
cols
,
values
,
4
,
3
,
3
,
"CSR"
);
}
TEST
(
sparse
,
dense_to_sparse_double
)
{
std
::
vector
<
double
>
dense_data
=
{
0.0
,
1.0
,
0.0
,
2.0
,
0.0
,
3.0
,
3.2
,
0.0
};
std
::
vector
<
double
>
values
=
{
1.0
,
2.0
,
3.0
,
3.2
};
std
::
vector
<
int64_t
>
rows
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int64_t
>
crows
=
{
0
,
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
cols
=
{
1
,
1
,
1
,
0
};
TestDenseToSparse
<
double
>
(
dense_data
,
rows
,
cols
,
values
,
4
,
4
,
2
,
"COO"
);
TestDenseToSparse
<
double
>
(
dense_data
,
crows
,
cols
,
values
,
4
,
4
,
2
,
"CSR"
);
}
TEST
(
sparse
,
dense_to_sparse_fp16
)
{
using
float16
=
paddle
::
platform
::
float16
;
std
::
vector
<
float16
>
dense_data
=
{
float16
(
0.0
),
float16
(
1.0
),
float16
(
0.0
),
float16
(
2.0
),
float16
(
0.0
),
float16
(
3.0
),
float16
(
3.2
),
float16
(
0.0
)};
std
::
vector
<
float16
>
values
=
{
float16
(
1.0
),
float16
(
2.0
),
float16
(
3.0
),
float16
(
3.2
)};
std
::
vector
<
int64_t
>
rows
=
{
0
,
1
,
2
,
3
};
std
::
vector
<
int64_t
>
crows
=
{
0
,
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
cols
=
{
1
,
1
,
1
,
0
};
TestDenseToSparse
<
float16
>
(
dense_data
,
rows
,
cols
,
values
,
4
,
4
,
2
,
"COO"
);
TestDenseToSparse
<
float16
>
(
dense_data
,
crows
,
cols
,
values
,
4
,
4
,
2
,
"CSR"
);
}
paddle/fluid/operators/math/sparse.h
已删除
100644 → 0
浏览文件 @
03517d8a
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
framework
{
class
ExecutionContext
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
DeviceContext
>
class
Sparse
{
public:
explicit
Sparse
(
const
DeviceContext
&
context
)
:
context_
(
context
)
{}
template
<
typename
T
>
void
nnz
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int
*
nnz
,
int
*
nnzPerRowColumn
)
const
;
template
<
typename
T
>
void
DenseToSparseCoo
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int64_t
*
rows
,
int64_t
*
cols
,
T
*
values
)
const
;
template
<
typename
T
>
void
DenseToSparseCsr
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int64_t
*
crows
,
int64_t
*
cols
,
T
*
values
)
const
;
template
<
typename
T
>
void
SparseCooToDense
(
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
nnz
,
const
int64_t
*
rows
,
const
int64_t
*
cols
,
const
T
*
values
,
T
*
dense
)
const
;
template
<
typename
T
>
void
SparseCsrToDense
(
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
nnz
,
const
int64_t
*
crows
,
const
int64_t
*
cols
,
const
T
*
values
,
T
*
dense
)
const
;
private:
const
DeviceContext
&
context_
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
SparseT
:
private
Sparse
<
DeviceContext
>
{
public:
using
Sparse
<
DeviceContext
>::
Sparse
;
template
<
typename
...
ARGS
>
void
nnz
(
ARGS
...
args
)
const
{
Base
()
->
template
nnz
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
DenseToSparseCoo
(
ARGS
...
args
)
const
{
Base
()
->
template
DenseToSparseCoo
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
DenseToSparseCsr
(
ARGS
...
args
)
const
{
Base
()
->
template
DenseToSparseCsr
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
SparseCooToDense
(
ARGS
...
args
)
const
{
Base
()
->
template
SparseCooToDense
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
SparseCsrToDense
(
ARGS
...
args
)
const
{
Base
()
->
template
SparseCsrToDense
<
T
>(
args
...);
}
private:
const
Sparse
<
DeviceContext
>*
Base
()
const
{
return
static_cast
<
const
Sparse
<
DeviceContext
>*>
(
this
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
inline
SparseT
<
DeviceContext
,
T
>
GetSparse
(
const
framework
::
ExecutionContext
&
exe_ctx
)
{
return
SparseT
<
DeviceContext
,
T
>
(
exe_ctx
.
template
device_context
<
DeviceContext
>());
}
template
<
typename
DeviceContext
,
typename
T
>
inline
SparseT
<
DeviceContext
,
T
>
GetSparse
(
const
DeviceContext
&
dev_ctx
)
{
return
SparseT
<
DeviceContext
,
T
>
(
dev_ctx
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11020
#include "paddle/fluid/operators/math/sparse_impl.cu.h"
#endif
#endif
paddle/fluid/operators/math/sparse_impl.cu.h
已删除
100644 → 0
浏览文件 @
03517d8a
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cusparse.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
cudaDataType_t
GetGpuDataType
()
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
CUDA_R_32F
;
}
else
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
CUDA_R_64F
;
}
else
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
return
CUDA_R_16F
;
}
}
template
<
>
template
<
typename
T
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
nnz
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int
*
nnz
,
int
*
nnzPerRowColumn
)
const
{}
template
<
>
template
<
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
nnz
(
const
int
M
,
const
int
N
,
const
float
*
dense
,
int
*
nnz
,
int
*
nnzPerRowColumn
)
const
{
cusparseMatDescr_t
descr
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseCreateMatDescr
(
&
descr
));
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseSetMatType
(
descr
,
CUSPARSE_MATRIX_TYPE_GENERAL
));
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseSetMatIndexBase
(
descr
,
CUSPARSE_INDEX_BASE_ZERO
));
context_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseSnnz
(
handle
,
CUSPARSE_DIRECTION_ROW
,
M
,
N
,
descr
,
dense
,
M
,
nnzPerRowColumn
,
nnz
));
});
}
template
<
>
template
<
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
nnz
(
const
int
M
,
const
int
N
,
const
double
*
dense
,
int
*
nnz
,
int
*
nnzPerRowColumn
)
const
{
cusparseMatDescr_t
descr
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseCreateMatDescr
(
&
descr
));
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseSetMatType
(
descr
,
CUSPARSE_MATRIX_TYPE_GENERAL
));
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseSetMatIndexBase
(
descr
,
CUSPARSE_INDEX_BASE_ZERO
));
context_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
paddle
::
platform
::
dynload
::
cusparseDnnz
(
handle
,
CUSPARSE_DIRECTION_ROW
,
M
,
N
,
descr
,
dense
,
M
,
nnzPerRowColumn
,
nnz
));
});
}
template
<
typename
T
>
inline
void
DenseToSparse
(
const
platform
::
CUDADeviceContext
&
context
,
const
int
M
,
const
int
N
,
const
T
*
dense
,
int64_t
*
rows
,
int64_t
*
cols
,
T
*
values
,
const
cusparseFormat_t
format
)
{
cusparseSpMatDescr_t
matB
;
cusparseDnMatDescr_t
matA
;
cudaDataType_t
dtype
=
GetGpuDataType
<
T
>
();
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateDnMat
(
&
matA
,
M
,
N
,
N
,
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
dense
)),
dtype
,
CUSPARSE_ORDER_ROW
));
if
(
format
==
CUSPARSE_FORMAT_COO
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateCoo
(
&
matB
,
M
,
N
,
0
,
nullptr
,
nullptr
,
nullptr
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
}
else
if
(
format
==
CUSPARSE_FORMAT_CSR
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateCsr
(
&
matB
,
M
,
N
,
0
,
rows
,
nullptr
,
nullptr
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"the sparse format [%s] is not supported"
,
format
));
}
size_t
buffer_size
=
0
;
context
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseDenseToSparse_bufferSize
(
handle
,
matA
,
matB
,
CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
,
&
buffer_size
));
});
framework
::
Tensor
buffer
;
float
*
buffer_data
=
buffer
.
mutable_data
<
float
>
(
{
static_cast
<
int64_t
>
(
buffer_size
)},
context
.
GetPlace
());
context
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseDenseToSparse_analysis
(
handle
,
matA
,
matB
,
CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
,
buffer_data
));
});
if
(
format
==
CUSPARSE_FORMAT_COO
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCooSetPointers
(
matB
,
rows
,
cols
,
reinterpret_cast
<
void
*>
(
values
)));
}
else
if
(
format
==
CUSPARSE_FORMAT_CSR
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCsrSetPointers
(
matB
,
rows
,
cols
,
reinterpret_cast
<
void
*>
(
values
)));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"the sparse format [%s] is not supported"
,
format
));
}
context
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseDenseToSparse_convert
(
handle
,
matA
,
matB
,
CUSPARSE_DENSETOSPARSE_ALG_DEFAULT
,
buffer_data
));
});
}
template
<
>
template
<
typename
T
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
DenseToSparseCoo
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int64_t
*
rows
,
int64_t
*
cols
,
T
*
values
)
const
{
DenseToSparse
<
T
>
(
context_
,
M
,
N
,
dense
,
rows
,
cols
,
values
,
CUSPARSE_FORMAT_COO
);
}
template
<
>
template
<
typename
T
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
DenseToSparseCsr
(
const
int
M
,
const
int
N
,
const
T
*
dense
,
int64_t
*
crows
,
int64_t
*
cols
,
T
*
values
)
const
{
DenseToSparse
<
T
>
(
context_
,
M
,
N
,
dense
,
crows
,
cols
,
values
,
CUSPARSE_FORMAT_CSR
);
}
template
<
typename
T
>
void
SparseToDense
(
const
platform
::
CUDADeviceContext
&
context
,
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
nnz
,
const
int64_t
*
rows
,
const
int64_t
*
cols
,
const
T
*
values
,
T
*
dense
,
const
cusparseFormat_t
format
)
{
cusparseSpMatDescr_t
matA
;
cusparseDnMatDescr_t
matB
;
cudaDataType_t
dtype
=
GetGpuDataType
<
T
>
();
if
(
format
==
CUSPARSE_FORMAT_COO
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateCoo
(
&
matA
,
M
,
N
,
nnz
,
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
rows
)),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
cols
)),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
values
)),
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
}
else
if
(
format
==
CUSPARSE_FORMAT_CSR
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateCsr
(
&
matA
,
M
,
N
,
nnz
,
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
rows
)),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
cols
)),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
values
)),
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"the sparse format [%s] is not supported"
,
format
));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseCreateDnMat
(
&
matB
,
M
,
N
,
N
,
reinterpret_cast
<
void
*>
(
dense
),
dtype
,
CUSPARSE_ORDER_ROW
));
size_t
buffer_size
=
0
;
context
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseSparseToDense_bufferSize
(
handle
,
matA
,
matB
,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT
,
&
buffer_size
));
});
framework
::
Tensor
buffer
;
float
*
buffer_data
=
buffer
.
mutable_data
<
float
>
(
{
static_cast
<
int64_t
>
(
buffer_size
)},
context
.
GetPlace
());
context
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusparseSparseToDense
(
handle
,
matA
,
matB
,
CUSPARSE_SPARSETODENSE_ALG_DEFAULT
,
buffer_data
));
});
}
template
<
>
template
<
typename
T
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
SparseCooToDense
(
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
nnz
,
const
int64_t
*
rows
,
const
int64_t
*
cols
,
const
T
*
values
,
T
*
dense
)
const
{
SparseToDense
<
T
>
(
context_
,
M
,
N
,
nnz
,
rows
,
cols
,
values
,
dense
,
CUSPARSE_FORMAT_COO
);
}
template
<
>
template
<
typename
T
>
void
Sparse
<
platform
::
CUDADeviceContext
>::
SparseCsrToDense
(
const
int64_t
M
,
const
int64_t
N
,
const
int64_t
nnz
,
const
int64_t
*
crows
,
const
int64_t
*
cols
,
const
T
*
values
,
T
*
dense
)
const
{
SparseToDense
<
T
>
(
context_
,
M
,
N
,
nnz
,
crows
,
cols
,
values
,
dense
,
CUSPARSE_FORMAT_CSR
);
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/dynload/cusparse.h
浏览文件 @
4a08c781
...
...
@@ -31,30 +31,23 @@ namespace dynload {
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseCooSetPointers); \
__macro(cusparseCsrSetPointers); \
__macro(cusparseDenseToSparse_bufferSize); \
__macro(cusparseDenseToSparse_analysis); \
__macro(cusparseDenseToSparse_convert); \
__macro(cusparseSparseToDense_bufferSize); \
__macro(cusparseSparseToDense); \
__macro(cusparseDnMatSetStridedBatch); \
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
...
...
paddle/phi/backends/dynload/cusparse.h
浏览文件 @
4a08c781
...
...
@@ -43,30 +43,23 @@ extern void *cusparse_dso_handle;
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseCooSetPointers); \
__macro(cusparseCsrSetPointers); \
__macro(cusparseDenseToSparse_bufferSize); \
__macro(cusparseDenseToSparse_analysis); \
__macro(cusparseDenseToSparse_convert); \
__macro(cusparseSparseToDense_bufferSize); \
__macro(cusparseSparseToDense); \
__macro(cusparseDnMatSetStridedBatch); \
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录