提交 881ea62b 编写于 作者: K ktlichkid

Added BeamSearchOpMaker class

上级 17212696
......@@ -195,10 +195,10 @@ std::string ItemToString(const BeamSearch::Item &item) {
return stream.str();
}
class BeamSearchProtoAndCheckerMaker
class BeamSearchOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
......@@ -222,6 +222,59 @@ class BeamSearchProtoAndCheckerMaker
}
};
class BeamSearchOp : public framework::OperatorWithKernel {
public:
BeamSearchOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
BeamSearchOp(const BeamSearchOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not Implemented");
}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
}
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores"));
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto& ids = ids_var->Get<framework::LoDTensor>();
auto& scores = scores_var->Get<framework::LoDTensor>();
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
size_t level = Attr<int>("level");
size_t beam_size = Attr<int>("beam_size");
int end_id = Attr<int>("end_id");
BeamSearch alg(ids, scores, level, beam_size, end_id);
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
auto& selected_ids_tensor =
*selected_ids_var->GetMutable<framework::LoDTensor>();
auto& selected_scores_tensor =
*selected_scores_var->GetMutable<framework::LoDTensor>();
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
/*
class BeamSearchInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
......@@ -250,7 +303,7 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
}
}
};
*/
} // namespace operators
} // namespace paddle
......
......@@ -192,56 +192,10 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
std::string ItemToString(const BeamSearch::Item& item);
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker{
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker){
}
}
class BeamSearchOp : public framework::OperatorBase {
public:
BeamSearchOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
BeamSearchOp(const BeamSearchOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not Implemented");
}
template <typename DeviceContext, typename T>
class BeamSearchKernel : public framework::OpKernel<T>{
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores"));
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto& ids = ids_var->Get<framework::LoDTensor>();
auto& scores = scores_var->Get<framework::LoDTensor>();
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
size_t level = Attr<int>("level");
size_t beam_size = Attr<int>("beam_size");
int end_id = Attr<int>("end_id");
BeamSearch alg(ids, scores, level, beam_size, end_id);
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
auto& selected_ids_tensor =
*selected_ids_var->GetMutable<framework::LoDTensor>();
auto& selected_scores_tensor =
*selected_scores_var->GetMutable<framework::LoDTensor>();
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
}
};
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册