diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index ad7f93d73e902bbac684832d3a77ba83b517daf6..3cafbce04d3333a29f978f94e286686405e50e7e 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/adadelta_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -23,77 +26,6 @@ class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, - platform::errors::InvalidArgument( - "Input(Param) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true, - platform::errors::InvalidArgument( - "Input(Grad) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("AvgSquaredGrad"), true, - platform::errors::InvalidArgument( - "Input(AvgSquaredGrad) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("AvgSquaredUpdate"), true, - platform::errors::InvalidArgument( - "Input(AvgSquaredUpdate) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->GetInputsVarType("Param").front() == - framework::proto::VarType::LOD_TENSOR, - true, - platform::errors::InvalidArgument( - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), - ctx->GetInputsVarType("Param").front())); - PADDLE_ENFORCE_EQ( - ctx->GetInputsVarType("Grad").front() == - framework::proto::VarType::LOD_TENSOR, - true, - platform::errors::InvalidArgument( - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Grad").front(), - ctx->GetInputsVarType("Grad").front())); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("ParamOut"), true, - platform::errors::InvalidArgument( - "Output(ParamOut) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("AvgSquaredGradOut"), true, - platform::errors::InvalidArgument( - "Output(AvgSquaredGradOut) of AdadeltaOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("AvgSquaredUpdateOut"), true, - platform::errors::InvalidArgument( - "Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.")); - - auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and grad input of AdadeltaOp should have same dimension.")); - PADDLE_ENFORCE_NE( - phi::product(ctx->GetInputDim("AvgSquaredGrad")), 0, - platform::errors::InvalidArgument( - "Maybe the Input variable AvgSquaredGrad has not " - "been initialized. You may need to confirm if you put " - "exe.run(startup_program) after optimizer.minimize " - "function.")); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"), - platform::errors::InvalidArgument( - "Param and AvgSquaredGrad input of AdadeltaOp " - "should have same dimension")); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"), - platform::errors::InvalidArgument( - "Param and AvgSquaredUpdate input of AdadeltaOp " - "should have same dimension")); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("AvgSquaredGradOut", param_dim); - ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( @@ -149,7 +81,11 @@ $$ } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker); -REGISTER_OP_CPU_KERNEL( - adadelta, ops::AdadeltaOpKernel, - ops::AdadeltaOpKernel); +namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(adadelta, AdadeltaInferMetaFunctor, + PT_INFER_META(phi::AdadeltaInferMeta)); +REGISTER_OPERATOR( + adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + AdadeltaInferMetaFunctor); diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cu b/paddle/fluid/operators/optimizers/adadelta_op.cu deleted file mode 100644 index 562a157f063b44d65254d556d44439eee3636c4c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adadelta_op.cu +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/optimizers/adadelta_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - adadelta, ops::AdadeltaOpKernel, - ops::AdadeltaOpKernel); diff --git a/paddle/fluid/operators/optimizers/adadelta_op.h b/paddle/fluid/operators/optimizers/adadelta_op.h deleted file mode 100644 index 85cfad35858bbe6b112169f196c0711d981e9446..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adadelta_op.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class AdadeltaOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - - auto param_out_tensor = ctx.Output("ParamOut"); - auto avg_squared_grad_out_tensor = - ctx.Output("AvgSquaredGradOut"); - auto avg_squared_update_out_tensor = - ctx.Output("AvgSquaredUpdateOut"); - - param_out_tensor->mutable_data(ctx.GetPlace()); - avg_squared_grad_out_tensor->mutable_data(ctx.GetPlace()); - avg_squared_update_out_tensor->mutable_data(ctx.GetPlace()); - - T rho = static_cast(ctx.Attr("rho")); - T epsilon = static_cast(ctx.Attr("epsilon")); - - auto param = framework::EigenVector::Flatten( - *ctx.Input("Param")); - auto grad = framework::EigenVector::Flatten( - *ctx.Input("Grad")); - // Squared gradient accumulator - auto avg_squared_grad = framework::EigenVector::Flatten( - *ctx.Input("AvgSquaredGrad")); - // Squared updates accumulator - auto avg_squared_update = framework::EigenVector::Flatten( - *ctx.Input("AvgSquaredUpdate")); - auto param_out = framework::EigenVector::Flatten(*param_out_tensor); - auto avg_squared_grad_out = - framework::EigenVector::Flatten(*avg_squared_grad_out_tensor); - auto avg_squared_update_out = - framework::EigenVector::Flatten(*avg_squared_update_out_tensor); - auto& place = *ctx.template device_context().eigen_device(); - - avg_squared_grad_out.device(place) = - rho * avg_squared_grad + (1 - rho) * grad.square(); - auto update = - -((avg_squared_update + epsilon) / (avg_squared_grad_out + epsilon)) - .sqrt() * - grad; - avg_squared_update_out.device(place) = - rho * avg_squared_update + (1 - rho) * update.square(); - param_out.device(place) = param + update; - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index a95a37c980c8c9d41dc9fd352e3dace787a7c4e9..29f3d3b09decc3911543cb8a13df35fc1f7174ae 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/adamax_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -22,67 +25,6 @@ class AdamaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adamax"); - OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adamax"); - OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adamax"); - OP_INOUT_CHECK(ctx->HasInput("InfNorm"), "Input", "InfNorm", "Adamax"); - OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate", - "Adamax"); - OP_INOUT_CHECK(ctx->HasInput("Beta1Pow"), "Input", "Beta1Pow", "Adamax"); - PADDLE_ENFORCE_EQ( - ctx->GetInputsVarType("Param").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), - ctx->GetInputsVarType("Param").front())); - PADDLE_ENFORCE_EQ( - ctx->GetInputsVarType("Grad").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Grad").front(), - ctx->GetInputsVarType("Grad").front())); - - OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adamax"); - OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut", - "Adamax"); - OP_INOUT_CHECK(ctx->HasOutput("InfNormOut"), "Output", "InfNormOut", - "Adamax"); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), 0, - platform::errors::InvalidArgument( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, - platform::errors::InvalidArgument( - "Learning rate should have 1 dimension")); - auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); - PADDLE_ENFORCE_EQ(phi::product(beta1_pow_dims), 1, - platform::errors::InvalidArgument( - "Beta1 power accumulator should have 1 dimension")); - auto param_dims = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and Grad input of AdamaxOp should have same dimension")); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Moment"), - platform::errors::InvalidArgument( - "Param and Moment input of AdamaxOp should have same dimension")); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("InfNorm"), - platform::errors::InvalidArgument( - "Param and InfNorm input of AdamaxOp should have same dimension")); - - ctx->SetOutputDim("ParamOut", param_dims); - ctx->SetOutputDim("MomentOut", param_dims); - ctx->SetOutputDim("InfNormOut", param_dims); - } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( @@ -150,7 +92,11 @@ division by 0 error. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker); -REGISTER_OP_CPU_KERNEL( - adamax, ops::AdamaxOpKernel, - ops::AdamaxOpKernel); +DELCARE_INFER_SHAPE_FUNCTOR(adamax, AdamaxInferMetaFunctor, + PT_INFER_META(phi::AdamaxInferMeta)); + +REGISTER_OPERATOR( + adamax, ops::AdamaxOp, ops::AdamaxOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + AdamaxInferMetaFunctor); diff --git a/paddle/fluid/operators/optimizers/adamax_op.cu b/paddle/fluid/operators/optimizers/adamax_op.cu deleted file mode 100644 index 80e0219d4414db2909b5babc22599d8c0d906c7d..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adamax_op.cu +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/optimizers/adamax_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - adamax, ops::AdamaxOpKernel, - ops::AdamaxOpKernel); diff --git a/paddle/fluid/operators/optimizers/adamax_op.h b/paddle/fluid/operators/optimizers/adamax_op.h deleted file mode 100644 index df0112448b1cbc82d699dc1ee6f3444bda3b142b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adamax_op.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class AdamaxOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - - auto param_out_tensor = ctx.Output("ParamOut"); - auto moment_out_tensor = ctx.Output("MomentOut"); - auto inf_norm_out_tensor = ctx.Output("InfNormOut"); - - param_out_tensor->mutable_data(ctx.GetPlace()); - moment_out_tensor->mutable_data(ctx.GetPlace()); - inf_norm_out_tensor->mutable_data(ctx.GetPlace()); - - T beta1 = static_cast(ctx.Attr("beta1")); - T beta2 = static_cast(ctx.Attr("beta2")); - T epsilon = static_cast(ctx.Attr("epsilon")); - - auto param = framework::EigenVector::Flatten( - *ctx.Input("Param")); - auto grad = framework::EigenVector::Flatten( - *ctx.Input("Grad")); - auto moment = framework::EigenVector::Flatten( - *ctx.Input("Moment")); - auto inf_norm = framework::EigenVector::Flatten( - *ctx.Input("InfNorm")); - auto lr = framework::EigenVector::Flatten( - *ctx.Input("LearningRate")); - auto beta1_pow = framework::EigenVector::Flatten( - *ctx.Input("Beta1Pow")); - auto param_out = framework::EigenVector::Flatten(*param_out_tensor); - auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); - auto inf_norm_out = - framework::EigenVector::Flatten(*inf_norm_out_tensor); - auto* place = ctx.template device_context().eigen_device(); - - moment_out.device(*place) = beta1 * moment + (1 - beta1) * grad; - inf_norm_out.device(*place) = - grad.abs().cwiseMax((beta2 * inf_norm) + epsilon); - auto lr_t = lr / (1 - beta1_pow); - Eigen::DSizes m_dsize(moment_out_tensor->numel()); - param_out.device(*place) = - param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index dc5478e8afb981defa9bc493cb440cead4f5965f..a21f077c09f09dfa493a8f5e6eceb5a71cf99e0a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -28,6 +28,98 @@ std::vector GetMetaTensorsDim(const std::vector& tensors) { return dims; } +void AdamaxInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment, + const MetaTensor& inf_norm, + const MetaTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* inf_norm_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_NE( + product(lr_dims), + 0, + errors::InvalidArgument("Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + PADDLE_ENFORCE_EQ( + product(lr_dims), + 1, + errors::InvalidArgument("Learning rate should have 1 dimension")); + auto beta1_pow_dims = beta1_pow.dims(); + PADDLE_ENFORCE_EQ(product(beta1_pow_dims), + 1, + errors::InvalidArgument( + "Beta1 power accumulator should have 1 dimension")); + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad.dims(), + errors::InvalidArgument( + "Param and Grad input of AdamaxOp should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + moment.dims(), + errors::InvalidArgument( + "Param and Moment input of AdamaxOp should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + inf_norm.dims(), + errors::InvalidArgument( + "Param and InfNorm input of AdamaxOp should have same dimension")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + + moment_out->set_dims(param_dims); + moment_out->set_dtype(moment.dtype()); + + inf_norm_out->set_dims(param_dims); + inf_norm_out->set_dtype(inf_norm.dtype()); +} + +void AdadeltaInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& avg_squared_grad, + const MetaTensor& avg_squared_update, + float rho, + float epsilon, + MetaTensor* param_out, + MetaTensor* avg_squared_grad_out, + MetaTensor* avg_squared_update_out) { + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad.dims(), + errors::InvalidArgument( + "Param and grad input of AdadeltaOp should have same dimension.")); + PADDLE_ENFORCE_EQ( + param_dims, + avg_squared_grad.dims(), + errors::InvalidArgument("Param and AvgSquaredGrad input of AdadeltaOp " + "should have same dimension")); + PADDLE_ENFORCE_EQ( + param_dims, + avg_squared_update.dims(), + errors::InvalidArgument("Param and AvgSquaredUpdate input of AdadeltaOp " + "should have same dimension")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + + avg_squared_grad_out->set_dims(param_dims); + avg_squared_grad_out->set_dtype(avg_squared_grad.dtype()); + + avg_squared_update_out->set_dims(param_dims); + avg_squared_update_out->set_dtype(avg_squared_update.dtype()); +} + void BilinearTensorProductInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 51738c5e08e9842c7cffcdd1a2ce7ee3764d6412..8cb6f70481de3160c7eb0a7d6633ba76905a5c41 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -39,4 +39,28 @@ void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, MetaTensor* out); + +void AdamaxInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment, + const MetaTensor& inf_norm, + const MetaTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* inf_norm_out); + +void AdadeltaInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& avg_squared_grad, + const MetaTensor& avg_squared_update, + float rho, + float epsilon, + MetaTensor* param_out, + MetaTensor* avg_squared_grad_out, + MetaTensor* avg_squared_update_out); + } // namespace phi diff --git a/paddle/phi/kernels/adadelta_kernel.h b/paddle/phi/kernels/adadelta_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..65a6aad415193be62424a6c6ac19c1aec6927e8b --- /dev/null +++ b/paddle/phi/kernels/adadelta_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AdadeltaKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& avg_squared_grad, + const DenseTensor& avg_squared_update, + float rho, + float epsilon, + DenseTensor* param_out, + DenseTensor* avg_squared_grad_out, + DenseTensor* avg_squared_update_out); + +} // namespace phi diff --git a/paddle/phi/kernels/adamax_kernel.h b/paddle/phi/kernels/adamax_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..feaf996f16266abbada655907fee68f4ab25bad3 --- /dev/null +++ b/paddle/phi/kernels/adamax_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AdamaxKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment, + const DenseTensor& inf_norm, + const DenseTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + DenseTensor* param_out, + DenseTensor* moment_out, + DenseTensor* inf_norm_out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/adadelta_kernel.cc b/paddle/phi/kernels/cpu/adadelta_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9b5397b616d7208969636d2f3114ecc46611d7b --- /dev/null +++ b/paddle/phi/kernels/cpu/adadelta_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adadelta_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h" + +PD_REGISTER_KERNEL( + adadelta, CPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/adamax_kernel.cc b/paddle/phi/kernels/cpu/adamax_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..867c900e70b687995930de697e1a9ee4c426e255 --- /dev/null +++ b/paddle/phi/kernels/cpu/adamax_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adamax_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/adamax_kernel_impl.h" + +PD_REGISTER_KERNEL(adamax, CPU, ALL_LAYOUT, phi::AdamaxKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/adadelta_kernel.cu b/paddle/phi/kernels/gpu/adadelta_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7516a277a746f12ec7c6326b7ebd8f64c789fc31 --- /dev/null +++ b/paddle/phi/kernels/gpu/adadelta_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adadelta_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h" + +PD_REGISTER_KERNEL( + adadelta, GPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/adamax_kernel.cu b/paddle/phi/kernels/gpu/adamax_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0817c531318c390ead201b1fda18da511ea5569f --- /dev/null +++ b/paddle/phi/kernels/gpu/adamax_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adamax_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/adamax_kernel_impl.h" + +PD_REGISTER_KERNEL(adamax, GPU, ALL_LAYOUT, phi::AdamaxKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/adadelta_kernel_impl.h b/paddle/phi/kernels/impl/adadelta_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..3fbdf435bab39fbded100656a591fc76ea2ca69b --- /dev/null +++ b/paddle/phi/kernels/impl/adadelta_kernel_impl.h @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/adadelta_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void AdadeltaKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& avg_squared_grad, + const DenseTensor& avg_squared_update, + float rho, + float epsilon, + DenseTensor* param_out, + DenseTensor* avg_squared_grad_out, + DenseTensor* avg_squared_update_out) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(avg_squared_grad_out); + dev_ctx.template Alloc(avg_squared_update_out); + + T rho_ = static_cast(rho); + T epsilon_ = static_cast(epsilon); + + auto eigen_param = EigenVector::Flatten(param); + auto eigen_grad = EigenVector::Flatten(grad); + // Squared gradient accumulator + auto eigen_avg_squared_grad = EigenVector::Flatten(avg_squared_grad); + // Squared updates accumulator + auto eigen_avg_squared_update = EigenVector::Flatten(avg_squared_update); + auto eigen_param_out = EigenVector::Flatten(*param_out); + auto eigen_avg_squared_grad_out = + EigenVector::Flatten(*avg_squared_grad_out); + auto eigen_avg_squared_update_out = + EigenVector::Flatten(*avg_squared_update_out); + auto& place = *dev_ctx.eigen_device(); + + eigen_avg_squared_grad_out.device(place) = + rho_ * eigen_avg_squared_grad + (1 - rho_) * eigen_grad.square(); + auto update = -((eigen_avg_squared_update + epsilon_) / + (eigen_avg_squared_grad_out + epsilon_)) + .sqrt() * + eigen_grad; + eigen_avg_squared_update_out.device(place) = + rho_ * eigen_avg_squared_update + (1 - rho_) * update.square(); + eigen_param_out.device(place) = eigen_param + update; +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/adamax_kernel_impl.h b/paddle/phi/kernels/impl/adamax_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..bff553319a2b98814d31a3f992b98541498de149 --- /dev/null +++ b/paddle/phi/kernels/impl/adamax_kernel_impl.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/adamax_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void AdamaxKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment, + const DenseTensor& inf_norm, + const DenseTensor& beta1_pow, + float beta1, + float beta2, + float epsilon, + DenseTensor* param_out, + DenseTensor* moment_out, + DenseTensor* inf_norm_out) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(moment_out); + dev_ctx.template Alloc(inf_norm_out); + + T beta1_ = static_cast(beta1); + T beta2_ = static_cast(beta2); + T epsilon_ = static_cast(epsilon); + + auto eigen_param = EigenVector::Flatten(param); + auto eigen_grad = EigenVector::Flatten(grad); + auto eigen_moment = EigenVector::Flatten(moment); + auto eigen_inf_norm = EigenVector::Flatten(inf_norm); + auto eigen_lr = EigenVector::Flatten(learning_rate); + auto eigen_beta1_pow = EigenVector::Flatten(beta1_pow); + + auto eigen_param_out = EigenVector::Flatten(*param_out); + auto eigen_moment_out = EigenVector::Flatten(*moment_out); + auto eigen_inf_norm_out = EigenVector::Flatten(*inf_norm_out); + + auto& place = *dev_ctx.eigen_device(); + + eigen_moment_out.device(place) = + beta1_ * eigen_moment + (1 - beta1_) * eigen_grad; + eigen_inf_norm_out.device(place) = + eigen_grad.abs().cwiseMax((beta2_ * eigen_inf_norm) + epsilon_); + auto lr_t = eigen_lr / (1 - eigen_beta1_pow); + Eigen::DSizes m_dsize(moment_out->numel()); + eigen_param_out.device(place) = + eigen_param - + lr_t.broadcast(m_dsize) * (eigen_moment_out / eigen_inf_norm_out); +} + +} // namespace phi