Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9a8a4c77
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9a8a4c77
编写于
12月 17, 2021
作者:
N
niuliling123
提交者:
GitHub
12月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Delete cub_reduce.h and modified the TensorReduce to TensorReduceFunctorImpl (#38197)
上级
431a2d6a
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
29 addition
and
544 deletion
+29
-544
paddle/fluid/operators/broadcast_tensors_op.cu
paddle/fluid/operators/broadcast_tensors_op.cu
+4
-14
paddle/fluid/operators/controlflow/compare_all_op.cu
paddle/fluid/operators/controlflow/compare_all_op.cu
+6
-10
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+1
-3
paddle/fluid/operators/kron_op.h
paddle/fluid/operators/kron_op.h
+5
-16
paddle/fluid/operators/matmul_v2_op.h
paddle/fluid/operators/matmul_v2_op.h
+3
-13
paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu
paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu
+5
-5
paddle/fluid/operators/reduce_ops/cub_reduce.h
paddle/fluid/operators/reduce_ops/cub_reduce.h
+0
-468
paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu
paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu
+1
-1
paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
+1
-1
paddle/fluid/operators/solve_op.h
paddle/fluid/operators/solve_op.h
+3
-13
未找到文件。
paddle/fluid/operators/broadcast_tensors_op.cu
浏览文件 @
9a8a4c77
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -28,16 +28,6 @@ namespace operators {
using
framework
::
Tensor
;
using
framework
::
DDim
;
template
<
typename
Tout
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
template
<
typename
U
>
HOSTDEVICE
inline
Tout
operator
()(
const
U
&
x
)
const
{
return
static_cast
<
Tout
>
(
x
);
}
};
template
<
typename
T
>
class
CUDABroadcastTensorsGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -99,9 +89,9 @@ class CUDABroadcastTensorsGradOpKernel : public framework::OpKernel<T> {
}
else
{
// reduce_sum implementation on CUDA
auto
stream
=
context
.
cuda_device_context
().
stream
();
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
<
T
>>
(
*
input_tensor
,
output_tensor
,
reduce_dims_vec
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
<
T
>
()
,
stream
);
TensorReduce
FunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
input_tensor
,
output_tensor
,
kps
::
IdentityFunctor
<
T
>
(
),
reduce_dims_vec
,
stream
);
}
}
}
...
...
paddle/fluid/operators/controlflow/compare_all_op.cu
浏览文件 @
9a8a4c77
...
...
@@ -15,20 +15,16 @@ limitations under the License. */
#include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
;
}
};
struct
BitwiseAdd
{
// Bitwise add operator, returns <tt>a + b</tt>
template
<
typename
T
>
inline
T
initial
()
{
return
static_cast
<
T
>
(
true
);
}
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
&
b
;
...
...
@@ -67,9 +63,9 @@ class CompareReduceOpKernel
reduce_dims
.
resize
(
tmp
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
reduce_dims
.
size
();
++
i
)
reduce_dims
[
i
]
=
i
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
TensorReduce
<
bool
,
bool
,
BitwiseAdd
,
IdentityFunctor
<
bool
>>
(
tmp
,
z
,
reduce_dims
,
true
,
BitwiseAdd
(),
IdentityFunctor
<
bool
>
(),
stream
);
TensorReduce
FunctorImpl
<
bool
,
bool
,
BitwiseAdd
,
kps
::
IdentityFunctor
<
bool
>>
(
tmp
,
z
,
kps
::
IdentityFunctor
<
bool
>
(),
reduce_dims
,
stream
);
}
}
};
...
...
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
9a8a4c77
...
...
@@ -33,7 +33,7 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_
functor_op
.h"
#include "paddle/fluid/operators/reduce_ops/reduce_
op.cu
.h"
#include "paddle/fluid/platform/fast_divmod.h"
namespace
paddle
{
...
...
@@ -41,8 +41,6 @@ namespace operators {
#define MAX_INPUT_NUM 2
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
...
...
paddle/fluid/operators/kron_op.h
浏览文件 @
9a8a4c77
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#include "thrust/device_vector.h"
#endif
...
...
@@ -237,15 +237,6 @@ struct KronGradElemFunctor<platform::complex<T>> {
const
int
ndims_
;
};
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
template
<
typename
U
>
HOSTDEVICE
inline
U
operator
()(
const
U
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
struct
KronGradOpFunctor
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
dout
,
...
...
@@ -314,14 +305,12 @@ struct KronGradOpFunctor {
#if defined(__NVCC__) || defined(__HIPCC__)
auto
stream
=
dev_ctx
.
stream
();
// it is a cuda device_context
if
(
dx
)
{
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
dout_x
,
dx
,
{
1
},
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dout_x
,
dx
,
kps
::
IdentityFunctor
<
T
>
(),
{
1
},
stream
);
}
if
(
dy
)
{
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
dout_y
,
dy
,
{
1
},
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dout_y
,
dy
,
kps
::
IdentityFunctor
<
T
>
(),
{
1
},
stream
);
}
#else
auto
*
place
=
dev_ctx
.
eigen_device
();
...
...
paddle/fluid/operators/matmul_v2_op.h
浏览文件 @
9a8a4c77
...
...
@@ -31,7 +31,7 @@ limitations under the License. */
#include "paddle/pten/include/linalg.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#endif
namespace
paddle
{
...
...
@@ -39,24 +39,14 @@ namespace operators {
using
framework
::
Tensor
;
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
template
<
typename
U
>
HOSTDEVICE
inline
U
operator
()(
const
U
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
void
ReduceSumForMatmulGrad
(
const
Tensor
*
input
,
Tensor
*
output
,
const
std
::
vector
<
int
>&
reduce_dims
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
input
,
output
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
#else
ReduceKernelFunctor
<
DeviceContext
,
T
,
ops
::
SumFunctor
>
(
input
,
output
,
reduce_dims
,
true
,
false
,
ctx
)
...
...
paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu
浏览文件 @
9a8a4c77
...
...
@@ -13,11 +13,11 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
namespace
detail
s
{
TEST
(
test_reduce_rank_check
,
all
)
{
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
...
...
@@ -39,15 +39,15 @@ TEST(test_reduce_rank_check, all) {
}
if
(
is_valid
)
{
CheckReduceRank
IsValid
(
reduce_rank
,
rank
);
CheckReduceRank
(
reduce_rank
,
rank
);
}
else
{
ASSERT_THROW
(
CheckReduceRank
IsValid
(
reduce_rank
,
rank
),
ASSERT_THROW
(
CheckReduceRank
(
reduce_rank
,
rank
),
paddle
::
platform
::
EnforceNotMet
);
}
}
}
}
}
// namespace detail
}
// namespace detail
s
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/cub_reduce.h
已删除
100644 → 0
浏览文件 @
431a2d6a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" // NOLINT
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
namespace
paddle
{
namespace
operators
{
namespace
detail
{
template
<
typename
T
,
size_t
ElementCount
>
struct
Array
{
public:
HOSTDEVICE
inline
Array
()
{}
HOSTDEVICE
inline
T
&
operator
[](
size_t
index
)
{
return
data_
[
index
];
}
HOSTDEVICE
inline
const
T
&
operator
[](
size_t
index
)
const
{
return
data_
[
index
];
}
HOSTDEVICE
constexpr
inline
size_t
size
()
const
{
return
ElementCount
;
}
template
<
typename
VectorLikeType
>
static
inline
Array
<
T
,
ElementCount
>
From
(
const
VectorLikeType
&
vec
)
{
PADDLE_ENFORCE_EQ
(
vec
.
size
(),
ElementCount
,
platform
::
errors
::
InvalidArgument
(
"Cub reduce Array: size not match. Received "
"vec.size() %d != ElementCount %d."
,
vec
.
size
(),
ElementCount
));
size_t
n
=
static_cast
<
size_t
>
(
vec
.
size
());
Array
<
T
,
ElementCount
>
ret
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
ret
[
i
]
=
vec
[
i
];
return
ret
;
}
private:
T
data_
[
ElementCount
];
};
// reduce the 1d array to one element
template
<
typename
Tx
,
typename
MPType
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
__global__
void
ReduceKernel1D
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
)
{
int
thread_id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
typedef
cub
::
BlockReduce
<
MPType
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
MPType
local_data
=
init
;
for
(
int
i
=
thread_id
;
i
<
reduce_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
local_data
=
static_cast
<
MPType
>
(
reducer
(
local_data
,
static_cast
<
MPType
>
(
transformer
(
x
[
i
]))));
}
__syncthreads
();
local_data
=
BlockReduce
(
temp_storage
).
Reduce
(
local_data
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
static_cast
<
Ty
>
(
local_data
);
}
}
// reduce the last axis of 2d array
template
<
typename
Tx
,
typename
MPType
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
__global__
void
ReduceKernel2D
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
)
{
__shared__
typename
cub
::
BlockReduce
<
MPType
,
BlockDim
>::
TempStorage
temp_storage
;
int
idx_x
=
blockIdx
.
x
*
reduce_num
;
int
idx_y
=
threadIdx
.
x
;
MPType
reduce_var
=
init
;
for
(
int
idx_y
=
threadIdx
.
x
;
idx_y
<
reduce_num
;
idx_y
+=
BlockDim
)
reduce_var
=
reducer
(
reduce_var
,
static_cast
<
MPType
>
(
transformer
(
x
[
idx_x
+
idx_y
])));
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
MPType
,
BlockDim
>
(
temp_storage
)
.
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
}
template
<
typename
Tx
,
typename
MPType
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
,
int
Rank
,
int
ReduceRank
>
__global__
void
ReduceKernel
(
const
Tx
*
x
,
Ty
*
y
,
ReduceOp
reducer
,
TransformOp
transformer
,
MPType
init
,
int
reduce_num
,
Array
<
int
,
Rank
>
x_strides
,
Array
<
int
,
ReduceRank
>
reduce_dim
,
Array
<
int
,
ReduceRank
>
reduce_strides
,
Array
<
int
,
Rank
-
ReduceRank
>
left_dim
,
Array
<
int
,
Rank
-
ReduceRank
>
left_strides
)
{
__shared__
typename
cub
::
BlockReduce
<
MPType
,
BlockDim
>::
TempStorage
temp_storage
;
Array
<
int
,
Rank
>
sub_index
;
int
left_idx
=
blockIdx
.
x
;
for
(
int
i
=
0
;
i
<
Rank
-
ReduceRank
;
++
i
)
{
sub_index
[
left_dim
[
i
]]
=
left_idx
/
left_strides
[
i
];
left_idx
%=
left_strides
[
i
];
}
int
reduce_idx
=
threadIdx
.
x
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
sub_index
[
reduce_dim
[
j
]]
=
reduce_idx
/
reduce_strides
[
j
];
reduce_idx
%=
reduce_strides
[
j
];
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
MPType
reduce_var
=
static_cast
<
MPType
>
(
transformer
(
x
[
idx_x
]));
for
(
int
i
=
threadIdx
.
x
+
BlockDim
;
i
<
reduce_num
;
i
+=
BlockDim
)
{
int
reduce_idx
=
i
;
for
(
int
j
=
0
;
j
<
ReduceRank
;
++
j
)
{
sub_index
[
reduce_dim
[
j
]]
=
reduce_idx
/
reduce_strides
[
j
];
reduce_idx
%=
reduce_strides
[
j
];
}
int
idx_x
=
0
;
for
(
int
k
=
0
;
k
<
Rank
;
++
k
)
idx_x
+=
(
sub_index
[
k
]
*
x_strides
[
k
]);
reduce_var
=
static_cast
<
MPType
>
(
reducer
(
reduce_var
,
static_cast
<
MPType
>
(
transformer
(
x
[
idx_x
]))));
}
__syncthreads
();
reduce_var
=
cub
::
BlockReduce
<
MPType
,
BlockDim
>
(
temp_storage
)
.
Reduce
(
reduce_var
,
reducer
);
if
(
threadIdx
.
x
==
0
)
{
y
[
blockIdx
.
x
]
=
static_cast
<
Ty
>
(
reduce_var
);
}
}
static
inline
std
::
vector
<
int
>
GetStrides
(
const
std
::
vector
<
int
>&
dims
)
{
int
n
=
static_cast
<
int
>
(
dims
.
size
());
if
(
n
==
0
)
return
std
::
vector
<
int
>
();
std
::
vector
<
int
>
strides
(
n
);
strides
.
back
()
=
1
;
for
(
int
i
=
n
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
dims
[
i
+
1
];
}
return
strides
;
}
static
inline
std
::
vector
<
int
>
GetStrides
(
const
std
::
vector
<
int
>&
dims
,
const
std
::
vector
<
int
>&
idx
)
{
int
n
=
static_cast
<
int
>
(
idx
.
size
());
if
(
n
==
0
)
return
std
::
vector
<
int
>
();
std
::
vector
<
int
>
strides
(
n
);
strides
.
back
()
=
1
;
for
(
int
i
=
n
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
dims
[
idx
[
i
+
1
]];
}
return
strides
;
}
#ifdef __HIPCC__
constexpr
int
kMaxBlockDim
=
256
;
#else
constexpr
int
kMaxBlockDim
=
512
;
#endif
static
inline
int
GetDesiredBlockDim
(
int
block_dim
)
{
return
block_dim
>=
kMaxBlockDim
?
kMaxBlockDim
:
(
1
<<
static_cast
<
int
>
(
std
::
log2
(
block_dim
)));
}
static
inline
void
CheckReduceRankIsValid
(
int
reduce_rank
,
int
rank
)
{
if
(
rank
%
2
==
0
)
{
PADDLE_ENFORCE_EQ
(
reduce_rank
,
rank
/
2
,
platform
::
errors
::
InvalidArgument
(
"ReduceOp: invalid reduce rank. When rank = %d, "
"reduce_rank must be %d, but got %d."
,
rank
,
rank
/
2
,
reduce_rank
));
}
else
{
auto
lower_rank
=
(
rank
-
1
)
/
2
;
auto
upper_rank
=
(
rank
+
1
)
/
2
;
PADDLE_ENFORCE_EQ
(
reduce_rank
==
lower_rank
||
reduce_rank
==
upper_rank
,
true
,
platform
::
errors
::
InvalidArgument
(
"ReduceOp: invalid reduce rank. When rank = %d, reduce_rank "
"must be %d or %d, but got %d."
,
rank
,
lower_rank
,
upper_rank
,
reduce_rank
));
}
}
template
<
typename
Tx
,
typename
MPType
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
typename
std
::
enable_if
<!
std
::
is_same
<
Tx
,
paddle
::
platform
::
float16
>::
value
,
void
>::
type
LaunchCubReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
MPType
&
init
,
int
reduce_num
,
gpuStream_t
stream
)
{
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transformer
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
place
);
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
init
,
stream
);
}
template
<
typename
Tx
,
typename
MPType
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
,
int
BlockDim
>
typename
std
::
enable_if
<
std
::
is_same
<
Tx
,
paddle
::
platform
::
float16
>::
value
,
void
>::
type
LaunchCubReduceKernel
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
MPType
&
init
,
int
reduce_num
,
gpuStream_t
stream
)
{
int
element_per_block
=
BlockDim
*
10
;
int
block_per_grid
=
(
reduce_num
+
element_per_block
-
1
)
/
element_per_block
;
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
MPType
>
(
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
block_per_grid
*
sizeof
(
MPType
))}),
place
);
// each block reduce number to interim result
ReduceKernel1D
<
Tx
,
MPType
,
MPType
,
ReduceOp
,
TransformOp
,
BlockDim
><<<
block_per_grid
,
BlockDim
,
0
,
stream
>>>
(
x_data
,
temp_storage
,
reducer
,
transformer
,
init
,
reduce_num
);
// reduce all number to final result
ReduceKernel1D
<
MPType
,
MPType
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
><<<
1
,
BlockDim
,
0
,
stream
>>>
(
temp_storage
,
y_data
,
reducer
,
transformer
,
init
,
block_per_grid
);
}
template
<
typename
Tx
,
typename
Ty
,
int
BlockDim
,
typename
ReduceOp
,
typename
TransformOp
>
static
void
TensorReduceImpl
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
platform
::
Place
&
place
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
const
Ty
&
init
,
int
left_num
,
int
reduce_num
,
const
std
::
vector
<
int
>&
x_strides
,
const
std
::
vector
<
int
>&
reduce_dim
,
const
std
::
vector
<
int
>&
reduce_strides
,
const
std
::
vector
<
int
>&
left_dim
,
const
std
::
vector
<
int
>&
left_strides
,
gpuStream_t
stream
)
{
using
MPType
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
MPType
init_mp
=
static_cast
<
MPType
>
(
init
);
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init_mp, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
} break
int
rank
=
x_strides
.
size
();
int
reduce_rank
=
reduce_strides
.
size
();
if
(
rank
==
reduce_rank
)
{
LaunchCubReduceKernel
<
Tx
,
MPType
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
>
(
x_data
,
y_data
,
place
,
reducer
,
transformer
,
init_mp
,
reduce_num
,
stream
);
return
;
}
if
(
rank
==
2
&&
reduce_rank
==
1
&&
reduce_dim
[
0
]
==
1
)
{
ReduceKernel2D
<
Tx
,
MPType
,
Ty
,
ReduceOp
,
TransformOp
,
BlockDim
><<<
left_num
,
BlockDim
,
0
,
stream
>>>
(
x_data
,
y_data
,
reducer
,
transformer
,
init_mp
,
reduce_num
);
return
;
}
/*
if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) {
// TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced.
// Currently, it is handled by code below, but inefficient
return;
}
*/
/**
* Since we have combined the adjacent reduce dimensions inside TensorReduce,
* The reduce ranks and non-reduce ranks must be interleaving. That is to say,
* the rank of Tensor must be `1010...` or `0101...` where 1 represents that
* the dimension is about to be reduced.
*
* Therefore,
* If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2.
* If rank is even, only need to switch-case rank/2.
*
* The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12,
* it would speed up compiling and make the binary size lower.
*/
CheckReduceRankIsValid
(
reduce_rank
,
rank
);
switch
(
rank
)
{
CUB_RANK_CASE
(
2
,
CUB_REDUCE_RANK_CASE
(
1
););
CUB_RANK_CASE
(
3
,
CUB_REDUCE_RANK_CASE
(
1
);
CUB_REDUCE_RANK_CASE
(
2
););
CUB_RANK_CASE
(
4
,
CUB_REDUCE_RANK_CASE
(
2
););
CUB_RANK_CASE
(
5
,
CUB_REDUCE_RANK_CASE
(
2
);
CUB_REDUCE_RANK_CASE
(
3
););
CUB_RANK_CASE
(
6
,
CUB_REDUCE_RANK_CASE
(
3
););
CUB_RANK_CASE
(
7
,
CUB_REDUCE_RANK_CASE
(
3
);
CUB_REDUCE_RANK_CASE
(
4
););
CUB_RANK_CASE
(
8
,
CUB_REDUCE_RANK_CASE
(
4
););
CUB_RANK_CASE
(
9
,
CUB_REDUCE_RANK_CASE
(
4
);
CUB_REDUCE_RANK_CASE
(
5
););
}
#undef CUB_REDUCE_RANK_CASE
#undef CUB_RANK_CASE
}
}
// namespace detail
template
<
typename
Tx
,
typename
Ty
,
typename
ReduceOp
,
typename
TransformOp
>
void
TensorReduce
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
Ty
&
init
,
const
ReduceOp
&
reducer
,
const
TransformOp
&
transformer
,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
std
::
vector
<
int
>
new_x_dim
,
new_reduce_dims
;
int
is_reduced
=
0
;
for
(
auto
e
:
origin_reduce_dims
)
{
auto
pos
=
e
>=
0
?
e
:
e
+
x_dim
.
size
();
is_reduced
|=
1
<<
e
;
}
for
(
int
i
=
0
;
i
<
x_dim
.
size
();
i
++
)
{
if
((
i
==
0
)
||
(((
is_reduced
>>
i
)
^
(
is_reduced
>>
(
i
-
1
)))
&
1
))
{
new_x_dim
.
push_back
(
x_dim
[
i
]);
if
((
is_reduced
>>
i
)
&
1
)
new_reduce_dims
.
push_back
(
new_x_dim
.
size
()
-
1
);
}
else
{
new_x_dim
[
new_x_dim
.
size
()
-
1
]
*=
x_dim
[
i
];
}
}
x_dim
=
new_x_dim
;
origin_reduce_dims
=
new_reduce_dims
;
int
x_rank
=
static_cast
<
int
>
(
x_dim
.
size
());
std
::
set
<
int
>
left_set
,
reduce_set
;
for
(
int
i
=
0
;
i
<
x_rank
;
++
i
)
left_set
.
insert
(
i
);
for
(
auto
e
:
origin_reduce_dims
)
{
left_set
.
erase
(
e
);
reduce_set
.
insert
(
e
);
}
std
::
vector
<
int
>
reduce_dim
(
reduce_set
.
begin
(),
reduce_set
.
end
());
std
::
vector
<
int
>
left_dim
(
left_set
.
begin
(),
left_set
.
end
());
std
::
vector
<
int
>
x_strides
=
detail
::
GetStrides
(
x_dim
);
std
::
vector
<
int
>
reduce_strides
=
detail
::
GetStrides
(
x_dim
,
reduce_dim
);
std
::
vector
<
int
>
left_strides
=
detail
::
GetStrides
(
x_dim
,
left_dim
);
int
reduce_num
=
reduce_strides
[
0
]
*
x_dim
[
reduce_dim
[
0
]];
int
left_num
=
1
;
if
(
left_dim
.
size
())
left_num
=
left_strides
[
0
]
*
x_dim
[
left_dim
[
0
]];
std
::
vector
<
int
>
y_dim
(
left_dim
.
size
());
for
(
int
i
=
0
;
i
<
left_dim
.
size
();
++
i
)
{
y_dim
[
i
]
=
x_dim
[
left_dim
[
i
]];
}
auto
x_data
=
x
.
data
<
Tx
>
();
auto
y_data
=
y
->
mutable_data
<
Ty
>
(
x
.
place
());
if
(
reduce_num
==
1
)
{
auto
out_dims
=
y
->
dims
();
framework
::
TensorCopy
(
x
,
y
->
place
(),
y
);
y
->
Resize
(
out_dims
);
return
;
}
#define CUB_BLOCK_DIM_CASE(block_dim) \
case block_dim: { \
constexpr auto kBlockDim = block_dim; \
detail::TensorReduceImpl<Tx, Ty, block_dim, ReduceOp, TransformOp>( \
x_data, y_data, x.place(), reducer, transformer, init, left_num, \
reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, \
left_strides, stream); \
} break
switch
(
detail
::
GetDesiredBlockDim
(
reduce_num
))
{
CUB_BLOCK_DIM_CASE
(
512
);
CUB_BLOCK_DIM_CASE
(
256
);
CUB_BLOCK_DIM_CASE
(
128
);
CUB_BLOCK_DIM_CASE
(
64
);
CUB_BLOCK_DIM_CASE
(
32
);
CUB_BLOCK_DIM_CASE
(
16
);
CUB_BLOCK_DIM_CASE
(
8
);
CUB_BLOCK_DIM_CASE
(
4
);
CUB_BLOCK_DIM_CASE
(
2
);
}
#undef CUB_BLOCK_DIM_CASE
}
template
<
typename
Tx
,
typename
ReduceOp
,
template
<
typename
>
class
TransformOp
>
struct
TensorReduceFunctor
{
const
framework
::
Tensor
&
x
;
framework
::
Tensor
*
y
;
std
::
vector
<
int
>
origin_reduce_dims
;
const
double
&
init
;
const
ReduceOp
&
reducer
;
gpuStream_t
stream
;
TensorReduceFunctor
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
double
&
init
,
const
ReduceOp
&
reducer
,
gpuStream_t
stream
)
:
x
(
x
),
y
(
y
),
origin_reduce_dims
(
origin_reduce_dims
),
init
(
init
),
reducer
(
reducer
),
stream
(
stream
)
{}
template
<
typename
Ty
>
void
apply
()
const
{
const
Ty
&
init_cast
=
static_cast
<
Ty
>
(
init
);
TensorReduce
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
<
Ty
>>
(
x
,
y
,
origin_reduce_dims
,
init_cast
,
reducer
,
TransformOp
<
Ty
>
(),
stream
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/frobenius_norm_op.cu
浏览文件 @
9a8a4c77
...
...
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/reduce_ops/frobenius_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
template
<
typename
T
>
using
CUDAFrobeniusNormKernel
=
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
浏览文件 @
9a8a4c77
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
template
<
typename
T
>
...
...
paddle/fluid/operators/solve_op.h
浏览文件 @
9a8a4c77
...
...
@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/operators/squeeze_op.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/
cub_reduce
.h"
#include "paddle/fluid/operators/reduce_ops/
reduce_op.cu
.h"
#endif
#define MAX_RANK_SUPPORTED 6
...
...
@@ -39,24 +39,14 @@ using framework::To32BitIndex;
constexpr
int
kMULMKLDNNINT8
=
1
;
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
template
<
typename
U
>
HOSTDEVICE
inline
U
operator
()(
const
U
&
x
)
const
{
return
x
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
void
ReduceSumForSolve
(
const
Tensor
*
input
,
Tensor
*
output
,
const
std
::
vector
<
int
>&
reduce_dims
,
bool
keep_dim
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduce
<
T
,
T
,
cub
::
Sum
,
IdentityFunctor
>
(
*
input
,
output
,
reduce_dims
,
static_cast
<
T
>
(
0
),
cub
::
Sum
(),
IdentityFunctor
(),
stream
);
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
*
input
,
output
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_dims
,
stream
);
#else
ReduceKernelFunctor
<
DeviceContext
,
T
,
ops
::
SumFunctor
>
(
input
,
output
,
reduce_dims
,
keep_dim
,
false
,
ctx
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录