Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e6c3f64f
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e6c3f64f
编写于
12月 25, 2021
作者:
C
Chen Weihang
提交者:
GitHub
12月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix renorm op include error and format error (#38451)
* remove needless header * remove needless header * adjust header order
上级
bbe879fc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
15 addition
and
14 deletion
+15
-14
paddle/fluid/operators/renorm_op.cu
paddle/fluid/operators/renorm_op.cu
+15
-14
未找到文件。
paddle/fluid/operators/renorm_op.cu
浏览文件 @
e6c3f64f
...
...
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/renorm_op.h"
#include <algorithm>
#include <cstdio>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/renorm_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "stdio.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -60,7 +60,7 @@ __global__ void RenormKernelFunc3(int64_t size, T* dim_value, float p,
}
template
<
typename
T
>
__global__
void
RenormKernelFunc4
(
T
*
x_data
,
T
*
out_data
,
int64_t
size
,
__global__
void
RenormKernelFunc4
(
const
T
*
x_data
,
T
*
out_data
,
int64_t
size
,
T
*
dim_value
,
int64_t
dimension_each
,
int64_t
dim_divisor
)
{
int64_t
i
=
((
int64_t
)
blockIdx
.
x
)
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -74,8 +74,8 @@ __global__ void RenormKernelFunc4(T* x_data, T* out_data, int64_t size,
}
template
<
typename
T
>
__global__
void
RenormGradKernelFunc1
(
T
*
x_data
,
T
*
dout_data
,
T
*
pow_value
,
T
*
mul_value
,
int64_t
size
,
__global__
void
RenormGradKernelFunc1
(
const
T
*
x_data
,
const
T
*
dout_data
,
T
*
pow_value
,
T
*
mul_value
,
int64_t
size
,
int64_t
dimension_each
,
float
p
,
int64_t
dim_divisor
)
{
int64_t
i
=
((
int64_t
)
blockIdx
.
x
)
*
blockDim
.
x
+
threadIdx
.
x
;
...
...
@@ -87,8 +87,8 @@ __global__ void RenormGradKernelFunc1(T* x_data, T* dout_data, T* pow_value,
}
template
<
typename
T
>
__global__
void
RenormGradKernelFunc2
(
T
*
x_data
,
T
*
dout_data
,
T
*
dx
_data
,
int64_t
size
,
T
*
dim_value
,
__global__
void
RenormGradKernelFunc2
(
const
T
*
x_data
,
const
T
*
dout
_data
,
T
*
dx_data
,
int64_t
size
,
T
*
dim_value
,
T
*
dim_power_sum
,
T
*
weight_derivative
,
int64_t
dimension_each
,
float
p
,
float
max_norm
,
int64_t
dim_divisor
)
{
...
...
@@ -100,8 +100,9 @@ __global__ void RenormGradKernelFunc2(T* x_data, T* dout_data, T* dx_data,
if
(
temp
>
max_norm
)
{
dim_power_sum
[
i
]
=
pow
(
dim_value
[
i
],
(
T
)(
-
1.0
-
1.0
/
p
))
*
-
1
*
max_norm
;
dim_value
[
i
]
=
max_norm
/
temp
;
}
else
}
else
{
dim_value
[
i
]
=
1.0
;
}
}
__syncthreads
();
if
(
i
<
size
)
{
...
...
@@ -120,7 +121,7 @@ class CUDARenormKernel : public framework::OpKernel<T> {
const
Tensor
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
numel
=
x
->
numel
();
T
*
x_data
=
(
T
*
)
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
input_dims
=
x
->
dims
();
float
max_norm
=
context
.
Attr
<
float
>
(
"max_norm"
);
float
p
=
context
.
Attr
<
float
>
(
"p"
);
...
...
@@ -176,8 +177,8 @@ class CUDAGradRenormKernel : public framework::OpKernel<T> {
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
numel
=
d_out
->
numel
();
T
*
dout_data
=
(
T
*
)
d_out
->
data
<
T
>
();
T
*
x_data
=
(
T
*
)
x
->
data
<
T
>
();
const
T
*
dout_data
=
d_out
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
input_dims
=
x
->
dims
();
float
max_norm
=
ctx
.
Attr
<
float
>
(
"max_norm"
);
float
p
=
ctx
.
Attr
<
float
>
(
"p"
);
...
...
@@ -234,4 +235,4 @@ REGISTER_OP_CUDA_KERNEL(renorm, ops::CUDARenormKernel<float>,
ops
::
CUDARenormKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
renorm_grad
,
ops
::
CUDAGradRenormKernel
<
float
>
,
ops
::
CUDAGradRenormKernel
<
double
>
);
\ No newline at end of file
ops
::
CUDAGradRenormKernel
<
double
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录