diff --git a/paddle/phi/kernels/cpu/sgd_kernel.cc b/paddle/phi/kernels/cpu/sgd_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..91b39292612b52ee47fc2bc77c7f205158bdc29c --- /dev/null +++ b/paddle/phi/kernels/cpu/sgd_kernel.cc @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" +#include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void sgd_dense_param_dense_grad_impl(const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + DenseTensor* param_out) { + const auto sz = param_out->numel(); + paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1); + const T* lr = learning_rate.data(); + const T* param_data = param.data(); + const T* grad_data = grad.data(); + int64_t rows_idx = 0; + T* out_data = param_out->data(); + + auto sgd = + paddle::operators::jit::KernelFuncs, + phi::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); +} + +template <> +void sgd_dense_param_dense_grad_impl( + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + DenseTensor* param_out) { + auto p = EigenVector::Flatten(param); + auto g = EigenVector::Flatten(grad); + auto o = EigenVector::Flatten(*param_out); + const auto* lr = learning_rate.data(); + + o = p - lr[0] * g; +} + +template +void sgd_dense_param_sparse_grad_impl(const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + DenseTensor* param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const T* param_data = param.data(); + const T* grad_data = grad_value.data(); + const T* lr = learning_rate.data(); + const int64_t* rows_data = grad_rows.data(); + T* out_data = param_out->data(); + + paddle::operators::jit::sgd_attr_t attr; + attr.param_height = param_out->dims()[0]; + attr.param_width = param_out->numel() / attr.param_height; + attr.grad_height = grad_rows.size(); // note: it is not grad->height() + attr.grad_width = grad_value.numel() / attr.grad_height; + attr.selected_rows_size = grad_rows.size(); + + auto sgd = + paddle::operators::jit::KernelFuncs, + phi::CPUPlace>::Cache() + .At(attr); + sgd(lr, param_data, grad_data, rows_data, out_data, &attr); +} + +template <> +void sgd_dense_param_sparse_grad_impl( + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + DenseTensor* param_out) { + const auto& grad_value = grad.value(); + const auto& grad_rows = grad.rows(); + const auto grad_height = grad.height(); + const int64_t grad_val_height = static_cast(grad_rows.size()); + const auto grad_width = grad_value.numel() / grad_val_height; + + const auto* grad_data = grad_value.data(); + auto* out_data = param_out->data(); + const auto* lr = learning_rate.data(); + + for (size_t i = 0; i < grad_rows.size(); ++i) { + PADDLE_ENFORCE_LT( + grad_rows[i], + grad_height, + phi::errors::OutOfRange( + "Grad rows index value should be less than grad height." + "Got [%s], but expected less than [%s]", + grad_rows[i], + grad_height)); + const int64_t row = grad_rows[i]; + for (int64_t j = 0; j < grad_width; ++j) { + out_data[row * grad_width + j] -= lr[0] * grad_data[i * grad_width + j]; + } + } +} + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + dev_ctx.template Alloc(param_out); + sgd_dense_param_dense_grad_impl(param, learning_rate, grad, param_out); +} + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + dev_ctx.template Alloc(param_out); + sgd_dense_param_sparse_grad_impl(param, learning_rate, grad, param_out); +} + +template +void SGDKernel(const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const SelectedRows& master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out) { + // for distributed training, a sparse var may be empty, + // just skip updating. + if (grad.rows().size() == 0) { + return; + } + + auto param_row_width = param.value().dims()[1]; + auto grad_row_width = grad.value().dims()[1]; + PADDLE_ENFORCE_EQ( + param_row_width, + grad_row_width, + phi::errors::InvalidArgument( + "The param_row in SgdOP should have the same size with grad_row. " + "But received param_row's width is [%s], and grad_row's width is " + "[%s]", + param_row_width, + grad_row_width)); + + const auto* lr = learning_rate.data(); + const auto* grad_data = grad.value().data(); + auto* out_data = param_out->mutable_value()->data(); + for (size_t i = 0; i < grad.rows().size(); i++) { + int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false); + PADDLE_ENFORCE_GE( + id_index, + static_cast(0), + phi::errors::InvalidArgument( + "The id in SgdOp should be >= 0. But recevied id_index is [%s]", + id_index)); + for (int64_t j = 0; j < grad_row_width; j++) { + out_data[id_index * grad_row_width + j] -= + lr[0] * grad_data[i * grad_row_width + j]; + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/sgd_kernel.cu b/paddle/phi/kernels/gpu/sgd_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..076bd0a7ad1e4ac1a70907823f7b23fb8bc620da --- /dev/null +++ b/paddle/phi/kernels/gpu/sgd_kernel.cu @@ -0,0 +1,167 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_helper.h" + +namespace phi { + +template +__global__ void SGDKernelMT(const T* param, + const T* grad, + const T* learning_rate, + const int num, + T* param_out, + const MT* master_param, + MT* master_param_out) { + MT lr = static_cast(learning_rate[0]); + CUDA_KERNEL_LOOP(i, num) { + MT p_data = master_param ? master_param[i] : static_cast(param[i]); + MT g_data = static_cast(grad[i]); + p_data = p_data - lr * g_data; + param_out[i] = static_cast(p_data); + if (master_param_out) { + master_param_out[i] = p_data; + } + } +} + +template +__global__ void SparseSGDFunctorKernel(const T* selected_rows, + const int64_t* rows, + const T* learning_rate, + T* tensor_out, + int64_t row_numel, + int64_t limit) { + for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) { + const T* selected_rows_ptr = selected_rows + i * row_numel; + T* tensor_out_ptr = tensor_out + rows[i] * row_numel; + for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd( + tensor_out_ptr + index, + -static_cast(1.0) * learning_rate[0] * selected_rows_ptr[index]); + } + } +} + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + using MPDType = typename paddle::operators::details::MPTypeTrait::Type; + // do check here + // if (multi_precision) { + // bool has_master = + // ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + + // } + const MPDType* master_in_data = + multi_precision ? master_param.data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out->mutable_data(dev_ctx.GetPlace()) + : nullptr; + + int block = 512; + int grid = (param.numel() + block - 1) / block; + + SGDKernelMT<<>>( + param.data(), + grad.data(), + learning_rate.data(), + param.numel(), + param_out->mutable_data(ctx.GetPlace()), + master_in_data, + master_out_data); +} + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + using MPDType = typename paddle::operators::details::MPTypeTrait::Type; + // do some check here + // if (multi_precision) { + // bool has_master = + // ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + + // } + const MPDType* master_in_data = + multi_precision ? master_param.data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out->mutable_data(dev_ctx.GetPlace()) + : nullptr; + + PADDLE_ENFORCE_EQ( + ¶m, + param_out, + phi::errors::InvalidArgument( + "The input tensor Param of SgdOp should be equal with ParamOut " + "if variable's type is SelectedRows.")); + + auto in_height = grad.height(); + auto out_dims = param_out->dims(); + PADDLE_ENFORCE_EQ(in_height, + out_dims[0], + phi::errors::InvalidArgument( + "The input tensor Grad's height of SgdOp should be " + "equal with ParamOut's dims. But received Grad's " + "height [%s] and ParamOut's dims [%s]", + in_height, + out_dims[0])); + + auto& in_value = grad.value(); + auto& in_rows = grad.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, + param_out->numel() / in_height, + phi::errors::InvalidArgument( + "The in_row_numel of SgdOp should be equal with " + "param_out's numel / in_height.")); + + auto* in_data = in_value.data(); + auto* out_data = param_out->data(); + + const int kThreadsPerBlock = 256; + int thread_x = kThreadsPerBlock; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + paddle::framework::MixVector mixv_in_rows(&in_rows); + SparseSGDFunctorKernel<<>>( + in_data, + mixv_in_rows.CUDAData(dev_ctx.GetPlace()), + learning_rate.data(), + out_data, + in_row_numel, + in_rows.size()); +} + +} // namespace phi diff --git a/paddle/phi/kernels/sgd_kernel.h b/paddle/phi/kernels/sgd_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9490940325e91efddc861561ab9e5d233c34eed2 --- /dev/null +++ b/paddle/phi/kernels/sgd_kernel.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); + +template +void SGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const DenseTensor& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); + +template +void SGDKernel(const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const SelectedRows& master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out); + +} // namespace phi