Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
348565b0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
348565b0
编写于
5月 12, 2023
作者:
H
huangjiyi
提交者:
GitHub
5月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move pow2_decay_with_linear_warmup kernel to phi (#53741)
* update * update
上级
4e416c99
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
277 addition
and
100 deletion
+277
-100
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc
.../operators/optimizers/pow2_decay_with_linear_warmup_op.cc
+0
-7
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc
...rators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc
+0
-84
paddle/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc
...e/phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc
+10
-9
paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu
...e/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu
+25
-0
paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h
.../kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h
+110
-0
paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h
paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h
+31
-0
paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc
...e/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc
+71
-0
paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc
paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc
+30
-0
未找到文件。
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cc
浏览文件 @
348565b0
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
@@ -78,12 +76,7 @@ When step_num > total_steps, lr = end_lr
...
@@ -78,12 +76,7 @@ When step_num > total_steps, lr = end_lr
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
pow2_decay_with_linear_warmup
,
REGISTER_OP_WITHOUT_GRADIENT
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupOp
,
ops
::
Pow2DecayWithLinearWarmupOp
,
ops
::
Pow2DecayWithLinearWarmupOpMaker
);
ops
::
Pow2DecayWithLinearWarmupOpMaker
);
REGISTER_OP_CPU_KERNEL
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
phi
::
CPUContext
,
double
>
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
phi
::
CPUContext
,
float
>
);
paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op_xpu.cc
已删除
100644 → 0
浏览文件 @
4e416c99
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
Pow2DecayWithLinearWarmupXPUOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
auto
*
lr
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"LearningRate"
);
const
auto
*
step
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Step"
);
auto
*
lr_out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"LearningRateOut"
);
auto
*
step_out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"StepOut"
);
PADDLE_ENFORCE_EQ
(
lr
,
lr_out
,
platform
::
errors
::
InvalidArgument
(
"Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."
));
PADDLE_ENFORCE_NOT_NULL
(
lr
,
platform
::
errors
::
InvalidArgument
(
"Input(LearingRate) should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
step
,
step_out
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) and Output(StepOut) must be the same."
));
PADDLE_ENFORCE_NOT_NULL
(
step
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
step
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) must be initialized."
));
auto
warmup_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"warmup_steps"
));
auto
total_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"total_steps"
));
PADDLE_ENFORCE_LE
(
warmup_steps
,
total_steps
,
platform
::
errors
::
InvalidArgument
(
"warmup_steps must not be larger than total_steps."
));
auto
base_lr
=
ctx
.
Attr
<
float
>
(
"base_lr"
);
auto
end_lr
=
ctx
.
Attr
<
float
>
(
"end_lr"
);
auto
*
lr_data
=
lr_out
->
data
<
T
>
();
auto
*
step_data
=
step_out
->
data
<
int64_t
>
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
int
r
=
xpu
::
pow2_decay_with_linear_warmup
(
dev_ctx
.
x_context
(),
lr_data
,
step_data
,
warmup_steps
,
total_steps
,
base_lr
,
end_lr
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"pow2_decay_with_linear_warmup"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
pow2_decay_with_linear_warmup
,
ops
::
Pow2DecayWithLinearWarmupXPUOpKernel
<
float
>
);
#endif
paddle/
fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.cu
→
paddle/
phi/kernels/cpu/pow2_decay_with_linear_warmup_kernel.cc
浏览文件 @
348565b0
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -12,13 +12,14 @@
...
@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
ops
=
paddle
::
operators
;
#include "paddle/phi/core/kernel_registry.h"
namespace
plat
=
paddle
::
platform
;
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL
(
PD_REGISTER_KERNEL
(
pow2_decay_with_linear_warmup
,
pow2_decay_with_linear_warmup
,
CPU
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
phi
::
GPUContext
,
double
>
,
ALL_LAYOUT
,
ops
::
Pow2DecayWithLinearWarmupOpKernel
<
phi
::
GPUContext
,
float
>
);
phi
::
Pow2DecayWithLinearWarmupKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/pow2_decay_with_linear_warmup_kernel.cu
0 → 100644
浏览文件 @
348565b0
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"
PD_REGISTER_KERNEL
(
pow2_decay_with_linear_warmup
,
GPU
,
ALL_LAYOUT
,
phi
::
Pow2DecayWithLinearWarmupKernel
,
float
,
double
)
{}
paddle/
fluid/operators/optimizers/pow2_decay_with_linear_warmup_op
.h
→
paddle/
phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl
.h
浏览文件 @
348565b0
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
3
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -14,18 +14,16 @@
...
@@ -14,18 +14,16 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
template
<
typename
T
,
typename
AttrT
>
template
<
typename
T
,
typename
AttrT
>
struct
Pow2DecayWithLinearWarmupFunctor
{
struct
Pow2DecayWithLinearWarmupFunctor
{
template
<
typename
U
>
template
<
typename
U
>
using
RestrictPtr
=
U
*
PADDLE_RESTRICT
;
using
RestrictPtr
=
U
*
PADDLE_RESTRICT
;
public:
public:
HOSTDEVICE
Pow2DecayWithLinearWarmupFunctor
(
RestrictPtr
<
T
>
lr
,
HOSTDEVICE
Pow2DecayWithLinearWarmupFunctor
(
RestrictPtr
<
T
>
lr
,
...
@@ -67,59 +65,46 @@ struct Pow2DecayWithLinearWarmupFunctor {
...
@@ -67,59 +65,46 @@ struct Pow2DecayWithLinearWarmupFunctor {
AttrT
end_lr_
;
AttrT
end_lr_
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
Context
>
class
Pow2DecayWithLinearWarmupOpKernel
:
public
framework
::
OpKernel
<
T
>
{
void
Pow2DecayWithLinearWarmupKernel
(
const
Context
&
dev_ctx
,
public:
const
DenseTensor
&
lr
,
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
DenseTensor
&
step
,
const
auto
*
lr
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"LearningRate"
);
int64_t
warmup_steps
,
const
auto
*
step
=
ctx
.
Input
<
phi
::
DenseTensor
>
(
"Step"
);
int64_t
total_steps
,
auto
*
lr_out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"LearningRateOut"
);
float
base_lr
,
auto
*
step_out
=
ctx
.
Output
<
phi
::
DenseTensor
>
(
"StepOut"
);
float
end_lr
,
PADDLE_ENFORCE_EQ
(
DenseTensor
*
lr_out
,
lr
,
DenseTensor
*
step_out
)
{
lr_out
,
PADDLE_ENFORCE_EQ
(
&
lr
,
platform
::
errors
::
InvalidArgument
(
"Input(LearningRate) and "
lr_out
,
"Output(LearningRateOut) "
phi
::
errors
::
InvalidArgument
(
"Input(LearningRate) and "
"must be the same."
));
"Output(LearningRateOut) "
PADDLE_ENFORCE_NOT_NULL
(
lr
,
"must be the same."
));
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
&
step
,
"Input(LearingRate) should not be nullptr."
));
step_out
,
PADDLE_ENFORCE_EQ
(
step
,
phi
::
errors
::
InvalidArgument
(
step_out
,
"Input(Step) and Output(StepOut) must be the same."
));
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
"Input(Step) and Output(StepOut) must be the same."
));
step
.
IsInitialized
(),
PADDLE_ENFORCE_NOT_NULL
(
step
,
true
,
platform
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Input(Step) must be initialized."
));
"Input(Step) should not be nullptr."
));
PADDLE_ENFORCE_EQ
(
step
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Step) must be initialized."
));
auto
warmup_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"warmup_steps"
));
auto
total_steps
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int64_t
>
(
"total_steps"
));
PADDLE_ENFORCE_LE
(
warmup_steps
,
total_steps
,
platform
::
errors
::
InvalidArgument
(
"warmup_steps must not be larger than total_steps."
));
auto
base_lr
=
ctx
.
Attr
<
float
>
(
"base_lr"
);
auto
end_lr
=
ctx
.
Attr
<
float
>
(
"end_lr"
);
auto
*
lr_data
=
lr_out
->
data
<
T
>
();
PADDLE_ENFORCE_LE
(
warmup_steps
,
auto
*
step_data
=
step_out
->
data
<
int64_t
>
();
total_steps
,
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
phi
::
errors
::
InvalidArgument
(
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
1
);
"warmup_steps must not be larger than total_steps."
));
using
AttrT
=
double
;
Pow2DecayWithLinearWarmupFunctor
<
T
,
AttrT
>
functor
(
lr_data
,
step_data
,
warmup_steps
,
total_steps
,
static_cast
<
AttrT
>
(
base_lr
),
static_cast
<
AttrT
>
(
end_lr
));
for_range
(
functor
);
}
};
}
// namespace operators
auto
*
lr_data
=
lr_out
->
data
<
T
>
();
}
// namespace paddle
auto
*
step_data
=
step_out
->
data
<
int64_t
>
();
phi
::
funcs
::
ForRange
<
Context
>
for_range
(
dev_ctx
,
1
);
using
AttrT
=
double
;
Pow2DecayWithLinearWarmupFunctor
<
T
,
AttrT
>
functor
(
lr_data
,
step_data
,
static_cast
<
size_t
>
(
warmup_steps
),
static_cast
<
size_t
>
(
total_steps
),
static_cast
<
AttrT
>
(
base_lr
),
static_cast
<
AttrT
>
(
end_lr
));
for_range
(
functor
);
}
}
// namespace phi
paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h
0 → 100644
浏览文件 @
348565b0
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
Pow2DecayWithLinearWarmupKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
lr
,
const
DenseTensor
&
step
,
int64_t
warmup_steps
,
int64_t
total_steps
,
float
base_lr
,
float
end_lr
,
DenseTensor
*
lr_out
,
DenseTensor
*
step_out
);
}
// namespace phi
paddle/phi/kernels/xpu/pow2_decay_with_linear_warmup_kernel.cc
0 → 100644
浏览文件 @
348565b0
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/macros.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
Pow2DecayWithLinearWarmupKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
lr
,
const
DenseTensor
&
step
,
int64_t
warmup_steps
,
int64_t
total_steps
,
float
base_lr
,
float
end_lr
,
DenseTensor
*
lr_out
,
DenseTensor
*
step_out
)
{
PADDLE_ENFORCE_EQ
(
&
lr
,
lr_out
,
phi
::
errors
::
InvalidArgument
(
"Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."
));
PADDLE_ENFORCE_EQ
(
&
step
,
step_out
,
phi
::
errors
::
InvalidArgument
(
"Input(Step) and Output(StepOut) must be the same."
));
PADDLE_ENFORCE_EQ
(
step
.
IsInitialized
(),
true
,
phi
::
errors
::
InvalidArgument
(
"Input(Step) must be initialized."
));
PADDLE_ENFORCE_LE
(
warmup_steps
,
total_steps
,
phi
::
errors
::
InvalidArgument
(
"warmup_steps must not be larger than total_steps."
));
auto
*
lr_data
=
lr_out
->
data
<
T
>
();
auto
*
step_data
=
step_out
->
data
<
int64_t
>
();
int
r
=
xpu
::
pow2_decay_with_linear_warmup
(
dev_ctx
.
x_context
(),
lr_data
,
step_data
,
static_cast
<
size_t
>
(
warmup_steps
),
static_cast
<
size_t
>
(
total_steps
),
base_lr
,
end_lr
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"pow2_decay_with_linear_warmup"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
pow2_decay_with_linear_warmup
,
XPU
,
ALL_LAYOUT
,
phi
::
Pow2DecayWithLinearWarmupKernel
,
float
)
{}
paddle/phi/ops/compat/pow2_decay_with_linear_warmup_sig.cc
0 → 100644
浏览文件 @
348565b0
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
Pow2DecayWithLinearWarmupOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"pow2_decay_with_linear_warmup"
,
{
"LearningRate"
,
"Step"
},
{
"warmup_steps"
,
"total_steps"
,
"base_lr"
,
"end_lr"
},
{
"LearningRateOut"
,
"StepOut"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
pow2_decay_with_linear_warmup
,
phi
::
Pow2DecayWithLinearWarmupOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录