Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
39903f72
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39903f72
编写于
6月 06, 2022
作者:
N
niuliling123
提交者:
GitHub
6月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace ReduceAmax/Amax.part.cu with KP (#43202)
上级
2a17e3c1
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
150 addition
and
25 deletion
+150
-25
paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu
paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu
+8
-11
paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu
paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu
+8
-11
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+95
-1
paddle/phi/kernels/funcs/broadcast_function.h
paddle/phi/kernels/funcs/broadcast_function.h
+18
-1
paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
+21
-1
未找到文件。
paddle/fluid/operators/reduce_ops/reduce_amax_op.part.cu
浏览文件 @
39903f72
...
@@ -12,15 +12,12 @@
...
@@ -12,15 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_
min_max_
op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL
(
template
<
typename
T
>
reduce_amax_grad
,
using
CUDAReduceMaxGradKernel
=
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ReduceCudaAMaxAMinGradKernel
<
T
,
kps
::
IdentityFunctor
>
;
ops
::
AMaxOrAMinGradFunctor
>
,
REGISTER_OP_CUDA_KERNEL
(
reduce_amax_grad
,
CUDAReduceMaxGradKernel
<
int
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
CUDAReduceMaxGradKernel
<
int64_t
>
,
ops
::
AMaxOrAMinGradFunctor
>
,
CUDAReduceMaxGradKernel
<
float
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
CUDAReduceMaxGradKernel
<
double
>
);
ops
::
AMaxOrAMinGradFunctor
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
AMaxOrAMinGradFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_amin_op.part.cu
浏览文件 @
39903f72
...
@@ -12,15 +12,12 @@
...
@@ -12,15 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_
min_max_
op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
REGISTER_OP_CUDA_KERNEL
(
template
<
typename
T
>
reduce_amin_grad
,
using
CUDAReduceMinGradKernel
=
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
ReduceCudaAMaxAMinGradKernel
<
T
,
kps
::
IdentityFunctor
>
;
ops
::
AMaxOrAMinGradFunctor
>
,
REGISTER_OP_CUDA_KERNEL
(
reduce_amin_grad
,
CUDAReduceMinGradKernel
<
int
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
CUDAReduceMinGradKernel
<
int64_t
>
,
ops
::
AMaxOrAMinGradFunctor
>
,
CUDAReduceMinGradKernel
<
float
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
,
CUDAReduceMinGradKernel
<
double
>
);
ops
::
AMaxOrAMinGradFunctor
>
,
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
,
ops
::
AMaxOrAMinGradFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
39903f72
...
@@ -24,7 +24,6 @@ limitations under the License. */
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
// only can include the headers in paddle/phi/api dirs
// only can include the headers in paddle/phi/api dirs
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h"
...
@@ -655,6 +654,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
...
@@ -655,6 +654,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
std
::
vector
<
int
>
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
std
::
vector
<
int
>
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
auto
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
d_out
=
auto
*
d_out
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
@@ -685,12 +685,106 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
...
@@ -685,12 +685,106 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
if
(
out_dtype
<=
0
)
{
if
(
out_dtype
<=
0
)
{
pt_out_dtype
=
d_out
->
dtype
();
pt_out_dtype
=
d_out
->
dtype
();
}
}
using
MPType
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
phi
::
ReduceGrad
<
T
,
TransformOp
<
T
,
MPType
>>
(
phi
::
ReduceGrad
<
T
,
TransformOp
<
T
,
MPType
>>
(
dev_ctx
,
pt_d_out
.
get
(),
pt_d_x
.
get
(),
pt_out_dtype
,
dev_ctx
,
pt_d_out
.
get
(),
pt_d_x
.
get
(),
pt_out_dtype
,
TransformOp
<
T
,
MPType
>
(
reduce_num
));
TransformOp
<
T
,
MPType
>
(
reduce_num
));
}
}
};
};
template
<
typename
T
>
struct
EqualFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
0.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
static_cast
<
T
>
(
a
==
b
);
}
};
template
<
typename
T
,
typename
Enable
=
void
>
struct
DivideFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
1.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
/
b
;
}
};
template
<
typename
T
,
template
<
typename
,
typename
>
class
TransformOp
>
class
ReduceCudaAMaxAMinGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
std
::
vector
<
int
>
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
auto
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out_y
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
d_out
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
out_dtype
=
context
.
Attr
<
int
>
(
"in_dtype"
);
auto
pt_out_dtype
=
framework
::
TransToPhiDataType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
out_dtype
));
// get reduce_dim and reduce_num for reduce_mean_grad
int
dim_size
=
in_x
->
dims
().
size
();
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
dims
,
dim_size
,
reduce_all
);
auto
update_dims
=
vectorize
(
d_x
->
dims
());
int
reduce_num
=
1
;
for
(
auto
i
:
reduce_dims
)
{
reduce_num
*=
(
in_x
->
dims
())[
i
];
update_dims
[
i
]
=
1
;
}
auto
&
dev_ctx
=
context
.
cuda_device_context
();
// make new tensor reduce_out
phi
::
DenseTensor
new_y
(
out_y
->
type
());
new_y
.
ShareDataWith
(
*
out_y
);
new_y
.
Resize
(
phi
::
make_ddim
(
update_dims
));
// make new tensor d_out
phi
::
DenseTensor
new_dout
(
d_out
->
type
());
new_dout
.
ShareDataWith
(
*
d_out
);
new_dout
.
Resize
(
phi
::
make_ddim
(
update_dims
));
d_x
->
mutable_data
(
dev_ctx
.
GetPlace
(),
d_out
->
dtype
());
auto
new_in
=
paddle
::
experimental
::
MakePhiDenseTensor
(
*
in_x
);
auto
new_in_tensor
=
new_in
.
get
();
auto
new_dx
=
paddle
::
experimental
::
MakePhiDenseTensor
(
*
d_x
);
auto
new_dx_tensor
=
new_dx
.
get
();
// make equal_out
phi
::
DenseTensor
*
equal_out
=
new
phi
::
DenseTensor
();
equal_out
->
Resize
(
in_x
->
dims
());
dev_ctx
.
template
Alloc
<
T
>(
equal_out
);
auto
equal_out_tensor
=
*
equal_out
;
// make new tensor equal_count
phi
::
DenseTensor
*
equal_count
=
new
phi
::
DenseTensor
();
equal_count
->
Resize
(
phi
::
make_ddim
(
update_dims
));
dev_ctx
.
template
Alloc
<
T
>(
equal_count
);
// compute
// 1. equal_out = Equal(x, y)
std
::
vector
<
const
phi
::
DenseTensor
*>
equal_inputs
=
{
&
new_y
,
new_in_tensor
};
std
::
vector
<
phi
::
DenseTensor
*>
equal_outputs
=
{
&
equal_out_tensor
};
phi
::
funcs
::
BroadcastKernel
<
phi
::
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
equal_inputs
,
&
equal_outputs
,
0
,
EqualFunctor
<
T
>
());
// 2. equal_count = reduceSum(equal_out)
using
MPType
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
phi
::
funcs
::
ReduceKernel
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
,
MPType
>>
(
dev_ctx
,
equal_out_tensor
,
equal_count
,
kps
::
IdentityFunctor
<
T
,
MPType
>
(),
reduce_dims
,
false
);
// 3. dx = Div(dout, equal_out)
std
::
vector
<
const
phi
::
DenseTensor
*>
grad_inputs
=
{
&
equal_out_tensor
,
equal_count
};
std
::
vector
<
phi
::
DenseTensor
*>
grad_outputs
=
{
new_dx_tensor
};
phi
::
funcs
::
BroadcastKernel
<
phi
::
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
grad_inputs
,
&
grad_outputs
,
0
,
DivideFunctor
<
T
>
());
delete
equal_out
;
delete
equal_count
;
}
};
#endif
#endif
#endif
#endif
...
...
paddle/phi/kernels/funcs/broadcast_function.h
浏览文件 @
39903f72
...
@@ -605,7 +605,22 @@ void ElementwiseCompute(const GPUContext &dev_ctx,
...
@@ -605,7 +605,22 @@ void ElementwiseCompute(const GPUContext &dev_ctx,
dev_ctx
,
ins
,
&
outs
,
axis
,
func
);
dev_ctx
,
ins
,
&
outs
,
axis
,
func
);
}
}
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
,
typename
InverseFunctor
>
void
DefaultElementwiseOperator
(
const
DeviceContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
z
,
int
axis
=
-
1
)
{
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
dev_ctx
.
template
Alloc
<
T
>(
z
);
funcs
::
ElementwiseCompute
<
Functor
,
T
>
(
dev_ctx
,
x
,
y
,
axis
,
Functor
(),
z
);
}
#else
template
<
typename
DeviceContext
,
template
<
typename
DeviceContext
,
typename
T
,
typename
T
,
...
@@ -627,5 +642,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
...
@@ -627,5 +642,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
}
}
}
}
#endif
}
// namespace funcs
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/frobenius_norm_kernel.cu
浏览文件 @
39903f72
...
@@ -14,7 +14,27 @@
...
@@ -14,7 +14,27 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/frobenius_norm_kernel.h"
#include "paddle/phi/kernels/frobenius_norm_kernel.h"
#include "paddle/phi/kernels/impl/frobenius_norm_kernel_impl.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FrobeniusNormKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int64_t
>&
dims
,
bool
keep_dim
,
bool
reduce_all
,
DenseTensor
*
out
)
{
auto
out_dtype
=
x
.
dtype
();
phi
::
Reduce
<
T
,
kps
::
AddFunctor
,
kps
::
SquareFunctor
>
(
dev_ctx
,
x
,
reduce_all
,
dims
,
keep_dim
,
out_dtype
,
out
);
std
::
vector
<
const
DenseTensor
*>
ins
=
{
out
};
std
::
vector
<
DenseTensor
*>
outs
=
{
out
};
auto
functor
=
funcs
::
CudaSqrtFunctor
<
T
>
();
funcs
::
ElementwiseKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
PD_REGISTER_KERNEL
(
frobenius_norm
,
GPU
,
ALL_LAYOUT
,
phi
::
FrobeniusNormKernel
,
float
,
double
)
{}
frobenius_norm
,
GPU
,
ALL_LAYOUT
,
phi
::
FrobeniusNormKernel
,
float
,
double
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录