Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
10d9ab4b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
10d9ab4b
编写于
12月 13, 2021
作者:
N
Noel
提交者:
GitHub
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[pnorm] Optimize p_norm op for special cases (#37685)
上级
3a339cc0
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
231 addition
and
228 deletion
+231
-228
paddle/fluid/operators/p_norm_op.cu
paddle/fluid/operators/p_norm_op.cu
+166
-174
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+65
-52
paddle/fluid/operators/unity_build_rule.cmake
paddle/fluid/operators/unity_build_rule.cmake
+0
-2
未找到文件。
paddle/fluid/operators/p_norm_op.cu
浏览文件 @
10d9ab4b
...
...
@@ -21,7 +21,10 @@ limitations under the License. */
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
...
...
@@ -56,87 +59,94 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return
pow
(
base
,
exponent
);
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
Pnorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
post
,
float
porder
,
T
*
out_norm
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
auto
porder_t
=
static_cast
<
MT
>
(
porder
);
auto
porder_inv
=
static_cast
<
MT
>
(
1.0
/
porder
);
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
MT
sum
=
static_cast
<
MT
>
(
0.0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
]);
sum
+=
inline_pow
(
inline_abs
(
x_ij
),
porder_t
);
struct
IdentityFunctor
{
HOSTDEVICE
explicit
inline
IdentityFunctor
()
{}
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
x
);
}
MT
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
static_cast
<
T
>
(
inline_pow
(
reduce_result
,
porder_inv
));
};
struct
NonzeroFunctor
{
HOSTDEVICE
explicit
inline
NonzeroFunctor
()
{}
HOSTDEVICE
explicit
inline
NonzeroFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
static_cast
<
double
>
(
x
)
!=
0
);
}
}
}
;
template
<
typename
T
,
int
BlockDim
>
__global__
void
ZeorNorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
post
,
T
*
out_norm
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
MT
sum
=
static_cast
<
MT
>
(
0.0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
const
MT
x_ij
=
static_cast
<
MT
>
(
x
[
base
+
j
*
post
]);
sum
+=
static_cast
<
MT
>
(
static_cast
<
double
>
(
x_ij
)
!=
0
);
struct
AbsFunctor
{
HOSTDEVICE
explicit
inline
AbsFunctor
()
{}
HOSTDEVICE
explicit
inline
AbsFunctor
(
int
n
)
{}
template
<
typename
T
>
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
static_cast
<
T
>
(
inline_abs
(
x
));
}
MT
reduce_result
=
BlockReduce
(
temp_storage
).
Sum
(
sum
);
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
static_cast
<
T
>
(
reduce_result
);
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
UnsignedPowFunctor
{
HOSTDEVICE
explicit
inline
UnsignedPowFunctor
(
float
porder
)
{
this
->
porder
=
porder
;
}
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
inline_pow
(
inline_abs
(
x
),
static_cast
<
Tx
>
(
porder
)));
}
float
porder
;
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
PowFunctor
{
HOSTDEVICE
explicit
inline
PowFunctor
(
float
porder
)
{
this
->
porder
=
porder
;
}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
inline_pow
(
x
,
static_cast
<
Tx
>
(
porder
)));
}
float
porder
;
};
template
<
typename
T
,
int
BlockDim
>
__global__
void
InfNorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
post
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
cur_max
=
inline_abs
(
x
[
base
]);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
T
x_ij_abs
=
inline_abs
(
x
[
base
+
j
*
post
]);
if
(
cur_max
<
x_ij_abs
)
cur_max
=
x_ij_abs
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
AbsAndMin
{
using
Transformer
=
AbsFunctor
;
using
MT
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
std
::
numeric_limits
<
MT
>::
infinity
());
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Reduce
(
cur_max
,
cub
::
Max
());
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
reduce_result
;
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
a
<
b
)
?
a
:
b
;
}
}
}
;
template
<
typename
T
,
int
BlockDim
>
__global__
void
NegInfNorm
(
const
T
*
x
,
const
int
pre
,
const
int
axis_n
,
// dim in axis
const
int
post
,
T
*
out_norm
)
{
typedef
cub
::
BlockReduce
<
T
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
int
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
T
cur_min
=
inline_abs
(
x
[
base
]);
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
T
x_ij_abs
=
inline_abs
(
x
[
base
+
j
*
post
]);
if
(
cur_min
>
x_ij_abs
)
cur_min
=
x_ij_abs
;
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
AbsAndMax
{
using
Transformer
=
AbsFunctor
;
using
MT
=
typename
details
::
MPTypeTrait
<
Ty
>::
Type
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
-
std
::
numeric_limits
<
MT
>::
infinity
());
}
T
reduce_result
=
BlockReduce
(
temp_storage
).
Reduce
(
cur_min
,
cub
::
Min
());
if
(
threadIdx
.
x
==
0
)
out_norm
[
i
]
=
reduce_result
;
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
(
a
>
b
)
?
a
:
b
;
}
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
NonzeroAndSum
{
using
Transformer
=
NonzeroFunctor
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityAndSum
{
using
Transformer
=
IdentityFunctor
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
__device__
__forceinline__
Ty
operator
()(
const
Ty
&
a
,
const
Ty
&
b
)
const
{
return
b
+
a
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
PnormCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -146,101 +156,83 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto
*
out_norm
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
T
*
x
=
in_x
->
data
<
T
>
();
T
*
norm
=
out_norm
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
xdim
=
in_x
->
dims
();
auto
ndim
=
out_norm
->
dims
();
float
porder
=
ctx
.
Attr
<
float
>
(
"porder"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
bool
asvector
=
ctx
.
Attr
<
bool
>
(
"asvector"
);
if
(
axis
<
0
)
axis
=
xdim
.
size
()
+
axis
;
int
pre
,
n
,
post
;
GetDims
(
xdim
,
axis
,
&
pre
,
&
n
,
&
post
,
asvector
);
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
std
::
vector
<
int
>
reduce_axis
=
{
axis
};
#ifdef __HIPCC__
const
int
block
=
256
;
#else
const
int
block
=
512
;
#endif
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid
=
std
::
min
(
max_blocks
,
pre
*
post
);
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
if
(
porder
==
0
)
{
ZeorNorm
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
pre
,
n
,
post
,
nor
m
);
TensorReduceFunctorImpl
<
T
,
T
,
NonzeroAndSum
>
(
*
in_x
,
out_norm
,
reduce_axis
,
strea
m
);
}
else
if
(
porder
==
INFINITY
)
{
InfNorm
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
pre
,
n
,
post
,
nor
m
);
TensorReduceFunctorImpl
<
T
,
T
,
AbsAndMax
>
(
*
in_x
,
out_norm
,
reduce_axis
,
strea
m
);
}
else
if
(
porder
==
-
INFINITY
)
{
NegInfNorm
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
pre
,
n
,
post
,
nor
m
);
TensorReduceFunctorImpl
<
T
,
T
,
AbsAndMin
>
(
*
in_x
,
out_norm
,
reduce_axis
,
strea
m
);
}
else
{
Pnorm
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
pre
,
n
,
post
,
porder
,
norm
);
framework
::
Tensor
tmp_x
;
tmp_x
.
mutable_data
<
T
>
(
xdim
,
ctx
.
GetPlace
());
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
in_x
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
&
tmp_x
};
auto
func
=
UnsignedPowFunctor
<
MT
,
T
>
(
porder
);
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
MT
,
T
,
UnsignedPowFunctor
<
MT
,
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
func
);
framework
::
Tensor
tmp_y
;
tmp_y
.
mutable_data
<
T
>
(
ndim
,
ctx
.
GetPlace
());
TensorReduceFunctorImpl
<
T
,
T
,
IdentityAndSum
>
(
tmp_x
,
&
tmp_y
,
reduce_axis
,
stream
);
const
framework
::
Tensor
*
tmp_norm
=
&
tmp_y
;
ins
=
{
tmp_norm
};
outs
=
{
out_norm
};
auto
func_inverse
=
UnsignedPowFunctor
<
MT
,
T
>
(
1.
/
porder
);
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
MT
,
T
,
UnsignedPowFunctor
<
MT
,
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
func_inverse
);
}
}
};
template
<
typename
T
,
int
BlockDim
>
__global__
void
PnormGradient
(
const
T
*
x
,
const
T
*
x_norm
,
const
T
*
y_grad
,
const
float
porder
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
const
T
eps
,
T
*
x_grad
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int
num
=
pre
*
post
;
auto
porder_grad
=
static_cast
<
MT
>
(
porder
-
1.0
f
);
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
__shared__
MT
pnorm_i
;
__shared__
MT
yout_i
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
if
(
threadIdx
.
x
==
0
)
{
pnorm_i
=
static_cast
<
MT
>
(
x_norm
[
i
]);
yout_i
=
static_cast
<
MT
>
(
y_grad
[
i
]);
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
const
MT
x_ij
=
static_cast
<
MT
>
(
inline_abs
(
x
[
index
]));
x_grad
[
index
]
=
static_cast
<
T
>
(
inline_pow
(
x_ij
,
porder_grad
)
/
(
inline_pow
(
pnorm_i
,
porder_grad
)
+
static_cast
<
MT
>
(
eps
))
*
yout_i
*
static_cast
<
MT
>
(
inline_sign
(
x
[
index
])));
}
}
}
template
<
typename
T
,
int
BlockDim
>
__global__
void
InfNormGradient
(
const
T
*
x
,
const
T
*
x_norm
,
const
T
*
y_grad
,
const
int
pre
,
const
int
axis_n
,
const
int
post
,
T
*
x_grad
)
{
int
num
=
pre
*
post
;
for
(
int
i
=
blockIdx
.
x
;
i
<
num
;
i
+=
gridDim
.
x
)
{
__shared__
T
pnorm_i
;
__shared__
T
yout_i
;
auto
base
=
(
i
/
post
)
*
post
*
axis_n
+
(
i
%
post
);
if
(
threadIdx
.
x
==
0
)
{
pnorm_i
=
x_norm
[
i
];
yout_i
=
y_grad
[
i
];
template
<
typename
T
>
struct
AbsMaxAndMinGradFunctor
{
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
auto
equals
=
((
*
x
).
abs
()
==
y
->
broadcast
(
dim
));
auto
ones
=
dx
->
constant
(
static_cast
<
T
>
(
1.
));
auto
negs
=
dx
->
constant
(
static_cast
<
T
>
(
-
1.
));
auto
zeros
=
dx
->
constant
(
static_cast
<
T
>
(
0.
));
auto
positives
=
(
*
x
)
>
zeros
;
dx
->
device
(
place
)
=
dy
->
broadcast
(
dim
)
*
equals
.
select
(
ones
,
zeros
)
*
positives
.
select
(
ones
,
negs
);
}
__syncthreads
()
;
}
;
for
(
int
j
=
threadIdx
.
x
;
j
<
axis_n
;
j
+=
blockDim
.
x
)
{
int
index
=
base
+
j
*
post
;
const
T
x_ij
=
inline_abs
(
x
[
index
]);
if
(
x_ij
==
pnorm_i
)
{
x_grad
[
index
]
=
static_cast
<
T
>
(
inline_sign
(
x
[
index
]))
*
yout_i
;
}
else
{
x_grad
[
index
]
=
static_cast
<
T
>
(
0
);
}
}
template
<
typename
T
>
struct
PNormPostGradFunctor
{
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
auto
ones
=
dx
->
constant
(
static_cast
<
T
>
(
1.
));
auto
negs
=
dx
->
constant
(
static_cast
<
T
>
(
-
1.
));
auto
zeros
=
dx
->
constant
(
static_cast
<
T
>
(
0.
));
auto
positives
=
(
*
x
)
>
zeros
;
dx
->
device
(
place
)
=
(
*
dx
)
*
dy
->
broadcast
(
dim
)
*
y
->
broadcast
(
dim
)
*
positives
.
select
(
ones
,
negs
);
}
}
}
;
template
<
typename
DeviceContext
,
typename
T
,
typename
AttrType
=
T
>
class
PnormGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -252,40 +244,40 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
out_dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
dx
=
out_dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
x
=
in_x
->
data
<
T
>
();
const
T
*
x_norm
=
in_norm
->
data
<
T
>
();
const
T
*
norm_dy
=
in_norm_dy
->
data
<
T
>
();
auto
xdim
=
in_x
->
dims
();
float
porder
=
ctx
.
Attr
<
float
>
(
"porder"
);
T
eps
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
bool
asvector
=
ctx
.
Attr
<
bool
>
(
"asvector"
);
bool
reduce_all
=
((
axis
<
0
)
||
(
in_norm
->
numel
()
==
1
)
);
if
(
axis
<
0
)
axis
=
xdim
.
size
()
+
axis
;
int
pre
,
n
,
post
;
GetDims
(
xdim
,
axis
,
&
pre
,
&
n
,
&
post
,
asvector
);
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
std
::
vector
<
int
>
dims
=
{
axis
};
#ifdef __HIPCC__
const
int
block
=
256
;
#else
const
int
block
=
512
;
#endif
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid
=
std
::
min
(
max_blocks
,
pre
*
post
);
if
(
porder
==
0
)
{
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
out_dx
,
static_cast
<
T
>
(
0
));
set_zero
(
cuda_ctx
,
out_dx
,
static_cast
<
T
>
(
0
));
}
else
if
(
porder
==
INFINITY
||
porder
==
-
INFINITY
)
{
InfNormGradient
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>
>>
(
x
,
x_norm
,
norm_dy
,
pre
,
n
,
post
,
dx
);
LaunchReduceGradKernel
<
DeviceContext
,
T
,
AbsMaxAndMinGradFunctor
<
T
>>
(
ctx
,
in_x
,
in_norm
,
in_norm_dy
,
out_dx
,
dims
,
reduce_all
);
}
else
{
PnormGradient
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
,
x_norm
,
norm_dy
,
porder
,
pre
,
n
,
post
,
eps
,
dx
);
framework
::
Tensor
tmp_norm
;
tmp_norm
.
mutable_data
<
T
>
(
in_norm
->
dims
(),
ctx
.
GetPlace
());
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
in_norm
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
&
tmp_norm
};
auto
pow_functor
=
PowFunctor
<
T
>
(
1.
-
porder
);
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
,
PowFunctor
<
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
pow_functor
);
ins
=
{
in_x
};
outs
=
{
out_dx
};
auto
unsigned_pow
=
UnsignedPowFunctor
<
T
>
(
porder
-
1.
);
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
,
UnsignedPowFunctor
<
T
>>
(
cuda_ctx
,
ins
,
&
outs
,
unsigned_pow
);
const
framework
::
Tensor
*
tmp_norm_const
=
&
tmp_norm
;
LaunchReduceGradKernel
<
DeviceContext
,
T
,
PNormPostGradFunctor
<
T
>>
(
ctx
,
in_x
,
tmp_norm_const
,
in_norm_dy
,
out_dx
,
dims
,
reduce_all
);
}
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
10d9ab4b
...
...
@@ -326,46 +326,14 @@ class BoolReduceKernel : public framework::OpKernel<OutT> {
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
,
bool
kNoNeedBufferX
=
false
,
bool
kNoNeedBufferY
=
false
>
class
ReduceGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
ComputeFromInput
(
const
Tensor
*
input2
,
const
framework
::
ExecutionContext
&
context
)
const
{
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
auto
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
auto
*
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
input1
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// The dims has full dim, set the reduce_all is True
const
auto
&
input_dim_size
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
std
::
set
<
int
>
dims_set
(
dims
.
begin
(),
dims
.
end
());
bool
full_dim
=
true
;
for
(
auto
i
=
0
;
i
<
input_dim_size
;
i
++
)
{
if
(
dims_set
.
find
(
i
)
==
dims_set
.
end
())
{
full_dim
=
false
;
break
;
}
}
reduce_all
=
(
reduce_all
||
full_dim
);
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
// and use fake var that has same dims.
if
(
kNoNeedBufferX
)
{
input0
=
output
;
}
if
(
kNoNeedBufferY
)
{
input1
=
input2
;
}
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if
(
!
input1
)
input1
=
input2
;
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
>
void
LaunchReduceGradKernel
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
*
input0
,
const
framework
::
Tensor
*
input1
,
const
framework
::
Tensor
*
input2
,
paddle
::
framework
::
Tensor
*
output
,
const
std
::
vector
<
int
>&
dims
,
bool
reduce_all
=
false
)
{
if
(
reduce_all
)
{
auto
x
=
EigenVector
<
T
>::
Flatten
(
*
input0
);
auto
x_reduce
=
EigenVector
<
T
>::
Flatten
(
*
input1
);
...
...
@@ -383,33 +351,33 @@ class ReduceGradKernel : public framework::OpKernel<T> {
switch
(
rank
)
{
case
1
:
ReduceGradFunctor
<
DeviceContext
,
T
,
1
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
case
2
:
ReduceGradFunctor
<
DeviceContext
,
T
,
2
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
case
3
:
ReduceGradFunctor
<
DeviceContext
,
T
,
3
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
case
4
:
ReduceGradFunctor
<
DeviceContext
,
T
,
4
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
case
5
:
ReduceGradFunctor
<
DeviceContext
,
T
,
5
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
case
6
:
ReduceGradFunctor
<
DeviceContext
,
T
,
6
,
Functor
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
context
.
template
device_context
<
DeviceContext
>(),
*
input0
,
*
input1
,
*
input2
,
output
,
dims
);
break
;
default:
HandleLargeDimGrad
<
DeviceContext
,
T
,
Functor
>
(
context
,
input0
,
input1
,
...
...
@@ -417,6 +385,51 @@ class ReduceGradKernel : public framework::OpKernel<T> {
break
;
}
}
}
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
,
bool
kNoNeedBufferX
=
false
,
bool
kNoNeedBufferY
=
false
>
class
ReduceGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
ComputeFromInput
(
const
Tensor
*
input2
,
const
framework
::
ExecutionContext
&
context
)
const
{
bool
reduce_all
=
context
.
Attr
<
bool
>
(
"reduce_all"
);
auto
dims
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
auto
*
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
input1
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// The dims has full dim, set the reduce_all is True
const
auto
&
input_dim_size
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
std
::
set
<
int
>
dims_set
(
dims
.
begin
(),
dims
.
end
());
bool
full_dim
=
true
;
for
(
auto
i
=
0
;
i
<
input_dim_size
;
i
++
)
{
if
(
dims_set
.
find
(
i
)
==
dims_set
.
end
())
{
full_dim
=
false
;
break
;
}
}
reduce_all
=
(
reduce_all
||
full_dim
);
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
// and use fake var that has same dims.
if
(
kNoNeedBufferX
)
{
input0
=
output
;
}
if
(
kNoNeedBufferY
)
{
input1
=
input2
;
}
const
std
::
vector
<
int
>
const_dims
=
dims
;
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if
(
!
input1
)
input1
=
input2
;
LaunchReduceGradKernel
<
DeviceContext
,
T
,
Functor
>
(
context
,
input0
,
input1
,
input2
,
output
,
const_dims
,
reduce_all
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
...
paddle/fluid/operators/unity_build_rule.cmake
浏览文件 @
10d9ab4b
...
...
@@ -186,7 +186,6 @@ register_unity_group(cc
norm_op.cc
one_hot_op.cc
one_hot_v2_op.cc
p_norm_op.cc
pad2d_op.cc
pad3d_op.cc
pad_constant_like_op.cc
...
...
@@ -468,7 +467,6 @@ register_unity_group(cu
nll_loss_op.cu
norm_op.cu
one_hot_op.cu
p_norm_op.cu
pad2d_op.cu
pad3d_op.cu
pad_constant_like_op.cu
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录