From 868a3203eba4745d43be8dec1adad32994cb80c4 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sun, 3 Apr 2022 14:54:15 +0800 Subject: [PATCH] Add infer meta (#41054) * add some infer meta * fix bug * fix bugs; * fix bug and add set data type * revert infer shape of lookup table * recover test --- paddle/fluid/operators/meshgrid_op.cc | 33 ++--- .../fluid/operators/optimizers/adagrad_op.cc | 42 ++---- .../fluid/operators/optimizers/rmsprop_op.cc | 88 ++----------- paddle/fluid/operators/optimizers/sgd_op.cc | 48 +------ paddle/fluid/operators/temporal_shift_op.cc | 52 ++------ paddle/phi/infermeta/binary.cc | 26 ++++ paddle/phi/infermeta/binary.h | 5 + paddle/phi/infermeta/multiary.cc | 124 ++++++++++++++++++ paddle/phi/infermeta/multiary.h | 34 +++++ paddle/phi/infermeta/unary.cc | 46 +++++++ paddle/phi/infermeta/unary.h | 7 + 11 files changed, 281 insertions(+), 224 deletions(-) diff --git a/paddle/fluid/operators/meshgrid_op.cc b/paddle/fluid/operators/meshgrid_op.cc index 103169fedb..5a6862f380 100644 --- a/paddle/fluid/operators/meshgrid_op.cc +++ b/paddle/fluid/operators/meshgrid_op.cc @@ -19,6 +19,10 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -28,30 +32,6 @@ class MeshgridOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_GE( - ctx->Inputs("X").size(), 1UL, - platform::errors::InvalidArgument("Input(X) should not be empty.")); - PADDLE_ENFORCE_GE( - ctx->Outputs("Out").size(), 1UL, - platform::errors::InvalidArgument("Output(Out) should not be empty.")); - - auto inputs_dims = ctx->GetInputsDim("X"); - const size_t inputs_num = inputs_dims.size(); - auto outs_names = ctx->Outputs("Out"); - const size_t outputs_num = outs_names.size(); - - auto out_shape = std::vector(inputs_num); - - for (size_t i = 0; i < inputs_num; i++) { - out_shape[i] = inputs_dims[i][0]; - } - auto out_dims = phi::make_ddim(std::vector(out_shape)); - std::vector outs_dims(outputs_num, out_dims); - ctx->SetOutputsDim("Out", outs_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -142,7 +122,10 @@ class MeshgridGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(meshgrid, MeshgridInferShapeFunctor, + PD_INFER_META(phi::MeshgridInferMeta)); REGISTER_OPERATOR(meshgrid, ops::MeshgridOp, ops::MeshgridOpMaker, ops::MeshgridGradOpMaker, - ops::MeshgridGradOpMaker); + ops::MeshgridGradOpMaker, + MeshgridInferShapeFunctor); REGISTER_OPERATOR(meshgrid_grad, ops::MeshgridGradOp); diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 33c4cf94cf..91bad14306 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -27,39 +31,6 @@ class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adagrad"); - OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate", - "Adagrad"); - OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adagrad"); - OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut", - "Adagrad"); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), 0, - platform::errors::InvalidArgument( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, - platform::errors::InvalidArgument( - "LearningRate should have one element")); - auto param_dims = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument("Param and Grad input of AdagradOp " - "should have the same dimension.")); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Moment"), - platform::errors::InvalidArgument("Param and Moment input of AdagradOp " - "should have the same dimension.")); - - ctx->SetOutputDim("ParamOut", param_dims); - ctx->SetOutputDim("MomentOut", param_dims); - } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( @@ -105,4 +76,7 @@ for numerical stability to avoid the division by zero error. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker); +DECLARE_INFER_SHAPE_FUNCTOR(adagrad, AdagradInferShapeFunctor, + PD_INFER_META(phi::AdagradInferMeta)); +REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker, + AdagradInferShapeFunctor); diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cc b/paddle/fluid/operators/optimizers/rmsprop_op.cc index cd6fdcf34e..b345872448 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cc +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cc @@ -14,91 +14,16 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { class RmspropOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, - platform::errors::NotFound( - "Input(Param) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("MeanSquare"), true, - platform::errors::NotFound( - "Input(MeanSquare) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("LearningRate"), true, - platform::errors::NotFound( - "Input(LearningRate) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true, - platform::errors::NotFound( - "Input(Grad) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Moment"), true, - platform::errors::NotFound( - "Input(Moment) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type in RmspropOp should be " - "LoDTensor, but the received is %s", - ctx->GetInputsVarType("Param").front())); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("ParamOut"), true, - platform::errors::NotFound( - "Output(param_out) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MomentOut"), true, - platform::errors::NotFound( - "Output(MomentOut) of RmspropOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MeanSquareOut"), true, - platform::errors::NotFound( - "Output(MeanSquareOut) of RmspropOp should not be null.")); - if (ctx->Attrs().Get("centered")) { - PADDLE_ENFORCE_EQ( - ctx->HasOutput("MeanGradOut"), true, - platform::errors::NotFound( - "Output(MeanGradOut) of RmspropOp should not be null.")); - } - - auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and grad input of RmspropOp should have the same dimension. " - "But received Param's dim [%s] and Grad's dim [%s].", - param_dim, ctx->GetInputDim("Grad"))); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), - platform::errors::InvalidArgument( - "Param and Momentum input of RmspropOp " - "should have the same dimension. But received " - "Param's dim [%s] and Moment [%s]", - param_dim, ctx->GetInputDim("Moment"))); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), - platform::errors::InvalidArgument( - "Param and Momentum input of RmspropOp " - "should have the same dimension. But received " - "Param's dim [%s] and MeanSquare [%s]", - param_dim, ctx->GetInputDim("MeanSquare"))); - - auto lr_dim = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_EQ(phi::product(lr_dim), 1, - platform::errors::InvalidArgument( - "Learning Rate of RmspropOp should be a scalar. But " - "received LearningRate's dim [%s]", - phi::product(lr_dim))); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("MomentOut", param_dim); - ctx->SetOutputDim("MeanSquareOut", param_dim); - if (ctx->Attrs().Get("centered")) { - ctx->SetOutputDim("MeanGradOut", param_dim); - } - } }; class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { @@ -169,4 +94,7 @@ http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); +DECLARE_INFER_SHAPE_FUNCTOR(rmsprop, RmspropInferShapeFunctor, + PD_INFER_META(phi::RmspropInferMeta)); +REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker, + RmspropInferShapeFunctor); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 0e3f895d27..f51d776d71 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { @@ -26,46 +30,6 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, - platform::errors::NotFound( - "Input(Param) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Grad"), true, - platform::errors::NotFound("Input(Grad) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true, - platform::errors::NotFound( - "Input(LearningRate) of SGDOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true, - platform::errors::NotFound( - "Output(ParamOut) of SGDOp should not be null.")); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), 0, - platform::errors::NotFound( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), 1, - platform::errors::InvalidArgument( - "Learning rate should have 1 element. But received " - "LearningRate dims [%s]", - phi::product(lr_dims))); - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "SGD Operator's input Param and Grad dimensions do not match. " - "The Param %s shape is [%s], but the Grad %s shape is [%s].", - ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0], - ctx->GetInputDim("Grad"))); - } - ctx->SetOutputDim("ParamOut", param_dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -161,8 +125,10 @@ $$param\_out = param - learning\_rate * grad$$ } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(sgd, SGDInferShapeFunctor, + PD_INFER_META(phi::SGDInferMeta)); REGISTER_OPERATOR( sgd, ops::SGDOp, ops::SGDOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::SGDOpInferVarType); + ops::SGDOpInferVarType, SGDInferShapeFunctor); diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index acf99d09ff..3bdb9cb972 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -15,6 +15,10 @@ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { @@ -24,49 +28,6 @@ class TemporalShiftOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); - - auto dim_x = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(dim_x.size(), 4, - platform::errors::InvalidArgument( - "Input(X) rank should be 4 in shape of [N*T, C, H, " - "W], but received X rank(%d)", - dim_x.size())); - - int seg_num = ctx->Attrs().Get("seg_num"); - float shift_ratio = ctx->Attrs().Get("shift_ratio"); - PADDLE_ENFORCE_GT( - seg_num, 0, - platform::errors::InvalidArgument( - "Attr(seg_num) should be greater than 0, but received %d", - seg_num)); - PADDLE_ENFORCE_GT( - shift_ratio, 0., - platform::errors::InvalidArgument( - "Attr(shift_ratio) should be greater than 0, but received %d", - shift_ratio)); - PADDLE_ENFORCE_LT( - shift_ratio, 0.5, - platform::errors::InvalidArgument( - "Attr(shift_ratio) should be less than 0.5, but received %d", - shift_ratio)); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, - platform::errors::InvalidArgument( - "Input(X) dimension[0] should be divided exactly " - "by Attr(seg_num), but received X dimension[0](%d) " - "mod seg_num(%d) != 0", - dim_x[0], seg_num)); - } - - ctx->SetOutputDim("Out", dim_x); - ctx->ShareLoD("X", "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -186,10 +147,13 @@ class TemporalShiftGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(temporal_shift, TemporalShiftInferShapeFunctor, + PD_INFER_META(phi::TemporalShiftInferMeta)); REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker, ops::TemporalShiftGradOpMaker, - ops::TemporalShiftGradOpMaker); + ops::TemporalShiftGradOpMaker, + TemporalShiftInferShapeFunctor); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel, ops::TemporalShiftKernel); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 44ae53a00d..ab13df081a 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -75,6 +75,32 @@ void AllValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void EmbeddingInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t padding_idx, + MetaTensor* out) { + auto table_dims = weight.dims(); + auto ids_dims = input.dims(); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ( + table_dims.size(), + 2, + phi::errors::InvalidArgument( + "ShapeError: The dimensions of the 'lookup table' must be 2. " + "But received lookup table's dimensions = %d, " + "lookup table's shape = [%s].", + table_dims.size(), + table_dims)); + + auto output_dims = phi::vectorize(ids_dims); + output_dims.push_back(table_dims[1]); + + out->set_dims(phi::make_ddim(output_dims)); + out->set_dtype(weight.dtype()); + out->share_lod(input); +} + void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 751422a4de..3fcbf69c35 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -37,6 +37,11 @@ void AllValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void EmbeddingInferMeta(const MetaTensor& input, + const MetaTensor& weight, + int64_t padding_idx, + MetaTensor* out); + void KLDivInferMeta(const MetaTensor& x, const MetaTensor& label, const std::string& reduction, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8e4f0b1fbb..4fbd264f10 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -66,6 +66,32 @@ void AdadeltaInferMeta(const MetaTensor& param, avg_squared_update_out->set_dtype(avg_squared_update.dtype()); } +void AdagradInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ( + phi::product(lr_dims), + 1, + phi::errors::InvalidArgument("LearningRate should have one element")); + auto param_dims = param.dims(); + + PADDLE_ENFORCE_EQ( + param_dims, + moment.dims(), + phi::errors::InvalidArgument("Param and Moment input of AdagradOp " + "should have the same dimension.")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + moment_out->set_dims(param_dims); + moment_out->set_dtype(moment.dtype()); +} + void AdamInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -1390,6 +1416,22 @@ void InterpolateInferMeta( } } +void MeshgridInferMeta(const std::vector& inputs, + std::vector outputs) { + const size_t inputs_num = inputs.size(); + + auto out_shape = std::vector(inputs_num); + + for (size_t i = 0; i < inputs.size(); i++) { + out_shape[i] = inputs[i]->dims()[0]; + } + auto out_dims = phi::make_ddim(std::vector(out_shape)); + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i]->set_dims(out_dims); + outputs[i]->set_dtype(inputs[0]->dtype()); + } +} + void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { auto inputs_dims = GetMetaTensorsDim(x); @@ -1582,6 +1624,65 @@ void PsroiPoolInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RmspropInferMeta(const MetaTensor& param, + const MetaTensor& mean_square, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* mean_square_out, + MetaTensor* mean_grad_out) { + if (centered) { + PADDLE_ENFORCE_NOT_NULL( + mean_grad_out, + phi::errors::InvalidArgument( + "Output(MeanGradOut) of RmspropOp should not be null.")); + } + + auto param_dim = param.dims(); + PADDLE_ENFORCE_EQ(param_dim, + moment.dims(), + phi::errors::InvalidArgument( + "Param and Momentum input of RmspropOp " + "should have the same dimension. But received " + "Param's dim [%s] and Moment [%s]", + param_dim, + moment.dims())); + PADDLE_ENFORCE_EQ(param_dim, + mean_square.dims(), + phi::errors::InvalidArgument( + "Param and Momentum input of RmspropOp " + "should have the same dimension. But received " + "Param's dim [%s] and MeanSquare [%s]", + param_dim, + mean_square.dims())); + + auto lr_dim = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dim), + 1, + phi::errors::InvalidArgument( + "Learning Rate of RmspropOp should be a scalar. But " + "received LearningRate's dim [%s]", + phi::product(lr_dim))); + + param_out->set_dims(param_dim); + param_out->set_dtype(param.dtype()); + moment_out->set_dims(param_dim); + moment_out->set_dtype(moment.dtype()); + mean_square_out->set_dims(param_dim); + mean_square_out->set_dtype(mean_square.dtype()); + if (centered) { + mean_grad_out->set_dims(param_dim); + mean_grad_out->set_dtype(mean_grad.get_ptr()->dtype()); + } +} + void RnnInferMeta(const MetaTensor& x, const std::vector& pre_state, const std::vector& weight_list, @@ -1667,6 +1768,29 @@ void RnnInferMeta(const MetaTensor& x, } } +void SGDInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& grad, + paddle::optional master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* master_param_out) { + PADDLE_ENFORCE_NOT_NULL(param_out, + phi::errors::InvalidArgument( + "Output(ParamOut) of SGDOp should not be null.")); + + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning rate should have 1 element. But received " + "LearningRate dims [%s]", + phi::product(lr_dims))); + + param_out->set_dims(param.dims()); + param_out->set_dtype(param.dtype()); +} + void StackInferMeta(const std::vector& x, int axis, MetaTensor* out) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 72c64e8500..64a11ed0b2 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -47,6 +47,14 @@ void AdadeltaInferMeta(const MetaTensor& param, MetaTensor* avg_squared_grad_out, MetaTensor* avg_squared_update_out); +void AdagradInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + float epsilon, + MetaTensor* param_out, + MetaTensor* moment_out); + void AdamaxInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -215,6 +223,9 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void MeshgridInferMeta(const std::vector& inputs, + std::vector outputs); + void MultiDotInferMeta(const std::vector& x, MetaTensor* out); void MultiplexInferMeta(const std::vector& ins, @@ -230,6 +241,21 @@ void PsroiPoolInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* out); +void RmspropInferMeta(const MetaTensor& param, + const MetaTensor& mean_square, + const MetaTensor& grad, + const MetaTensor& moment, + const MetaTensor& learning_rate, + paddle::optional mean_grad, + float epsilon, + float decay, + float momentum, + bool centered, + MetaTensor* param_out, + MetaTensor* moment_out, + MetaTensor* mean_square_out, + MetaTensor* mean_grad_out); + void RnnInferMeta(const MetaTensor& x, const std::vector& pre_state, const std::vector& weight_list, @@ -247,6 +273,14 @@ void RnnInferMeta(const MetaTensor& x, std::vector state, MetaTensor* reserve); +void SGDInferMeta(const MetaTensor& param, + const MetaTensor& learning_rate, + const MetaTensor& grad, + paddle::optional master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* master_param_out); + void StackInferMeta(const std::vector& x, int axis, MetaTensor* out); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6bf7a36b06..36c192cbf2 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2102,6 +2102,52 @@ void SumRawInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void TemporalShiftInferMeta(const MetaTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format, + MetaTensor* out, + MetaConfig config) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_EQ(dim_x.size(), + 4, + phi::errors::InvalidArgument( + "Input(X) rank should be 4 in shape of [N*T, C, H, " + "W], but received X rank(%d)", + dim_x.size())); + + PADDLE_ENFORCE_GT( + seg_num, + 0, + phi::errors::InvalidArgument( + "Attr(seg_num) should be greater than 0, but received %d", seg_num)); + PADDLE_ENFORCE_GT( + shift_ratio, + 0., + phi::errors::InvalidArgument( + "Attr(shift_ratio) should be greater than 0, but received %d", + shift_ratio)); + PADDLE_ENFORCE_LT( + shift_ratio, + 0.5, + phi::errors::InvalidArgument( + "Attr(shift_ratio) should be less than 0.5, but received %d", + shift_ratio)); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, + 0, + phi::errors::InvalidArgument( + "Input(X) dimension[0] should be divided exactly " + "by Attr(seg_num), but received X dimension[0](%d) " + "mod seg_num(%d) != 0", + dim_x[0], + seg_num)); + } + + out->share_meta(x); +} + void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 54f70d8d55..bda9c83fce 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -315,6 +315,13 @@ void SumRawInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); +void TemporalShiftInferMeta(const MetaTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, -- GitLab