Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3825b40f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3825b40f
编写于
1月 25, 2022
作者:
N
Noel
提交者:
GitHub
1月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[pnorm] fix bug in fp16 & optimize memory (#39011)
上级
c1e5a393
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
123 addition
and
97 deletion
+123
-97
paddle/fluid/operators/p_norm_op.cu
paddle/fluid/operators/p_norm_op.cu
+26
-66
paddle/fluid/operators/reduce_ops/logsumexp_op.h
paddle/fluid/operators/reduce_ops/logsumexp_op.h
+5
-4
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+15
-14
paddle/fluid/operators/reduce_ops/reduce_op_function.h
paddle/fluid/operators/reduce_ops/reduce_op_function.h
+1
-2
python/paddle/fluid/tests/unittests/test_norm_all.py
python/paddle/fluid/tests/unittests/test_norm_all.py
+76
-11
未找到文件。
paddle/fluid/operators/p_norm_op.cu
浏览文件 @
3825b40f
...
@@ -76,22 +76,13 @@ struct AbsFunctor {
...
@@ -76,22 +76,13 @@ struct AbsFunctor {
}
}
};
};
template
<
typename
T
x
,
typename
Ty
=
Tx
>
template
<
typename
T
>
struct
UnsignedPowFunctor
{
struct
UnsignedPowFunctor
{
HOSTDEVICE
explicit
inline
UnsignedPowFunctor
(
float
porder
)
{
HOSTDEVICE
explicit
inline
UnsignedPowFunctor
(
float
porder
)
{
this
->
porder
=
porder
;
this
->
porder
=
porder
;
}
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
x
)
const
{
HOSTDEVICE
inline
T
operator
()(
const
T
x
)
const
{
return
static_cast
<
Ty
>
(
inline_pow
(
inline_abs
(
x
),
static_cast
<
Tx
>
(
porder
)));
return
static_cast
<
T
>
(
inline_pow
(
inline_abs
(
x
),
static_cast
<
T
>
(
porder
)));
}
float
porder
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
PowFunctor
{
HOSTDEVICE
explicit
inline
PowFunctor
(
float
porder
)
{
this
->
porder
=
porder
;
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
x
)
const
{
return
static_cast
<
Ty
>
(
inline_pow
(
x
,
static_cast
<
Tx
>
(
porder
)));
}
}
float
porder
;
float
porder
;
};
};
...
@@ -105,13 +96,11 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
...
@@ -105,13 +96,11 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
const
T
*
x
=
in_x
->
data
<
T
>
();
const
T
*
x
=
in_x
->
data
<
T
>
();
T
*
norm
=
out_norm
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
norm
=
out_norm
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
xdim
=
in_x
->
dims
();
auto
xdim
=
in_x
->
dims
();
auto
ndim
=
out_norm
->
dims
();
float
porder
=
ctx
.
Attr
<
float
>
(
"porder"
);
float
porder
=
ctx
.
Attr
<
float
>
(
"porder"
);
bool
asvector
=
ctx
.
Attr
<
bool
>
(
"asvector"
);
bool
asvector
=
ctx
.
Attr
<
bool
>
(
"asvector"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
std
::
vector
<
int
>
reduce_axis
=
{
axis
};
std
::
vector
<
int
>
reduce_axis
=
{
axis
};
reduce_axis
=
GetReduceDim
(
reduce_axis
,
xdim
.
size
(),
asvector
);
reduce_axis
=
GetReduceDim
(
reduce_axis
,
xdim
.
size
(),
asvector
);
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -125,29 +114,17 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
...
@@ -125,29 +114,17 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
MinFunctor
,
AbsFunctor
<
T
>>
(
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
MinFunctor
,
AbsFunctor
<
T
>>
(
*
in_x
,
out_norm
,
AbsFunctor
<
T
>
(),
reduce_axis
,
stream
);
*
in_x
,
out_norm
,
AbsFunctor
<
T
>
(),
reduce_axis
,
stream
);
}
else
{
}
else
{
framework
::
Tensor
tmp_x
;
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
UnsignedPowFunctor
<
T
>>
(
tmp_x
.
mutable_data
<
T
>
(
xdim
,
ctx
.
GetPlace
());
*
in_x
,
out_norm
,
UnsignedPowFunctor
<
T
>
(
porder
),
reduce_axis
,
stream
);
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
in_x
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
&
tmp_x
};
const
framework
::
Tensor
*
tmp_norm
=
out_norm
;
auto
func
=
UnsignedPowFunctor
<
MT
,
T
>
(
porder
);
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
tmp_norm
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out_norm
};
const
auto
&
cuda_ctx
=
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
MT
,
T
,
UnsignedPowFunctor
<
MT
,
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
func
);
framework
::
Tensor
tmp_y
;
tmp_y
.
mutable_data
<
T
>
(
ndim
,
ctx
.
GetPlace
());
TensorReduceFunctorImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
tmp_x
,
&
tmp_y
,
kps
::
IdentityFunctor
<
T
>
(),
reduce_axis
,
stream
);
const
framework
::
Tensor
*
tmp_norm
=
&
tmp_y
;
ins
=
{
tmp_norm
};
outs
=
{
out_norm
};
auto
func_inverse
=
UnsignedPowFunctor
<
MT
,
T
>
(
1.
/
porder
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
MT
,
T
,
UnsignedPowFunctor
<
MT
,
T
>>
(
ElementwiseType
::
kUnary
,
T
,
T
,
UnsignedPowFunctor
<
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
func_inverse
);
cuda_ctx
,
ins
,
&
outs
,
UnsignedPowFunctor
<
T
>
(
1.
/
porder
)
);
}
}
}
}
};
};
...
@@ -158,29 +135,25 @@ struct AbsMaxAndMinGradFunctor {
...
@@ -158,29 +135,25 @@ struct AbsMaxAndMinGradFunctor {
typename
DY
,
typename
Dim
>
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
const
Dim
&
dim
,
int
size
)
{
auto
equals
=
((
*
x
).
abs
()
==
y
->
broadcast
(
dim
));
dx
->
device
(
place
)
=
dy
->
broadcast
(
dim
)
*
(
*
x
).
sign
()
*
auto
ones
=
dx
->
constant
(
static_cast
<
T
>
(
1.
));
((
*
x
).
abs
()
==
y
->
broadcast
(
dim
)).
template
cast
<
T
>();
auto
negs
=
dx
->
constant
(
static_cast
<
T
>
(
-
1.
));
auto
zeros
=
dx
->
constant
(
static_cast
<
T
>
(
0.
));
auto
positives
=
(
*
x
)
>
zeros
;
dx
->
device
(
place
)
=
dy
->
broadcast
(
dim
)
*
equals
.
select
(
ones
,
zeros
)
*
positives
.
select
(
ones
,
negs
);
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
PNormPostGradFunctor
{
struct
PNormGradFunctor
{
HOSTDEVICE
explicit
inline
PNormGradFunctor
(
float
porder
)
{
this
->
porder
=
static_cast
<
T
>
(
porder
-
1.
);
}
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
const
Dim
&
dim
,
int
size
)
{
auto
ones
=
dx
->
constant
(
static_cast
<
T
>
(
1.
));
dx
->
device
(
place
)
=
(
*
x
).
abs
().
pow
(
this
->
porder
)
*
(
*
x
).
sign
()
*
auto
negs
=
dx
->
constant
(
static_cast
<
T
>
(
-
1.
));
dy
->
broadcast
(
dim
)
*
auto
zeros
=
dx
->
constant
(
static_cast
<
T
>
(
0.
));
(
*
y
).
pow
(
-
this
->
porder
).
broadcast
(
dim
);
auto
positives
=
(
*
x
)
>
zeros
;
dx
->
device
(
place
)
=
(
*
dx
)
*
dy
->
broadcast
(
dim
)
*
y
->
broadcast
(
dim
)
*
positives
.
select
(
ones
,
negs
);
}
}
T
porder
;
};
};
template
<
typename
DeviceContext
,
typename
T
,
typename
AttrType
=
T
>
template
<
typename
DeviceContext
,
typename
T
,
typename
AttrType
=
T
>
...
@@ -207,26 +180,13 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -207,26 +180,13 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
set_zero
(
cuda_ctx
,
out_dx
,
static_cast
<
T
>
(
0
));
set_zero
(
cuda_ctx
,
out_dx
,
static_cast
<
T
>
(
0
));
}
else
if
(
porder
==
INFINITY
||
porder
==
-
INFINITY
)
{
}
else
if
(
porder
==
INFINITY
||
porder
==
-
INFINITY
)
{
AbsMaxAndMinGradFunctor
<
T
>
functor
;
LaunchReduceGradKernel
<
DeviceContext
,
T
,
AbsMaxAndMinGradFunctor
<
T
>>
(
LaunchReduceGradKernel
<
DeviceContext
,
T
,
AbsMaxAndMinGradFunctor
<
T
>>
(
ctx
,
in_x
,
in_norm
,
in_norm_dy
,
out_dx
,
dims
,
reduce_all
);
ctx
,
in_x
,
in_norm
,
in_norm_dy
,
out_dx
,
functor
,
dims
,
reduce_all
);
}
else
{
}
else
{
framework
::
Tensor
tmp_norm
;
auto
functor
=
PNormGradFunctor
<
T
>
(
porder
);
tmp_norm
.
mutable_data
<
T
>
(
in_norm
->
dims
(),
ctx
.
GetPlace
());
LaunchReduceGradKernel
<
DeviceContext
,
T
,
PNormGradFunctor
<
T
>>
(
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
in_norm
};
ctx
,
in_x
,
in_norm
,
in_norm_dy
,
out_dx
,
functor
,
dims
,
reduce_all
);
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
&
tmp_norm
};
auto
pow_functor
=
PowFunctor
<
T
>
(
1.
-
porder
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
,
PowFunctor
<
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
pow_functor
);
ins
=
{
in_x
};
outs
=
{
out_dx
};
auto
unsigned_pow
=
UnsignedPowFunctor
<
T
>
(
porder
-
1.
);
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
,
UnsignedPowFunctor
<
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
unsigned_pow
);
const
framework
::
Tensor
*
tmp_norm_const
=
&
tmp_norm
;
LaunchReduceGradKernel
<
DeviceContext
,
T
,
PNormPostGradFunctor
<
T
>>
(
ctx
,
in_x
,
tmp_norm_const
,
in_norm_dy
,
out_dx
,
dims
,
reduce_all
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/reduce_ops/logsumexp_op.h
浏览文件 @
3825b40f
...
@@ -139,26 +139,27 @@ class LogsumexpGradKernel : public framework::OpKernel<T> {
...
@@ -139,26 +139,27 @@ class LogsumexpGradKernel : public framework::OpKernel<T> {
broadcast_dim
[
0
]);
broadcast_dim
[
0
]);
}
else
{
}
else
{
int
rank
=
input
->
dims
().
size
();
int
rank
=
input
->
dims
().
size
();
LogsumexpGradFunctor
functor
;
switch
(
rank
)
{
switch
(
rank
)
{
case
1
:
case
1
:
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
LogsumexpGradFunctor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
axis
);
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
break
;
case
2
:
case
2
:
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
LogsumexpGradFunctor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
axis
);
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
break
;
case
3
:
case
3
:
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
LogsumexpGradFunctor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
axis
);
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
break
;
case
4
:
case
4
:
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
LogsumexpGradFunctor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
LogsumexpGradFunctor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
output_grad
,
input_grad
,
axis
);
*
output_grad
,
input_grad
,
functor
,
axis
);
break
;
break
;
}
}
}
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
3825b40f
...
@@ -143,7 +143,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context,
...
@@ -143,7 +143,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
const
framework
::
Tensor
*
dout
,
framework
::
Tensor
*
dx
,
const
std
::
vector
<
int
>&
dims
)
{
Functor
functor
,
const
std
::
vector
<
int
>&
dims
)
{
const
int64_t
unreduced
=
out
->
numel
();
const
int64_t
unreduced
=
out
->
numel
();
const
int64_t
reduced
=
x
->
numel
()
/
unreduced
;
const
int64_t
reduced
=
x
->
numel
()
/
unreduced
;
DDim
out_dim
(
out
->
dims
());
DDim
out_dim
(
out
->
dims
());
...
@@ -157,7 +157,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context,
...
@@ -157,7 +157,7 @@ void HandleLargeDimGrad(const framework::ExecutionContext& context,
dx
->
Resize
({
unreduced
,
reduced
});
dx
->
Resize
({
unreduced
,
reduced
});
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
shuffled_x
,
*
out
,
*
dout
,
context
.
template
device_context
<
DeviceContext
>(),
shuffled_x
,
*
out
,
*
dout
,
dx
,
{
1
});
dx
,
functor
,
{
1
});
// transpose dX
// transpose dX
std
::
vector
<
int
>
origin_axis
(
x_dim
.
size
());
std
::
vector
<
int
>
origin_axis
(
x_dim
.
size
());
GetOriginDimFromShuffled
(
x_dim
,
dims
,
&
origin_axis
);
GetOriginDimFromShuffled
(
x_dim
,
dims
,
&
origin_axis
);
...
@@ -333,7 +333,7 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
...
@@ -333,7 +333,7 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
const
framework
::
Tensor
*
input0
,
const
framework
::
Tensor
*
input0
,
const
framework
::
Tensor
*
input1
,
const
framework
::
Tensor
*
input1
,
const
framework
::
Tensor
*
input2
,
const
framework
::
Tensor
*
input2
,
paddle
::
framework
::
Tensor
*
output
,
paddle
::
framework
::
Tensor
*
output
,
Functor
functor
,
const
std
::
vector
<
int
>&
dims
,
const
std
::
vector
<
int
>&
dims
,
bool
reduce_all
=
false
)
{
bool
reduce_all
=
false
)
{
if
(
reduce_all
)
{
if
(
reduce_all
)
{
...
@@ -345,7 +345,6 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
...
@@ -345,7 +345,6 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
broadcast_dim
=
auto
broadcast_dim
=
Eigen
::
array
<
int
,
1
>
({{
static_cast
<
int
>
(
input0
->
numel
())}});
Eigen
::
array
<
int
,
1
>
({{
static_cast
<
int
>
(
input0
->
numel
())}});
Functor
functor
;
functor
(
place
,
&
x
,
&
x_reduce
,
&
x_grad
,
&
x_reduce_grad
,
broadcast_dim
,
functor
(
place
,
&
x
,
&
x_reduce
,
&
x_grad
,
&
x_reduce_grad
,
broadcast_dim
,
broadcast_dim
[
0
]);
broadcast_dim
[
0
]);
}
else
{
}
else
{
...
@@ -354,36 +353,36 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
...
@@ -354,36 +353,36 @@ void LaunchReduceGradKernel(const framework::ExecutionContext& context,
case
1
:
case
1
:
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
case
2
:
case
2
:
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
case
3
:
case
3
:
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
case
4
:
case
4
:
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
case
5
:
case
5
:
ReduceGradFunctor
<
DeviceContext
,
T
,
5
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
5
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
case
6
:
case
6
:
ReduceGradFunctor
<
DeviceContext
,
T
,
6
,
Functor
>
(
ReduceGradFunctor
<
DeviceContext
,
T
,
6
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
*
input2
,
output
,
functor
,
dims
);
break
;
break
;
default:
default:
HandleLargeDimGrad
<
DeviceContext
,
T
,
Functor
>
(
context
,
input0
,
input1
,
HandleLargeDimGrad
<
DeviceContext
,
T
,
Functor
>
(
input2
,
output
,
dims
);
context
,
input0
,
input1
,
input2
,
output
,
functor
,
dims
);
break
;
break
;
}
}
}
}
...
@@ -430,8 +429,10 @@ class ReduceGradKernel : public framework::OpKernel<T> {
...
@@ -430,8 +429,10 @@ class ReduceGradKernel : public framework::OpKernel<T> {
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
// not be set as Input in grad Maker, use Out_grad to replace here
if
(
!
input1
)
input1
=
input2
;
if
(
!
input1
)
input1
=
input2
;
LaunchReduceGradKernel
<
DeviceContext
,
T
,
Functor
>
(
Functor
functor
;
context
,
input0
,
input1
,
input2
,
output
,
const_dims
,
reduce_all
);
LaunchReduceGradKernel
<
DeviceContext
,
T
,
Functor
>
(
context
,
input0
,
input1
,
input2
,
output
,
functor
,
const_dims
,
reduce_all
);
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
...
paddle/fluid/operators/reduce_ops/reduce_op_function.h
浏览文件 @
3825b40f
...
@@ -74,7 +74,7 @@ void ReduceGradFunctor(const DeviceContext& context,
...
@@ -74,7 +74,7 @@ void ReduceGradFunctor(const DeviceContext& context,
const
framework
::
Tensor
&
input0
,
const
framework
::
Tensor
&
input0
,
const
framework
::
Tensor
&
input1
,
const
framework
::
Tensor
&
input1
,
const
framework
::
Tensor
&
input2
,
const
framework
::
Tensor
&
input2
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
output
,
Functor
functor
,
const
std
::
vector
<
int
>&
dims
)
{
const
std
::
vector
<
int
>&
dims
)
{
auto
x
=
EigenTensor
<
T
,
D
>::
From
(
input0
);
auto
x
=
EigenTensor
<
T
,
D
>::
From
(
input0
);
auto
x_grad
=
EigenTensor
<
T
,
D
>::
From
(
*
output
);
auto
x_grad
=
EigenTensor
<
T
,
D
>::
From
(
*
output
);
...
@@ -100,7 +100,6 @@ void ReduceGradFunctor(const DeviceContext& context,
...
@@ -100,7 +100,6 @@ void ReduceGradFunctor(const DeviceContext& context,
auto
&
place
=
*
context
.
eigen_device
();
auto
&
place
=
*
context
.
eigen_device
();
Functor
functor
;
functor
(
place
,
&
x
,
&
x_reduce
,
&
x_grad
,
&
x_reduce_grad
,
broadcast_dim
,
functor
(
place
,
&
x
,
&
x_reduce
,
&
x_grad
,
&
x_reduce_grad
,
broadcast_dim
,
broad_cats_times
);
broad_cats_times
);
}
}
...
...
python/paddle/fluid/tests/unittests/test_norm_all.py
浏览文件 @
3825b40f
...
@@ -19,11 +19,12 @@ import numpy as np
...
@@ -19,11 +19,12 @@ import numpy as np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
def
p_norm
(
x
,
axis
,
porder
,
keepdims
=
False
):
def
p_norm
(
x
,
axis
,
porder
,
keepdims
=
False
,
reduce_all
=
False
):
r
=
[]
r
=
[]
if
axis
is
None
:
if
axis
is
None
or
reduce_all
:
x
=
x
.
flatten
()
x
=
x
.
flatten
()
if
porder
==
np
.
inf
:
if
porder
==
np
.
inf
:
r
=
np
.
amax
(
np
.
abs
(
x
),
keepdims
=
keepdims
)
r
=
np
.
amax
(
np
.
abs
(
x
),
keepdims
=
keepdims
)
...
@@ -53,8 +54,8 @@ def p_norm(x, axis, porder, keepdims=False):
...
@@ -53,8 +54,8 @@ def p_norm(x, axis, porder, keepdims=False):
else
:
else
:
if
isinstance
(
axis
,
list
):
if
isinstance
(
axis
,
list
):
axis
=
tuple
(
axis
)
axis
=
tuple
(
axis
)
r
=
np
.
linalg
.
norm
(
r
=
np
.
linalg
.
norm
(
x
,
ord
=
porder
,
axis
=
axis
,
keepdims
=
keepdims
)
x
,
ord
=
porder
,
axis
=
axis
,
keepdims
=
keepdims
)
.
astype
(
x
.
dtype
)
r
=
r
.
astype
(
x
.
dtype
)
return
r
return
r
...
@@ -111,13 +112,14 @@ class TestPnormOp(OpTest):
...
@@ -111,13 +112,14 @@ class TestPnormOp(OpTest):
self
.
op_type
=
"p_norm"
self
.
op_type
=
"p_norm"
self
.
init_test_case
()
self
.
init_test_case
()
x
=
(
np
.
random
.
random
(
self
.
shape
)
+
0.5
).
astype
(
self
.
dtype
)
x
=
(
np
.
random
.
random
(
self
.
shape
)
+
0.5
).
astype
(
self
.
dtype
)
norm
=
p_norm
(
x
,
self
.
axis
,
self
.
porder
,
self
.
keepdim
)
norm
=
p_norm
(
x
,
self
.
axis
,
self
.
porder
,
self
.
keepdim
,
self
.
asvector
)
self
.
inputs
=
{
'X'
:
x
}
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
,
'epsilon'
:
self
.
epsilon
,
'axis'
:
self
.
axis
,
'axis'
:
self
.
axis
,
'keepdim'
:
self
.
keepdim
,
'keepdim'
:
self
.
keepdim
,
'porder'
:
float
(
self
.
porder
)
'porder'
:
float
(
self
.
porder
),
'asvector'
:
self
.
asvector
}
}
self
.
outputs
=
{
'Out'
:
norm
}
self
.
outputs
=
{
'Out'
:
norm
}
self
.
gradient
=
self
.
calc_gradient
()
self
.
gradient
=
self
.
calc_gradient
()
...
@@ -135,34 +137,42 @@ class TestPnormOp(OpTest):
...
@@ -135,34 +137,42 @@ class TestPnormOp(OpTest):
self
.
porder
=
2.0
self
.
porder
=
2.0
self
.
keepdim
=
False
self
.
keepdim
=
False
self
.
dtype
=
"float64"
self
.
dtype
=
"float64"
self
.
asvector
=
False
def
calc_gradient
(
self
):
def
calc_gradient
(
self
):
self
.
attrs
=
{
self
.
attrs
=
{
'epsilon'
:
self
.
epsilon
,
'epsilon'
:
self
.
epsilon
,
'axis'
:
self
.
axis
,
'axis'
:
self
.
axis
,
'keepdim'
:
self
.
keepdim
,
'keepdim'
:
self
.
keepdim
,
'porder'
:
float
(
self
.
porder
)
'porder'
:
float
(
self
.
porder
),
'asvector'
:
self
.
asvector
}
}
x
=
self
.
inputs
[
"X"
]
x
=
self
.
inputs
[
"X"
]
porder
=
self
.
attrs
[
"porder"
]
porder
=
self
.
attrs
[
"porder"
]
axis
=
self
.
attrs
[
"axis"
]
axis
=
self
.
attrs
[
"axis"
]
asvector
=
self
.
attrs
[
"asvector"
]
x_dtype
=
x
.
dtype
x
=
x
.
astype
(
np
.
float32
)
if
x
.
dtype
==
np
.
float16
else
x
if
porder
==
0
:
if
porder
==
0
:
grad
=
np
.
zeros
(
x
.
shape
).
astype
(
x
.
dtype
)
grad
=
np
.
zeros
(
x
.
shape
).
astype
(
x
.
dtype
)
elif
porder
in
[
float
(
"inf"
),
float
(
"-inf"
)]:
elif
porder
in
[
float
(
"inf"
),
float
(
"-inf"
)]:
norm
=
p_norm
(
x
,
axis
=
axis
,
porder
=
porder
,
keepdims
=
True
)
norm
=
p_norm
(
x
,
axis
=
axis
,
porder
=
porder
,
keepdims
=
True
,
reduce_all
=
asvector
)
x_abs
=
np
.
abs
(
x
)
x_abs
=
np
.
abs
(
x
)
grad
=
np
.
sign
(
x
)
grad
=
np
.
sign
(
x
)
grad
[
x_abs
!=
norm
]
=
0.0
grad
[
x_abs
!=
norm
]
=
0.0
else
:
else
:
norm
=
p_norm
(
x
,
axis
=
axis
,
porder
=
porder
,
keepdims
=
True
)
norm
=
p_norm
(
x
,
axis
=
axis
,
porder
=
porder
,
keepdims
=
True
,
reduce_all
=
asvector
)
grad
=
np
.
power
(
norm
,
1
-
porder
)
*
np
.
power
(
grad
=
np
.
power
(
norm
,
1
-
porder
)
*
np
.
power
(
np
.
abs
(
x
),
porder
-
1
)
*
np
.
sign
(
x
)
np
.
abs
(
x
),
porder
-
1
)
*
np
.
sign
(
x
)
numel
=
1
numel
=
1
for
s
in
x
.
shape
:
for
s
in
x
.
shape
:
numel
*=
s
numel
*=
s
numel
/=
x
.
shape
[
axis
]
divisor
=
numel
if
asvector
else
x
.
shape
[
axis
]
return
[
grad
.
astype
(
x
.
dtype
)
*
1
/
numel
]
numel
/=
divisor
return
[
grad
.
astype
(
x_dtype
)
*
1
/
numel
]
class
TestPnormOp2
(
TestPnormOp
):
class
TestPnormOp2
(
TestPnormOp
):
...
@@ -173,6 +183,7 @@ class TestPnormOp2(TestPnormOp):
...
@@ -173,6 +183,7 @@ class TestPnormOp2(TestPnormOp):
self
.
porder
=
2.0
self
.
porder
=
2.0
self
.
keepdim
=
True
self
.
keepdim
=
True
self
.
dtype
=
"float32"
self
.
dtype
=
"float32"
self
.
asvector
=
False
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
)
...
@@ -186,6 +197,7 @@ class TestPnormOp3(TestPnormOp):
...
@@ -186,6 +197,7 @@ class TestPnormOp3(TestPnormOp):
self
.
porder
=
np
.
inf
self
.
porder
=
np
.
inf
self
.
keepdim
=
True
self
.
keepdim
=
True
self
.
dtype
=
"float32"
self
.
dtype
=
"float32"
self
.
asvector
=
False
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
...
@@ -199,6 +211,7 @@ class TestPnormOp4(TestPnormOp):
...
@@ -199,6 +211,7 @@ class TestPnormOp4(TestPnormOp):
self
.
porder
=
-
np
.
inf
self
.
porder
=
-
np
.
inf
self
.
keepdim
=
True
self
.
keepdim
=
True
self
.
dtype
=
"float32"
self
.
dtype
=
"float32"
self
.
asvector
=
False
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
...
@@ -212,11 +225,63 @@ class TestPnormOp5(TestPnormOp):
...
@@ -212,11 +225,63 @@ class TestPnormOp5(TestPnormOp):
self
.
porder
=
0
self
.
porder
=
0
self
.
keepdim
=
True
self
.
keepdim
=
True
self
.
dtype
=
"float32"
self
.
dtype
=
"float32"
self
.
asvector
=
False
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
class
TestPnormOp6
(
TestPnormOp
):
def
init_test_case
(
self
):
self
.
shape
=
[
3
,
20
,
3
]
self
.
axis
=
-
1
self
.
epsilon
=
1e-12
self
.
porder
=
2
self
.
keepdim
=
False
self
.
dtype
=
"float32"
self
.
asvector
=
True
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestPnormOpFP16
(
TestPnormOp
):
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
axis
=
1
self
.
epsilon
=
1e-12
self
.
porder
=
2.0
self
.
keepdim
=
False
self
.
dtype
=
"float16"
self
.
asvector
=
False
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
user_defined_grads
=
self
.
gradient
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestPnormOpFP161
(
TestPnormOpFP16
):
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
axis
=
-
1
self
.
epsilon
=
1e-12
self
.
porder
=
2.0
self
.
keepdim
=
False
self
.
dtype
=
"float16"
self
.
asvector
=
True
def
run_fro
(
self
,
p
,
axis
,
shape_x
,
dtype
,
keep_dim
,
check_dim
=
False
):
def
run_fro
(
self
,
p
,
axis
,
shape_x
,
dtype
,
keep_dim
,
check_dim
=
False
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
with
fluid
.
program_guard
(
fluid
.
Program
()):
data
=
fluid
.
data
(
name
=
"X"
,
shape
=
shape_x
,
dtype
=
dtype
)
data
=
fluid
.
data
(
name
=
"X"
,
shape
=
shape_x
,
dtype
=
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录