提交 6f06b322 编写于 作者: K ktlichkid

Added GetExpectedKernelType and Debug message

上级 df70d5f1
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include <iostream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -252,6 +254,17 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -252,6 +254,17 @@ class BeamSearchOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput(arg), PADDLE_ENFORCE(ctx->HasOutput(arg),
"BeamSearch need output argument '%s'", arg); "BeamSearch need output argument '%s'", arg);
} }
std::cout << "Done Infer Shape\n";
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
std::cout << "Get Expected type 1\n";
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
std::cout << "Get Expected type 2\n";
kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
std::cout << "Get Expected type 3\n";
return kt;
} }
/* /*
private: private:
......
...@@ -23,6 +23,8 @@ limitations under the License. */ ...@@ -23,6 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include <iostream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -196,31 +198,47 @@ template <typename DeviceContext, typename T> ...@@ -196,31 +198,47 @@ template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T>{ class BeamSearchOpKernel : public framework::OpKernel<T>{
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
std::cout << "Compute 1\n";
auto ids_var = context.Input<framework::LoDTensor>("ids"); auto ids_var = context.Input<framework::LoDTensor>("ids");
std::cout << "Compute 2\n";
auto scores_var = context.Input<framework::LoDTensor>("scores"); auto scores_var = context.Input<framework::LoDTensor>("scores");
std::cout << "Compute 3\n";
auto pre_ids_var = context.Input<framework::LoDTensor>("pre_ids"); auto pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
std::cout << "Compute 4\n";
PADDLE_ENFORCE_NOT_NULL(ids_var); PADDLE_ENFORCE_NOT_NULL(ids_var);
std::cout << "Compute 5\n";
PADDLE_ENFORCE_NOT_NULL(scores_var); PADDLE_ENFORCE_NOT_NULL(scores_var);
std::cout << "Compute 6\n";
PADDLE_ENFORCE_NOT_NULL(pre_ids_var); PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
std::cout << "Compute 7\n";
//auto& ids = ids_var->Get<framework::LoDTensor>(); // auto& ids = ids_var->Get<framework::LoDTensor>();
//auto& scores = scores_var->Get<framework::LoDTensor>(); // auto& scores = scores_var->Get<framework::LoDTensor>();
//auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>(); // auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
size_t level = context.Attr<int>("level"); size_t level = context.Attr<int>("level");
std::cout << "Compute 8\n";
size_t beam_size = context.Attr<int>("beam_size"); size_t beam_size = context.Attr<int>("beam_size");
std::cout << "Compute 9\n";
int end_id = context.Attr<int>("end_id"); int end_id = context.Attr<int>("end_id");
std::cout << "Compute 10\n";
BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id); BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
std::cout << "Compute 11\n";
auto selected_ids_var = context.Output<framework::LoDTensor>("selected_ids"); auto selected_ids_var =
auto selected_scores_var = context.Output<framework::LoDTensor>("selected_scores"); context.Output<framework::LoDTensor>("selected_ids");
std::cout << "Compute 12\n";
auto selected_scores_var =
context.Output<framework::LoDTensor>("selected_scores");
std::cout << "Compute 13\n";
PADDLE_ENFORCE_NOT_NULL(selected_ids_var); PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
std::cout << "Compute 14\n";
PADDLE_ENFORCE_NOT_NULL(selected_scores_var); PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
//auto& selected_ids_tensor = std::cout << "Compute 15\n";
// auto& selected_ids_tensor =
// *selected_ids_var->GetMutable<framework::LoDTensor>(); // *selected_ids_var->GetMutable<framework::LoDTensor>();
//auto& selected_scores_tensor = // auto& selected_scores_tensor =
// *selected_scores_var->GetMutable<framework::LoDTensor>(); // *selected_scores_var->GetMutable<framework::LoDTensor>();
alg(*pre_ids_var, selected_ids_var, selected_scores_var); alg(*pre_ids_var, selected_ids_var, selected_scores_var);
std::cout << "Compute 16\n";
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册