From 246ac9764de419985be99d00dc89eaeff1aca322 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 14 Jul 2022 20:44:04 +0800 Subject: [PATCH] [operator migration] Migrate infer shape for merged momentum (#44338) --- .../final_state_generator/codegen_utils.py | 1 + .../operators/optimizers/merged_momentum_op.cc | 12 +++++++++--- paddle/phi/api/lib/data_transform.cc | 10 ++++++++++ paddle/phi/api/lib/data_transform.h | 5 +++++ paddle/phi/api/yaml/generator/api_base.py | 2 ++ paddle/phi/infermeta/multiary.cc | 16 ++++++++++++++++ paddle/phi/infermeta/multiary.h | 16 ++++++++++++++++ 7 files changed, 59 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index cd5805740b..79f5da4bec 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -45,6 +45,7 @@ yaml_types_mapping = { 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'str' : 'std::string', \ + 'str[]' : 'std::vector', 'float[]' : 'std::vector', \ 'Place' : 'paddle::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', 'Tensor' : 'Tensor', diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index 220c0be9dd..85b2f818fe 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -22,8 +25,6 @@ class MergedMomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto param_dtype = @@ -100,6 +101,11 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(merged_momentum, + MergedMomentumInferShapeFunctor, + PD_INFER_META(phi::MergedMomentumInferMeta)); + REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp, - ops::MergedMomentumOpMaker); + ops::MergedMomentumOpMaker, + MergedMomentumInferShapeFunctor); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 4dafc7a7ee..58795c0f06 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -284,5 +284,15 @@ std::unique_ptr> PrepareData( return pt_tensors; } +paddle::optional> PrepareData( + const paddle::optional>& inputs, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag) { + if (inputs) { + return {*PrepareData(*inputs, target_args_def, transform_flag)}; + } + return paddle::none; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 4d70078ef3..3feba2465f 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -76,5 +76,10 @@ std::unique_ptr> PrepareData( const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag); +paddle::optional> PrepareData( + const paddle::optional>& inputs, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index aacb4ce55b..2659d80615 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -131,9 +131,11 @@ class BaseAPI(object): 'long': 'long', 'size_t': 'size_t', 'float': 'float', + 'float[]': 'const std::vector&', 'double': 'double', 'bool': 'bool', 'str': 'const std::string&', + 'str[] ': 'const std::vector&', 'Place': 'const Place&', 'DataLayout': 'DataLayout', 'DataType': 'DataType', diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 575e60923c..3369b0c392 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1549,6 +1549,22 @@ void MergedAdamInferMeta( std::vector beta2_pow_out, std::vector master_param_out) {} +void MergedMomentumInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& velocity, + const std::vector& learning_rate, + const paddle::optional>& master_param, + float mu, + bool use_nesterov, + const std::vector& regularization_method, + const std::vector& regularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) {} + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c0972816f3..0ec71e8689 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -255,6 +255,22 @@ void MergedAdamInferMeta( std::vector beta2_pow_out, std::vector master_param_out); +void MergedMomentumInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& velocity, + const std::vector& learning_rate, + const paddle::optional>& master_param, + float mu, + bool use_nesterov, + const std::vector& regularization_method, + const std::vector& regularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out); + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs); -- GitLab