未验证 提交 d15b490a 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate merged momentum cpu/gpu kernels (#44300)

上级 84b72c5f
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -103,7 +103,3 @@ namespace plat = paddle::platform; ...@@ -103,7 +103,3 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, REGISTER_OP_WITHOUT_GRADIENT(merged_momentum,
ops::MergedMomentumOp, ops::MergedMomentumOp,
ops::MergedMomentumOpMaker); ops::MergedMomentumOpMaker);
REGISTER_OP_CPU_KERNEL(merged_momentum,
ops::MergedMomentumOpKernel<phi::CPUContext, float>,
ops::MergedMomentumOpKernel<phi::CPUContext, double>);
...@@ -12,8 +12,14 @@ ...@@ -12,8 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/phi/core/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -29,9 +29,3 @@ limitations under the License. */ ...@@ -29,9 +29,3 @@ limitations under the License. */
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
#endif // __FLT_MAX__ #endif // __FLT_MAX__
#endif // PADDLE_WITH_MUSL #endif // PADDLE_WITH_MUSL
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
...@@ -53,4 +53,10 @@ namespace phi { ...@@ -53,4 +53,10 @@ namespace phi {
#define PD_CONCATENATE2(arg1, arg2) arg1##arg2 #define PD_CONCATENATE2(arg1, arg2) arg1##arg2
#define PD_EXPAND(x) x #define PD_EXPAND(x) x
#if defined(__NVCC__) || defined(__HIPCC__)
#define PADDLE_RESTRICT __restrict__
#else
#define PADDLE_RESTRICT
#endif
} // namespace phi } // namespace phi
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
namespace ops = paddle::operators; PD_REGISTER_KERNEL(merged_momentum,
namespace plat = paddle::platform; CPU,
ALL_LAYOUT,
REGISTER_OP_CUDA_KERNEL( phi::MergedMomentumKernel,
merged_momentum, float,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, plat::float16>, double) {}
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, float>,
ops::MergedMomentumOpKernel<plat::CUDADeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/merged_momentum_impl.h"
PD_REGISTER_KERNEL(merged_momentum,
GPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MergedMomentumKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& param,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& velocity,
const std::vector<const DenseTensor*>& learning_rate,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
float mu,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor*> param_out,
std::vector<DenseTensor*> velocity_out,
std::vector<DenseTensor*> master_param_out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MergedMomentumOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"merged_momentum",
{"Param", "Grad", "Velocity", "LearningRate", "MasterParam"},
{"mu",
"use_nesterov",
"regularization_method",
"regularization_coeff",
"multi_precision",
"rescale_grad"},
{
"ParamOut",
"VelocityOut",
"MasterParamOut",
});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(merged_momentum,
phi::MergedMomentumOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册