diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index c4e42467cdef3c9861f2ae5bb4d0fa14a2e9418b..f77287826ffb3572de3e1ce7fd35e99c981c474f 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -280,6 +280,46 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + if (UNLIKELY(input1.rows().size() == 0)) { + LOG(WARNING) << "input selected rows is empty!"; + return; + } + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; @@ -287,6 +327,11 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; // This is a separated namespace for manipulate SelectedRows typed // data. Like merge duplicated rows, adding two SelectedRows etc. // diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 16ef013f689c4fc438f0dc7d8e5a9bd529001ea2..542d4c9784352e98d0c033fd4a25d8b7af58d4ab 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -174,12 +174,77 @@ struct SelectedRowsAddTensor { } }; +template +struct SelectedRowsAddTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument( + "The two inputs height must be equal." + "But recieved first input height = [%d], first input height = [%d]", + in1_height, in2_dims[0])); + PADDLE_ENFORCE_EQ( + in1_height, out_dims[0], + platform::errors::InvalidArgument( + "The input and output height must be equal." + "But recieved input height = [%d], output height = [%d]", + in1_height, out_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2.numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2.numel() / in1_height)); + PADDLE_ENFORCE_EQ( + in1_row_numel, output->numel() / in1_height, + platform::errors::InvalidArgument( + "The input and output width must be equal." + "But recieved input width = [%d], output width = [%d]", + in1_row_numel, output->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2.data(); + auto* out_data = output->data(); + + phi::funcs::SetConstant functor; + functor(context, output, static_cast(0)); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), out_data, + in1_row_numel); + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen; + } +}; + template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; template struct SelectedRowsAdd; template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAddTensor; +template struct SelectedRowsAdd; +template struct SelectedRowsAddTensor; + template struct SelectedRowsAddTo { void operator()(const platform::CUDADeviceContext& context, @@ -285,12 +350,54 @@ struct SelectedRowsAddToTensor { } }; +template +struct SelectedRowsAddToTensor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& input1, framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ( + in1_height, in2_dims[0], + platform::errors::InvalidArgument("The two inputs height must be equal." + "But recieved first input height = " + "[%d], second input height = [%d]", + in1_height, in2_dims[0])); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ( + in1_row_numel, input2->numel() / in1_height, + platform::errors::InvalidArgument( + "The two inputs width must be equal." + "But recieved first input width = [%d], second input width = [%d]", + in1_row_numel, input2->numel() / in1_height)); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2->data(); + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(in1_rows.size(), 1); + paddle::framework::MixVector mixv_in1_rows(&in1_rows); + SelectedRowsAddToTensorKernel< + T, block_size><<>>( + in1_data, mixv_in1_rows.CUDAData(context.GetPlace()), in2_data, + in1_row_numel); + } +}; + template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; +template struct SelectedRowsAddToTensor; namespace scatter { diff --git a/paddle/fluid/operators/meshgrid_op.cc b/paddle/fluid/operators/meshgrid_op.cc index 741c4bb65d80750e81cb575495fb83f35b27c455..103169fedb90e67c9a3afd4b7a95fdb82cde8bb0 100644 --- a/paddle/fluid/operators/meshgrid_op.cc +++ b/paddle/fluid/operators/meshgrid_op.cc @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/meshgrid_op.h" - #include #include #include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + namespace paddle { namespace operators { @@ -145,29 +146,3 @@ REGISTER_OPERATOR(meshgrid, ops::MeshgridOp, ops::MeshgridOpMaker, ops::MeshgridGradOpMaker, ops::MeshgridGradOpMaker); REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp); -REGISTER_OP_CPU_KERNEL( - meshgrid, ops::MeshgridKernel, - ops::MeshgridKernel, - ops::MeshgridKernel, - ops::MeshgridKernel); - -REGISTER_OP_CPU_KERNEL( - meshgrid_grad, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - meshgrid, ops::MeshgridKernel, - ops::MeshgridKernel, - ops::MeshgridKernel, - ops::MeshgridKernel, - ops::MeshgridKernel); -REGISTER_OP_CUDA_KERNEL( - meshgrid_grad, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel, - ops::MeshgridGradKernel); -#endif diff --git a/paddle/fluid/operators/meshgrid_op.h b/paddle/fluid/operators/meshgrid_op.h deleted file mode 100644 index 4fef0797099c4c8220f08eddbb1a234e7cec70fd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/meshgrid_op.h +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/platform/errors.h" - -#define MAX_RANK_SUPPORTED 6 - -namespace paddle { -namespace operators { - -template -class MeshgridKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto ins = context.MultiInput("X"); - auto rank = ins.size(); - switch (rank) { - case 1: - MeshgridForward<1>(context); - break; - case 2: - MeshgridForward<2>(context); - break; - case 3: - MeshgridForward<3>(context); - break; - case 4: - MeshgridForward<4>(context); - break; - case 5: - MeshgridForward<5>(context); - break; - case 6: - MeshgridForward<6>(context); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Excepted Tensor numbers between 1 and 6, but only received d% .", - rank)); - } - } - - protected: - template - void MeshgridForward(const framework::ExecutionContext& context) const { - auto ins = context.MultiInput("X"); - auto outs = context.MultiOutput("Out"); - PADDLE_ENFORCE_EQ( - ins.size() > 1, true, - platform::errors::InvalidArgument( - "Expected at least 2 input tensors, but only received d%.", - ins.size())); - - int64_t size = ins.size(); - std::vector shape(size); - - for (int64_t i = 0; i < size; i++) { - switch (ins[i]->dims().size()) { - case 0: - shape[i] = 1; - break; - case 1: - shape[i] = ins[i]->dims()[0]; - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Expected scalar or 1D tensor in the tensor list but got tensor " - "%d: ", - i)); - } - } - - for (int64_t i = 0; i < size; i++) { - std::vector view_shape(size, 1); - view_shape[i] = shape[i]; - - framework::Tensor reshape_ins_tensor; - paddle::framework::TensorCopy(*ins[i], context.GetPlace(), - context.device_context(), - &reshape_ins_tensor); - framework::DDim out_dims_reshape = phi::make_ddim(view_shape); - reshape_ins_tensor.Resize(out_dims_reshape); - framework::DDim out_dims = phi::make_ddim(shape); - - Eigen::DSizes bcast_dims; - for (int64_t j = 0; j < size; j++) { - bcast_dims[j] = shape[j]; - } - bcast_dims[i] = 1; - - outs[i]->Resize(out_dims); - auto x = framework::EigenTensor::From( - static_cast(reshape_ins_tensor)); - outs[i]->mutable_data(context.GetPlace()); - auto y = framework::EigenTensor::From(*outs[i]); - auto& place = - *context.template device_context().eigen_device(); - EigenBroadcast, T, Rank>::Eval(place, y, x, - bcast_dims); - } - } -}; - -template -class MeshgridGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto out_grad = - context.MultiInput(framework::GradVarName("Out")); - int n = out_grad.size(); - switch (n) { - case 1: - MeshgridBackward<1>(context); - break; - case 2: - MeshgridBackward<2>(context); - break; - case 3: - MeshgridBackward<3>(context); - break; - case 4: - MeshgridBackward<4>(context); - break; - case 5: - MeshgridBackward<5>(context); - break; - case 6: - MeshgridBackward<6>(context); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Excepted Tensor numbers between 1 and 6, but only received d% .", - n)); - } - } - - protected: - template - void MeshgridBackward(const framework::ExecutionContext& context) const { - auto out_grad = - context.MultiInput(framework::GradVarName("Out")); - auto ins = context.MultiInput("X"); - auto outs = - context.MultiOutput(framework::GradVarName("X")); - - int n = out_grad.size(); - auto out_dims = out_grad[0]->dims(); - - for (int i = 0; i < n; i++) { - outs[i]->mutable_data(context.GetPlace()); - auto out_grad_tmp = framework::EigenVector::Flatten(*out_grad[i]); - auto in_grad = framework::EigenVector::Flatten(*outs[i]); - - std::vector reduce_dims_vec; - std::vector reshape_dims_vec; - for (int j = 0; j < n; j++) { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - if (j == i) { - reshape_dims_vec.push_back(1); - reshape_dims_vec.push_back(out_dims[j]); - } else { - reshape_dims_vec.push_back(out_dims[j]); - reshape_dims_vec.push_back(1); - } - } - - Eigen::DSizes reduce_dims; - for (int k = 0; k < n; k++) { - reduce_dims[k] = reduce_dims_vec[k]; - } - - Eigen::DSizes reshape_dims; - for (int k = 0; k < n * 2; k++) { - reshape_dims[k] = reshape_dims_vec[k]; - } - - auto& place = - *context.template device_context().eigen_device(); - EigenBroadcastGrad, T, Rank>::Eval( - place, in_grad, out_grad_tmp, reduce_dims, reshape_dims); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/meshgrid_op_npu.cc b/paddle/fluid/operators/meshgrid_op_npu.cc index c73db5e940df793129e69688dff0577951606f0a..4b6fccd14d7e964522fde0f00e8c03c6ad3cd89a 100644 --- a/paddle/fluid/operators/meshgrid_op_npu.cc +++ b/paddle/fluid/operators/meshgrid_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/meshgrid_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 1d73c7a6db561f94bf5a8758e7c722cefe737740..33c4cf94cf25a07675b38f914b96a9cd8b1c57d8 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -12,11 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/adagrad_op.h" -#include - #include +#include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -102,54 +101,8 @@ for numerical stability to avoid the division by zero error. } }; -namespace { -size_t FindPos(const std::vector& rows, int64_t value) { - return std::find(rows.begin(), rows.end(), value) - rows.begin(); -} -} // namespace - -template -struct SparseAdagradFunctor { - void operator()(const platform::CPUDeviceContext& context, - const phi::SelectedRows& grad, - const framework::Tensor& learning_rate, T epsilon, - framework::Tensor* moment, framework::Tensor* param) { - // 1. g_m.rows = set(g.rows) - auto grad_width = grad.value().dims()[1]; - math::scatter::MergeAdd merge_func; - auto grad_merge = merge_func(context, grad); - auto& merge_rows = grad_merge.rows(); - auto* grad_merge_data = grad_merge.mutable_value()->template data(); - - // 2. m += g_m * g_m - auto grad_square = - SquareSelectedRows(context, grad_merge); - - math::SelectedRowsAddToTensor functor; - functor(context, grad_square, moment); - - // 3. update parameter - auto* lr = learning_rate.data(); - auto* param_data = param->data(); - auto* moment_data = moment->data(); - - for (size_t i = 0; i < merge_rows.size(); i++) { - for (int64_t j = 0; j < grad_width; j++) { - param_data[merge_rows[i] * grad_width + j] -= - lr[0] * grad_merge_data[i * grad_width + j] / - (std::sqrt(moment_data[merge_rows[i] * grad_width + j]) + epsilon); - } - } - } -}; - -template struct SparseAdagradFunctor; -template struct SparseAdagradFunctor; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker); -REGISTER_OP_CPU_KERNEL( - adagrad, ops::AdagradOpKernel, - ops::AdagradOpKernel); diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cu b/paddle/fluid/operators/optimizers/adagrad_op.cu deleted file mode 100644 index 3b8ef9056946a1f84d98621442394dbf3e806576..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adagrad_op.cu +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/operators/optimizers/adagrad_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -namespace { - -template -__global__ void MergeGradKernel(const T* grad, const int64_t* grad_rows, - T* grad_merge, const int64_t* grad_merge_rows, - size_t grad_merge_rows_size, - int64_t row_numel) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - __shared__ size_t grad_merge_idx; - - if (tid == 0) { - for (size_t i = 0; i < grad_merge_rows_size; i++) { - if (grad_rows[ty] == grad_merge_rows[i]) { - grad_merge_idx = i; - } - } - } - - __syncthreads(); - - grad += ty * row_numel; - grad_merge += grad_merge_idx * row_numel; - for (int index = tid; index < row_numel; index += block_size) { - paddle::platform::CudaAtomicAdd(grad_merge + index, grad[index]); - } -} - -template -__global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows, - const T* learning_rate, T* param, - T* moment, int64_t row_numel, - T epsilon) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - grad += ty * row_numel; - param += rows[ty] * row_numel; - moment += rows[ty] * row_numel; - - for (int index = tid; index < row_numel; index += block_size) { - // Since index in rows of SelectedRows can be duplicate, we have to use - // Atomic Operation to avoid concurrent write error. - paddle::platform::CudaAtomicAdd(param + index, - -1.0 * learning_rate[0] * grad[index] / - (sqrt(moment[index]) + epsilon)); - } -} -} // namespace - -template -struct SparseAdagradFunctor { - void operator()(const platform::CUDADeviceContext& context, - const phi::SelectedRows& grad, - const framework::Tensor& learning_rate, T epsilon, - framework::Tensor* moment, framework::Tensor* param) { - // 1. g_m.rows = set(g.rows) - auto grad_width = grad.value().dims()[1]; - math::scatter::MergeAdd merge_func; - auto grad_merge = merge_func(context, grad); - auto* grad_merge_data = grad_merge.mutable_value()->template data(); - framework::Vector merge_rows(grad_merge.rows()); - // 2. m += g_m * g_m - auto grad_square = - SquareSelectedRows(context, grad_merge); - - math::SelectedRowsAddToTensor functor; - functor(context, grad_square, moment); - - // 3. update parameter - auto* lr = learning_rate.data(); - auto* param_data = param->data(); - auto* moment_data = moment->data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid2(1, merge_rows.size()); - paddle::framework::MixVector mixv_merge_rows(&merge_rows); - SparseAdagradFunctorKernel< - T, 256><<(context) - .stream()>>>( - grad_merge_data, mixv_merge_rows.CUDAMutableData(context.GetPlace()), - lr, param_data, moment_data, grad_width, epsilon); - mixv_merge_rows.CopyToCPU(); - } -}; - -template struct SparseAdagradFunctor; -template struct SparseAdagradFunctor; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - adagrad, ops::AdagradOpKernel, - ops::AdagradOpKernel); diff --git a/paddle/fluid/operators/optimizers/adagrad_op.h b/paddle/fluid/operators/optimizers/adagrad_op.h deleted file mode 100644 index 63f4f4e0bb03115b4bc9095752a3e18ebf325ccd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/adagrad_op.h +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -struct SparseAdagradFunctor { - void operator()(const DeviceContext &context, const phi::SelectedRows &grad, - const framework::Tensor &learning_rate, T epsilon, - framework::Tensor *moment, framework::Tensor *param); -}; - -template -phi::SelectedRows SquareSelectedRows(const DeviceContext &context, - const phi::SelectedRows &input) { - phi::SelectedRows out; - out.set_rows(input.rows()); - out.set_height(input.height()); - out.mutable_value()->mutable_data(input.value().dims(), - context.GetPlace()); - auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); - auto e_in = framework::EigenVector::Flatten(input.value()); - e_out.device(*context.eigen_device()) = e_in.square(); - return out; -} - -template -class AdagradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - - auto *param_out_tensor = ctx.Output("ParamOut"); - auto *moment_out_tensor = ctx.Output("MomentOut"); - - param_out_tensor->mutable_data(ctx.GetPlace()); - moment_out_tensor->mutable_data(ctx.GetPlace()); - - T epsilon = static_cast(ctx.Attr("epsilon")); - - auto *grad_var = ctx.InputVar("Grad"); - if (grad_var->IsType()) { - auto param = framework::EigenVector::Flatten( - *ctx.Input("Param")); - auto grad = framework::EigenVector::Flatten( - *ctx.Input("Grad")); - auto moment = framework::EigenVector::Flatten( - *ctx.Input("Moment")); - auto *learning_rate = ctx.Input("LearningRate"); - - auto param_out = framework::EigenVector::Flatten(*param_out_tensor); - auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); - auto *place = ctx.template device_context().eigen_device(); - - moment_out.device(*place) = moment + grad * grad; - Eigen::DSizes m_dsize(moment_out_tensor->numel()); - if (platform::is_cpu_place(ctx.GetPlace())) { - auto *lr = learning_rate->data(); - param_out.device(*place) = - param - lr[0] * grad / (moment_out.sqrt() + epsilon); - } else { - auto lr = framework::EigenVector::Flatten(*learning_rate); - param_out.device(*place) = - param - - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); - } - } else if (grad_var->IsType()) { - auto *param_tensor = ctx.Input("Param"); - PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor, - platform::errors::InvalidArgument( - "the input tensor not euqal with output tensor")); - - auto *moment_tensor = ctx.Input("Moment"); - PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor, - platform::errors::InvalidArgument( - "the input moment not eual with output moment")); - - SparseAdagradFunctor functor; - functor(ctx.template device_context(), - *ctx.Input("Grad"), - *ctx.Input("LearningRate"), epsilon, - moment_out_tensor, param_out_tensor); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported Variable Type of Grad")); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.h b/paddle/fluid/operators/optimizers/dgc_momentum_op.h index c86f544ed77ff13cc59735971cf856f66bc12202..fc954e60a8c3e9dd2a1f5078586223407a586f4b 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.h +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.h @@ -17,6 +17,7 @@ #include #include "paddle/fluid/operators/optimizers/momentum_op.h" +#include "paddle/phi/kernels/momentum_kernel.h" #include "paddle/phi/kernels/sgd_kernel.h" namespace paddle { @@ -25,8 +26,7 @@ namespace operators { template class DGCMomentumKernel : public framework::OpKernel { public: - DGCMomentumKernel() - : _momentum_op_kernel(new MomentumOpKernel()) {} + DGCMomentumKernel() {} void Compute(const framework::ExecutionContext& context) const override { auto rampup_begin_step = context.Attr("rampup_begin_step"); @@ -60,15 +60,56 @@ class DGCMomentumKernel : public framework::OpKernel { VLOG(10) << "current_step:" << *current_step << ", rampup_begin_step:" << rampup_begin_step; + const auto* grad_var = context.InputVar("Grad"); if (static_cast(*current_step) < static_cast(rampup_begin_step)) { VLOG(10) << " so use momentum optimizer"; - return _momentum_op_kernel->Compute(context); + auto* learning_rate = context.Input("LearningRate"); + bool multi_precision = context.Attr("multi_precision"); + + auto* param = context.Input("Param"); + auto* velocity = context.Input("Velocity"); + auto* param_out = context.Output("ParamOut"); + auto* velocity_out = context.Output("VelocityOut"); + auto* master_param_out = + context.Output("MasterParamOut"); + paddle::optional master_param_opt = + paddle::none; + float mu = context.Attr("mu"); + bool use_nesterov = context.Attr("use_nesterov"); + std::string regularization_method = + context.Attr("regularization_method"); + float regularization_coeff = context.Attr("regularization_coeff"); + float rescale_grad = context.Attr("rescale_grad"); + + if (grad_var->IsType()) { + // sgd_dense + auto* grad = context.Input("Grad"); + phi::MomentumDenseKernel( + static_cast::TYPE&>(dev_ctx), + *param, *grad, *velocity, *learning_rate, master_param_opt, mu, + use_nesterov, regularization_method, regularization_coeff, + multi_precision, rescale_grad, param_out, velocity_out, + master_param_out); + } else { + // sgd dense param sparse grad + auto* grad = context.Input("Grad"); + phi::MomentumSparseKernel( + static_cast::TYPE&>(dev_ctx), + *param, *grad, *velocity, *learning_rate, master_param_opt, mu, + use_nesterov, regularization_method, regularization_coeff, + multi_precision, rescale_grad, param_out, velocity_out, + master_param_out); + } + + return; } VLOG(10) << " so use sgd optimizer"; const auto* param_var = context.InputVar("Param"); - const auto* grad_var = context.InputVar("Grad"); + auto* learning_rate = context.Input("LearningRate"); bool multi_precision = context.Attr("multi_precision"); if (param_var->IsType()) { @@ -125,9 +166,6 @@ class DGCMomentumKernel : public framework::OpKernel { PADDLE_THROW("gdc not support yet"); } } - - private: - std::unique_ptr> _momentum_op_kernel; }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h index c1ac2e366f4b4f2e1dd623005100aa8632919213..ed9b32c78e72c920b9a09c827823af2b3e681449 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.h +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.h @@ -18,13 +18,16 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/macros.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { +template +using MultiPrecisionType = typename details::MPTypeTrait::Type; + template struct MergedMomentumMasterParams { MT *PADDLE_RESTRICT master_params[kParamNum]; @@ -259,11 +262,11 @@ class MergedMomentumOpKernel : public framework::OpKernel { #undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL } else { for (size_t idx = 0; idx < n; idx++) { - RegularizationType regularization_flag = + phi::RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" - ? RegularizationType::kL2DECAY - : RegularizationType::kNONE; + ? phi::RegularizationType::kL2DECAY + : phi::RegularizationType::kNONE; MT regularization_coeff = static_cast(0.0); if (regularization_coeffs.size() != 0) { @@ -276,7 +279,7 @@ class MergedMomentumOpKernel : public framework::OpKernel { MT *master_out_data = multi_precision ? master_params_out[idx]->data() : nullptr; if (platform::is_cpu_place(ctx.GetPlace())) { - CPUDenseMomentumFunctor functor; + phi::CPUDenseMomentumFunctor functor; functor(params[idx], grads[idx], velocitys[idx], lr_temp, static_cast(mu), use_nesterov, regularization_flag, regularization_coeff, params_out[idx], velocitys_out[idx]); @@ -286,7 +289,7 @@ class MergedMomentumOpKernel : public framework::OpKernel { static_cast(ctx.device_context()), params[idx]->numel()); #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ - DenseMomentumFunctor functor( \ + phi::DenseMomentumFunctor functor( \ params[idx]->data(), grads[idx]->data(), \ velocitys[idx]->data(), lr_temp->data(), master_in_data, \ static_cast(mu), static_cast(rescale_grad), \ @@ -294,26 +297,26 @@ class MergedMomentumOpKernel : public framework::OpKernel { velocitys_out[idx]->data(), master_out_data); \ for_range(functor); if (use_nesterov) { - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - UseNesterov, RegularizationType::kL2DECAY); + phi::UseNesterov, phi::RegularizationType::kL2DECAY); VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; } else { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov, - RegularizationType::kNONE); + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::UseNesterov, phi::RegularizationType::kNONE); VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE."; } } else { - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( - NoNesterov, RegularizationType::kL2DECAY); + phi::NoNesterov, phi::RegularizationType::kL2DECAY); VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; } else { - PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov, - RegularizationType::kNONE); + PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( + phi::NoNesterov, phi::RegularizationType::kNONE); VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; } } diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc index f29a42be9d9a87bd43abd01ccddda3c1768c2f06..5fad5eca9affc4d44e2fc196616c776df6d8702b 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_npu.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { @@ -118,11 +119,11 @@ class NPUMergedMomentumOpKernel : public framework::OpKernel { FillNpuTensorWithConstant(&mu_tensor, mu); for (size_t idx = 0; idx < n; ++idx) { - RegularizationType regularization_flag = + phi::RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" - ? RegularizationType::kL2DECAY - : RegularizationType::kNONE; + ? phi::RegularizationType::kL2DECAY + : phi::RegularizationType::kNONE; float regularization_coeff = 0.0; if (regularization_coeffs.size() != 0) { regularization_coeff = regularization_coeffs[idx]; @@ -136,7 +137,7 @@ class NPUMergedMomentumOpKernel : public framework::OpKernel { auto grad = grads[idx]; Tensor regularized_grad; - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad.mutable_data(grad->dims(), ctx.GetPlace()); const auto& runner1 = NpuOpRunner("Muls", {*param}, {regularized_grad}, {{"value", regularization_coeff}}); diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index bf30d8512addb44574006c5b19ced17be12ba637..50d2c946f3afee12632735f207ceabfee91cd6fc 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -108,9 +108,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::MomentumOpInferVarType); -REGISTER_OP_CPU_KERNEL( - momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel); REGISTER_OP_VERSION(momentum) .AddCheckpoint( diff --git a/paddle/fluid/operators/optimizers/momentum_op.cu b/paddle/fluid/operators/optimizers/momentum_op.cu deleted file mode 100644 index 7f9e7246401bc3c765e539ac4395c4feef3c9508..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/momentum_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/optimizers/momentum_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel, - ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index e271755b740ce33369348ca6f415af958a43616d..017f33d7458fcd5552e540d944775a38c78b06b8 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -26,44 +26,6 @@ limitations under the License. */ namespace paddle { namespace operators { -using framework::Tensor; -using phi::SelectedRows; -struct NoNesterov; -struct UseNesterov; - -namespace details { - -template -struct CPUDenseUpdater { - template - void operator()(const Tensor& param, const Tensor& velocity, const T& mu, - const T& lr, const bool use_nesterov, G&& grad, - Tensor* param_out, Tensor* velocity_out) const { - auto param_out_vec = framework::EigenVector::Flatten(*param_out); - auto velocity_out_vec = framework::EigenVector::Flatten(*velocity_out); - - auto param_vec = framework::EigenVector::Flatten(param); - auto velocity_vec = framework::EigenVector::Flatten(velocity); - velocity_out_vec = velocity_vec * mu + grad; - if (use_nesterov) { - param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr; - } else { - param_out_vec = param_vec - lr * velocity_out_vec; - } - } -}; - -} // namespace details - -template -using MultiPrecisionType = typename details::MPTypeTrait::Type; - -enum class RegularizationType { - kNONE = 0, - kL1DECAY = 1, // do not need support right now - kL2DECAY = 2, -}; - class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -148,460 +110,5 @@ class MomentumOp : public framework::OperatorWithKernel { } }; -template -class CPUDenseMomentumFunctor { - public: - void operator()(const Tensor* param, const Tensor* grad, - const Tensor* velocity, const Tensor* learning_rate, - const T mu, const bool use_nesterov, - const RegularizationType regularization_flag, - const T regularization_coeff, Tensor* param_out, - Tensor* velocity_out) { - auto grad_vec = framework::EigenVector::Flatten(*grad); - auto* lr = learning_rate->data>(); - - details::CPUDenseUpdater updater; - if (regularization_flag == RegularizationType::kL2DECAY) { - auto param_vec = framework::EigenVector::Flatten(*param); - updater(*param, *velocity, mu, static_cast(lr[0]), use_nesterov, - param_vec * regularization_coeff + grad_vec, param_out, - velocity_out); - } else { - updater(*param, *velocity, mu, static_cast(lr[0]), use_nesterov, - grad_vec, param_out, velocity_out); - } - } -}; - -template -class DenseMomentumFunctor; - -// NOTE(dzh) for performance. -// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two -// functor. -template -class DenseMomentumFunctor { - private: - const T* param_; - const T* grad_; - const MT* velocity_; - const MultiPrecisionType* lr_; - const MT* master_param_; - const MT mu_; - const MT rescale_grad_; - const int64_t num_; - T* param_out_; - MT* velocity_out_; - MT* master_param_out_; - const MT regularization_coeff_; - - public: - DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity, - const MultiPrecisionType* learning_rate, - const MT* master_param, const MT mu, - const MT rescale_grad, const int64_t num, - const MT regularization_coeff, T* param_out, - MT* velocity_out, MT* master_param_out) - : param_(param), - grad_(grad), - velocity_(velocity), - lr_(learning_rate), - master_param_(master_param), - mu_(mu), - rescale_grad_(rescale_grad), - num_(num), - param_out_(param_out), - velocity_out_(velocity_out), - master_param_out_(master_param_out), - regularization_coeff_(regularization_coeff) {} - inline HOSTDEVICE void operator()(size_t i) const { - // put memory access in register - const MT param = - master_param_ ? master_param_[i] : static_cast(param_[i]); - MT grad = static_cast(grad_[i]) * rescale_grad_; - const MT lr = static_cast(lr_[0]); - const MT velocity = velocity_[i]; - - if (kRegType == RegularizationType::kL2DECAY) { - grad += regularization_coeff_ * param; - } - - MT velocity_out = velocity * mu_ + grad; - MT param_out = param - (grad + velocity_out * mu_) * lr; - // write reigster to memory - velocity_out_[i] = velocity_out; - param_out_[i] = static_cast(param_out); - if (master_param_out_) { - master_param_out_[i] = param_out; - } - } -}; - -template -class DenseMomentumFunctor { - private: - const T* param_; - const T* grad_; - const MT* velocity_; - const MultiPrecisionType* lr_; - const MT* master_param_; - const MT mu_; - const MT rescale_grad_; - const int64_t num_; - T* param_out_; - MT* velocity_out_; - MT* master_param_out_; - const MT regularization_coeff_; - - public: - DenseMomentumFunctor(const T* param, const T* grad, const MT* velocity, - const MultiPrecisionType* learning_rate, - const MT* master_param, const MT mu, - const MT rescale_grad, const int64_t num, - const MT regularization_coeff, T* param_out, - MT* velocity_out, MT* master_param_out) - : param_(param), - grad_(grad), - velocity_(velocity), - lr_(learning_rate), - master_param_(master_param), - mu_(mu), - rescale_grad_(rescale_grad), - num_(num), - param_out_(param_out), - velocity_out_(velocity_out), - master_param_out_(master_param_out), - regularization_coeff_(regularization_coeff) {} - inline HOSTDEVICE void operator()(size_t i) const { - // put memory access in register - const MT param = - master_param_ ? master_param_[i] : static_cast(param_[i]); - MT grad = static_cast(grad_[i]) * rescale_grad_; - const MT lr = static_cast(lr_[0]); - const MT velocity = velocity_[i]; - - if (kRegType == RegularizationType::kL2DECAY) { - grad += regularization_coeff_ * param; - } - - MT velocity_out = velocity * mu_ + grad; - MT param_out = param - lr * velocity_out; - // write reigster to memory - velocity_out_[i] = velocity_out; - param_out_[i] = static_cast(param_out); - if (master_param_out_) { - master_param_out_[i] = param_out; - } - } -}; - -template -class SparseMomentumFunctor; - -template -class SparseMomentumFunctor { - private: - const T* param_; - const T* grad_; - const MT* velocity_; - const MultiPrecisionType* lr_; - const MT* master_param_; - const MT mu_; - const MT rescale_grad_; - const int64_t* rows_; - const int64_t row_numel_; - const int64_t row_height_; - T* param_out_; - MT* velocity_out_; - MT* master_param_out_; - const RegularizationType regularization_flag_; - const MT regularization_coeff_; - - public: - SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity, - const MultiPrecisionType* lr, - const MT* master_param, const MT mu, - const MT rescale_grad, const int64_t* rows, - int64_t row_numel, int64_t row_height, - const RegularizationType regularization_flag, - const MT regularization_coeff, T* param_out, - MT* velocity_out, MT* master_param_out) - : param_(param), - grad_(grad), - velocity_(velocity), - lr_(lr), - master_param_(master_param), - mu_(mu), - rescale_grad_(rescale_grad), - rows_(rows), - row_numel_(row_numel), - row_height_(row_height), - param_out_(param_out), - velocity_out_(velocity_out), - master_param_out_(master_param_out), - regularization_flag_(regularization_flag), - regularization_coeff_(regularization_coeff) {} - - inline HOSTDEVICE void operator()(size_t i) { - auto row_idx = - phi::funcs::BinarySearch(rows_, row_height_, i / row_numel_); - MT grad = - row_idx >= 0 - ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * - rescale_grad_ - : static_cast(0); - // put memory access in register - const MT param = - master_param_ ? master_param_[i] : static_cast(param_[i]); - const MT lr = static_cast(lr_[0]); - const MT velocity = velocity_[i]; - - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; - - MT velocity_out = velocity * mu_ + grad; - MT param_out = param - (grad + velocity_out * mu_) * lr; - // write reigster to memory - velocity_out_[i] = velocity_out; - param_out_[i] = static_cast(param_out); - if (master_param_out_) { - master_param_out_[i] = param_out; - } - } -}; - -template -class SparseMomentumFunctor { - private: - const T* param_; - const T* grad_; - const MT* velocity_; - const MultiPrecisionType* lr_; - const MT* master_param_; - const MT mu_; - const MT rescale_grad_; - const int64_t* rows_; - const int64_t row_numel_; - const int64_t row_height_; - T* param_out_; - MT* velocity_out_; - MT* master_param_out_; - const RegularizationType regularization_flag_; - const MT regularization_coeff_; - - public: - SparseMomentumFunctor(const T* param, const T* grad, const MT* velocity, - const MultiPrecisionType* lr, - const MT* master_param, const MT mu, - const MT rescale_grad, const int64_t* rows, - int64_t row_numel, int64_t row_height, - const RegularizationType regularization_flag, - const MT regularization_coeff, T* param_out, - MT* velocity_out, MT* master_param_out) - : param_(param), - grad_(grad), - velocity_(velocity), - lr_(lr), - master_param_(master_param), - mu_(mu), - rescale_grad_(rescale_grad), - rows_(rows), - row_numel_(row_numel), - row_height_(row_height), - param_out_(param_out), - velocity_out_(velocity_out), - master_param_out_(master_param_out), - regularization_flag_(regularization_flag), - regularization_coeff_(regularization_coeff) {} - - inline HOSTDEVICE void operator()(size_t i) { - auto row_idx = - phi::funcs::BinarySearch(rows_, row_height_, i / row_numel_); - MT grad = - row_idx >= 0 - ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * - rescale_grad_ - : static_cast(0); - // put memory access in register - const MT param = - master_param_ ? master_param_[i] : static_cast(param_[i]); - const MT lr = static_cast(lr_[0]); - const MT velocity = velocity_[i]; - - grad = regularization_flag_ == RegularizationType::kL2DECAY - ? grad + regularization_coeff_ * param - : grad; - - MT velocity_out = velocity * mu_ + grad; - MT param_out = param - velocity_out * lr; - // write reigster to memory - velocity_out_[i] = velocity_out; - param_out_[i] = static_cast(param_out); - if (master_param_out_) { - master_param_out_[i] = param_out; - } - } -}; - -template -class MomentumOpKernel : public framework::OpKernel { - using MPDType = MultiPrecisionType; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const bool multi_precision = ctx.Attr("multi_precision"); - if (multi_precision) { - InnerCompute(ctx, multi_precision); - } else { - InnerCompute(ctx, multi_precision); - } - } - - private: - template - void InnerCompute(const framework::ExecutionContext& ctx, - const bool multi_precision) const { - std::string regularization_method = - ctx.Attr("regularization_method"); - MT regularization_coeff = - static_cast(ctx.Attr("regularization_coeff")); - RegularizationType regularization_flag{ - RegularizationType::kNONE}; // disable regularization - if (regularization_method == "l2_decay") { - regularization_flag = RegularizationType::kL2DECAY; - } - - MT mu = static_cast(ctx.Attr("mu")); - MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - bool use_nesterov = ctx.Attr("use_nesterov"); - - auto learning_rate = ctx.Input("LearningRate"); - auto param = ctx.Input("Param"); - auto param_out = ctx.Output("ParamOut"); - auto velocity = ctx.Input("Velocity"); - auto velocity_out = ctx.Output("VelocityOut"); - - const framework::Tensor* master_param = nullptr; - framework::Tensor* master_param_out = nullptr; - if (multi_precision) { - bool has_master = - ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); - PADDLE_ENFORCE_EQ(has_master, true, - platform::errors::InvalidArgument( - "The Input(MasterParam) and Output(MasterParamOut) " - "should not be null when " - "the attr `multi_precision` is true")); - master_param = ctx.Input("MasterParam"); - master_param_out = ctx.Output("MasterParamOut"); - } - - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - const MT* master_in_data = - multi_precision ? master_param->data() : nullptr; - MT* master_out_data = - multi_precision ? master_param_out->mutable_data(ctx.GetPlace()) - : nullptr; - - auto* grad_var = ctx.InputVar("Grad"); - if (grad_var->IsType()) { - auto grad = ctx.Input("Grad"); - if (platform::is_cpu_place(ctx.GetPlace())) { - CPUDenseMomentumFunctor functor; - functor(param, grad, velocity, learning_rate, mu, use_nesterov, - regularization_flag, regularization_coeff, param_out, - velocity_out); - } else if (platform::is_gpu_place(ctx.GetPlace())) { - platform::ForRange for_range( - static_cast(ctx.device_context()), - param->numel()); -#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ - DenseMomentumFunctor functor( \ - param->data(), grad->data(), velocity->data(), \ - learning_rate->data(), master_in_data, mu, rescale_grad, \ - param->numel(), regularization_coeff, \ - param_out->mutable_data(ctx.GetPlace()), \ - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); \ - for_range(functor); - - if (use_nesterov) { - if (regularization_flag == RegularizationType::kL2DECAY) { - PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, - RegularizationType::kL2DECAY); - } else { - PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, - RegularizationType::kNONE); - } - } else { - if (regularization_flag == RegularizationType::kL2DECAY) { - PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, - RegularizationType::kL2DECAY); - } else { - PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, - RegularizationType::kNONE); - } - } - } - - } else if (grad_var->IsType()) { - // sparse update embedding with selectedrows - auto grad = ctx.Input("Grad"); - - // sparse update maybe empty. - if (grad->rows().size() == 0) { - VLOG(3) << "Grad SelectedRows contains no data!"; - return; - } - - phi::SelectedRows tmp_merged_grad; - phi::SelectedRows* merged_grad = &tmp_merged_grad; - math::scatter::MergeAdd merge_func; - merge_func(ctx.template device_context(), *grad, - merged_grad); - - auto* grad_merge_rows = merged_grad->mutable_rows(); - paddle::framework::MixVector mixv_grad_merge_rows( - grad_merge_rows); - const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); - int64_t row_numel = - merged_grad->value().numel() / merged_grad->rows().size(); - platform::ForRange for_range( - static_cast(ctx.device_context()), - param->numel()); - if (use_nesterov) { - SparseMomentumFunctor functor( - param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), - master_in_data, mu, rescale_grad, rows, row_numel, - static_cast(merged_grad->rows().size()), - regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); - - } else { - SparseMomentumFunctor functor( - param->data(), merged_grad->value().data(), - velocity->data(), learning_rate->data(), - master_in_data, mu, rescale_grad, rows, row_numel, - static_cast(merged_grad->rows().size()), - regularization_flag, regularization_coeff, - param_out->mutable_data(ctx.GetPlace()), - velocity_out->mutable_data(ctx.GetPlace()), master_out_data); - for_range(functor); - } - } else { - PADDLE_ENFORCE_EQ(false, true, - platform::errors::PermissionDenied( - "Unsupported Variable Type of Grad " - "in MomentumOp. Excepted LodTensor " - "or SelectedRows, But received [%s]", - paddle::framework::ToTypeName(grad_var->Type()))); - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/optimizers/momentum_op_npu.cc b/paddle/fluid/operators/optimizers/momentum_op_npu.cc index 6853b2dac8868b846dc0cb666314e4683de797c8..2d73766b9736429bbe6b2de77363961d0c977cbb 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_npu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { @@ -28,10 +29,10 @@ class NPUMomentumOpKernel : public framework::OpKernel { std::string regularization_method = ctx.Attr("regularization_method"); auto regularization_coeff = ctx.Attr("regularization_coeff"); - RegularizationType regularization_flag{ - RegularizationType::kNONE}; // disable regularization + phi::RegularizationType regularization_flag{ + phi::RegularizationType::kNONE}; // disable regularization if (regularization_method == "l2_decay") { - regularization_flag = RegularizationType::kL2DECAY; + regularization_flag = phi::RegularizationType::kL2DECAY; } T mu = static_cast(ctx.Attr("mu")); @@ -55,7 +56,7 @@ class NPUMomentumOpKernel : public framework::OpKernel { FillNpuTensorWithConstant(&mu_tensor, mu); Tensor regularized_grad; - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad.mutable_data(grad->dims(), ctx.GetPlace()); const auto& runner1 = NpuOpRunner("Muls", {*param}, {regularized_grad}, {{"value", regularization_coeff}}); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cc b/paddle/fluid/operators/optimizers/rmsprop_op.cc index 652a343abf3c8993700224d8797d7e1bc200db5a..cd6fdcf34e95f2689fb5181451099dd9f6ce380e 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/rmsprop_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -170,6 +170,3 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); -REGISTER_OP_CPU_KERNEL( - rmsprop, ops::RmspropOpKernel, - ops::RmspropOpKernel); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cu b/paddle/fluid/operators/optimizers/rmsprop_op.cu deleted file mode 100644 index bf11ee686757c6c5e54e05f055eaa19f6553f915..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cu +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/optimizers/rmsprop_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - rmsprop, ops::RmspropOpKernel, - ops::RmspropOpKernel); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.h b/paddle/fluid/operators/optimizers/rmsprop_op.h deleted file mode 100644 index 71decd27d0d7822c67ba4a2782c1ec2461e67911..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/rmsprop_op.h +++ /dev/null @@ -1,273 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/algorithm.h" - -namespace paddle { -namespace operators { - -template -struct DenseRmspropGradFunctor { - inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} - - HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } - - const T *grad_; -}; - -template -struct SparseRmspropGradFunctor { - inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows, - int64_t row_numel, int64_t row_count) - : grad_(grad), - rows_(rows), - row_numel_(row_numel), - row_count_(row_count) {} - - HOSTDEVICE inline T operator()(int64_t idx) const { - auto row_idx = - phi::funcs::BinarySearch(rows_, row_count_, idx / row_numel_); - return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; - } - - const T *grad_; - const int64_t *rows_; - int64_t row_numel_; - int64_t row_count_; -}; - -template -struct UncenteredRmspropFunctor { - UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho, - T epsilon, T momentum, - const GradFunctor &grad_functor) - : param_(param), - ms_(ms), - mom_(mom), - lr_(lr), - rho_(rho), - epsilon_(epsilon), - momentum_(momentum), - grad_functor_(grad_functor) {} - - HOSTDEVICE inline void operator()(int64_t idx) const { - T g = grad_functor_(idx); - T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; - T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); - param_[idx] -= mom_out; - ms_[idx] = ms_out; - mom_[idx] = mom_out; - } - - T *param_; - T *ms_; - T *mom_; - const T *lr_; - T rho_; - T epsilon_; - T momentum_; - GradFunctor grad_functor_; -}; - -template -struct CenteredRmspropFunctor { - CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr, - T rho, T epsilon, T momentum, - const GradFunctor &grad_functor) - : param_(param), - ms_(ms), - mom_(mom), - mean_grad_(mean_grad), - lr_(lr), - rho_(rho), - epsilon_(epsilon), - momentum_(momentum), - grad_functor_(grad_functor) {} - - HOSTDEVICE inline void operator()(int64_t idx) const { - T g = grad_functor_(idx); - T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; - T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; - T mom_out = momentum_ * mom_[idx] + - lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); - param_[idx] -= mom_out; - ms_[idx] = ms_out; - mom_[idx] = mom_out; - mean_grad_[idx] = mg_out; - } - - T *param_; - T *ms_; - T *mom_; - T *mean_grad_; - const T *lr_; - T rho_; - T epsilon_; - T momentum_; - GradFunctor grad_functor_; -}; - -template -class RmspropOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using LoDTensor = framework::LoDTensor; - auto *grad_var = ctx.InputVar("Grad"); - auto *param_out = ctx.Output("ParamOut"); - auto *moment_out = ctx.Output("MomentOut"); - auto *mean_square_out = ctx.Output("MeanSquareOut"); - - auto epsilon = static_cast(ctx.Attr("epsilon")); - auto rho = static_cast(ctx.Attr("decay")); - auto momentum = static_cast(ctx.Attr("momentum")); - bool centered = ctx.Attr("centered"); - - auto &p_tensor = *ctx.Input("Param"); - auto &ms_tensor = *ctx.Input("MeanSquare"); - auto &lr_tensor = *ctx.Input("LearningRate"); - auto &mom_tensor = *ctx.Input("Moment"); - - PADDLE_ENFORCE_EQ(p_tensor.IsSharedBufferWith(*param_out), true, - platform::errors::InvalidArgument( - "Param and ParamOut must be the same Tensor")); - PADDLE_ENFORCE_EQ(mom_tensor.IsSharedBufferWith(*moment_out), true, - platform::errors::InvalidArgument( - "Moment and MomentOut must be the same Tensor")); - PADDLE_ENFORCE_EQ( - ms_tensor.IsSharedBufferWith(*mean_square_out), true, - platform::errors::InvalidArgument( - "MeanSquare and MeanSquareOut must be the same Tensor")); - - auto &dev_ctx = ctx.template device_context(); - size_t limit = static_cast(ms_tensor.numel()); - - if (grad_var->IsType()) { - auto &grad_tensor = grad_var->Get(); - - if (std::is_same::value) { - auto &place = - *ctx.template device_context().eigen_device(); - auto lr_value = lr_tensor.data()[0]; - - auto p = framework::EigenVector::Flatten(p_tensor); - auto ms = framework::EigenVector::Flatten(ms_tensor); - auto g = framework::EigenVector::Flatten(grad_tensor); - auto mom = framework::EigenVector::Flatten(mom_tensor); - - auto p_out = framework::EigenVector::Flatten(*param_out); - auto mom_out = framework::EigenVector::Flatten(*moment_out); - auto ms_out = framework::EigenVector::Flatten(*mean_square_out); - - ms_out.device(place) = rho * ms + (1 - rho) * g * g; - if (centered) { - auto &mg_tensor = *ctx.Input("MeanGrad"); - auto mg = framework::EigenVector::Flatten(mg_tensor); - auto *mean_grad_out = ctx.Output("MeanGradOut"); - PADDLE_ENFORCE_EQ( - &mg_tensor, mean_grad_out, - platform::errors::InvalidArgument( - "MeanGrad and MeanGradOut must be the same Tensor")); - auto mg_out = framework::EigenVector::Flatten(*mean_grad_out); - - mg_out.device(place) = rho * mg + (1 - rho) * g; - mom_out.device(place) = - momentum * mom + - lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); - } else { - mom_out.device(place) = - momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); - } - p_out.device(place) = p - mom_out; - } else { - DenseRmspropGradFunctor grad_func(grad_tensor.data()); - platform::ForRange for_range(dev_ctx, limit); - if (centered) { - auto &mg_tensor = *ctx.Input("MeanGrad"); - auto *mean_grad_out = ctx.Output("MeanGradOut"); - PADDLE_ENFORCE_EQ( - &mg_tensor, mean_grad_out, - platform::errors::InvalidArgument( - "MeanGrad and MeanGradOut must be the same Tensor")); - for_range(CenteredRmspropFunctor>( - param_out->mutable_data(ctx.GetPlace()), - mean_square_out->mutable_data(ctx.GetPlace()), - moment_out->mutable_data(ctx.GetPlace()), - mean_grad_out->mutable_data(ctx.GetPlace()), - lr_tensor.data(), rho, epsilon, momentum, grad_func)); - } else { - for_range(UncenteredRmspropFunctor>( - param_out->mutable_data(ctx.GetPlace()), - mean_square_out->mutable_data(ctx.GetPlace()), - moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), - rho, epsilon, momentum, grad_func)); - } - } - } else if (grad_var->IsType()) { - auto &grad = grad_var->Get(); - phi::SelectedRows tmp_merged_grad; - phi::SelectedRows *merged_grad = &tmp_merged_grad; - math::scatter::MergeAdd merge_func; - merge_func(dev_ctx, grad, merged_grad); - - platform::ForRange for_range(dev_ctx, limit); - auto &grad_merge_rows = merged_grad->rows(); - paddle::framework::MixVector mixv_grad_merge_rows( - &grad_merge_rows); - const int64_t *rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); - - auto &merged_tensor = merged_grad->value(); - int64_t row_count = merged_grad->rows().size(); - int64_t row_numel = merged_tensor.numel() / row_count; - SparseRmspropGradFunctor grad_func(merged_tensor.data(), rows, - row_numel, row_count); - - if (centered) { - auto &mg_tensor = *ctx.Input("MeanGrad"); - auto *mean_grad_out = ctx.Output("MeanGradOut"); - PADDLE_ENFORCE_EQ( - &mg_tensor, mean_grad_out, - platform::errors::InvalidArgument( - "MeanGrad and MeanGradOut must be the same Tensor")); - for_range(CenteredRmspropFunctor>( - param_out->mutable_data(ctx.GetPlace()), - mean_square_out->mutable_data(ctx.GetPlace()), - moment_out->mutable_data(ctx.GetPlace()), - mean_grad_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), - rho, epsilon, momentum, grad_func)); - } else { - for_range(UncenteredRmspropFunctor>( - param_out->mutable_data(ctx.GetPlace()), - mean_square_out->mutable_data(ctx.GetPlace()), - moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), - rho, epsilon, momentum, grad_func)); - } - } else { - PADDLE_ENFORCE_EQ(false, true, - platform::errors::PermissionDenied( - "Unsupported Variable Type of Grad " - "in RmspropOp. Excepted LodTensor " - "or SelectedRows, But received [%s]", - paddle::framework::ToTypeName(grad_var->Type()))); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc b/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc index 12aa56ebb5c7cdf162aee3af921c6f7a26d01503..111151f2356da1f3ce8f5e646017adf527e1513c 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op_npu.cc @@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/rmsprop_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc b/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc index 6a962b241fafb5fd4f1571ab79d7c20f44a9d5bc..85c2d42c841f020e44994546ea3dafb86de0c8f8 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op_xpu.cc @@ -14,9 +14,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/optimizers/rmsprop_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index b652b1c5fae94a6215040c765407d0c9ee8ab9b7..2247ec9867e9fa886fa18f6130d6305146712433 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -11,7 +11,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor ) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/phi/kernels/adagrad_kernel.h b/paddle/phi/kernels/adagrad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cac662fddf264657a0dcdab3271e83674c3f9231 --- /dev/null +++ b/paddle/phi/kernels/adagrad_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void AdagradDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& moment, + const DenseTensor& learning_rate, + float epsilon, + DenseTensor* param_out, + DenseTensor* moment_out); + +template +void AdagradSparseKernel(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& moment, + const DenseTensor& learning_rate, + float epsilon, + DenseTensor* param_out, + DenseTensor* moment_out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/adagrad_kernel.cc b/paddle/phi/kernels/cpu/adagrad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fcd89caf7fa29d404a427fc6b445a790f7fee2ec --- /dev/null +++ b/paddle/phi/kernels/cpu/adagrad_kernel.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adagrad_kernel.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/adagrad_kernel_impl.h" + +namespace phi { + +namespace { +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} +} // namespace + +template +struct SparseAdagradFunctor { + void operator()(const phi::CPUContext& context, + const phi::SelectedRows& grad, + const DenseTensor& learning_rate, + T epsilon, + DenseTensor* moment, + DenseTensor* param) { + // 1. g_m.rows = set(g.rows) + auto grad_width = grad.value().dims()[1]; + paddle::operators::math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto& merge_rows = grad_merge.rows(); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + + // 2. m += g_m * g_m + auto grad_square = + SquareSelectedRows(context, grad_merge); + + paddle::operators::math::SelectedRowsAddToTensor + functor; + functor(context, grad_square, moment); + + // 3. update parameter + auto* lr = learning_rate.data(); + auto* param_data = param->data(); + auto* moment_data = moment->data(); + + for (size_t i = 0; i < merge_rows.size(); i++) { + for (int64_t j = 0; j < grad_width; j++) { + param_data[merge_rows[i] * grad_width + j] -= + lr[0] * grad_merge_data[i * grad_width + j] / + (std::sqrt(moment_data[merge_rows[i] * grad_width + j]) + epsilon); + } + } + } +}; + +template struct SparseAdagradFunctor; +template struct SparseAdagradFunctor; + +} // namespace phi + +PD_REGISTER_KERNEL( + adagrad, CPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {} + +PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::AdagradSparseKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/meshgrid_grad_kernel.cc b/paddle/phi/kernels/cpu/meshgrid_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..159d109255381bf80a6129ed8df9ea24ffbe74f6 --- /dev/null +++ b/paddle/phi/kernels/cpu/meshgrid_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/meshgrid_grad_kernel.h" +#include "paddle/phi/kernels/impl/meshgrid_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(meshgrid_grad, + CPU, + ALL_LAYOUT, + phi::MeshgridGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/meshgrid_kernel.cc b/paddle/phi/kernels/cpu/meshgrid_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c201103b3dac4a2304a18ffb17dd0bce16236d64 --- /dev/null +++ b/paddle/phi/kernels/cpu/meshgrid_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/meshgrid_kernel.h" +#include "paddle/phi/kernels/impl/meshgrid_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(meshgrid, + CPU, + ALL_LAYOUT, + phi::MeshgridKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/momentum_kernel.cc b/paddle/phi/kernels/cpu/momentum_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..63cc5592ef42200833001701e11422011ecef5d8 --- /dev/null +++ b/paddle/phi/kernels/cpu/momentum_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/momentum_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" + +PD_REGISTER_KERNEL( + momentum, CPU, ALL_LAYOUT, phi::MomentumDenseKernel, float, double) {} + +PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::MomentumSparseKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/rmsprop_kernel.cc b/paddle/phi/kernels/cpu/rmsprop_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa1e1a2eed345b833f07d3738530af43f9c57bb2 --- /dev/null +++ b/paddle/phi/kernels/cpu/rmsprop_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rmsprop_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/rmsprop_kernel_impl.h" + +PD_REGISTER_KERNEL( + rmsprop, CPU, ALL_LAYOUT, phi::RmspropDenseKernel, float, double) {} + +PD_REGISTER_KERNEL(rmsprop_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::RmspropSparseKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/adagrad_kernel.cu b/paddle/phi/kernels/gpu/adagrad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0e037eb808ceb047f5225eb5baced092c2734986 --- /dev/null +++ b/paddle/phi/kernels/gpu/adagrad_kernel.cu @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adagrad_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/adagrad_kernel_impl.h" + +namespace phi { + +template +__global__ void MergeGradKernel(const T* grad, + const int64_t* grad_rows, + T* grad_merge, + const int64_t* grad_merge_rows, + size_t grad_merge_rows_size, + int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t grad_merge_idx; + + if (tid == 0) { + for (size_t i = 0; i < grad_merge_rows_size; i++) { + if (grad_rows[ty] == grad_merge_rows[i]) { + grad_merge_idx = i; + } + } + } + + __syncthreads(); + + grad += ty * row_numel; + grad_merge += grad_merge_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(grad_merge + index, grad[index]); + } +} + +template +__global__ void SparseAdagradFunctorKernel(const T* grad, + const int64_t* rows, + const T* learning_rate, + T* param, + T* moment, + int64_t row_numel, + T epsilon) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + grad += ty * row_numel; + param += rows[ty] * row_numel; + moment += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicAdd(param + index, + -1.0 * learning_rate[0] * grad[index] / + (sqrt(moment[index]) + epsilon)); + } +} + +template +struct SparseAdagradFunctor { + void operator()(const phi::GPUContext& context, + const phi::SelectedRows& grad, + const DenseTensor& learning_rate, + T epsilon, + DenseTensor* moment, + DenseTensor* param) { + // 1. g_m.rows = set(g.rows) + auto grad_width = grad.value().dims()[1]; + paddle::operators::math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + paddle::framework::Vector merge_rows(grad_merge.rows()); + // 2. m += g_m * g_m + auto grad_square = + SquareSelectedRows(context, grad_merge); + + paddle::operators::math::SelectedRowsAddToTensor + functor; + functor(context, grad_square, moment); + + // 3. update parameter + auto* lr = learning_rate.data(); + auto* param_data = param->data(); + auto* moment_data = moment->data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid2(1, merge_rows.size()); + paddle::framework::MixVector mixv_merge_rows(&merge_rows); + SparseAdagradFunctorKernel< + T, + 256><<(context).stream()>>>( + grad_merge_data, + mixv_merge_rows.CUDAMutableData(context.GetPlace()), + lr, + param_data, + moment_data, + grad_width, + epsilon); + mixv_merge_rows.CopyToCPU(); + } +}; + +template struct SparseAdagradFunctor; +template struct SparseAdagradFunctor; + +} // namespace phi + +PD_REGISTER_KERNEL( + adagrad, GPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {} + +PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::AdagradSparseKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..37f2c40143b65911ed827cbfadbde93e8caa822c --- /dev/null +++ b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/meshgrid_grad_kernel.h" +#include "paddle/phi/kernels/impl/meshgrid_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(meshgrid_grad, + GPU, + ALL_LAYOUT, + phi::MeshgridGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d52d1e115de96d7dbcbd1d4c7e693fb31d075b8 --- /dev/null +++ b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/meshgrid_kernel.h" +#include "paddle/phi/kernels/impl/meshgrid_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(meshgrid, + GPU, + ALL_LAYOUT, + phi::MeshgridKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/momentum_kernel.cu b/paddle/phi/kernels/gpu/momentum_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a4f5d33e6165370afb67960b8eb61200f034229 --- /dev/null +++ b/paddle/phi/kernels/gpu/momentum_kernel.cu @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/momentum_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" + +PD_REGISTER_KERNEL(momentum, + GPU, + ALL_LAYOUT, + phi::MomentumDenseKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::MomentumSparseKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/rmsprop_kernel.cu b/paddle/phi/kernels/gpu/rmsprop_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..071c09ea675788bbae7741c5131fd33c6529684b --- /dev/null +++ b/paddle/phi/kernels/gpu/rmsprop_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rmsprop_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/rmsprop_kernel_impl.h" + +PD_REGISTER_KERNEL( + rmsprop, GPU, ALL_LAYOUT, phi::RmspropDenseKernel, float, double) {} + +PD_REGISTER_KERNEL(rmsprop_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::RmspropSparseKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/adagrad_kernel_impl.h b/paddle/phi/kernels/impl/adagrad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..ca9fedaf158d6cd1bf1623456f6334f96f04a598 --- /dev/null +++ b/paddle/phi/kernels/impl/adagrad_kernel_impl.h @@ -0,0 +1,120 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/adagrad_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +struct SparseAdagradFunctor { + void operator()(const DeviceContext& context, + const phi::SelectedRows& grad, + const DenseTensor& learning_rate, + T epsilon, + DenseTensor* moment, + DenseTensor* param); +}; + +template +phi::SelectedRows SquareSelectedRows(const DeviceContext& context, + const phi::SelectedRows& input) { + phi::SelectedRows out; + out.set_rows(input.rows()); + out.set_height(input.height()); + out.mutable_value()->Resize(input.value().dims()); + context.template Alloc(out.mutable_value()); + auto e_out = EigenVector::Flatten(*(out.mutable_value())); + auto e_in = EigenVector::Flatten(input.value()); + e_out.device(*context.eigen_device()) = e_in.square(); + return out; +} + +template +void AdagradDenseKernel(const Context& ctx, + const DenseTensor& param_t, + const DenseTensor& grad_t, + const DenseTensor& moment_t, + const DenseTensor& learning_rate, + float epsilon_t, + DenseTensor* param_out_tensor, + DenseTensor* moment_out_tensor) { + ctx.template Alloc(param_out_tensor); + ctx.template Alloc(moment_out_tensor); + + T epsilon = static_cast(epsilon_t); + + auto param = EigenVector::Flatten(param_t); + + auto grad = EigenVector::Flatten(grad_t); + + auto moment = EigenVector::Flatten(moment_t); + + auto param_out = EigenVector::Flatten(*param_out_tensor); + auto moment_out = EigenVector::Flatten(*moment_out_tensor); + auto place = *ctx.eigen_device(); + + moment_out.device(place) = moment + grad * grad; + Eigen::DSizes m_dsize(moment_out_tensor->numel()); + if (paddle::platform::is_cpu_place(ctx.GetPlace())) { + auto* lr = learning_rate.data(); + param_out.device(place) = + param - lr[0] * grad / (moment_out.sqrt() + epsilon); + } else { + auto lr = EigenVector::Flatten(learning_rate); + param_out.device(place) = + param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + } +} + +template +void AdagradSparseKernel(const Context& ctx, + const DenseTensor& param_t, + const SelectedRows& grad_t, + const DenseTensor& moment_t, + const DenseTensor& learning_rate, + float epsilon_t, + DenseTensor* param_out, + DenseTensor* moment_out) { + auto* param_out_tensor = param_out; + auto* moment_out_tensor = moment_out; + + ctx.template Alloc(param_out_tensor); + ctx.template Alloc(moment_out_tensor); + + T epsilon = static_cast(epsilon_t); + + auto* param_tensor = ¶m_t; + PADDLE_ENFORCE_EQ(param_tensor, + param_out_tensor, + phi::errors::InvalidArgument( + "the input tensor not euqal with output tensor")); + + auto* moment_tensor = &moment_t; + PADDLE_ENFORCE_EQ(moment_tensor, + moment_out_tensor, + phi::errors::InvalidArgument( + "the input moment not eual with output moment")); + + SparseAdagradFunctor functor; + functor( + ctx, grad_t, learning_rate, epsilon, moment_out_tensor, param_out_tensor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/meshgrid_grad_kernel_impl.h b/paddle/phi/kernels/impl/meshgrid_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..b31fc5ac348fbe70695cd40ecc00ccb1622a6442 --- /dev/null +++ b/paddle/phi/kernels/impl/meshgrid_grad_kernel_impl.h @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/meshgrid_grad_kernel.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void MeshgridBackward(const Context& ctx, + const std::vector& ins, + const std::vector& out_grad, + std::vector outs) { + int n = out_grad.size(); + auto out_dims = out_grad[0]->dims(); + + for (int i = 0; i < n; i++) { + ctx.template Alloc(outs[i]); + auto out_grad_tmp = EigenVector::Flatten(*out_grad[i]); + auto in_grad = EigenVector::Flatten(*outs[i]); + + std::vector reduce_dims_vec; + std::vector reshape_dims_vec; + for (int j = 0; j < n; j++) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + if (j == i) { + reshape_dims_vec.push_back(1); + reshape_dims_vec.push_back(out_dims[j]); + } else { + reshape_dims_vec.push_back(out_dims[j]); + reshape_dims_vec.push_back(1); + } + } + + Eigen::DSizes reduce_dims; + for (int k = 0; k < n; k++) { + reduce_dims[k] = reduce_dims_vec[k]; + } + + Eigen::DSizes reshape_dims; + for (int k = 0; k < n * 2; k++) { + reshape_dims[k] = reshape_dims_vec[k]; + } + + auto& place = *ctx.eigen_device(); + funcs::EigenBroadcastGrad, T, Rank>::Eval( + place, in_grad, out_grad_tmp, reduce_dims, reshape_dims); + } +} + +template +void MeshgridGradKernel(const Context& ctx, + const std::vector& inputs, + const std::vector& outputs_grad, + std::vector inputs_grad) { + int n = outputs_grad.size(); + switch (n) { + case 1: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + case 2: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + case 3: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + case 4: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + case 5: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + case 6: + MeshgridBackward(ctx, inputs, outputs_grad, inputs_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Excepted Tensor numbers between 1 and 6, but only received d% .", + n)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/meshgrid_kernel_impl.h b/paddle/phi/kernels/impl/meshgrid_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9167cab978a19887295e8c3c707e818f24135186 --- /dev/null +++ b/paddle/phi/kernels/impl/meshgrid_kernel_impl.h @@ -0,0 +1,115 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/meshgrid_kernel.h" + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void MeshgridForward(const Context& ctx, + const std::vector& ins, + std::vector outs) { + PADDLE_ENFORCE_EQ( + ins.size() > 1, + true, + phi::errors::InvalidArgument( + "Expected at least 2 input tensors, but only received d%.", + ins.size())); + + int64_t size = ins.size(); + std::vector shape(size); + + for (int64_t i = 0; i < size; i++) { + switch (ins[i]->dims().size()) { + case 0: + shape[i] = 1; + break; + case 1: + shape[i] = ins[i]->dims()[0]; + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Expected scalar or 1D tensor in the tensor list but got tensor " + "%d: ", + i)); + } + } + + for (int64_t i = 0; i < size; i++) { + std::vector view_shape(size, 1); + view_shape[i] = shape[i]; + + DenseTensor reshape_ins_tensor; + paddle::framework::TensorCopy( + *ins[i], ctx.GetPlace(), ctx, &reshape_ins_tensor); + DDim out_dims_reshape = phi::make_ddim(view_shape); + reshape_ins_tensor.Resize(out_dims_reshape); + DDim out_dims = phi::make_ddim(shape); + + Eigen::DSizes bcast_dims; + for (int64_t j = 0; j < size; j++) { + bcast_dims[j] = shape[j]; + } + bcast_dims[i] = 1; + + outs[i]->Resize(out_dims); + auto x = EigenTensor::From( + static_cast(reshape_ins_tensor)); + ctx.template Alloc(outs[i]); + auto y = EigenTensor::From(*outs[i]); + auto& place = *ctx.eigen_device(); + funcs::EigenBroadcast, T, Rank>::Eval( + place, y, x, bcast_dims); + } +} + +template +void MeshgridKernel(const Context& ctx, + const std::vector& inputs, + std::vector outputs) { + int rank = inputs.size(); + switch (rank) { + case 1: + MeshgridForward(ctx, inputs, outputs); + break; + case 2: + MeshgridForward(ctx, inputs, outputs); + break; + case 3: + MeshgridForward(ctx, inputs, outputs); + break; + case 4: + MeshgridForward(ctx, inputs, outputs); + break; + case 5: + MeshgridForward(ctx, inputs, outputs); + break; + case 6: + MeshgridForward(ctx, inputs, outputs); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "Excepted Tensor numbers between 1 and 6, but only received d% .", + rank)); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..3aca225ad403be74b51b66409224b26670f12b52 --- /dev/null +++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h @@ -0,0 +1,703 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/momentum_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/kernels/funcs/algorithm.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +using MultiPrecisionType = typename phi::dtype::MPTypeTrait::Type; + +template +struct CPUDenseUpdater { + template + void operator()(const DenseTensor& param, + const DenseTensor& velocity, + const T& mu, + const T& lr, + const bool use_nesterov, + G&& grad, + DenseTensor* param_out, + DenseTensor* velocity_out) const { + auto param_out_vec = EigenVector::Flatten(*param_out); + auto velocity_out_vec = EigenVector::Flatten(*velocity_out); + + auto param_vec = EigenVector::Flatten(param); + auto velocity_vec = EigenVector::Flatten(velocity); + velocity_out_vec = velocity_vec * mu + grad; + if (use_nesterov) { + param_out_vec = param_vec - (grad + velocity_out_vec * mu) * lr; + } else { + param_out_vec = param_vec - lr * velocity_out_vec; + } + } +}; + +struct NoNesterov; +struct UseNesterov; + +enum class RegularizationType { + kNONE = 0, + kL1DECAY = 1, // do not need support right now + kL2DECAY = 2, +}; + +template +class CPUDenseMomentumFunctor { + public: + void operator()(const DenseTensor* param, + const DenseTensor* grad, + const DenseTensor* velocity, + const DenseTensor* learning_rate, + const T mu, + const bool use_nesterov, + const RegularizationType regularization_flag, + const T regularization_coeff, + DenseTensor* param_out, + DenseTensor* velocity_out) { + auto grad_vec = EigenVector::Flatten(*grad); + auto* lr = learning_rate->data>(); + + CPUDenseUpdater updater; + if (regularization_flag == RegularizationType::kL2DECAY) { + auto param_vec = EigenVector::Flatten(*param); + updater(*param, + *velocity, + mu, + static_cast(lr[0]), + use_nesterov, + param_vec * regularization_coeff + grad_vec, + param_out, + velocity_out); + } else { + updater(*param, + *velocity, + mu, + static_cast(lr[0]), + use_nesterov, + grad_vec, + param_out, + velocity_out); + } + } +}; + +template +class DenseMomentumFunctor; + +// NOTE(dzh) for performance. +// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two +// functor. +template +class DenseMomentumFunctor { + private: + const T* param_; + const T* grad_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; + const int64_t num_; + T* param_out_; + MT* velocity_out_; + MT* master_param_out_; + const MT regularization_coeff_; + + public: + DenseMomentumFunctor(const T* param, + const T* grad, + const MT* velocity, + const MultiPrecisionType* learning_rate, + const MT* master_param, + const MT mu, + const MT rescale_grad, + const int64_t num, + const MT regularization_coeff, + T* param_out, + MT* velocity_out, + MT* master_param_out) + : param_(param), + grad_(grad), + velocity_(velocity), + lr_(learning_rate), + master_param_(master_param), + mu_(mu), + rescale_grad_(rescale_grad), + num_(num), + param_out_(param_out), + velocity_out_(velocity_out), + master_param_out_(master_param_out), + regularization_coeff_(regularization_coeff) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + MT grad = static_cast(grad_[i]) * rescale_grad_; + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; + + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } + + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - (grad + velocity_out * mu_) * lr; + // write reigster to memory + velocity_out_[i] = velocity_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } + } +}; + +template +class DenseMomentumFunctor { + private: + const T* param_; + const T* grad_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; + const int64_t num_; + T* param_out_; + MT* velocity_out_; + MT* master_param_out_; + const MT regularization_coeff_; + + public: + DenseMomentumFunctor(const T* param, + const T* grad, + const MT* velocity, + const MultiPrecisionType* learning_rate, + const MT* master_param, + const MT mu, + const MT rescale_grad, + const int64_t num, + const MT regularization_coeff, + T* param_out, + MT* velocity_out, + MT* master_param_out) + : param_(param), + grad_(grad), + velocity_(velocity), + lr_(learning_rate), + master_param_(master_param), + mu_(mu), + rescale_grad_(rescale_grad), + num_(num), + param_out_(param_out), + velocity_out_(velocity_out), + master_param_out_(master_param_out), + regularization_coeff_(regularization_coeff) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + MT grad = static_cast(grad_[i]) * rescale_grad_; + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; + + if (kRegType == RegularizationType::kL2DECAY) { + grad += regularization_coeff_ * param; + } + + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - lr * velocity_out; + // write reigster to memory + velocity_out_[i] = velocity_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } + } +}; + +template +class SparseMomentumFunctor; + +template +class SparseMomentumFunctor { + private: + const T* param_; + const T* grad_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* param_out_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; + + public: + SparseMomentumFunctor(const T* param, + const T* grad, + const MT* velocity, + const MultiPrecisionType* lr, + const MT* master_param, + const MT mu, + const MT rescale_grad, + const int64_t* rows, + int64_t row_numel, + int64_t row_height, + const RegularizationType regularization_flag, + const MT regularization_coeff, + T* param_out, + MT* velocity_out, + MT* master_param_out) + : param_(param), + grad_(grad), + velocity_(velocity), + lr_(lr), + master_param_(master_param), + mu_(mu), + rescale_grad_(rescale_grad), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + param_out_(param_out), + velocity_out_(velocity_out), + master_param_out_(master_param_out), + regularization_flag_(regularization_flag), + regularization_coeff_(regularization_coeff) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + phi::funcs::BinarySearch(rows_, row_height_, i / row_numel_); + MT grad = + row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * + rescale_grad_ + : static_cast(0); + // put memory access in register + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; + + grad = regularization_flag_ == RegularizationType::kL2DECAY + ? grad + regularization_coeff_ * param + : grad; + + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - (grad + velocity_out * mu_) * lr; + // write reigster to memory + velocity_out_[i] = velocity_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } + } +}; + +template +class SparseMomentumFunctor { + private: + const T* param_; + const T* grad_; + const MT* velocity_; + const MultiPrecisionType* lr_; + const MT* master_param_; + const MT mu_; + const MT rescale_grad_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* param_out_; + MT* velocity_out_; + MT* master_param_out_; + const RegularizationType regularization_flag_; + const MT regularization_coeff_; + + public: + SparseMomentumFunctor(const T* param, + const T* grad, + const MT* velocity, + const MultiPrecisionType* lr, + const MT* master_param, + const MT mu, + const MT rescale_grad, + const int64_t* rows, + int64_t row_numel, + int64_t row_height, + const RegularizationType regularization_flag, + const MT regularization_coeff, + T* param_out, + MT* velocity_out, + MT* master_param_out) + : param_(param), + grad_(grad), + velocity_(velocity), + lr_(lr), + master_param_(master_param), + mu_(mu), + rescale_grad_(rescale_grad), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + param_out_(param_out), + velocity_out_(velocity_out), + master_param_out_(master_param_out), + regularization_flag_(regularization_flag), + regularization_coeff_(regularization_coeff) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + phi::funcs::BinarySearch(rows_, row_height_, i / row_numel_); + MT grad = + row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) * + rescale_grad_ + : static_cast(0); + // put memory access in register + const MT param = + master_param_ ? master_param_[i] : static_cast(param_[i]); + const MT lr = static_cast(lr_[0]); + const MT velocity = velocity_[i]; + + grad = regularization_flag_ == RegularizationType::kL2DECAY + ? grad + regularization_coeff_ * param + : grad; + + MT velocity_out = velocity * mu_ + grad; + MT param_out = param - velocity_out * lr; + // write reigster to memory + velocity_out_[i] = velocity_out; + param_out_[i] = static_cast(param_out); + if (master_param_out_) { + master_param_out_[i] = param_out; + } + } +}; + +template +void MomentumDenseImpl(const Context& ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param_opt, + float mu_t, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff_t, + bool multi_precision, + float rescale_grad_t, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out) { + MT regularization_coeff = static_cast(regularization_coeff_t); + RegularizationType regularization_flag{ + RegularizationType::kNONE}; // disable regularization + if (regularization_method == "l2_decay") { + regularization_flag = RegularizationType::kL2DECAY; + } + MT mu = static_cast(mu_t); + MT rescale_grad = static_cast(rescale_grad_t); + auto master_param = master_param_opt.get_ptr(); + if (multi_precision) { + bool has_master = ((master_param_opt.get_ptr() != nullptr) && + (master_param_out != nullptr)); + PADDLE_ENFORCE_EQ(has_master, + true, + phi::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + } + + ctx.template Alloc(param_out); + ctx.template Alloc(velocity_out); + const MT* master_in_data = + multi_precision ? master_param->data() : nullptr; + MT* master_out_data = + multi_precision ? ctx.template Alloc(master_param_out) : nullptr; + if (paddle::platform::is_cpu_place(ctx.GetPlace())) { + CPUDenseMomentumFunctor functor; + functor(¶m, + &grad, + &velocity, + &learning_rate, + mu, + use_nesterov, + regularization_flag, + regularization_coeff, + param_out, + velocity_out); + } else if (paddle::platform::is_gpu_place(ctx.GetPlace())) { + funcs::ForRange for_range(ctx, param.numel()); +#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ + DenseMomentumFunctor functor( \ + param.data(), \ + grad.data(), \ + velocity.data(), \ + learning_rate.data>(), \ + master_in_data, \ + mu, \ + rescale_grad, \ + param.numel(), \ + regularization_coeff, \ + ctx.template Alloc(param_out), \ + ctx.template Alloc(velocity_out), \ + master_out_data); \ + for_range(functor); + + if (use_nesterov) { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(UseNesterov, + RegularizationType::kNONE); + } + } else { + if (regularization_flag == RegularizationType::kL2DECAY) { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kL2DECAY); + } else { + PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(NoNesterov, + RegularizationType::kNONE); + } + } + } +} + +template +void MomentumSparseImpl(const Context& ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param_opt, + float mu_t, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff_t, + bool multi_precision, + float rescale_grad_t, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out) { + MT regularization_coeff = static_cast(regularization_coeff_t); + RegularizationType regularization_flag{ + RegularizationType::kNONE}; // disable regularization + if (regularization_method == "l2_decay") { + regularization_flag = RegularizationType::kL2DECAY; + } + + MT mu = static_cast(mu_t); + MT rescale_grad = static_cast(rescale_grad_t); + + auto master_param = master_param_opt.get_ptr(); + if (multi_precision) { + bool has_master = ((master_param_opt.get_ptr() != nullptr) && + (master_param_out != nullptr)); + PADDLE_ENFORCE_EQ(has_master, + true, + phi::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + } + + ctx.template Alloc(param_out); + ctx.template Alloc(velocity_out); + + const MT* master_in_data = + multi_precision ? master_param->data() : nullptr; + MT* master_out_data = + multi_precision ? ctx.template Alloc(master_param_out) : nullptr; + + // sparse update maybe empty. + if (grad.rows().size() == 0) { + VLOG(3) << "Grad SelectedRows contains no data!"; + return; + } + + phi::SelectedRows tmp_merged_grad; + phi::SelectedRows* merged_grad = &tmp_merged_grad; + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(ctx, grad, merged_grad); + + auto* grad_merge_rows = merged_grad->mutable_rows(); + paddle::framework::MixVector mixv_grad_merge_rows(grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); + int64_t row_numel = merged_grad->value().numel() / merged_grad->rows().size(); + funcs::ForRange for_range(ctx, param.numel()); + if (use_nesterov) { + SparseMomentumFunctor functor( + param.data(), + merged_grad->value().data(), + velocity.data(), + learning_rate.data>(), + master_in_data, + mu, + rescale_grad, + rows, + row_numel, + static_cast(merged_grad->rows().size()), + regularization_flag, + regularization_coeff, + ctx.template Alloc(param_out), + ctx.template Alloc(velocity_out), + master_out_data); + for_range(functor); + + } else { + SparseMomentumFunctor functor( + param.data(), + merged_grad->value().data(), + velocity.data(), + learning_rate.data>(), + master_in_data, + mu, + rescale_grad, + rows, + row_numel, + static_cast(merged_grad->rows().size()), + regularization_flag, + regularization_coeff, + ctx.template Alloc(param_out), + ctx.template Alloc(velocity_out), + master_out_data); + for_range(functor); + } +} + +template +void MomentumDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out) { + using MT = typename phi::dtype::MPTypeTrait::Type; + if (multi_precision) { + MomentumDenseImpl(dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + multi_precision, + rescale_grad, + param_out, + velocity_out, + master_param_out); + } else { + MomentumDenseImpl(dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + multi_precision, + rescale_grad, + param_out, + velocity_out, + master_param_out); + } +} + +template +void MomentumSparseKernel(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out) { + using MT = typename phi::dtype::MPTypeTrait::Type; + if (multi_precision) { + MomentumSparseImpl(dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + multi_precision, + rescale_grad, + param_out, + velocity_out, + master_param_out); + } else { + MomentumSparseImpl(dev_ctx, + param, + grad, + velocity, + learning_rate, + master_param, + mu, + use_nesterov, + regularization_method, + regularization_coeff, + multi_precision, + rescale_grad, + param_out, + velocity_out, + master_param_out); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/rmsprop_kernel_impl.h b/paddle/phi/kernels/impl/rmsprop_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..64b12837074dd0d9280caf576750dfbb3f7d735f --- /dev/null +++ b/paddle/phi/kernels/impl/rmsprop_kernel_impl.h @@ -0,0 +1,336 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/kernels/rmsprop_kernel.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/kernels/funcs/algorithm.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +struct DenseRmspropGradFunctor { + inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } + + const T *grad_; +}; + +template +struct SparseRmspropGradFunctor { + inline SparseRmspropGradFunctor(const T *grad, + const int64_t *rows, + int64_t row_numel, + int64_t row_count) + : grad_(grad), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { + auto row_idx = + phi::funcs::BinarySearch(rows_, row_count_, idx / row_numel_); + return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; + } + + const T *grad_; + const int64_t *rows_; + int64_t row_numel_; + int64_t row_count_; +}; + +template +struct UncenteredRmspropFunctor { + UncenteredRmspropFunctor(T *param, + T *ms, + T *mom, + const T *lr, + T rho, + T epsilon, + T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + } + + T *param_; + T *ms_; + T *mom_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +struct CenteredRmspropFunctor { + CenteredRmspropFunctor(T *param, + T *ms, + T *mom, + T *mean_grad, + const T *lr, + T rho, + T epsilon, + T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + mean_grad_(mean_grad), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; + T mom_out = momentum_ * mom_[idx] + + lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + mean_grad_[idx] = mg_out; + } + + T *param_; + T *ms_; + T *mom_; + T *mean_grad_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +void RmspropDenseKernel(const Context &ctx, + const DenseTensor ¶m, + const DenseTensor &mean_square, + const DenseTensor &grad, + const DenseTensor &moment, + const DenseTensor &learning_rate, + paddle::optional mean_grad_opt, + float epsilon_t, + float decay_t, + float momentum_t, + bool centered, + DenseTensor *param_out, + DenseTensor *moment_out, + DenseTensor *mean_square_out, + DenseTensor *mean_grad_out) { + auto epsilon = static_cast(epsilon_t); + auto rho = static_cast(decay_t); + auto momentum = static_cast(momentum_t); + + auto &p_tensor = param; + auto &ms_tensor = mean_square; + auto &lr_tensor = learning_rate; + auto &mom_tensor = moment; + + PADDLE_ENFORCE_EQ(p_tensor.IsSharedBufferWith(*param_out), + true, + phi::errors::InvalidArgument( + "Param and ParamOut must be the same Tensor")); + PADDLE_ENFORCE_EQ(mom_tensor.IsSharedBufferWith(*moment_out), + true, + phi::errors::InvalidArgument( + "Moment and MomentOut must be the same Tensor")); + PADDLE_ENFORCE_EQ( + ms_tensor.IsSharedBufferWith(*mean_square_out), + true, + phi::errors::InvalidArgument( + "MeanSquare and MeanSquareOut must be the same Tensor")); + size_t limit = static_cast(ms_tensor.numel()); + auto &grad_tensor = grad; + if (paddle::platform::is_cpu_place(ctx.GetPlace())) { + auto &place = *ctx.eigen_device(); + auto lr_value = lr_tensor.data()[0]; + + auto p = EigenVector::Flatten(p_tensor); + auto ms = EigenVector::Flatten(ms_tensor); + auto g = EigenVector::Flatten(grad_tensor); + auto mom = EigenVector::Flatten(mom_tensor); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + if (centered) { + auto mg_tensor = mean_grad_opt.get_ptr(); + auto mg = EigenVector::Flatten(*mg_tensor); + PADDLE_ENFORCE_EQ( + mg_tensor, + mean_grad_out, + phi::errors::InvalidArgument( + "MeanGrad and MeanGradOut must be the same Tensor")); + auto mg_out = EigenVector::Flatten(*mean_grad_out); + + mg_out.device(place) = rho * mg + (1 - rho) * g; + mom_out.device(place) = + momentum * mom + + lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); + } else { + mom_out.device(place) = + momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); + } + p_out.device(place) = p - mom_out; + } else { + DenseRmspropGradFunctor grad_func(grad_tensor.data()); + funcs::ForRange for_range(ctx, limit); + if (centered) { + auto mg_tensor = mean_grad_opt.get_ptr(); + + PADDLE_ENFORCE_EQ( + mg_tensor, + mean_grad_out, + phi::errors::InvalidArgument( + "MeanGrad and MeanGradOut must be the same Tensor")); + for_range(CenteredRmspropFunctor>( + ctx.template Alloc(param_out), + ctx.template Alloc(mean_square_out), + ctx.template Alloc(moment_out), + ctx.template Alloc(mean_grad_out), + lr_tensor.data(), + rho, + epsilon, + momentum, + grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + ctx.template Alloc(param_out), + ctx.template Alloc(mean_square_out), + ctx.template Alloc(moment_out), + lr_tensor.data(), + rho, + epsilon, + momentum, + grad_func)); + } + } +} + +template +void RmspropSparseKernel(const Context &ctx, + const DenseTensor ¶m, + const DenseTensor &mean_square, + const SelectedRows &grad, + const DenseTensor &moment, + const DenseTensor &learning_rate, + paddle::optional mean_grad_opt, + float epsilon_t, + float decay_t, + float momentum_t, + bool centered, + DenseTensor *param_out, + DenseTensor *moment_out, + DenseTensor *mean_square_out, + DenseTensor *mean_grad_out) { + auto epsilon = static_cast(epsilon_t); + auto rho = static_cast(decay_t); + auto momentum = static_cast(momentum_t); + + auto &p_tensor = param; + auto &ms_tensor = mean_square; + auto &lr_tensor = learning_rate; + auto &mom_tensor = moment; + + PADDLE_ENFORCE_EQ(p_tensor.IsSharedBufferWith(*param_out), + true, + phi::errors::InvalidArgument( + "Param and ParamOut must be the same Tensor")); + PADDLE_ENFORCE_EQ(mom_tensor.IsSharedBufferWith(*moment_out), + true, + phi::errors::InvalidArgument( + "Moment and MomentOut must be the same Tensor")); + PADDLE_ENFORCE_EQ( + ms_tensor.IsSharedBufferWith(*mean_square_out), + true, + phi::errors::InvalidArgument( + "MeanSquare and MeanSquareOut must be the same Tensor")); + size_t limit = static_cast(ms_tensor.numel()); + + phi::SelectedRows tmp_merged_grad; + phi::SelectedRows *merged_grad = &tmp_merged_grad; + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(ctx, grad, merged_grad); + + funcs::ForRange for_range(ctx, limit); + auto &grad_merge_rows = merged_grad->rows(); + paddle::framework::MixVector mixv_grad_merge_rows(&grad_merge_rows); + const int64_t *rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); + + auto &merged_tensor = merged_grad->value(); + int64_t row_count = merged_grad->rows().size(); + int64_t row_numel = merged_tensor.numel() / row_count; + SparseRmspropGradFunctor grad_func( + merged_tensor.data(), rows, row_numel, row_count); + + if (centered) { + auto mg_tensor = mean_grad_opt.get_ptr(); + + PADDLE_ENFORCE_EQ(mg_tensor, + mean_grad_out, + phi::errors::InvalidArgument( + "MeanGrad and MeanGradOut must be the same Tensor")); + for_range(CenteredRmspropFunctor>( + ctx.template Alloc(param_out), + ctx.template Alloc(mean_square_out), + ctx.template Alloc(moment_out), + ctx.template Alloc(mean_grad_out), + lr_tensor.data(), + rho, + epsilon, + momentum, + grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + ctx.template Alloc(param_out), + ctx.template Alloc(mean_square_out), + ctx.template Alloc(moment_out), + lr_tensor.data(), + rho, + epsilon, + momentum, + grad_func)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/meshgrid_grad_kernel.h b/paddle/phi/kernels/meshgrid_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9ce98db63cb5d599fb5cc3b851a3bbb3fe78ad46 --- /dev/null +++ b/paddle/phi/kernels/meshgrid_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MeshgridGradKernel(const Context& ctx, + const std::vector& inputs, + const std::vector& outputs_grad, + std::vector inputs_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/meshgrid_kernel.h b/paddle/phi/kernels/meshgrid_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d468c7c1398aa0dad86f564bb366725ed3d8d6d4 --- /dev/null +++ b/paddle/phi/kernels/meshgrid_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MeshgridKernel(const Context& ctx, + const std::vector& inputs, + std::vector outputs); + +} // namespace phi diff --git a/paddle/phi/kernels/momentum_kernel.h b/paddle/phi/kernels/momentum_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b4ba449aaf3a5337fae0c3d16bcdf0d176b060ee --- /dev/null +++ b/paddle/phi/kernels/momentum_kernel.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void MomentumDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out); + +template +void MomentumSparseKernel(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& velocity, + const DenseTensor& learning_rate, + paddle::optional master_param, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + DenseTensor* param_out, + DenseTensor* velocity_out, + DenseTensor* master_param_out); + +} // namespace phi diff --git a/paddle/phi/kernels/rmsprop_kernel.h b/paddle/phi/kernels/rmsprop_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4c3c9aa822115572c642571b80a9b4958bf84f7e --- /dev/null +++ b/paddle/phi/kernels/rmsprop_kernel.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void RmspropDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& mean_square, + const DenseTensor& grad, + const DenseTensor& moment, + const DenseTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + DenseTensor* param_out, + DenseTensor* moment_out, + DenseTensor* mean_square_out, + DenseTensor* mean_grad_out); + +template +void RmspropSparseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& mean_square, + const SelectedRows& grad, + const DenseTensor& moment, + const DenseTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + DenseTensor* param_out, + DenseTensor* moment_out, + DenseTensor* mean_square_out, + DenseTensor* mean_grad_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/adagrad_sig.cc b/paddle/phi/ops/compat/adagrad_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d9a8a65d789121e681be24e7e1f47340ee6b1d4 --- /dev/null +++ b/paddle/phi/ops/compat/adagrad_sig.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature AdagradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Grad")) { + return KernelSignature("adagrad", + {"Param", "Grad", "Moment", "LearningRate"}, + {"epsilon"}, + {"ParamOut", "MomentOut"}); + } else if (ctx.IsSelectedRowsInput("Grad")) { + return KernelSignature("adagrad_dense_param_sparse_grad", + {"Param", "Grad", "Moment", "LearningRate"}, + {"epsilon"}, + {"ParamOut", "MomentOut"}); + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(adagrad, phi::AdagradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/meshgrid_sig.cc b/paddle/phi/ops/compat/meshgrid_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..44671c84e7afb5ec781cf1c103c18b0c3886c8be --- /dev/null +++ b/paddle/phi/ops/compat/meshgrid_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MeshgridOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("meshgrid", {"X"}, {}, {"Out"}); +} + +KernelSignature MeshgridGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "meshgrid_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(meshgrid, phi::MeshgridOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(meshgrid_grad, phi::MeshgridGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/momentum_sig.cc b/paddle/phi/ops/compat/momentum_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..3511ddc63c891c9726d71945ed5719d1121aba72 --- /dev/null +++ b/paddle/phi/ops/compat/momentum_sig.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MomentumOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Grad")) { + return KernelSignature( + "momentum", + {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}, + {"mu", + "use_nesterov", + "regularization_method", + "regularization_coeff", + "multi_precision", + "rescale_grad"}, + {"ParamOut", "VelocityOut", "MasterParamOut"}); + } else if (ctx.IsSelectedRowsInput("Grad")) { + return KernelSignature( + "momentum_dense_param_sparse_grad", + {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}, + {"mu", + "use_nesterov", + "regularization_method", + "regularization_coeff", + "multi_precision", + "rescale_grad"}, + {"ParamOut", "VelocityOut", "MasterParamOut"}); + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(momentum, phi::MomentumOpArgumentMapping); diff --git a/paddle/phi/ops/compat/rmsprop_sig.cc b/paddle/phi/ops/compat/rmsprop_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..74def7d0b6a5c75faba9b56648bac79451dd785e --- /dev/null +++ b/paddle/phi/ops/compat/rmsprop_sig.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RmspropOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Grad")) { + return KernelSignature( + "rmsprop", + {"Param", "MeanSquare", "Grad", "Moment", "LearningRate", "MeanGrad"}, + {"epsilon", "decay", "momentum", "centered"}, + {"ParamOut", "MomentOut", "MeanSquareOut", "MeanGradOut"}); + } else if (ctx.IsSelectedRowsInput("Grad")) { + return KernelSignature( + "rmsprop_dense_param_sparse_grad", + {"Param", "MeanSquare", "Grad", "Moment", "LearningRate", "MeanGrad"}, + {"epsilon", "decay", "momentum", "centered"}, + {"ParamOut", "MomentOut", "MeanSquareOut", "MeanGradOut"}); + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(rmsprop, phi::RmspropOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_adagrad_op.py b/python/paddle/fluid/tests/unittests/test_adagrad_op.py index fc3b7ce2fd87afc22030bcca55236fb949c1f129..ae047e602d15a2869b721c6ba0f6b8d8e80fb9f7 100644 --- a/python/paddle/fluid/tests/unittests/test_adagrad_op.py +++ b/python/paddle/fluid/tests/unittests/test_adagrad_op.py @@ -20,6 +20,7 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator from op_test import OpTest import math +import paddle class TestAdagradOp1(OpTest): @@ -189,4 +190,5 @@ class TestSparseAdagradOp(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py index 9bc3bb7ad341f06dc4609d1d744eb911b1878685..c38dea8bc3942eae392c6f1fef1e55791805a6e5 100644 --- a/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_merged_momentum_op.py @@ -258,6 +258,7 @@ class TestMergedMomentum(unittest.TestCase): def setUp(self): paddle.enable_static() self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 def gen_rand_data(self, shapes, dtype): @@ -391,4 +392,5 @@ class TestMergedMomentum2(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index 10058ddae9b10a391f7eedfb156b161df3bcea0a..2cb83eba3767c92f0c6a43e6480e7825d1302051 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -84,7 +84,6 @@ class TestMeshgridOp3(unittest.TestCase): feed={'x': input_1, 'y': input_2}, fetch_list=[grid_x, grid_y]) - assert np.array_equal(res_1, out_1) assert np.array_equal(res_2, out_2) @@ -180,4 +179,5 @@ class TestMeshgridOp8(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index a59b355b4a70e4cf9e4a716438ac97f0e76114e1..7f3690cff60f564b0e530199bb2c1e3c7e571fe1 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -872,6 +872,7 @@ class TestMultiTensorMomentumDygraph(unittest.TestCase): place=place, use_amp=use_amp, use_multi_tensor=True) output2, params2 = self._momentum_optimize_dygraph( place=place, use_amp=use_amp, use_multi_tensor=False) + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) for idx in range(len(params1)): self.assertEqual( @@ -991,4 +992,5 @@ class TestMultiTensorMomentumStatic(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index 08ab2e18c733a6ba4bad904f10abce2baf9517ed..62839d3a960f1372557ed149eb9cc3ea67971be0 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -316,4 +316,5 @@ class TestRMSPropV2Group(TestRMSPropV2): if __name__ == "__main__": + paddle.enable_static() unittest.main()