未验证 提交 348565b0 编写于 作者: H huangjiyi 提交者: GitHub

move pow2_decay_with_linear_warmup kernel to phi (#53741)

* update

* update
上级 4e416c99
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
......@@ -78,12 +76,7 @@ When step_num > total_steps, lr = end_lr
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOp,
ops::Pow2DecayWithLinearWarmupOpMaker);
REGISTER_OP_CPU_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::CPUContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::CPUContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
namespace paddle {
namespace operators {
template <typename T>
class Pow2DecayWithLinearWarmupXPUOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const {
const auto *lr = ctx.Input<phi::DenseTensor>("LearningRate");
const auto *step = ctx.Input<phi::DenseTensor>("Step");
auto *lr_out = ctx.Output<phi::DenseTensor>("LearningRateOut");
auto *step_out = ctx.Output<phi::DenseTensor>("StepOut");
PADDLE_ENFORCE_EQ(
lr,
lr_out,
platform::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_NOT_NULL(lr,
platform::errors::InvalidArgument(
"Input(LearingRate) should not be nullptr."));
PADDLE_ENFORCE_EQ(step,
step_out,
platform::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_NOT_NULL(step,
platform::errors::InvalidArgument(
"Input(Step) should not be nullptr."));
PADDLE_ENFORCE_EQ(
step->IsInitialized(),
true,
platform::errors::InvalidArgument("Input(Step) must be initialized."));
auto warmup_steps = static_cast<size_t>(ctx.Attr<int64_t>("warmup_steps"));
auto total_steps = static_cast<size_t>(ctx.Attr<int64_t>("total_steps"));
PADDLE_ENFORCE_LE(warmup_steps,
total_steps,
platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr");
auto *lr_data = lr_out->data<T>();
auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(),
lr_data,
step_data,
warmup_steps,
total_steps,
base_lr,
end_lr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupXPUOpKernel<float>);
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/pow2_decay_with_linear_warmup_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL(
pow2_decay_with_linear_warmup,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::GPUContext, double>,
ops::Pow2DecayWithLinearWarmupOpKernel<phi::GPUContext, float>);
PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
CPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pow2_decay_with_linear_warmup_kernel_impl.h"
PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
GPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,18 +14,16 @@
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace paddle {
namespace operators {
namespace phi {
template <typename T, typename AttrT>
struct Pow2DecayWithLinearWarmupFunctor {
template <typename U>
using RestrictPtr = U *PADDLE_RESTRICT;
using RestrictPtr = U* PADDLE_RESTRICT;
public:
HOSTDEVICE Pow2DecayWithLinearWarmupFunctor(RestrictPtr<T> lr,
......@@ -67,59 +65,46 @@ struct Pow2DecayWithLinearWarmupFunctor {
AttrT end_lr_;
};
template <typename DeviceContext, typename T>
class Pow2DecayWithLinearWarmupOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const {
const auto *lr = ctx.Input<phi::DenseTensor>("LearningRate");
const auto *step = ctx.Input<phi::DenseTensor>("Step");
auto *lr_out = ctx.Output<phi::DenseTensor>("LearningRateOut");
auto *step_out = ctx.Output<phi::DenseTensor>("StepOut");
PADDLE_ENFORCE_EQ(
lr,
lr_out,
platform::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_NOT_NULL(lr,
platform::errors::InvalidArgument(
"Input(LearingRate) should not be nullptr."));
PADDLE_ENFORCE_EQ(step,
step_out,
platform::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_NOT_NULL(step,
platform::errors::InvalidArgument(
"Input(Step) should not be nullptr."));
PADDLE_ENFORCE_EQ(
step->IsInitialized(),
true,
platform::errors::InvalidArgument("Input(Step) must be initialized."));
auto warmup_steps = static_cast<size_t>(ctx.Attr<int64_t>("warmup_steps"));
auto total_steps = static_cast<size_t>(ctx.Attr<int64_t>("total_steps"));
PADDLE_ENFORCE_LE(warmup_steps,
total_steps,
platform::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto base_lr = ctx.Attr<float>("base_lr");
auto end_lr = ctx.Attr<float>("end_lr");
template <typename T, typename Context>
void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
const DenseTensor& lr,
const DenseTensor& step,
int64_t warmup_steps,
int64_t total_steps,
float base_lr,
float end_lr,
DenseTensor* lr_out,
DenseTensor* step_out) {
PADDLE_ENFORCE_EQ(&lr,
lr_out,
phi::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_EQ(&step,
step_out,
phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ(
step.IsInitialized(),
true,
phi::errors::InvalidArgument("Input(Step) must be initialized."));
auto *lr_data = lr_out->data<T>();
auto *step_data = step_out->data<int64_t>();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, 1);
using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data,
step_data,
warmup_steps,
total_steps,
static_cast<AttrT>(base_lr),
static_cast<AttrT>(end_lr));
for_range(functor);
}
};
PADDLE_ENFORCE_LE(warmup_steps,
total_steps,
phi::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
} // namespace operators
} // namespace paddle
auto* lr_data = lr_out->data<T>();
auto* step_data = step_out->data<int64_t>();
phi::funcs::ForRange<Context> for_range(dev_ctx, 1);
using AttrT = double;
Pow2DecayWithLinearWarmupFunctor<T, AttrT> functor(
lr_data,
step_data,
static_cast<size_t>(warmup_steps),
static_cast<size_t>(total_steps),
static_cast<AttrT>(base_lr),
static_cast<AttrT>(end_lr));
for_range(functor);
}
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
const DenseTensor& lr,
const DenseTensor& step,
int64_t warmup_steps,
int64_t total_steps,
float base_lr,
float end_lr,
DenseTensor* lr_out,
DenseTensor* step_out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pow2_decay_with_linear_warmup_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/macros.h"
namespace phi {
template <typename T, typename Context>
void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
const DenseTensor& lr,
const DenseTensor& step,
int64_t warmup_steps,
int64_t total_steps,
float base_lr,
float end_lr,
DenseTensor* lr_out,
DenseTensor* step_out) {
PADDLE_ENFORCE_EQ(&lr,
lr_out,
phi::errors::InvalidArgument("Input(LearningRate) and "
"Output(LearningRateOut) "
"must be the same."));
PADDLE_ENFORCE_EQ(&step,
step_out,
phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ(
step.IsInitialized(),
true,
phi::errors::InvalidArgument("Input(Step) must be initialized."));
PADDLE_ENFORCE_LE(warmup_steps,
total_steps,
phi::errors::InvalidArgument(
"warmup_steps must not be larger than total_steps."));
auto* lr_data = lr_out->data<T>();
auto* step_data = step_out->data<int64_t>();
int r = xpu::pow2_decay_with_linear_warmup(dev_ctx.x_context(),
lr_data,
step_data,
static_cast<size_t>(warmup_steps),
static_cast<size_t>(total_steps),
base_lr,
end_lr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow2_decay_with_linear_warmup");
}
} // namespace phi
PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
XPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature Pow2DecayWithLinearWarmupOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("pow2_decay_with_linear_warmup",
{"LearningRate", "Step"},
{"warmup_steps", "total_steps", "base_lr", "end_lr"},
{"LearningRateOut", "StepOut"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pow2_decay_with_linear_warmup,
phi::Pow2DecayWithLinearWarmupOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册