diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index bee0a29e9052c30be9752931eb40e222d5c6cc51..b0e284a26032dc06b72ac8aff83f22b8a5f41b08 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -318,7 +318,7 @@ class BeamSearchInferShape : public framework::InferShapeBase { } } }; - +*/ class BeamSearchInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc &op_desc, @@ -331,7 +331,7 @@ class BeamSearchInferVarType : public framework::VarTypeInference { } } }; -*/ + } // namespace operators } // namespace paddle /* @@ -343,7 +343,8 @@ REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp, */ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp, - ops::BeamSearchOpMaker); + ops::BeamSearchOpMaker, + ops::BeamSearchInferVarType); REGISTER_OP_CPU_KERNEL( beam_search, ops::BeamSearchOpKernel,