Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e4670d80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e4670d80
编写于
11月 18, 2022
作者:
H
huangjiyi
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm "paddle/fluid/operators/amp/fp16_type_traits.h" in phi (#48051)
上级
fafc7be2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
15 addition
and
15 deletion
+15
-15
paddle/fluid/operators/group_norm_op.cu
paddle/fluid/operators/group_norm_op.cu
+1
-1
paddle/fluid/operators/uniform_random_op.h
paddle/fluid/operators/uniform_random_op.h
+1
-1
paddle/phi/kernels/funcs/functors.h
paddle/phi/kernels/funcs/functors.h
+4
-4
paddle/phi/kernels/gpu/norm_grad_kernel.cu
paddle/phi/kernels/gpu/norm_grad_kernel.cu
+2
-2
paddle/phi/kernels/gpu/norm_kernel.cu
paddle/phi/kernels/gpu/norm_kernel.cu
+2
-2
paddle/phi/kernels/gpu/sgd_kernel.cu
paddle/phi/kernels/gpu/sgd_kernel.cu
+3
-3
paddle/phi/kernels/primitive/functor_primitives.h
paddle/phi/kernels/primitive/functor_primitives.h
+2
-2
未找到文件。
paddle/fluid/operators/group_norm_op.cu
浏览文件 @
e4670d80
...
...
@@ -324,7 +324,7 @@ class GroupNormKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
dim3
grid
(
group_size
,
groups
,
x_dims
[
0
]);
dim3
threads
(
block_size
,
1
,
1
);
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
using
AccT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
AccT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
constexpr
int
vec_size
=
sizeof
(
float4
)
/
sizeof
(
T
);
int
size
=
group_size
*
imsize
;
const
int
max_num_threads
=
1024
;
...
...
paddle/fluid/operators/uniform_random_op.h
浏览文件 @
e4670d80
...
...
@@ -165,7 +165,7 @@ void UniformRandom(const framework::ExecutionContext& context,
if
(
seed
==
0
)
{
// Use global Generator seed
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
phi
::
funcs
::
uniform_distribution
<
MT
>
dist
;
phi
::
funcs
::
uniform_real_transform
<
MT
>
trans
(
min
,
max
);
phi
::
funcs
::
distribution_and_transform
<
T
>
(
dev_cxt
,
tensor
,
dist
,
trans
);
...
...
paddle/phi/kernels/funcs/functors.h
浏览文件 @
e4670d80
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/
fluid/operators/amp/fp16
_type_traits.h"
#include "paddle/
phi/common/amp
_type_traits.h"
#include "paddle/phi/kernels/funcs/math.h"
namespace
phi
{
...
...
@@ -38,7 +38,7 @@ struct AddGradFunctor {
template
<
typename
T
>
struct
ScaleFunctor
{
using
MT
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
explicit
ScaleFunctor
(
const
MT
coeff
)
:
coeff_
(
coeff
)
{}
inline
HOSTDEVICE
T
operator
()(
T
ele
)
{
...
...
@@ -125,7 +125,7 @@ struct SigmoidGradFunctor {
template
<
typename
T
>
struct
GeluFunctor
{
using
MT
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
inline
HOSTDEVICE
T
operator
()(
T
x
)
{
// this function is tanh approximation of gelu
// actual gelu is:
...
...
@@ -141,7 +141,7 @@ struct GeluFunctor {
template
<
typename
T
>
struct
GeluGradFunctor
{
using
MT
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
inline
HOSTDEVICE
T
UseX
(
T
x
)
{
MT
mx
=
static_cast
<
MT
>
(
x
);
MT
tanh_out
=
...
...
paddle/phi/kernels/gpu/norm_grad_kernel.cu
浏览文件 @
e4670d80
...
...
@@ -22,8 +22,8 @@
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
...
...
@@ -38,7 +38,7 @@ __global__ void NormalizeGradient(const T* x,
const
int
axis_n
,
const
int
post
,
T
*
x_grad
)
{
using
MT
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage_sum
;
int
num
=
pre
*
post
;
...
...
paddle/phi/kernels/gpu/norm_kernel.cu
浏览文件 @
e4670d80
...
...
@@ -22,8 +22,8 @@
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
...
...
@@ -46,7 +46,7 @@ __global__ void Normalize(const T* x,
const
T
eps
,
T
*
y
,
T
*
out_norm
)
{
using
MT
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MT
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
typedef
cub
::
BlockReduce
<
MT
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
num
=
pre
*
post
;
...
...
paddle/phi/kernels/gpu/sgd_kernel.cu
浏览文件 @
e4670d80
...
...
@@ -15,10 +15,10 @@
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
...
...
@@ -72,7 +72,7 @@ void SGDDenseKernel(const Context& dev_ctx,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
using
MPDType
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MPDType
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// do check here
// if (multi_precision) {
// bool has_master =
...
...
@@ -109,7 +109,7 @@ void SGDDenseParamSparseGradKernel(
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
using
MPDType
=
typename
p
addle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
using
MPDType
=
typename
p
hi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
// do some check here
// if (multi_precision) {
// bool has_master =
...
...
paddle/phi/kernels/primitive/functor_primitives.h
浏览文件 @
e4670d80
...
...
@@ -14,7 +14,7 @@
#pragma once
#include "paddle/
fluid/operators/amp/fp16
_type_traits.h"
#include "paddle/
phi/common/amp
_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
...
...
@@ -79,7 +79,7 @@ struct IdentityFunctor {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
DivideFunctor
{
private:
using
MPType
=
typename
::
p
addle
::
operators
::
details
::
MPTypeTrait
<
Tx
>::
Type
;
using
MPType
=
typename
::
p
hi
::
dtype
::
MPTypeTrait
<
Tx
>::
Type
;
public:
HOSTDEVICE
inline
DivideFunctor
()
{
n_inv
=
static_cast
<
MPType
>
(
1.0
f
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录