提交 f57efeb6 编写于 作者: K ktlichkid

Added GetExpectedKernelType and Debug message

上级 6f06b322
...@@ -260,10 +260,13 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -260,10 +260,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
std::cout << "Get Expected type 1\n"; std::cout << "Get Expected type 1\n";
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace());
std::cout << "Get Expected type 2\n"; std::cout << "Get Expected type 2\n";
kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place(); // kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
std::cout << "Get Expected type 3\n"; // std::cout << "Get Expected type 3\n";
return kt; return kt;
} }
/* /*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册