diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 0499d8cbef75cd31b6633702ea10208ded0426da..bee0a29e9052c30be9752931eb40e222d5c6cc51 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -260,10 +260,13 @@ class BeamSearchOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { std::cout << "Get Expected type 1\n"; - framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + framework::OpKernelType kt = framework::OpKernelType( + framework::ToDataType( + ctx.Input("pre_ids")->type()), + platform::CPUPlace()); std::cout << "Get Expected type 2\n"; - kt.place_ = ctx.Input("pre_ids")->place(); - std::cout << "Get Expected type 3\n"; + // kt.place_ = ctx.Input("pre_ids")->place(); + // std::cout << "Get Expected type 3\n"; return kt; } /*