提交 df70d5f1 编写于 作者: K ktlichkid

Fixed some bugs

上级 d060b5df
...@@ -223,33 +223,37 @@ class BeamSearchOpMaker ...@@ -223,33 +223,37 @@ class BeamSearchOpMaker
}; };
class BeamSearchOp : public framework::OperatorWithKernel { class BeamSearchOp : public framework::OperatorWithKernel {
/*
public: public:
BeamSearchOp(const std::string& type, BeamSearchOp(const std::string& type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
BeamSearchOp(const BeamSearchOp& o) BeamSearchOp(const BeamSearchOp& o)
: framework::OperatorBase( : framework::OperatorWithKernel(
static_cast<const framework::OperatorBase&>(o)) { static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
*/
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
for (const std::string &arg : for (const std::string &arg :
std::vector<std::string>({"pre_ids", "ids", "scores"})) { std::vector<std::string>({"pre_ids", "ids", "scores"})) {
PADDLE_ENFORCE(context->HasInput(arg), PADDLE_ENFORCE(ctx->HasInput(arg),
"BeamSearch need input argument '%s'", arg); "BeamSearch need input argument '%s'", arg);
} }
for (const std::string &arg : for (const std::string &arg :
std::vector<std::string>({"selected_ids", "selected_scores"})) { std::vector<std::string>({"selected_ids", "selected_scores"})) {
PADDLE_ENFORCE(context->HasOutput(arg), PADDLE_ENFORCE(ctx->HasOutput(arg),
"BeamSearch need output argument '%s'", arg); "BeamSearch need output argument '%s'", arg);
} }
} }
/*
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
...@@ -278,9 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -278,9 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
*selected_scores_var->GetMutable<framework::LoDTensor>(); *selected_scores_var->GetMutable<framework::LoDTensor>();
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
} }
*/
public:
using framework::OperatorWithKernel::OperatorWithKernel;
}; };
......
...@@ -196,33 +196,33 @@ template <typename DeviceContext, typename T> ...@@ -196,33 +196,33 @@ template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T>{ class BeamSearchOpKernel : public framework::OpKernel<T>{
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* ids_var = context.Input<framework::Tensor>("ids"); auto ids_var = context.Input<framework::LoDTensor>("ids");
auto* scores_var = context.Input<framework::Tensor>("scores"); auto scores_var = context.Input<framework::LoDTensor>("scores");
auto* pre_ids_var = context.Input<framework::Tensor>("pre_ids"); auto pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
PADDLE_ENFORCE_NOT_NULL(ids_var); PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var); PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var); PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto& ids = ids_var->Get<framework::LoDTensor>(); //auto& ids = ids_var->Get<framework::LoDTensor>();
auto& scores = scores_var->Get<framework::LoDTensor>(); //auto& scores = scores_var->Get<framework::LoDTensor>();
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>(); //auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
size_t level = Attr<int>("level"); size_t level = context.Attr<int>("level");
size_t beam_size = Attr<int>("beam_size"); size_t beam_size = context.Attr<int>("beam_size");
int end_id = Attr<int>("end_id"); int end_id = context.Attr<int>("end_id");
BeamSearch alg(ids, scores, level, beam_size, end_id); BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
auto* selected_ids_var = context.Output<framework::Tensor>("selected_ids"); auto selected_ids_var = context.Output<framework::LoDTensor>("selected_ids");
auto* selected_scores_var = context.Output<framework::Tensor>("selected_scores"); auto selected_scores_var = context.Output<framework::LoDTensor>("selected_scores");
PADDLE_ENFORCE_NOT_NULL(selected_ids_var); PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var); PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
auto& selected_ids_tensor = //auto& selected_ids_tensor =
*selected_ids_var->GetMutable<framework::LoDTensor>(); // *selected_ids_var->GetMutable<framework::LoDTensor>();
auto& selected_scores_tensor = //auto& selected_scores_tensor =
*selected_scores_var->GetMutable<framework::LoDTensor>(); // *selected_scores_var->GetMutable<framework::LoDTensor>();
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor); alg(*pre_ids_var, selected_ids_var, selected_scores_var);
} }
} };
/* /*
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册