diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index cd5805740bef04bb9af2db712329922b7e48a206..79f5da4bec79e8f67aed80646540d66721705880 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -45,6 +45,7 @@ yaml_types_mapping = { 'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'str' : 'std::string', \ + 'str[]' : 'std::vector', 'float[]' : 'std::vector', \ 'Place' : 'paddle::Place', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'int64_t[]' : 'std::vector', 'int[]' : 'std::vector', 'Tensor' : 'Tensor', diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.cc b/paddle/fluid/operators/optimizers/merged_momentum_op.cc index 220c0be9ddf0fe5de63dbde447c7bc27600e298f..85b2f818fe137ec05159a88af85b37d67a40f4d3 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -22,8 +25,6 @@ class MergedMomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override {} - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto param_dtype = @@ -100,6 +101,11 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace plat = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(merged_momentum, + MergedMomentumInferShapeFunctor, + PD_INFER_META(phi::MergedMomentumInferMeta)); + REGISTER_OP_WITHOUT_GRADIENT(merged_momentum, ops::MergedMomentumOp, - ops::MergedMomentumOpMaker); + ops::MergedMomentumOpMaker, + MergedMomentumInferShapeFunctor); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 4dafc7a7ee57977bd49b22d666776dd06b1b4f8a..58795c0f06381dbe02c589a1edabed2997d8e570 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -284,5 +284,15 @@ std::unique_ptr> PrepareData( return pt_tensors; } +paddle::optional> PrepareData( + const paddle::optional>& inputs, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag) { + if (inputs) { + return {*PrepareData(*inputs, target_args_def, transform_flag)}; + } + return paddle::none; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 4d70078ef3444b07c796ad1129006381b9b97e5b..3feba2465f61bfbc38edd06ed1a84735cd0817fa 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -76,5 +76,10 @@ std::unique_ptr> PrepareData( const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag); +paddle::optional> PrepareData( + const paddle::optional>& inputs, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index aacb4ce55befa8c0af569f700cbf8ddc29128727..2659d80615f2dc6f44327738ffc750f7e3fdb799 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -131,9 +131,11 @@ class BaseAPI(object): 'long': 'long', 'size_t': 'size_t', 'float': 'float', + 'float[]': 'const std::vector&', 'double': 'double', 'bool': 'bool', 'str': 'const std::string&', + 'str[] ': 'const std::vector&', 'Place': 'const Place&', 'DataLayout': 'DataLayout', 'DataType': 'DataType', diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 575e60923cd2146710982412b1e0f53f08cd919e..3369b0c392ec33f24894427ccc320b4ba6473bfb 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1549,6 +1549,22 @@ void MergedAdamInferMeta( std::vector beta2_pow_out, std::vector master_param_out) {} +void MergedMomentumInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& velocity, + const std::vector& learning_rate, + const paddle::optional>& master_param, + float mu, + bool use_nesterov, + const std::vector& regularization_method, + const std::vector& regularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) {} + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs) { const size_t inputs_num = inputs.size(); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c0972816f3ba2816f312d9623ed937257ea60efb..0ec71e86893c3ce072db592f9de4ea4526cd6560 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -255,6 +255,22 @@ void MergedAdamInferMeta( std::vector beta2_pow_out, std::vector master_param_out); +void MergedMomentumInferMeta( + const std::vector& param, + const std::vector& grad, + const std::vector& velocity, + const std::vector& learning_rate, + const paddle::optional>& master_param, + float mu, + bool use_nesterov, + const std::vector& regularization_method, + const std::vector& regularization_coeff, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out); + void MeshgridInferMeta(const std::vector& inputs, std::vector outputs);