diff --git a/paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc b/paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..55a84dc0dbc8afd4161dd9aa920b6c4c5da55d66 --- /dev/null +++ b/paddle/fluid/operators/optimizers/multi_tensor_adam_op.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +class MultiTensorAdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto param_dtype = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Params"); + return phi::KernelKey(param_dtype, ctx.GetPlace()); + } +}; + +class MultiTensorAdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Params", "(Tensor) Input parameters").AsDuplicable(); + AddInput("Grads", "(Tensor) Input gradients").AsDuplicable(); + AddInput("LearningRate", "(Tensor, default Tensor) Learning rate"); + AddInput("Moments1", "(Tensor) Input first moments").AsDuplicable(); + AddInput("Moments2", "(Tensor) Input second moments").AsDuplicable(); + AddInput("Beta1Pows", + "(Tensor, default Tensor) Input beta1 power accumulator") + .AsDuplicable(); + AddInput("Beta2Pows", + "(Tensor, default Tensor) Input beta2 power accumulator") + .AsDuplicable(); + AddInput("MasterParams", "FP32 master weight for AMP.") + .AsDispensable() + .AsDuplicable(); + AddInput("SkipUpdate", "(Tensor, optional), Skip the update or not.") + .AsDispensable(); + + AddOutput("ParamsOut", "(Tensor) Output parameters").AsDuplicable(); + AddOutput("Moments1Out", "(Tensor) Output first moments").AsDuplicable(); + AddOutput("Moments2Out", "(Tensor) Output second moments").AsDuplicable(); + AddOutput("Beta1PowsOut", "(Tensor) Output beta1 power accumulator") + .AsDuplicable(); + AddOutput("Beta2PowsOut", "(Tensor) Output beta2 power accumulator") + .AsDuplicable(); + AddOutput("MasterParamsOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParams).") + .AsDispensable() + .AsDuplicable(); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f) + .SupportTensor(); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f) + .SupportTensor(); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f) + .SupportTensor(); + + AddAttr("chunk_size", "ChunkSize for blocks computing"); + + AddAttr("weight_decay", + "(float, default 0) " + "weight decay (L2 penalty)") + .SetDefault(0); + AddAttr("use_adamw", + "(bool, default False) " + "Whether to use AdamW" + "True for decoupled weight decay") + .SetDefault(false); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + // TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut + // as dispensable since they are not used when use_global_beta_pow is true. + AddAttr("use_global_beta_pow", + "(bool, default false) " + "Whether to use global beta_pow for whole model instead of " + "creating beta_pow for each parameter.") + .SetDefault(false); + + AddComment(R"DOC( +Adam Optimizer. + +This implements the Adam optimizer from Section 2 of the Adam +paper : https://arxiv.org/abs/1412.6980. +Adam is a first-order gradient-based optimization method based on +adaptive estimates of lower-order moments. + +Adam updates: + +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} +$$ + +AdamW updates: + +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out & = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param) +$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(multi_tensor_adam, + MultiTensorAdamInferShapeFunctor, + PD_INFER_META(phi::MultiTensorAdamInferMeta)); +REGISTER_OPERATOR( + multi_tensor_adam, + ops::MultiTensorAdamOp, + ops::MultiTensorAdamOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + MultiTensorAdamInferShapeFunctor); diff --git a/paddle/fluid/pybind/eager_generator.h b/paddle/fluid/pybind/eager_generator.h index 68ed995403aebfd43033ff618a3a4969e1638e03..05263cb27e4d6975ade4e6ed1883cceb07a73dc9 100644 --- a/paddle/fluid/pybind/eager_generator.h +++ b/paddle/fluid/pybind/eager_generator.h @@ -165,6 +165,16 @@ std::map> op_ins_map = { "Beta1Pow", "Beta2Pow", "MasterParam"}}, + {"multi_tensor_adam", + {"Params", + "Grads", + "LearningRate", + "Moments1", + "Moments2", + "Beta1Pows", + "Beta2Pows", + "MasterParams", + "SkipUpdate"}}, {"adamw", {"Param", "Grad", @@ -321,6 +331,13 @@ std::map> op_outs_map = { "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"multi_tensor_adam", + {"ParamsOut", + "Moments1Out", + "Moments2Out", + "Beta1PowsOut", + "Beta2PowsOut", + "MasterParamsOut"}}, {"adamw", {"ParamOut", "Moment1Out", @@ -382,6 +399,13 @@ std::map> op_passing_outs_map = { "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"multi_tensor_adam", + {"ParamsOut", + "Moments1Out", + "Moments2Out", + "Beta1PowsOut", + "Beta2PowsOut", + "MasterParamsOut"}}, {"adamw", {"ParamOut", "Moment1Out", diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 94bcff5b8986e98062bc65dbca261a17e06e912d..f0de136a56eb14dbeaca414d355122fc0950273a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1225,6 +1225,17 @@ optional : master_param inplace : (param -> param_out), (velocity -> velocity_out), (master_param -> master_param_out) +- op : multi_tensor_adam_ + args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) + output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} + infer_meta : + func : MultiTensorAdamInferMeta + kernel : + func : multi_tensor_adam + data_type : params + optional : skip_update, master_params + inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) + - op : multiclass_nms3 args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0) output : Tensor(out), Tensor(index), Tensor(nms_rois_num) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 545b3c6f52354db9bd81440392112fcad9d24655..cc1b02696144f052e6b4836eba4ab7b7b325c3ed 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2981,6 +2981,49 @@ void YoloLossInferMeta(const MetaTensor& x, gt_match_mask->set_dtype(x.dtype()); } +void MultiTensorAdamInferMeta( + const std::vector& params, + const std::vector& grads, + const MetaTensor& learning_rate, + const std::vector& moments1, + const std::vector& moments2, + const std::vector& beta1_pows, + const std::vector& beta2_pows, + const paddle::optional>& master_params, + const MetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + int chunk_size, + float weight_decay, + bool use_adamw, + bool multi_precision, + bool use_global_beta_pow, + std::vector params_out, + std::vector moments1_out, + std::vector moments2_out, + std::vector beta1_pows_out, + std::vector beta2_pows_out, + std::vector master_params_out) { + size_t in_size = params.size(); + for (size_t i = 0; i < in_size; i++) { + params_out[i]->set_dims(params[i]->dims()); + params_out[i]->set_dtype(params[i]->dtype()); + moments1_out[i]->set_dims(moments1[i]->dims()); + moments1_out[i]->set_dtype(moments1[i]->dtype()); + moments2_out[i]->set_dims(moments2[i]->dims()); + moments2_out[i]->set_dtype(moments2[i]->dtype()); + beta1_pows_out[i]->set_dims(beta1_pows[i]->dims()); + beta1_pows_out[i]->set_dtype(beta1_pows[i]->dtype()); + beta2_pows_out[i]->set_dims(beta2_pows[i]->dims()); + beta2_pows_out[i]->set_dtype(beta2_pows[i]->dtype()); + if (master_params && !master_params_out.empty()) { + master_params_out[i]->set_dims(master_params.get()[i]->dims()); + master_params_out[i]->set_dtype(master_params.get()[i]->dtype()); + } + } +} + void MoeInferMeta(const MetaTensor& x, const MetaTensor& gate, const MetaTensor& bmm0, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 1d7cf7d1c27be25bec854289e03793957234eb89..c0abeb222cbbafb59318380fc372ee8366559cad 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -532,6 +532,31 @@ void YoloLossInferMeta(const MetaTensor& x, MetaTensor* objectness_mask, MetaTensor* gt_match_mask); +void MultiTensorAdamInferMeta( + const std::vector& params, + const std::vector& grads, + const MetaTensor& learning_rate, + const std::vector& moments1, + const std::vector& moments2, + const std::vector& beta1_pows, + const std::vector& beta2_pows, + const paddle::optional>& master_params, + const MetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + int chunk_size, + float weight_decay, + bool use_adamw, + bool multi_precision, + bool use_global_beta_pow, + std::vector params_out, + std::vector moments1_out, + std::vector moments2_out, + std::vector beta1_pows_out, + std::vector beta2_pows_out, + std::vector master_params_out); + void MoeInferMeta(const MetaTensor& x, const MetaTensor& gate, const MetaTensor& bmm0, diff --git a/paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc b/paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ca5641a2406a98008cced3f1d2b7cd31ecce660 --- /dev/null +++ b/paddle/phi/kernels/cpu/multi_tensor_adam_kernel.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" +#include + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +#include "paddle/phi/kernels/adam_kernel.h" +#include "paddle/phi/kernels/adamw_kernel.h" + +namespace phi { + +static paddle::optional TensorPtrToOptionalTensor( + const paddle::optional>& t, size_t idx) { + return t ? paddle::optional(*(t.get()[idx])) : paddle::none; +} + +template +void MultiTensorAdamKernel( + const Context& dev_ctx, + const std::vector& params, + const std::vector& grads, + const DenseTensor& learning_rate, + const std::vector& moments1, + const std::vector& moments2, + const std::vector& beta1_pows, + const std::vector& beta2_pows, + const paddle::optional>& master_params, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + int chunk_size, + float weight_decay, + bool use_adamw, + bool multi_precision, + bool use_global_beta_pow, + std::vector params_out, + std::vector moments1_out, + std::vector moments2_out, + std::vector beta1_pows_out, + std::vector beta2_pows_out, + std::vector master_params_out) { + size_t params_num = params.size(); + PADDLE_ENFORCE_EQ( + params_num, + grads.size(), + errors::InvalidArgument("The size of Input(grads) must be equal to " + "Input(params), but got the size of Input(grads) " + "is %d, the size of Input(params) is %d.", + grads.size(), + params_num)); + PADDLE_ENFORCE_EQ(params_num, + moments1.size(), + errors::InvalidArgument( + "The size of Input(moments1) must be equal to " + "Input(params), but got the size of Input(moments1) " + "is %d, the size of Input(params) is %d.", + moments1.size(), + params_num)); + PADDLE_ENFORCE_EQ(params_num, + moments2.size(), + errors::InvalidArgument( + "The size of Input(moments2) must be equal to " + "Input(params), but got the size of Input(moments2) " + "is %d, the size of Input(params) is %d.", + moments2.size(), + params_num)); + PADDLE_ENFORCE_EQ(params_num, + beta1_pows.size(), + errors::InvalidArgument( + "The size of Input(beta1_pows) must be equal to " + "Input(params), but got the size of Input(beta1_pows) " + "is %d, the size of Input(params) is %d.", + beta1_pows.size(), + params_num)); + PADDLE_ENFORCE_EQ(params_num, + beta2_pows.size(), + errors::InvalidArgument( + "The size of Input(beta2_pows) must be equal to " + "Input(params), but got the size of Input(beta2_pows) " + "is %d, the size of Input(params) is %d.", + beta2_pows.size(), + params_num)); + + for (size_t idx = 0; idx < params_num; idx++) { + auto master_params_tmp = TensorPtrToOptionalTensor(master_params, idx); + if (!use_adamw) { + AdamDenseKernel( + dev_ctx, + *params[idx], + *grads[idx], + learning_rate, + *moments1[idx], + *moments2[idx], + *beta1_pows[idx], + *beta2_pows[idx], + master_params_tmp, + skip_update, + beta1, + beta2, + epsilon, + false, + 1000, + multi_precision, + use_global_beta_pow, + params_out[idx], + moments1_out[idx], + moments2_out[idx], + beta1_pows_out[idx], + beta2_pows_out[idx], + master_params_out.empty() ? nullptr : master_params_out[idx]); + } else { + AdamwDenseKernel( + dev_ctx, + *params[idx], + *grads[idx], + learning_rate, + *moments1[idx], + *moments2[idx], + *beta1_pows[idx], + *beta2_pows[idx], + master_params_tmp, + skip_update, + beta1, + beta2, + epsilon, + 1.0, + weight_decay, + use_adamw, + false, + 1000, + multi_precision, + use_global_beta_pow, + params_out[idx], + moments1_out[idx], + moments2_out[idx], + beta1_pows_out[idx], + beta2_pows_out[idx], + master_params_out.empty() ? nullptr : master_params_out[idx]); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multi_tensor_adam, + CPU, + ALL_LAYOUT, + phi::MultiTensorAdamKernel, + float, + double) {} diff --git a/paddle/phi/kernels/funcs/multi_tensor_apply.h b/paddle/phi/kernels/funcs/multi_tensor_apply.h new file mode 100644 index 0000000000000000000000000000000000000000..5be64dcab2ef107d49253ac4953a7480b7115b14 --- /dev/null +++ b/paddle/phi/kernels/funcs/multi_tensor_apply.h @@ -0,0 +1,179 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { +namespace funcs { + +// This code is referenced from apex's multi_tensor_apply.cuh. +// https://github.com/NVIDIA/apex + +template +struct TensorAndBlockInfo { + void *tensor_addrs[N - 1][MaxTensorSize]; + const void *grads[MaxTensorSize]; + int sizes[MaxTensorSize]; + uint8_t tensor_ids[MaxBlockSize]; + // int16 + uint16_t chunk_ids[MaxBlockSize]; + int start_chunk_id; + + DEVICE void GetChunkIdAndTensorId(int *chunk_id, int *tensor_id) const { + int block_id = blockIdx.x; + int tmp_tensor_id = tensor_ids[block_id]; + *chunk_id = static_cast(chunk_ids[block_id]) + + (tmp_tensor_id == 0) * start_chunk_id; + *tensor_id = tmp_tensor_id; + } +}; + +template +__global__ void MultiTensorApplyCudaKernel( + int chunk_size, + TensorAndBlockInfo t_info, + Functor functor, + ArgTypes... args) { + functor(chunk_size, t_info, args...); +} + +template +void LaunchMultiTensorApplyKernel( + const Context &dev_ctx, + int block_size, + int chunk_size, + const std::vector> &input_vector, + const std::vector &grads, + Functor functor, + ArgTypes... args) { + PADDLE_ENFORCE_EQ( + input_vector.size(), + InputNum - 1, + errors::InvalidArgument( + "input_vector.size() != InputNum - 1, the input vector's size is " + "unequal to InputNum - 1, please cheack grads, params, momemts1, " + "moments2, and, master_params.")); + size_t length = input_vector[0].size(); + PADDLE_ENFORCE_GT( + length, + 0, + errors::InvalidArgument( + "input_vector[0].size() is not > 0, please cheack params.")); + auto place = input_vector[0][0]->place(); + PADDLE_ENFORCE_EQ( + place, + GPUPlace(), + errors::InvalidArgument( + "expected input to be on gpu, but input is on cpu now.")); + for (size_t i = 0; i < input_vector.size(); i++) { + PADDLE_ENFORCE_EQ( + input_vector[i].size(), + length, + errors::InvalidArgument( + "some input vectors' size mismatch other input vector.")); + for (size_t j = 0; j < input_vector[i].size(); j++) { + PADDLE_ENFORCE_EQ( + input_vector[i][j]->place(), + place, + errors::InvalidArgument( + "A tensor was not on the same device as the first tensor")); + PADDLE_ENFORCE_EQ(input_vector[i][j]->numel(), + input_vector[0][j]->numel(), + errors::InvalidArgument( + "The number of elements of Inputs must be equal.")); + } + } + + size_t tensors_size = input_vector[0].size(); + + TensorAndBlockInfo t_info; + t_info.start_chunk_id = 0; + + auto stream = dev_ctx.stream(); + int block_id = 0; + int tensor_id = 0; + for (int t = 0; t < tensors_size; t++) { + t_info.sizes[tensor_id] = input_vector[0][t]->numel(); + t_info.grads[tensor_id] = grads[t]->data(); + for (int d = 0; d < InputNum - 1; d++) { + t_info.tensor_addrs[d][tensor_id] = input_vector[d][t]->data(); + } + tensor_id++; + int chunks_this_tensor = + (input_vector[0][t]->numel() + chunk_size - 1) / chunk_size; + + constexpr auto kMaxChunkId = std::numeric_limits::max(); + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + t_info.tensor_ids[block_id] = tensor_id - 1; + auto saved_chunk_id = + (tensor_id == 1 ? chunk - t_info.start_chunk_id : chunk); + PADDLE_ENFORCE_GE(saved_chunk_id, + 0, + errors::InvalidArgument( + "The chunk id is less than 0 in " + "MultiTensorApplyKernel. This may be a bug.")); + PADDLE_ENFORCE_LE( + saved_chunk_id, + kMaxChunkId, + errors::InvalidArgument( + "The chunk id exceeds maximum value %d. This may be a bug.", + kMaxChunkId)); + t_info.chunk_ids[block_id] = saved_chunk_id; + block_id++; + bool reach_tensors_limit = + (tensor_id == MaxTensorSize && chunk == chunks_this_tensor - 1); + bool reach_blocks_limit = (block_id == MaxBlockSize); + bool finish_compute = + (t == tensors_size - 1 && chunk == chunks_this_tensor - 1); + if (reach_tensors_limit || reach_blocks_limit || finish_compute) { + MultiTensorApplyCudaKernel + <<>>( + chunk_size, t_info, functor, args...); + + block_id = 0; + if (chunk == chunks_this_tensor - 1) { + tensor_id = 0; + t_info.start_chunk_id = 0; + } else { + t_info.sizes[0] = t_info.sizes[tensor_id - 1]; + t_info.grads[0] = t_info.grads[tensor_id - 1]; + for (int d = 0; d < InputNum - 1; d++) { + t_info.tensor_addrs[d][0] = t_info.tensor_addrs[d][tensor_id - 1]; + } + tensor_id = 1; + t_info.start_chunk_id = chunk + 1; + } + } + } + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu b/paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..176b453596e3f2677bcd9dc23bf0bdf12cd243f2 --- /dev/null +++ b/paddle/phi/kernels/gpu/multi_tensor_adam_kernel.cu @@ -0,0 +1,501 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" +#include +#include "glog/logging.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/multi_tensor_apply.h" + +namespace phi { + +// This code is referenced from apex's multi_tensor_adam.cu. +// https://github.com/NVIDIA/apex + +template +struct MultiTensorAdamBetaPowInfo { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + MultiTensorAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { + beta1pow_ = *beta1pow; + beta2pow_ = *beta2pow; + } + + DEVICE MPDType GetBeta1PowValue() const { return beta1pow_; } + + DEVICE MPDType GetBeta2PowValue() const { return beta2pow_; } + + private: + MPDType beta1pow_; + MPDType beta2pow_; +}; + +template +struct MultiTensorAdamBetaPowInfo { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + MultiTensorAdamBetaPowInfo(const MPDType* beta1pow, const MPDType* beta2pow) { + beta1pow_ = beta1pow; + beta2pow_ = beta2pow; + } + + DEVICE MPDType GetBeta1PowValue() const { return *beta1pow_; } + + DEVICE MPDType GetBeta2PowValue() const { return *beta2pow_; } + + private: + const MPDType* __restrict__ beta1pow_; + const MPDType* __restrict__ beta2pow_; +}; + +template +struct MultiTensorAdamFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, + const funcs::TensorAndBlockInfo& t_info, + MT beta1, + MT beta2, + MultiTensorAdamBetaPowInfo beta_pow, + MT epsilon, + const MT* learning_rate, + MT decay) const { + MT lr = *learning_rate; + MT beta1_pow = beta_pow.GetBeta1PowValue(); + MT beta2_pow = beta_pow.GetBeta2PowValue(); + T* __restrict__ p_ptr; + const T* __restrict__ g_ptr; + MT* __restrict__ mom1_ptr, * __restrict__ mom2_ptr; + MT* __restrict__ mp_ptr; + int n; + + { + int chunk_id, tensor_id; + t_info.GetChunkIdAndTensorId(&chunk_id, &tensor_id); + + n = t_info.sizes[tensor_id]; + int offset = chunk_id * chunk_size; + g_ptr = static_cast(t_info.grads[tensor_id]) + offset; + p_ptr = static_cast(t_info.tensor_addrs[0][tensor_id]) + offset; + mom1_ptr = static_cast(t_info.tensor_addrs[1][tensor_id]) + offset; + mom2_ptr = static_cast(t_info.tensor_addrs[2][tensor_id]) + offset; + mp_ptr = + IsMultiPrecision + ? static_cast(t_info.tensor_addrs[3][tensor_id]) + offset + : nullptr; + + n -= offset; + if (n > chunk_size) { + n = chunk_size; + } + } + + int stride = blockDim.x * VecSize; + int idx = threadIdx.x * VecSize; + + for (; idx < n; idx += stride) { + phi::AlignedVector g_vec; + phi::AlignedVector p_vec; + phi::AlignedVector mp_vec; + phi::AlignedVector mom1_vec; + phi::AlignedVector mom2_vec; + if (idx <= n - VecSize) { + if (IsMultiPrecision) { + phi::Load(mp_ptr + idx, &mp_vec); + } else { + phi::Load(p_ptr + idx, &p_vec); + } + phi::Load(g_ptr + idx, &g_vec); + phi::Load(mom1_ptr + idx, &mom1_vec); + phi::Load(mom2_ptr + idx, &mom2_vec); + } else { + int size = n - idx; + for (int j = 0; j < size; j++) { + if (IsMultiPrecision) { + mp_vec[j] = mp_ptr[idx + j]; + } else { + p_vec[j] = p_ptr[idx + j]; + } + g_vec[j] = g_ptr[idx + j]; + mom1_vec[j] = static_cast(mom1_ptr[idx + j]); + mom2_vec[j] = static_cast(mom2_ptr[idx + j]); + } +#pragma unroll + for (int j = size; j < VecSize; j++) { + g_vec[j] = T(0); + p_vec[j] = T(0); + mp_vec[j] = MT(0); + mom1_vec[j] = MT(0); + mom2_vec[j] = MT(0); + } + } + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + MT p = IsMultiPrecision ? mp_vec[j] : static_cast(p_vec[j]); + UpdateMoments(&mom1_vec[j], + &mom2_vec[j], + static_cast(g_vec[j]), + beta1, + beta2); + mp_vec[j] = UpdateParameter(p, + mom1_vec[j], + mom2_vec[j], + beta1_pow, + beta2_pow, + lr, + epsilon, + decay); + } + + if (idx <= n - VecSize) { + phi::Store(mom1_vec, mom1_ptr + idx); + phi::Store(mom2_vec, mom2_ptr + idx); + if (IsMultiPrecision) { + phi::Store(mp_vec, mp_ptr + idx); + } + for (int j = 0; j < VecSize; j++) { + p_ptr[idx + j] = static_cast(mp_vec[j]); + } + } else { + int size = n - idx; + for (int j = 0; j < size; j++) { + if (IsMultiPrecision) { + mp_ptr[idx + j] = mp_vec[j]; + } + p_ptr[idx + j] = static_cast(mp_vec[j]); + mom1_ptr[idx + j] = mom1_vec[j]; + mom2_ptr[idx + j] = mom2_vec[j]; + } + } + } + } + + private: + static __device__ __forceinline__ void UpdateMoments( + MT* __restrict__ mom1_ptr, + MT* __restrict__ mom2_ptr, + MT g, + MT beta1, + MT beta2) { + MT mom1 = static_cast(mom1_ptr[0]); + MT mom2 = static_cast(mom2_ptr[0]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + + mom1_ptr[0] = mom1; + mom2_ptr[0] = mom2; + } + + static __device__ __forceinline__ MT UpdateParameter(MT p, + MT mom1, + MT mom2, + MT beta1_pow, + MT beta2_pow, + MT lr, + MT epsilon, + MT decay) { + if (UseAdamW) { + p *= (static_cast(1.0) - lr * decay); + } + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); + return p; + } +}; + +template +__global__ void UpdateBetaPowGroup( + Array beta1_pow, Array beta2_pow, T beta1, T beta2, int n) { + auto idx = threadIdx.x; + if (idx < n) { + beta1_pow[idx][0] *= beta1; + beta2_pow[idx][0] *= beta2; + } +} + +template +static void CopyTensorIfDifferent(const Context& dev_ctx, + const std::vector& src, + const std::vector& dst, + bool use_src_place = false) { + for (size_t i = 0; i < src.size(); ++i) { + if (src[i] != dst[i]) { + VLOG(10) << "Copy Tensor " << i; + phi::Place place = (use_src_place ? src[i]->place() : dev_ctx.GetPlace()); + phi::Copy(dev_ctx, *(src[i]), place, false, dst[i]); + } + } +} + +template +static int GetVecSizeFromTensors(const std::vector& tensors, + int vec_size = 4) { + for (const auto* t : tensors) { + vec_size = min(vec_size, GetVectorizedSize(t->template data())); + } + return vec_size; +} + +template +void MultiTensorAdamKernel( + const Context& dev_ctx, + const std::vector& params, + const std::vector& grads, + const DenseTensor& learning_rate, + const std::vector& moments1, + const std::vector& moments2, + const std::vector& beta1_pows, + const std::vector& beta2_pows, + const paddle::optional>& master_params, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + int chunk_size, + float weight_decay, + bool use_adamw, + bool multi_precision, + bool use_global_beta_pow, + std::vector params_out, + std::vector moments1_out, + std::vector moments2_out, + std::vector beta1_pows_out, + std::vector beta2_pows_out, + std::vector master_params_out) { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + + auto n = params.size(); + auto beta1_pow_first = beta1_pows[0]; + auto beta2_pow_first = beta2_pows[0]; + + for (int i = 1; i < beta1_pows.size(); i++) { + PADDLE_ENFORCE_EQ(beta1_pow_first->place(), + beta1_pows[i]->place(), + phi::errors::InvalidArgument( + "All Beta1Pow must be in the same place.")); + PADDLE_ENFORCE_EQ(beta2_pow_first->place(), + beta2_pows[i]->place(), + phi::errors::InvalidArgument( + "All Beta2Pow must be in the same place.")); + } + + PADDLE_ENFORCE_EQ( + beta1_pow_first->place(), + beta2_pow_first->place(), + phi::errors::InvalidArgument( + "Input(Beta1Pows) and Input(Beta2Pows) must be in the same place.")); + + bool is_cpu_betapow = (beta1_pow_first->place() == CPUPlace()); + + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + CopyTensorIfDifferent(dev_ctx, params, params_out); + CopyTensorIfDifferent(dev_ctx, moments1, moments1_out); + CopyTensorIfDifferent(dev_ctx, moments2, moments2_out); + CopyTensorIfDifferent(dev_ctx, beta1_pows, beta1_pows_out, true); + CopyTensorIfDifferent(dev_ctx, beta2_pows, beta2_pows_out, true); + if (master_params) { + CopyTensorIfDifferent(dev_ctx, master_params.get(), master_params_out); + } + + bool skip_update_value = false; + if (skip_update.is_initialized()) { + PADDLE_ENFORCE_EQ( + skip_update->numel(), + 1, + errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", + skip_update->numel())); + DenseTensor skip_update_tensor; + phi::Copy( + dev_ctx, skip_update.get(), CPUPlace(), false, &skip_update_tensor); + skip_update_value = skip_update_tensor.data()[0]; + VLOG(4) << "skip_update_value:" << skip_update_value; + } + + // skip_update=true + if (skip_update_value) { + VLOG(4) << "Adam skip update"; + return; + } + + MPDType beta1_tmp = beta1.to(); + MPDType beta2_tmp = beta2.to(); + + std::vector> input_vector; + input_vector.reserve(4); + + input_vector.push_back(params_out); + input_vector.push_back(moments1_out); + input_vector.push_back(moments2_out); + if (multi_precision) { + input_vector.push_back(master_params_out); + } + + VLOG(4) << "use_adamw: " << use_adamw; + VLOG(4) << "multi_precision: " << multi_precision; + +#define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + __multi_precision, __is_cpu_betapow, __use_adamw, __vec_size) \ + do { \ + constexpr int kInputNum = __multi_precision ? 5 : 4; \ + constexpr int kMaxTensorSize = __multi_precision ? 48 : 60; \ + constexpr int kMaxBlockSize = __multi_precision ? 320 : 320; \ + constexpr int kBlockSize = 512; \ + MultiTensorAdamBetaPowInfo beta_pow_info( \ + beta1_pow_first->data(), beta2_pow_first->data()); \ + MultiTensorAdamFunctor \ + functor; \ + funcs::LaunchMultiTensorApplyKernel( \ + dev_ctx, \ + kBlockSize, \ + ((chunk_size + __vec_size - 1) / __vec_size) * __vec_size, \ + input_vector, \ + grads, \ + functor, \ + beta1_tmp, \ + beta2_tmp, \ + beta_pow_info, \ + epsilon.to(), \ + learning_rate.data(), \ + static_cast(weight_decay)); \ + } while (0) + +#define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(__vec_size) \ + case __vec_size: { \ + if (multi_precision) { \ + if (is_cpu_betapow) { \ + if (use_adamw) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, false, __vec_size); \ + } \ + } else { \ + if (use_adamw) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, false, __vec_size); \ + } \ + } \ + } else { \ + if (is_cpu_betapow) { \ + if (use_adamw) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, false, __vec_size); \ + } \ + } else { \ + if (use_adamw) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, false, __vec_size); \ + } \ + } \ + } \ + } break + + int vec_size = GetVecSizeFromTensors(params_out); + vec_size = GetVecSizeFromTensors(moments1_out, vec_size); + vec_size = GetVecSizeFromTensors(moments2_out, vec_size); + if (master_params) { + vec_size = GetVecSizeFromTensors(master_params_out, vec_size); + } + + switch (vec_size) { + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(4); + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(2); + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL(1); + default: + PADDLE_THROW( + errors::InvalidArgument("Unsupported vectorized size %d", vec_size)); + break; + } + + if (!use_global_beta_pow) { + if (is_cpu_betapow) { + for (size_t i = 0; i < n; i++) { + VLOG(10) << "CPU Update BetaPow here..."; + auto* beta1_ptr = + dev_ctx.template HostAlloc(beta1_pows_out[i]); + (*beta1_ptr) *= beta1_tmp; + + auto* beta2_ptr = + dev_ctx.template HostAlloc(beta2_pows_out[i]); + (*beta2_ptr) *= beta2_tmp; + } + } else { + constexpr size_t kGroupSize = 32; + auto group_num = (n + kGroupSize - 1) / kGroupSize; + VLOG(10) << "GPU Update BetaPow here..."; + for (size_t i = 0; i < group_num; ++i) { + size_t start = i * kGroupSize; + size_t end = std::min((i + 1) * kGroupSize, n); + Array beta1_ptrs, beta2_ptrs; + for (size_t j = start; j < end; ++j) { + size_t idx = j - start; + beta1_ptrs[idx] = dev_ctx.template Alloc(beta1_pows_out[j]); + beta2_ptrs[idx] = dev_ctx.template Alloc(beta2_pows_out[j]); + } + UpdateBetaPowGroup + <<<1, kGroupSize, 0, dev_ctx.stream()>>>( + beta1_ptrs, beta2_ptrs, beta1_tmp, beta2_tmp, end - start); + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multi_tensor_adam, + GPU, + ALL_LAYOUT, + phi::MultiTensorAdamKernel, + phi::dtype::float16, + float, + double) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/multi_tensor_adam_kernel.h b/paddle/phi/kernels/multi_tensor_adam_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5bc6399a7c10e6eeceb1a9ffe1a372f3e320955a --- /dev/null +++ b/paddle/phi/kernels/multi_tensor_adam_kernel.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MultiTensorAdamKernel( + const Context &dev_ctx, + const std::vector ¶ms, + const std::vector &grads, + const DenseTensor &learning_rate, + const std::vector &moments1, + const std::vector &moments2, + const std::vector &beta1_pows, + const std::vector &beta2_pows, + const paddle::optional> &master_params, + const paddle::optional &skip_update, + const Scalar &beta1, + const Scalar &beta2, + const Scalar &epsilon, + int chunk_size, + float weight_decay, + bool use_adamw, + bool multi_precision, + bool use_global_beta_pow, + std::vector params_out, + std::vector moments1_out, + std::vector moments2_out, + std::vector beta1_pows_out, + std::vector beta2_pows_out, + std::vector master_params_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/multi_tensor_adam_sig.cc b/paddle/phi/ops/compat/multi_tensor_adam_sig.cc new file mode 100755 index 0000000000000000000000000000000000000000..d7d901a437b3c4307b7ccef5f28f9872d5c118b2 --- /dev/null +++ b/paddle/phi/ops/compat/multi_tensor_adam_sig.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature MultiTensorAdamOpArgumentMapping( + const ArgumentMappingContext& ctx) { + paddle::small_vector in_names = {"Params", + "Grads", + "LearningRate", + "Moments1", + "Moments2", + "Beta1Pow", + "Beta2Pow", + "MasterParams", + "SkipUpdate"}; + paddle::small_vector out_names = {"ParamsOut", + "Moments1Out", + "Moments2Out", + "Beta1PowOut", + "Beta2PowOut", + "MasterParamsOut"}; + paddle::small_vector attr_names = {"beta1", + "beta2", + "epsilon", + "chunk_size", + "weight_decay", + "use_adamw", + "multi_precision", + "use_global_beta_pow"}; + + return KernelSignature("multi_tensor_adam", + std::move(in_names), + std::move(attr_names), + std::move(out_names)); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(multi_tensor_adam, + phi::MultiTensorAdamOpArgumentMapping); diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 1425ea0361cfec78de293f8bb2efe8a449419dda..646c2b36798df04a67c252d4b9def9574a2f0f26 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -76,6 +76,10 @@ if(WITH_GPU) test_auto_tune SRCS test_auto_tune.cu DEPS gtest) + cc_test( + test_multi_tensor_adam_kernel + SRCS test_multi_tensor_adam_kernel.cc + DEPS gtest phi phi_api_utils) elseif(WITH_ROCM) hip_test( test_gpu_timer diff --git a/paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc b/paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..68f4d8e5c14ce3c4d5ee6ebe80b945a830661485 --- /dev/null +++ b/paddle/phi/tests/kernels/test_multi_tensor_adam_kernel.cc @@ -0,0 +1,478 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/generator.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif + +#include "gtest/gtest.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/kernels/abs_kernel.h" +#include "paddle/phi/kernels/adam_kernel.h" +#include "paddle/phi/kernels/adamw_kernel.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/gaussian_kernel.h" +#include "paddle/phi/kernels/multi_tensor_adam_kernel.h" +#include "paddle/phi/kernels/reduce_max_kernel.h" + +namespace phi { + +template +auto GenerateRandomTensorVectors( + const Context &ctx, const std::vector> &shapes) { + size_t n = shapes.size(); + std::vector tensors(n); + for (size_t i = 0; i < n; ++i) { + GaussianKernel( + ctx, + shapes[i], + 0.0f, + 1.0f, + 0, + paddle::experimental::CppTypeToDataType::Type(), + &tensors[i]); + } + return tensors; +} + +template +auto GenerateConstantTensorVectors( + const Context &ctx, + const std::vector> &shapes, + T value) { + size_t n = shapes.size(); + std::vector tensors(n); + for (size_t i = 0; i < n; ++i) { + FullKernel(ctx, + shapes[i], + value, + paddle::experimental::CppTypeToDataType::Type(), + &tensors[i]); + } + return tensors; +} + +static auto ToConstTensorPtrVector(const std::vector &tensors) { + std::vector results; + for (const auto &t : tensors) { + results.push_back(&t); + } + return results; +} + +static auto ToMutableTensorPtrVector( + std::vector &tensors) { // NOLINT + std::vector results; + for (auto &t : tensors) { + results.push_back(&t); + } + return results; +} + +static auto ToMetaTensorVector(const std::vector &tensors) { + std::vector results; + for (auto &t : tensors) { + results.push_back(t); + } + return results; +} + +static auto ToConstMetaTensorPtrVector( + const std::vector &meta_tensors) { + std::vector results; + for (auto &t : meta_tensors) { + results.push_back(&t); + } + return results; +} + +static auto ToMutableMetaTensorPtrVector( + std::vector &meta_tensors) { // NOLINT + std::vector results; + for (auto &t : meta_tensors) { + results.push_back(&t); + } + return results; +} + +template +struct AdamInfo { + const Context *ctx; + std::vector> shapes; + + std::vector params; + std::vector master_params; + std::vector moment1s; + std::vector moment2s; + std::vector beta1_pows; + std::vector beta2_pows; + DenseTensor learning_rate; + float beta1; + float beta2; + float weight_decay; + float epsilon = 1e-6; + bool multi_precision; + bool use_adamw; + int chunk_size = 4096; + + using MT = typename phi::dtype::MPTypeTrait::Type; + + AdamInfo(const Context &ctx_ref, + const std::vector> &shapes, + float beta1, + float beta2, + float weight_decay, + bool multi_precision, + bool use_adamw) + : ctx(&ctx_ref), + shapes(shapes), + beta1(beta1), + beta2(beta2), + weight_decay(weight_decay), + multi_precision(multi_precision), + use_adamw(use_adamw) { + std::vector> one_shapes(shapes.size(), + std::vector(1, 1)); + std::vector> learning_rate_shapes( + one_shapes.begin(), one_shapes.begin() + 1); + + params = GenerateRandomTensorVectors(*ctx, shapes); + learning_rate = GenerateConstantTensorVectors( + *ctx, learning_rate_shapes, 1e-3)[0]; + moment1s = GenerateConstantTensorVectors(*ctx, shapes, 0); + moment2s = GenerateConstantTensorVectors(*ctx, shapes, 0); + + if (multi_precision) { + master_params.resize(shapes.size()); + for (size_t i = 0; i < shapes.size(); ++i) { + master_params[i] = Cast( + *ctx, + params[i], + paddle::experimental::CppTypeToDataType::Type()); + } + } + + beta1_pows = + GenerateConstantTensorVectors(*ctx, one_shapes, beta1); + beta2_pows = + GenerateConstantTensorVectors(*ctx, one_shapes, beta2); + } + + void Update(bool use_multi_tensor, const std::vector &grads) { + if (use_multi_tensor) { + UpdateWithMultiTensorAdam(grads); + } else { + for (size_t j = 0; j < params.size(); ++j) { + if (use_adamw) { + UpdateWithAdamWBaseline(grads, j); + } else { + UpdateWithAdamBaseline(grads, j); + } + } + } + } + + static AdamInfo DeepCopy(const AdamInfo &other) { + AdamInfo copied(*other.ctx, + other.shapes, + other.beta1, + other.beta2, + other.weight_decay, + other.multi_precision, + other.use_adamw); + auto copy_tensor = [&other](const DenseTensor &x, DenseTensor *y) { + Copy(*other.ctx, x, x.place(), false, y); + }; + + auto copy_tensors = [&other](const std::vector &xs, + std::vector *ys) { + for (size_t i = 0; i < xs.size(); ++i) { + Copy(*other.ctx, xs[i], xs[i].place(), false, &((*ys)[i])); + } + }; + + copy_tensors(other.params, &copied.params); + copy_tensors(other.master_params, &copied.master_params); + copy_tensors(other.moment1s, &copied.moment1s); + copy_tensors(other.moment2s, &copied.moment2s); + copy_tensors(other.beta1_pows, &copied.beta1_pows); + copy_tensors(other.beta2_pows, &copied.beta2_pows); + copy_tensor(other.learning_rate, &copied.learning_rate); + copied.epsilon = other.epsilon; + copied.chunk_size = other.chunk_size; + other.ctx->Wait(); + return copied; + } + + private: + void UpdateWithMultiTensorAdam(const std::vector &grads) { + auto param_metas = ToMetaTensorVector(params); + auto grad_metas = ToMetaTensorVector(grads); + auto master_param_metas = ToMetaTensorVector(master_params); + auto moment1_metas = ToMetaTensorVector(moment1s); + auto moment2_metas = ToMetaTensorVector(moment2s); + auto beta1_pow_metas = ToMetaTensorVector(beta1_pows); + auto beta2_pow_metas = ToMetaTensorVector(beta2_pows); + + MultiTensorAdamInferMeta( + ToConstMetaTensorPtrVector(param_metas), + ToConstMetaTensorPtrVector(grad_metas), + learning_rate, + ToConstMetaTensorPtrVector(moment1_metas), + ToConstMetaTensorPtrVector(moment2_metas), + ToConstMetaTensorPtrVector(beta1_pow_metas), + ToConstMetaTensorPtrVector(beta2_pow_metas), + multi_precision ? paddle::make_optional( + ToConstMetaTensorPtrVector(master_param_metas)) + : paddle::none, + MetaTensor(), + beta1, + beta2, + epsilon, + chunk_size, + weight_decay, + use_adamw, + multi_precision, + false, + ToMutableMetaTensorPtrVector(param_metas), + ToMutableMetaTensorPtrVector(moment1_metas), + ToMutableMetaTensorPtrVector(moment2_metas), + ToMutableMetaTensorPtrVector(beta1_pow_metas), + ToMutableMetaTensorPtrVector(beta2_pow_metas), + ToMutableMetaTensorPtrVector(master_param_metas)); + + MultiTensorAdamKernel( + *ctx, + ToConstTensorPtrVector(params), + ToConstTensorPtrVector(grads), + learning_rate, + ToConstTensorPtrVector(moment1s), + ToConstTensorPtrVector(moment2s), + ToConstTensorPtrVector(beta1_pows), + ToConstTensorPtrVector(beta2_pows), + multi_precision + ? paddle::make_optional(ToConstTensorPtrVector(master_params)) + : paddle::none, + paddle::none, + beta1, + beta2, + epsilon, + chunk_size, + weight_decay, + use_adamw, + multi_precision, + false, + ToMutableTensorPtrVector(params), + ToMutableTensorPtrVector(moment1s), + ToMutableTensorPtrVector(moment2s), + ToMutableTensorPtrVector(beta1_pows), + ToMutableTensorPtrVector(beta2_pows), + ToMutableTensorPtrVector(master_params)); + } + + void UpdateWithAdamWBaseline(const std::vector &grads, + size_t idx) { + AdamwDenseKernel( + *ctx, + params[idx], + grads[idx], + learning_rate, + moment1s[idx], + moment2s[idx], + beta1_pows[idx], + beta2_pows[idx], + multi_precision ? paddle::make_optional(master_params[idx]) + : paddle::none, + paddle::none, + beta1, + beta2, + epsilon, + 1.0, + weight_decay, + true, + false, + 1000, + multi_precision, + false, + ¶ms[idx], + &moment1s[idx], + &moment2s[idx], + &beta1_pows[idx], + &beta2_pows[idx], + multi_precision ? &master_params[idx] : nullptr); + } + + void UpdateWithAdamBaseline(const std::vector &grads, + size_t idx) { + AdamDenseKernel( + *ctx, + params[idx], + grads[idx], + learning_rate, + moment1s[idx], + moment2s[idx], + beta1_pows[idx], + beta2_pows[idx], + multi_precision ? paddle::make_optional(master_params[idx]) + : paddle::none, + paddle::none, + beta1, + beta2, + epsilon, + false, + 1000, + multi_precision, + false, + ¶ms[idx], + &moment1s[idx], + &moment2s[idx], + &beta1_pows[idx], + &beta2_pows[idx], + multi_precision ? &master_params[idx] : nullptr); + } +}; + +template +auto MaxDiff(const Context &ctx, + const DenseTensor &x_t, + const DenseTensor &y_t) { + using MT = typename AdamInfo::MT; + auto mp_dtype = paddle::experimental::CppTypeToDataType::Type(); + auto x = Cast(ctx, x_t, mp_dtype); + auto y = Cast(ctx, y_t, mp_dtype); + + EXPECT_EQ(x.dims(), y.dims()); + DenseTensor diff, diff_reduced, diff_reduced_cpu; + + diff.Resize(x.dims()); + ctx.template Alloc(&diff); + SubtractKernel(ctx, x, y, &diff); + AbsKernel(ctx, diff, &diff); + + diff_reduced.Resize({1}); + ctx.template Alloc(&diff_reduced); + MaxRawKernel( + ctx, diff, vectorize(x.dims()), false, true, &diff_reduced); + + diff_reduced_cpu.Resize(diff_reduced.dims()); + ctx.template HostAlloc(&diff_reduced_cpu); + Copy(ctx, diff_reduced, CPUPlace(), true, &diff_reduced_cpu); + EXPECT_EQ(diff_reduced_cpu.place(), CPUPlace()); + return diff_reduced_cpu.data()[0]; +} + +template +auto MaxDiff(const Context &ctx, + const std::vector &xs, + const std::vector &ys) { + using MT = typename AdamInfo::MT; + MT diff = 0; + for (size_t i = 0; i < xs.size(); ++i) { + diff = std::max(diff, MaxDiff(ctx, xs[i], ys[i])); + } + return diff; +} + +template +void TestMultiTensorAdamBase(const std::vector> &shapes, + float atol, + bool use_adamw, + bool multi_precision = false, + float beta1 = 0.9, + float beta2 = 0.99, + float weight_decay = 0.1, + size_t steps = 5, + uint64_t seed = 10) { + const auto &ctx = + *paddle::platform::DeviceContextPool::Instance().GetByPlace(PlaceType()); + using Context = typename std::remove_const< + typename std::remove_pointer::type>::type; + ctx.GetGenerator()->SetCurrentSeed(seed); + AdamInfo info1( + ctx, shapes, beta1, beta2, weight_decay, multi_precision, use_adamw); + auto info2 = AdamInfo::DeepCopy(info1); + + for (size_t i = 0; i < steps; ++i) { + auto grads = GenerateRandomTensorVectors(ctx, shapes); + info1.Update(false, grads); + info2.Update(true, grads); + } + + using MT = typename decltype(info1)::MT; + +#define PD_ADAM_TEST_COMP(__field, __dtype) \ + do { \ + MT __diff = MaxDiff<__dtype>(ctx, info1.__field, info2.__field); \ + EXPECT_LE(__diff, static_cast(atol)) \ + << #__field << " has diff when use_adamw = " << use_adamw \ + << " , multi_precision = " << multi_precision; \ + } while (0) + + PD_ADAM_TEST_COMP(beta1_pows, MT); + PD_ADAM_TEST_COMP(beta2_pows, MT); + PD_ADAM_TEST_COMP(params, T); + PD_ADAM_TEST_COMP(master_params, MT); + PD_ADAM_TEST_COMP(moment1s, MT); + PD_ADAM_TEST_COMP(moment2s, MT); +} + +static auto GenerateRandomShapes(size_t n, uint64_t low, uint64_t high) { + std::random_device device; + std::default_random_engine engine(device()); + std::uniform_int_distribution dist(low, high); + std::vector> shapes(n); + for (size_t i = 0; i < n; ++i) { + shapes[i].push_back(dist(engine)); + } + return shapes; +} + +TEST(multi_tensor_adam, test_fp32_cpu) { + auto shapes = GenerateRandomShapes(30, 10, 20); + float atol = 0.0f; + for (auto use_adamw : {false, true}) { + TestMultiTensorAdamBase(shapes, atol, use_adamw); + } +} + +#ifdef PADDLE_WITH_CUDA +TEST(multi_tensor_adam, test_fp32_gpu) { + auto shapes = GenerateRandomShapes(40, 0, 2 << 18); + float atol = 0.0f; + for (auto use_adamw : {false, true}) { + TestMultiTensorAdamBase(shapes, atol, use_adamw); + } +} + +TEST(multi_tensor_adam, test_fp16_gpu) { + auto shapes = GenerateRandomShapes(40, 0, 2 << 18); + float atol = 5e-3f; + for (auto use_adamw : {false, true}) { + TestMultiTensorAdamBase( + shapes, atol, use_adamw, true); + } +} +#endif + +} // namespace phi