未验证 提交 f3d54e2e 编写于 作者: H hong 提交者: GitHub

Move sgd to phi (#40045)

* move sgd to phi; test=develop

* update

* add sgd kernel; test=develop
上级 3fc698fb
...@@ -2051,7 +2051,11 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2051,7 +2051,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
// deal with optional here // deal with optional here
if ((it == ctx.inputs.end() || it->second.size() == 0) && if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index == (input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) { std::type_index(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/phi/kernels/sgd_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,8 +26,7 @@ template <typename DeviceContext, typename T> ...@@ -26,8 +26,7 @@ template <typename DeviceContext, typename T>
class DGCMomentumKernel : public framework::OpKernel<T> { class DGCMomentumKernel : public framework::OpKernel<T> {
public: public:
DGCMomentumKernel() DGCMomentumKernel()
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()), : _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()) {}
_sgd_op_kernel(new SGDOpKernel<DeviceContext, T>()) {}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step"); auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
...@@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel<T> { ...@@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel<T> {
} }
VLOG(10) << " so use sgd optimizer"; VLOG(10) << " so use sgd optimizer";
return _sgd_op_kernel->Compute(context);
const auto* param_var = context.InputVar("Param");
const auto* grad_var = context.InputVar("Grad");
auto* learning_rate = context.Input<framework::Tensor>("LearningRate");
bool multi_precision = context.Attr<bool>("multi_precision");
if (param_var->IsType<framework::LoDTensor>()) {
auto* param = context.Input<framework::Tensor>("Param");
auto* param_out = context.Output<framework::Tensor>("ParamOut");
auto* master_param_out =
context.Output<framework::Tensor>("MasterParamOut");
paddle::optional<const framework::Tensor&> master_param_opt =
paddle::none;
if (multi_precision) {
auto* master_param = context.Input<framework::Tensor>("MasterParam");
master_param_opt = *master_param;
}
if (grad_var->IsType<framework::Tensor>()) {
// sgd_dense
auto* grad = context.Input<framework::Tensor>("Grad");
phi::SGDDenseKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
} else {
// sgd dense param sparse grad
auto* grad = context.Input<phi::SelectedRows>("Grad");
phi::SGDDenseParamSparseGradKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
}
} else if (param_var->IsType<phi::SelectedRows>() &&
grad_var->IsType<phi::SelectedRows>() &&
platform::is_cpu_place(context.GetPlace())) {
// sgd sparse param sparse grad
auto* param = context.Input<phi::SelectedRows>("Param");
auto* param_out = context.Output<phi::SelectedRows>("ParamOut");
auto* master_param_out =
context.Output<phi::SelectedRows>("MasterParamOut");
paddle::optional<const phi::SelectedRows&> master_param_opt =
paddle::none;
if (multi_precision) {
auto* master_param = context.Input<phi::SelectedRows>("MasterParam");
master_param_opt = *master_param;
}
auto* grad = context.Input<phi::SelectedRows>("Grad");
phi::SGDSparseParamSparseGradKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
} else {
PADDLE_THROW("gdc not support yet");
}
} }
private: private:
std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel; std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel;
std::unique_ptr<SGDOpKernel<DeviceContext, T>> _sgd_op_kernel;
}; };
} // namespace operators } // namespace operators
......
...@@ -166,8 +166,3 @@ REGISTER_OPERATOR( ...@@ -166,8 +166,3 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SGDOpInferVarType); ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -166,10 +166,3 @@ class SGDOpKernel<platform::CUDADeviceContext, T> ...@@ -166,10 +166,3 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) { std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
......
...@@ -219,6 +219,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -219,6 +219,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T>
void sgd_dense_param_dense_grad_impl(const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
DenseTensor* param_out) {
const auto sz = param_out->numel();
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T* lr = learning_rate.data<T>();
const T* param_data = param.data<T>();
const T* grad_data = grad.data<T>();
int64_t rows_idx = 0;
T* out_data = param_out->data<T>();
auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
phi::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
}
template <>
void sgd_dense_param_dense_grad_impl<phi::dtype::bfloat16>(
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
DenseTensor* param_out) {
auto p = EigenVector<phi::dtype::bfloat16>::Flatten(param);
auto g = EigenVector<phi::dtype::bfloat16>::Flatten(grad);
auto o = EigenVector<phi::dtype::bfloat16>::Flatten(*param_out);
const auto* lr = learning_rate.data<phi::dtype::bfloat16>();
o = p - lr[0] * g;
}
template <typename T>
void sgd_dense_param_sparse_grad_impl(const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
DenseTensor* param_out) {
const auto& grad_value = grad.value();
const auto& grad_rows = grad.rows();
const T* param_data = param.data<T>();
const T* grad_data = grad_value.data<T>();
const T* lr = learning_rate.data<T>();
const int64_t* rows_data = grad_rows.data();
T* out_data = param_out->data<T>();
paddle::operators::jit::sgd_attr_t attr;
attr.param_height = param_out->dims()[0];
attr.param_width = param_out->numel() / attr.param_height;
attr.grad_height = grad_rows.size(); // note: it is not grad->height()
attr.grad_width = grad_value.numel() / attr.grad_height;
attr.selected_rows_size = grad_rows.size();
auto sgd =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
phi::CPUPlace>::Cache()
.At(attr);
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
}
template <>
void sgd_dense_param_sparse_grad_impl<phi::dtype::bfloat16>(
const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
DenseTensor* param_out) {
const auto& grad_value = grad.value();
const auto& grad_rows = grad.rows();
const auto grad_height = grad.height();
const int64_t grad_val_height = static_cast<int64_t>(grad_rows.size());
const auto grad_width = grad_value.numel() / grad_val_height;
const auto* grad_data = grad_value.data<phi::dtype::bfloat16>();
auto* out_data = param_out->data<phi::dtype::bfloat16>();
const auto* lr = learning_rate.data<phi::dtype::bfloat16>();
for (size_t i = 0; i < grad_rows.size(); ++i) {
PADDLE_ENFORCE_LT(
grad_rows[i],
grad_height,
phi::errors::OutOfRange(
"Grad rows index value should be less than grad height."
"Got [%s], but expected less than [%s]",
grad_rows[i],
grad_height));
const int64_t row = grad_rows[i];
for (int64_t j = 0; j < grad_width; ++j) {
out_data[row * grad_width + j] -= lr[0] * grad_data[i * grad_width + j];
}
}
}
template <typename T, typename Context>
void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
dev_ctx.template Alloc<T>(param_out);
sgd_dense_param_dense_grad_impl<T>(param, learning_rate, grad, param_out);
}
template <typename T, typename Context>
void SGDDenseParamSparseGradKernel(
const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
dev_ctx.template Alloc<T>(param_out);
sgd_dense_param_sparse_grad_impl<T>(param, learning_rate, grad, param_out);
}
template <typename T, typename Context>
void SGDSparseParamSparseGradKernel(
const Context& dev_ctx,
const SelectedRows& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const SelectedRows&> master_param,
bool multi_precision,
SelectedRows* param_out,
SelectedRows* master_param_out) {
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
auto param_row_width = param.value().dims()[1];
auto grad_row_width = grad.value().dims()[1];
PADDLE_ENFORCE_EQ(
param_row_width,
grad_row_width,
phi::errors::InvalidArgument(
"The param_row in SgdOP should have the same size with grad_row. "
"But received param_row's width is [%s], and grad_row's width is "
"[%s]",
param_row_width,
grad_row_width));
const auto* lr = learning_rate.data<T>();
const auto* grad_data = grad.value().data<T>();
auto* out_data = param_out->mutable_value()->data<T>();
for (size_t i = 0; i < grad.rows().size(); i++) {
int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false);
PADDLE_ENFORCE_GE(
id_index,
static_cast<int64_t>(0),
phi::errors::InvalidArgument(
"The id in SgdOp should be >= 0. But recevied id_index is [%s]",
id_index));
for (int64_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(sgd,
CPU,
ALL_LAYOUT,
phi::SGDDenseKernel,
phi::dtype::bfloat16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SGDDenseParamSparseGradKernel,
phi::dtype::bfloat16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SGDSparseParamSparseGradKernel,
phi::dtype::bfloat16,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename MT>
__global__ void SGDKernelMT(const T* param,
const T* grad,
const T* learning_rate,
const int num,
T* param_out,
const MT* master_param,
MT* master_param_out) {
MT lr = static_cast<MT>(learning_rate[0]);
CUDA_KERNEL_LOOP(i, num) {
MT p_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
MT g_data = static_cast<MT>(grad[i]);
p_data = p_data - lr * g_data;
param_out[i] = static_cast<T>(p_data);
if (master_param_out) {
master_param_out[i] = p_data;
}
}
}
template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows,
const T* learning_rate,
T* tensor_out,
int64_t row_numel,
int64_t limit) {
for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
const T* selected_rows_ptr = selected_rows + i * row_numel;
T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(
tensor_out_ptr + index,
-static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
}
}
}
template <typename T, typename Context>
void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do check here
// if (multi_precision) {
// bool has_master =
// ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
// }
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace())
: nullptr;
int block = 512;
int grid = (param.numel() + block - 1) / block;
SGDKernelMT<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
param.data<T>(),
grad.data<T>(),
learning_rate.data<T>(),
param.numel(),
param_out->mutable_data<T>(dev_ctx.GetPlace()),
master_in_data,
master_out_data);
}
template <typename T, typename Context>
void SGDDenseParamSparseGradKernel(
const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do some check here
// if (multi_precision) {
// bool has_master =
// ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
// }
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace())
: nullptr;
PADDLE_ENFORCE_EQ(
&param,
param_out,
phi::errors::InvalidArgument(
"The input tensor Param of SgdOp should be equal with ParamOut "
"if variable's type is SelectedRows."));
auto in_height = grad.height();
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height,
out_dims[0],
phi::errors::InvalidArgument(
"The input tensor Grad's height of SgdOp should be "
"equal with ParamOut's dims. But received Grad's "
"height [%s] and ParamOut's dims [%s]",
in_height,
out_dims[0]));
auto& in_value = grad.value();
auto& in_rows = grad.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel,
param_out->numel() / in_height,
phi::errors::InvalidArgument(
"The in_row_numel of SgdOp should be equal with "
"param_out's numel / in_height."));
auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>();
const int kThreadsPerBlock = 256;
int thread_x = kThreadsPerBlock;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
paddle::framework::MixVector<int64_t> mixv_in_rows(&in_rows);
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0, dev_ctx.stream()>>>(
in_data,
mixv_in_rows.CUDAData(dev_ctx.GetPlace()),
learning_rate.data<T>(),
out_data,
in_row_numel,
in_rows.size());
}
template <typename T, typename Context>
void SGDSparseParamSparseGradKernel(
const Context& dev_ctx,
const SelectedRows& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const SelectedRows&> master_param,
bool multi_precision,
SelectedRows* param_out,
SelectedRows* master_param_out) {
PADDLE_THROW("not impl");
}
} // namespace phi
PD_REGISTER_KERNEL(sgd,
GPU,
ALL_LAYOUT,
phi::SGDDenseKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU,
ALL_LAYOUT,
phi::SGDDenseParamSparseGradKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
GPU,
ALL_LAYOUT,
phi::SGDSparseParamSparseGradKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out);
template <typename T, typename Context>
void SGDDenseParamSparseGradKernel(
const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const DenseTensor&> master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out);
template <typename T, typename Context>
void SGDSparseParamSparseGradKernel(
const Context& dev_ctx,
const SelectedRows& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const SelectedRows&> master_param,
bool multi_precision,
SelectedRows* param_out,
SelectedRows* master_param_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Grad")) {
return KernelSignature("sgd",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
} else if (ctx.IsSelectedRowsInput("Grad")) {
if (ctx.IsDenseTensorInput("Param")) {
return KernelSignature("sgd_dense_param_sparse_grad",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
} else {
return KernelSignature("sgd_sparse_param_sparse_grad",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
}
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(sgd, phi::SGDOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册