From f3d54e2eaa668a04c230cab2291e4b222daed4b9 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 2 Mar 2022 20:49:28 +0800 Subject: [PATCH] Move sgd to phi (#40045) * move sgd to phi; test=develop * update * add sgd kernel; test=develop --- paddle/fluid/framework/operator.cc | 6 +- .../operators/optimizers/dgc_momentum_op.h | 65 +++++- paddle/fluid/operators/optimizers/sgd_op.cc | 5 - paddle/fluid/operators/optimizers/sgd_op.cu | 7 - paddle/phi/core/kernel_registry.h | 6 + paddle/phi/core/kernel_utils.h | 1 + paddle/phi/kernels/cpu/sgd_kernel.cc | 213 ++++++++++++++++++ paddle/phi/kernels/gpu/sgd_kernel.cu | 209 +++++++++++++++++ paddle/phi/kernels/sgd_kernel.h | 54 +++++ paddle/phi/ops/compat/sgd_sig.cc | 44 ++++ 10 files changed, 592 insertions(+), 18 deletions(-) create mode 100644 paddle/phi/kernels/cpu/sgd_kernel.cc create mode 100644 paddle/phi/kernels/gpu/sgd_kernel.cu create mode 100644 paddle/phi/kernels/sgd_kernel.h create mode 100644 paddle/phi/ops/compat/sgd_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ffdc3e6d3c2..6414dd455db 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2051,7 +2051,11 @@ void OperatorWithKernel::BuildPhiKernelContext( // deal with optional here if ((it == ctx.inputs.end() || it->second.size() == 0) && (input_defs[i].type_index == - std::type_index(typeid(paddle::optional)))) { + std::type_index( + typeid(paddle::optional)) || + input_defs[i].type_index == + std::type_index( + typeid(paddle::optional)))) { pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); auto end_idx = start_idx + 1; pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.h b/paddle/fluid/operators/optimizers/dgc_momentum_op.h index bea019f1f36..c86f544ed77 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.h +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.h @@ -17,7 +17,7 @@ #include #include "paddle/fluid/operators/optimizers/momentum_op.h" -#include "paddle/fluid/operators/optimizers/sgd_op.h" +#include "paddle/phi/kernels/sgd_kernel.h" namespace paddle { namespace operators { @@ -26,8 +26,7 @@ template class DGCMomentumKernel : public framework::OpKernel { public: DGCMomentumKernel() - : _momentum_op_kernel(new MomentumOpKernel()), - _sgd_op_kernel(new SGDOpKernel()) {} + : _momentum_op_kernel(new MomentumOpKernel()) {} void Compute(const framework::ExecutionContext& context) const override { auto rampup_begin_step = context.Attr("rampup_begin_step"); @@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel { } VLOG(10) << " so use sgd optimizer"; - return _sgd_op_kernel->Compute(context); + + const auto* param_var = context.InputVar("Param"); + const auto* grad_var = context.InputVar("Grad"); + auto* learning_rate = context.Input("LearningRate"); + bool multi_precision = context.Attr("multi_precision"); + if (param_var->IsType()) { + auto* param = context.Input("Param"); + auto* param_out = context.Output("ParamOut"); + auto* master_param_out = + context.Output("MasterParamOut"); + paddle::optional master_param_opt = + paddle::none; + if (multi_precision) { + auto* master_param = context.Input("MasterParam"); + master_param_opt = *master_param; + } + + if (grad_var->IsType()) { + // sgd_dense + auto* grad = context.Input("Grad"); + phi::SGDDenseKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + } else { + // sgd dense param sparse grad + auto* grad = context.Input("Grad"); + phi::SGDDenseParamSparseGradKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + } + } else if (param_var->IsType() && + grad_var->IsType() && + platform::is_cpu_place(context.GetPlace())) { + // sgd sparse param sparse grad + auto* param = context.Input("Param"); + auto* param_out = context.Output("ParamOut"); + auto* master_param_out = + context.Output("MasterParamOut"); + paddle::optional master_param_opt = + paddle::none; + if (multi_precision) { + auto* master_param = context.Input("MasterParam"); + master_param_opt = *master_param; + } + auto* grad = context.Input("Grad"); + phi::SGDSparseParamSparseGradKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + + } else { + PADDLE_THROW("gdc not support yet"); + } } private: std::unique_ptr> _momentum_op_kernel; - std::unique_ptr> _sgd_op_kernel; }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 529d60a2820..0e3f895d276 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -166,8 +166,3 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType); -REGISTER_OP_CPU_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel, - ops::SGDOpKernel); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cu b/paddle/fluid/operators/optimizers/sgd_op.cu index 3149f5f56ed..222244a2fd1 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cu +++ b/paddle/fluid/operators/optimizers/sgd_op.cu @@ -166,10 +166,3 @@ class SGDOpKernel }; } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel, - ops::SGDOpKernel); diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 7a05452cbeb..2b04d173af0 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -81,6 +81,12 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(const std::vector&))) { args_def->AppendInput(default_key.backend(), diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index e5de5e2b49e..b582375155a 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -219,6 +219,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); diff --git a/paddle/phi/kernels/cpu/sgd_kernel.cc b/paddle/phi/kernels/cpu/sgd_kernel.cc new file mode 100644 index 00000000000..c7b4074c70a --- /dev/null +++ b/paddle/phi/kernels/cpu/sgd_kernel.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void sgd_dense_param_dense_grad_impl(const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + DenseTensor* param_out) { + const auto sz = param_out->numel(); + paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); + const T* lr = learning_rate.data(); + const T* param_data = param.data(); + const T* grad_data = grad.data(); + int64_t rows_idx = 0; + T* out_data = param_out->data(); + + auto sgd = + paddle::operators::jit::KernelFuncs, + phi::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); +} + +template <> +void sgd_dense_param_dense_grad_impl( + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + DenseTensor* param_out) { + auto p = EigenVector::Flatten(param); + auto g = EigenVector::Flatten(grad); + auto o = EigenVector::Flatten(*param_out); + const auto* lr = learning_rate.data(); + + o = p - lr[0] * g; +} + +template +void sgd_dense_param_sparse_grad_impl(const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + DenseTensor* param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const T* param_data = param.data(); + const T* grad_data = grad_value.data(); + const T* lr = learning_rate.data(); + const int64_t* rows_data = grad_rows.data(); + T* out_data = param_out->data(); + + paddle::operators::jit::sgd_attr_t attr; + attr.param_height = param_out->dims()[0]; + attr.param_width = param_out->numel() / attr.param_height; + attr.grad_height = grad_rows.size(); // note: it is not grad->height() + attr.grad_width = grad_value.numel() / attr.grad_height; + attr.selected_rows_size = grad_rows.size(); + + auto sgd = + paddle::operators::jit::KernelFuncs, + phi::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, rows_data, out_data, &attr); +} + +template <> +void sgd_dense_param_sparse_grad_impl( + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + DenseTensor* param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const auto grad_height = grad.height(); + const int64_t grad_val_height = static_cast(grad_rows.size()); + const auto grad_width = grad_value.numel() / grad_val_height; + + const auto* grad_data = grad_value.data(); + auto* out_data = param_out->data(); + const auto* lr = learning_rate.data(); + + for (size_t i = 0; i < grad_rows.size(); ++i) { + PADDLE_ENFORCE_LT( + grad_rows[i], + grad_height, + phi::errors::OutOfRange( + "Grad rows index value should be less than grad height." + "Got [%s], but expected less than [%s]", + grad_rows[i], + grad_height)); + const int64_t row = grad_rows[i]; + for (int64_t j = 0; j < grad_width; ++j) { + out_data[row * grad_width + j] -= lr[0] * grad_data[i * grad_width + j]; + } + } +} + +template +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + dev_ctx.template Alloc(param_out); + sgd_dense_param_dense_grad_impl(param, learning_rate, grad, param_out); +} + +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + dev_ctx.template Alloc(param_out); + sgd_dense_param_sparse_grad_impl(param, learning_rate, grad, param_out); +} + +template +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out) { + // for distributed training, a sparse var may be empty, + // just skip updating. + if (grad.rows().size() == 0) { + return; + } + + auto param_row_width = param.value().dims()[1]; + auto grad_row_width = grad.value().dims()[1]; + PADDLE_ENFORCE_EQ( + param_row_width, + grad_row_width, + phi::errors::InvalidArgument( + "The param_row in SgdOP should have the same size with grad_row. " + "But received param_row's width is [%s], and grad_row's width is " + "[%s]", + param_row_width, + grad_row_width)); + + const auto* lr = learning_rate.data(); + const auto* grad_data = grad.value().data(); + auto* out_data = param_out->mutable_value()->data(); + for (size_t i = 0; i < grad.rows().size(); i++) { + int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false); + PADDLE_ENFORCE_GE( + id_index, + static_cast(0), + phi::errors::InvalidArgument( + "The id in SgdOp should be >= 0. But recevied id_index is [%s]", + id_index)); + for (int64_t j = 0; j < grad_row_width; j++) { + out_data[id_index * grad_row_width + j] -= + lr[0] * grad_data[i * grad_row_width + j]; + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(sgd, + CPU, + ALL_LAYOUT, + phi::SGDDenseKernel, + phi::dtype::bfloat16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + phi::dtype::bfloat16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SGDSparseParamSparseGradKernel, + phi::dtype::bfloat16, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/sgd_kernel.cu b/paddle/phi/kernels/gpu/sgd_kernel.cu new file mode 100644 index 00000000000..7dd5a03383f --- /dev/null +++ b/paddle/phi/kernels/gpu/sgd_kernel.cu @@ -0,0 +1,209 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_helper.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void SGDKernelMT(const T* param, + const T* grad, + const T* learning_rate, + const int num, + T* param_out, + const MT* master_param, + MT* master_param_out) { + MT lr = static_cast(learning_rate[0]); + CUDA_KERNEL_LOOP(i, num) { + MT p_data = master_param ? master_param[i] : static_cast(param[i]); + MT g_data = static_cast(grad[i]); + p_data = p_data - lr * g_data; + param_out[i] = static_cast(p_data); + if (master_param_out) { + master_param_out[i] = p_data; + } + } +} + +template +__global__ void SparseSGDFunctorKernel(const T* selected_rows, + const int64_t* rows, + const T* learning_rate, + T* tensor_out, + int64_t row_numel, + int64_t limit) { + for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) { + const T* selected_rows_ptr = selected_rows + i * row_numel; + T* tensor_out_ptr = tensor_out + rows[i] * row_numel; + for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd( + tensor_out_ptr + index, + -static_cast(1.0) * learning_rate[0] * selected_rows_ptr[index]); + } + } +} + +template +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + using MPDType = typename paddle::operators::details::MPTypeTrait::Type; + // do check here + // if (multi_precision) { + // bool has_master = + // ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + + // } + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out->mutable_data(dev_ctx.GetPlace()) + : nullptr; + + int block = 512; + int grid = (param.numel() + block - 1) / block; + + SGDKernelMT<<>>( + param.data(), + grad.data(), + learning_rate.data(), + param.numel(), + param_out->mutable_data(dev_ctx.GetPlace()), + master_in_data, + master_out_data); +} + +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + using MPDType = typename paddle::operators::details::MPTypeTrait::Type; + // do some check here + // if (multi_precision) { + // bool has_master = + // ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + + // } + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out->mutable_data(dev_ctx.GetPlace()) + : nullptr; + + PADDLE_ENFORCE_EQ( + ¶m, + param_out, + phi::errors::InvalidArgument( + "The input tensor Param of SgdOp should be equal with ParamOut " + "if variable's type is SelectedRows.")); + + auto in_height = grad.height(); + auto out_dims = param_out->dims(); + PADDLE_ENFORCE_EQ(in_height, + out_dims[0], + phi::errors::InvalidArgument( + "The input tensor Grad's height of SgdOp should be " + "equal with ParamOut's dims. But received Grad's " + "height [%s] and ParamOut's dims [%s]", + in_height, + out_dims[0])); + + auto& in_value = grad.value(); + auto& in_rows = grad.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, + param_out->numel() / in_height, + phi::errors::InvalidArgument( + "The in_row_numel of SgdOp should be equal with " + "param_out's numel / in_height.")); + + auto* in_data = in_value.data(); + auto* out_data = param_out->data(); + + const int kThreadsPerBlock = 256; + int thread_x = kThreadsPerBlock; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + paddle::framework::MixVector mixv_in_rows(&in_rows); + SparseSGDFunctorKernel<<>>( + in_data, + mixv_in_rows.CUDAData(dev_ctx.GetPlace()), + learning_rate.data(), + out_data, + in_row_numel, + in_rows.size()); +} + +template +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out) { + PADDLE_THROW("not impl"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(sgd, + GPU, + ALL_LAYOUT, + phi::SGDDenseKernel, + phi::dtype::float16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + phi::dtype::float16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::SGDSparseParamSparseGradKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/sgd_kernel.h b/paddle/phi/kernels/sgd_kernel.h new file mode 100644 index 00000000000..12361c738e2 --- /dev/null +++ b/paddle/phi/kernels/sgd_kernel.h @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); + +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); + +template +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/sgd_sig.cc b/paddle/phi/ops/compat/sgd_sig.cc new file mode 100644 index 00000000000..cdf1a221f7e --- /dev/null +++ b/paddle/phi/ops/compat/sgd_sig.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Grad")) { + return KernelSignature("sgd", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } else if (ctx.IsSelectedRowsInput("Grad")) { + if (ctx.IsDenseTensorInput("Param")) { + return KernelSignature("sgd_dense_param_sparse_grad", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } else { + return KernelSignature("sgd_sparse_param_sparse_grad", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sgd, phi::SGDOpArgumentMapping); -- GitLab