Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5ad020e2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5ad020e2
编写于
2月 28, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move sgd to phi; test=develop
上级
2bb5aae8
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
404 addition
and
0 deletion
+404
-0
paddle/phi/kernels/cpu/sgd_kernel.cc
paddle/phi/kernels/cpu/sgd_kernel.cc
+185
-0
paddle/phi/kernels/gpu/sgd_kernel.cu
paddle/phi/kernels/gpu/sgd_kernel.cu
+167
-0
paddle/phi/kernels/sgd_kernel.h
paddle/phi/kernels/sgd_kernel.h
+52
-0
未找到文件。
paddle/phi/kernels/cpu/sgd_kernel.cc
0 → 100644
浏览文件 @
5ad020e2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
>
void
sgd_dense_param_dense_grad_impl
(
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
DenseTensor
*
param_out
)
{
const
auto
sz
=
param_out
->
numel
();
paddle
::
operators
::
jit
::
sgd_attr_t
attr
(
1
,
sz
,
1
,
sz
,
1
);
const
T
*
lr
=
learning_rate
.
data
<
T
>
();
const
T
*
param_data
=
param
.
data
<
T
>
();
const
T
*
grad_data
=
grad
.
data
<
T
>
();
int64_t
rows_idx
=
0
;
T
*
out_data
=
param_out
->
data
<
T
>
();
auto
sgd
=
paddle
::
operators
::
jit
::
KernelFuncs
<
paddle
::
operators
::
jit
::
SgdTuple
<
T
>
,
phi
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
sgd
(
lr
,
param_data
,
grad_data
,
&
rows_idx
,
out_data
,
&
attr
);
}
template
<
>
void
sgd_dense_param_dense_grad_impl
<
phi
::
dtype
::
bfloat16
>
(
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
DenseTensor
*
param_out
)
{
auto
p
=
EigenVector
<
phi
::
dtype
::
bfloat16
>::
Flatten
(
param
);
auto
g
=
EigenVector
<
phi
::
dtype
::
bfloat16
>::
Flatten
(
grad
);
auto
o
=
EigenVector
<
phi
::
dtype
::
bfloat16
>::
Flatten
(
*
param_out
);
const
auto
*
lr
=
learning_rate
.
data
<
phi
::
dtype
::
bfloat16
>
();
o
=
p
-
lr
[
0
]
*
g
;
}
template
<
typename
T
>
void
sgd_dense_param_sparse_grad_impl
(
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
DenseTensor
*
param_out
)
{
const
auto
&
grad_value
=
grad
.
value
();
const
auto
&
grad_rows
=
grad
.
rows
();
const
T
*
param_data
=
param
.
data
<
T
>
();
const
T
*
grad_data
=
grad_value
.
data
<
T
>
();
const
T
*
lr
=
learning_rate
.
data
<
T
>
();
const
int64_t
*
rows_data
=
grad_rows
.
data
();
T
*
out_data
=
param_out
->
data
<
T
>
();
paddle
::
operators
::
jit
::
sgd_attr_t
attr
;
attr
.
param_height
=
param_out
->
dims
()[
0
];
attr
.
param_width
=
param_out
->
numel
()
/
attr
.
param_height
;
attr
.
grad_height
=
grad_rows
.
size
();
// note: it is not grad->height()
attr
.
grad_width
=
grad_value
.
numel
()
/
attr
.
grad_height
;
attr
.
selected_rows_size
=
grad_rows
.
size
();
auto
sgd
=
paddle
::
operators
::
jit
::
KernelFuncs
<
paddle
::
operators
::
jit
::
SgdTuple
<
T
>
,
phi
::
CPUPlace
>::
Cache
()
.
At
(
attr
);
sgd
(
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
}
template
<
>
void
sgd_dense_param_sparse_grad_impl
<
phi
::
dtype
::
bfloat16
>
(
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
DenseTensor
*
param_out
)
{
const
auto
&
grad_value
=
grad
.
value
();
const
auto
&
grad_rows
=
grad
.
rows
();
const
auto
grad_height
=
grad
.
height
();
const
int64_t
grad_val_height
=
static_cast
<
int64_t
>
(
grad_rows
.
size
());
const
auto
grad_width
=
grad_value
.
numel
()
/
grad_val_height
;
const
auto
*
grad_data
=
grad_value
.
data
<
phi
::
dtype
::
bfloat16
>
();
auto
*
out_data
=
param_out
->
data
<
phi
::
dtype
::
bfloat16
>
();
const
auto
*
lr
=
learning_rate
.
data
<
phi
::
dtype
::
bfloat16
>
();
for
(
size_t
i
=
0
;
i
<
grad_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
grad_rows
[
i
],
grad_height
,
phi
::
errors
::
OutOfRange
(
"Grad rows index value should be less than grad height."
"Got [%s], but expected less than [%s]"
,
grad_rows
[
i
],
grad_height
));
const
int64_t
row
=
grad_rows
[
i
];
for
(
int64_t
j
=
0
;
j
<
grad_width
;
++
j
)
{
out_data
[
row
*
grad_width
+
j
]
-=
lr
[
0
]
*
grad_data
[
i
*
grad_width
+
j
];
}
}
}
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
sgd_dense_param_dense_grad_impl
<
T
>
(
param
,
learning_rate
,
grad
,
param_out
);
}
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
sgd_dense_param_sparse_grad_impl
<
T
>
(
param
,
learning_rate
,
grad
,
param_out
);
}
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
SelectedRows
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
const
SelectedRows
&
master_param
,
bool
multi_precision
,
SelectedRows
*
param_out
,
SelectedRows
*
master_param_out
)
{
// for distributed training, a sparse var may be empty,
// just skip updating.
if
(
grad
.
rows
().
size
()
==
0
)
{
return
;
}
auto
param_row_width
=
param
.
value
().
dims
()[
1
];
auto
grad_row_width
=
grad
.
value
().
dims
()[
1
];
PADDLE_ENFORCE_EQ
(
param_row_width
,
grad_row_width
,
phi
::
errors
::
InvalidArgument
(
"The param_row in SgdOP should have the same size with grad_row. "
"But received param_row's width is [%s], and grad_row's width is "
"[%s]"
,
param_row_width
,
grad_row_width
));
const
auto
*
lr
=
learning_rate
.
data
<
T
>
();
const
auto
*
grad_data
=
grad
.
value
().
data
<
T
>
();
auto
*
out_data
=
param_out
->
mutable_value
()
->
data
<
T
>
();
for
(
size_t
i
=
0
;
i
<
grad
.
rows
().
size
();
i
++
)
{
int64_t
id_index
=
param_out
->
AutoGrownIndex
(
grad
.
rows
()[
i
],
false
);
PADDLE_ENFORCE_GE
(
id_index
,
static_cast
<
int64_t
>
(
0
),
phi
::
errors
::
InvalidArgument
(
"The id in SgdOp should be >= 0. But recevied id_index is [%s]"
,
id_index
));
for
(
int64_t
j
=
0
;
j
<
grad_row_width
;
j
++
)
{
out_data
[
id_index
*
grad_row_width
+
j
]
-=
lr
[
0
]
*
grad_data
[
i
*
grad_row_width
+
j
];
}
}
}
}
// namespace phi
paddle/phi/kernels/gpu/sgd_kernel.cu
0 → 100644
浏览文件 @
5ad020e2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
namespace
phi
{
template
<
typename
T
,
typename
MT
>
__global__
void
SGDKernelMT
(
const
T
*
param
,
const
T
*
grad
,
const
T
*
learning_rate
,
const
int
num
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
)
{
MT
lr
=
static_cast
<
MT
>
(
learning_rate
[
0
]);
CUDA_KERNEL_LOOP
(
i
,
num
)
{
MT
p_data
=
master_param
?
master_param
[
i
]
:
static_cast
<
MT
>
(
param
[
i
]);
MT
g_data
=
static_cast
<
MT
>
(
grad
[
i
]);
p_data
=
p_data
-
lr
*
g_data
;
param_out
[
i
]
=
static_cast
<
T
>
(
p_data
);
if
(
master_param_out
)
{
master_param_out
[
i
]
=
p_data
;
}
}
}
template
<
typename
T
>
__global__
void
SparseSGDFunctorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
const
T
*
learning_rate
,
T
*
tensor_out
,
int64_t
row_numel
,
int64_t
limit
)
{
for
(
int64_t
i
=
blockIdx
.
x
;
i
<
limit
;
i
+=
gridDim
.
x
)
{
const
T
*
selected_rows_ptr
=
selected_rows
+
i
*
row_numel
;
T
*
tensor_out_ptr
=
tensor_out
+
rows
[
i
]
*
row_numel
;
for
(
int64_t
index
=
threadIdx
.
x
;
index
<
row_numel
;
index
+=
blockDim
.
x
)
{
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle
::
platform
::
CudaAtomicAdd
(
tensor_out_ptr
+
index
,
-
static_cast
<
T
>
(
1.0
)
*
learning_rate
[
0
]
*
selected_rows_ptr
[
index
]);
}
}
}
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
using
MPDType
=
typename
paddle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
// do check here
// if (multi_precision) {
// bool has_master =
// ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
// }
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
.
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
master_param_out
->
mutable_data
<
MPDType
>
(
dev_ctx
.
GetPlace
())
:
nullptr
;
int
block
=
512
;
int
grid
=
(
param
.
numel
()
+
block
-
1
)
/
block
;
SGDKernelMT
<
T
,
MPDType
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
param
.
data
<
T
>
(),
grad
.
data
<
T
>
(),
learning_rate
.
data
<
T
>
(),
param
.
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
);
}
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
using
MPDType
=
typename
paddle
::
operators
::
details
::
MPTypeTrait
<
T
>::
Type
;
// do some check here
// if (multi_precision) {
// bool has_master =
// ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
// }
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
.
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
master_param_out
->
mutable_data
<
MPDType
>
(
dev_ctx
.
GetPlace
())
:
nullptr
;
PADDLE_ENFORCE_EQ
(
&
param
,
param_out
,
phi
::
errors
::
InvalidArgument
(
"The input tensor Param of SgdOp should be equal with ParamOut "
"if variable's type is SelectedRows."
));
auto
in_height
=
grad
.
height
();
auto
out_dims
=
param_out
->
dims
();
PADDLE_ENFORCE_EQ
(
in_height
,
out_dims
[
0
],
phi
::
errors
::
InvalidArgument
(
"The input tensor Grad's height of SgdOp should be "
"equal with ParamOut's dims. But received Grad's "
"height [%s] and ParamOut's dims [%s]"
,
in_height
,
out_dims
[
0
]));
auto
&
in_value
=
grad
.
value
();
auto
&
in_rows
=
grad
.
rows
();
int64_t
in_row_numel
=
in_value
.
numel
()
/
in_rows
.
size
();
PADDLE_ENFORCE_EQ
(
in_row_numel
,
param_out
->
numel
()
/
in_height
,
phi
::
errors
::
InvalidArgument
(
"The in_row_numel of SgdOp should be equal with "
"param_out's numel / in_height."
));
auto
*
in_data
=
in_value
.
data
<
T
>
();
auto
*
out_data
=
param_out
->
data
<
T
>
();
const
int
kThreadsPerBlock
=
256
;
int
thread_x
=
kThreadsPerBlock
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
paddle
::
framework
::
MixVector
<
int64_t
>
mixv_in_rows
(
&
in_rows
);
SparseSGDFunctorKernel
<<<
max_blocks
,
thread_x
,
0
,
dev_ctx
..
stream
()
>>>
(
in_data
,
mixv_in_rows
.
CUDAData
(
dev_ctx
.
GetPlace
()),
learning_rate
.
data
<
T
>
(),
out_data
,
in_row_numel
,
in_rows
.
size
());
}
}
// namespace phi
paddle/phi/kernels/sgd_kernel.h
0 → 100644
浏览文件 @
5ad020e2
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
);
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
const
DenseTensor
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
);
template
<
typename
T
,
typename
Context
>
void
SGDKernel
(
const
Context
&
dev_ctx
,
const
SelectedRows
&
param
,
const
DenseTensor
&
learning_rate
,
const
SelectedRows
&
grad
,
const
SelectedRows
&
master_param
,
bool
multi_precision
,
SelectedRows
*
param_out
,
SelectedRows
*
master_param_out
);
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录