未验证 提交 f5ec0314 编写于 作者: A Aurelius84 提交者: GitHub

[Phi]Migrate Adamax and Adadelta Optimizer Op into Phi (#40173)

* [Phi]Migrate Adamax into phi

* Add adadelta kernel
上级 da3de72d
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adadelta_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -23,77 +26,6 @@ class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::InvalidArgument(
"Input(Param) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
platform::errors::InvalidArgument(
"Input(Grad) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("AvgSquaredGrad"), true,
platform::errors::InvalidArgument(
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("AvgSquaredUpdate"), true,
platform::errors::InvalidArgument(
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
true,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(),
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
true,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(),
ctx->GetInputsVarType("Grad").front()));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("ParamOut"), true,
platform::errors::InvalidArgument(
"Output(ParamOut) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("AvgSquaredGradOut"), true,
platform::errors::InvalidArgument(
"Output(AvgSquaredGradOut) of AdadeltaOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("AvgSquaredUpdateOut"), true,
platform::errors::InvalidArgument(
"Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null."));
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and grad input of AdadeltaOp should have same dimension."));
PADDLE_ENFORCE_NE(
phi::product(ctx->GetInputDim("AvgSquaredGrad")), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable AvgSquaredGrad has not "
"been initialized. You may need to confirm if you put "
"exe.run(startup_program) after optimizer.minimize "
"function."));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"),
platform::errors::InvalidArgument(
"Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension"));
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"),
platform::errors::InvalidArgument(
"Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension"));
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
......@@ -149,7 +81,11 @@ $$
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double>);
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(adadelta, AdadeltaInferMetaFunctor,
PT_INFER_META(phi::AdadeltaInferMeta));
REGISTER_OPERATOR(
adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AdadeltaInferMetaFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adadelta_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut");
auto avg_squared_update_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredUpdateOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
T rho = static_cast<T>(ctx.Attr<float>("rho"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
// Squared gradient accumulator
auto avg_squared_grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredGrad"));
// Squared updates accumulator
auto avg_squared_update = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("AvgSquaredUpdate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto avg_squared_grad_out =
framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
auto avg_squared_update_out =
framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
avg_squared_grad_out.device(place) =
rho * avg_squared_grad + (1 - rho) * grad.square();
auto update =
-((avg_squared_update + epsilon) / (avg_squared_grad_out + epsilon))
.sqrt() *
grad;
avg_squared_update_out.device(place) =
rho * avg_squared_update + (1 - rho) * update.square();
param_out.device(place) = param + update;
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adamax_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -22,67 +25,6 @@ class AdamaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adamax");
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adamax");
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adamax");
OP_INOUT_CHECK(ctx->HasInput("InfNorm"), "Input", "InfNorm", "Adamax");
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
"Adamax");
OP_INOUT_CHECK(ctx->HasInput("Beta1Pow"), "Input", "Beta1Pow", "Adamax");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(),
ctx->GetInputsVarType("Param").front()));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Grad").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(),
ctx->GetInputsVarType("Grad").front()));
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adamax");
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
"Adamax");
OP_INOUT_CHECK(ctx->HasOutput("InfNormOut"), "Output", "InfNormOut",
"Adamax");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(phi::product(lr_dims), 0,
platform::errors::InvalidArgument(
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1,
platform::errors::InvalidArgument(
"Learning rate should have 1 dimension"));
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_EQ(phi::product(beta1_pow_dims), 1,
platform::errors::InvalidArgument(
"Beta1 power accumulator should have 1 dimension"));
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment"),
platform::errors::InvalidArgument(
"Param and Moment input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("InfNorm"),
platform::errors::InvalidArgument(
"Param and InfNorm input of AdamaxOp should have same dimension"));
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
......@@ -150,7 +92,11 @@ division by 0 error.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(
adamax, ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, double>);
DELCARE_INFER_SHAPE_FUNCTOR(adamax, AdamaxInferMetaFunctor,
PT_INFER_META(phi::AdamaxInferMeta));
REGISTER_OPERATOR(
adamax, ops::AdamaxOp, ops::AdamaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
AdamaxInferMetaFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adamax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
adamax, ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace());
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
auto grad = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto inf_norm = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("InfNorm"));
auto lr = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("LearningRate"));
auto beta1_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta1Pow"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto inf_norm_out =
framework::EigenVector<T>::Flatten(*inf_norm_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(*place) = beta1 * moment + (1 - beta1) * grad;
inf_norm_out.device(*place) =
grad.abs().cwiseMax((beta2 * inf_norm) + epsilon);
auto lr_t = lr / (1 - beta1_pow);
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(*place) =
param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out);
}
};
} // namespace operators
} // namespace paddle
......@@ -28,6 +28,98 @@ std::vector<DDim> GetMetaTensorsDim(const std::vector<MetaTensor*>& tensors) {
return dims;
}
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
product(lr_dims),
0,
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(
product(lr_dims),
1,
errors::InvalidArgument("Learning rate should have 1 dimension"));
auto beta1_pow_dims = beta1_pow.dims();
PADDLE_ENFORCE_EQ(product(beta1_pow_dims),
1,
errors::InvalidArgument(
"Beta1 power accumulator should have 1 dimension"));
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and Grad input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
moment.dims(),
errors::InvalidArgument(
"Param and Moment input of AdamaxOp should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
inf_norm.dims(),
errors::InvalidArgument(
"Param and InfNorm input of AdamaxOp should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
moment_out->set_dims(param_dims);
moment_out->set_dtype(moment.dtype());
inf_norm_out->set_dims(param_dims);
inf_norm_out->set_dtype(inf_norm.dtype());
}
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out) {
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
grad.dims(),
errors::InvalidArgument(
"Param and grad input of AdadeltaOp should have same dimension."));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_grad.dims(),
errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp "
"should have same dimension"));
PADDLE_ENFORCE_EQ(
param_dims,
avg_squared_update.dims(),
errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp "
"should have same dimension"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
avg_squared_grad_out->set_dims(param_dims);
avg_squared_grad_out->set_dtype(avg_squared_grad.dtype());
avg_squared_update_out->set_dims(param_dims);
avg_squared_update_out->set_dtype(avg_squared_update.dtype());
}
void BilinearTensorProductInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& weight,
......
......@@ -39,4 +39,28 @@ void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment,
const MetaTensor& inf_norm,
const MetaTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
MetaTensor* param_out,
MetaTensor* moment_out,
MetaTensor* inf_norm_out);
void AdadeltaInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& avg_squared_grad,
const MetaTensor& avg_squared_update,
float rho,
float epsilon,
MetaTensor* param_out,
MetaTensor* avg_squared_grad_out,
MetaTensor* avg_squared_update_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AdadeltaKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& avg_squared_grad,
const DenseTensor& avg_squared_update,
float rho,
float epsilon,
DenseTensor* param_out,
DenseTensor* avg_squared_grad_out,
DenseTensor* avg_squared_update_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AdamaxKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment,
const DenseTensor& inf_norm,
const DenseTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
DenseTensor* param_out,
DenseTensor* moment_out,
DenseTensor* inf_norm_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adadelta_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h"
PD_REGISTER_KERNEL(
adadelta, CPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adamax_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/adamax_kernel_impl.h"
PD_REGISTER_KERNEL(adamax, CPU, ALL_LAYOUT, phi::AdamaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adadelta_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h"
PD_REGISTER_KERNEL(
adadelta, GPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adamax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/adamax_kernel_impl.h"
PD_REGISTER_KERNEL(adamax, GPU, ALL_LAYOUT, phi::AdamaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/adadelta_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void AdadeltaKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& avg_squared_grad,
const DenseTensor& avg_squared_update,
float rho,
float epsilon,
DenseTensor* param_out,
DenseTensor* avg_squared_grad_out,
DenseTensor* avg_squared_update_out) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(avg_squared_grad_out);
dev_ctx.template Alloc<T>(avg_squared_update_out);
T rho_ = static_cast<T>(rho);
T epsilon_ = static_cast<T>(epsilon);
auto eigen_param = EigenVector<T>::Flatten(param);
auto eigen_grad = EigenVector<T>::Flatten(grad);
// Squared gradient accumulator
auto eigen_avg_squared_grad = EigenVector<T>::Flatten(avg_squared_grad);
// Squared updates accumulator
auto eigen_avg_squared_update = EigenVector<T>::Flatten(avg_squared_update);
auto eigen_param_out = EigenVector<T>::Flatten(*param_out);
auto eigen_avg_squared_grad_out =
EigenVector<T>::Flatten(*avg_squared_grad_out);
auto eigen_avg_squared_update_out =
EigenVector<T>::Flatten(*avg_squared_update_out);
auto& place = *dev_ctx.eigen_device();
eigen_avg_squared_grad_out.device(place) =
rho_ * eigen_avg_squared_grad + (1 - rho_) * eigen_grad.square();
auto update = -((eigen_avg_squared_update + epsilon_) /
(eigen_avg_squared_grad_out + epsilon_))
.sqrt() *
eigen_grad;
eigen_avg_squared_update_out.device(place) =
rho_ * eigen_avg_squared_update + (1 - rho_) * update.square();
eigen_param_out.device(place) = eigen_param + update;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/adamax_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void AdamaxKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment,
const DenseTensor& inf_norm,
const DenseTensor& beta1_pow,
float beta1,
float beta2,
float epsilon,
DenseTensor* param_out,
DenseTensor* moment_out,
DenseTensor* inf_norm_out) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(moment_out);
dev_ctx.template Alloc<T>(inf_norm_out);
T beta1_ = static_cast<T>(beta1);
T beta2_ = static_cast<T>(beta2);
T epsilon_ = static_cast<T>(epsilon);
auto eigen_param = EigenVector<T>::Flatten(param);
auto eigen_grad = EigenVector<T>::Flatten(grad);
auto eigen_moment = EigenVector<T>::Flatten(moment);
auto eigen_inf_norm = EigenVector<T>::Flatten(inf_norm);
auto eigen_lr = EigenVector<T>::Flatten(learning_rate);
auto eigen_beta1_pow = EigenVector<T>::Flatten(beta1_pow);
auto eigen_param_out = EigenVector<T>::Flatten(*param_out);
auto eigen_moment_out = EigenVector<T>::Flatten(*moment_out);
auto eigen_inf_norm_out = EigenVector<T>::Flatten(*inf_norm_out);
auto& place = *dev_ctx.eigen_device();
eigen_moment_out.device(place) =
beta1_ * eigen_moment + (1 - beta1_) * eigen_grad;
eigen_inf_norm_out.device(place) =
eigen_grad.abs().cwiseMax((beta2_ * eigen_inf_norm) + epsilon_);
auto lr_t = eigen_lr / (1 - eigen_beta1_pow);
Eigen::DSizes<int, 1> m_dsize(moment_out->numel());
eigen_param_out.device(place) =
eigen_param -
lr_t.broadcast(m_dsize) * (eigen_moment_out / eigen_inf_norm_out);
}
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册