diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 2868c3697eda19ed3e7cc1fb4c74e9beeaca9c0d..7f6c82032fe39da9d4de768330dcbcfc48610bcd 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree"); - OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree"); - - auto ids_dims = ctx->GetInputDim("Ids"); - auto parents_dims = ctx->GetInputDim("Parents"); - PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true, - platform::errors::InvalidArgument( - "The shape of Input(Parents) must be same with the " - "shape of Input(Ids).")); - ctx->SetOutputDim("Out", ids_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -72,4 +61,8 @@ selected ids. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker); +DELCARE_INFER_SHAPE_FUNCTOR(gather_tree, GatherTreeInferShapeFunctor, + PT_INFER_META(phi::GatherTreeMeta)); + +REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker, + GatherTreeInferShapeFunctor); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 675e68af74339b508f589a55a9c3cf3aed37cecb..7682f6b3d49b9281f4fabef26137a7db1a5b6126 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input, out->share_lod(input); } +void GatherTreeMeta(const MetaTensor& ids, + const MetaTensor& parents, + MetaTensor* out) { + auto ids_dims = ids.dims(); + auto parents_dims = parents.dims(); + PADDLE_ENFORCE_EQ(ids_dims == parents_dims, + true, + phi::errors::InvalidArgument( + "The shape of Input(Parents) must be same with the " + "shape of Input(Ids).")); + out->set_dims(ids_dims); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index a0140c9a5799f79af541b45847d5e44f982a3f58..5906e06b2935504babf993b657dbded403348175 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input, const MetaTensor& label, MetaTensor* out, MetaConfig config = MetaConfig()); + +void GatherTreeMeta(const MetaTensor& ids, + const MetaTensor& parents, + MetaTensor* out); } // namespace phi