提交 294b58a9 编写于 作者: K ktlichkid

Changed registered type

上级 df80b6ea
...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) { ...@@ -197,8 +197,7 @@ std::string ItemToString(const BeamSearch::Item &item) {
return stream.str(); return stream.str();
} }
class BeamSearchOpMaker class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
: public framework::OpProtoAndCheckerMaker {
public: public:
BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker) BeamSearchOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
...@@ -225,29 +224,15 @@ class BeamSearchOpMaker ...@@ -225,29 +224,15 @@ class BeamSearchOpMaker
}; };
class BeamSearchOp : public framework::OperatorWithKernel { class BeamSearchOp : public framework::OperatorWithKernel {
/*
public:
BeamSearchOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
BeamSearchOp(const BeamSearchOp& o)
: framework::OperatorWithKernel(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not Implemented");
}
*/
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
for (const std::string &arg : for (const std::string &arg :
std::vector<std::string>({"pre_ids", "ids", "scores"})) { std::vector<std::string>({"pre_ids", "ids", "scores"})) {
PADDLE_ENFORCE(ctx->HasInput(arg), PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'",
"BeamSearch need input argument '%s'", arg); arg);
} }
for (const std::string &arg : for (const std::string &arg :
std::vector<std::string>({"selected_ids", "selected_scores"})) { std::vector<std::string>({"selected_ids", "selected_scores"})) {
...@@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -263,62 +248,13 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework::OpKernelType kt = framework::OpKernelType( framework::OpKernelType kt = framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
ctx.Input<framework::LoDTensor>("pre_ids")->type()), ctx.Input<framework::LoDTensor>("pre_ids")->type()),
platform::CPUPlace()); platform::CPUPlace());
std::cout << "Get Expected type 2\n"; std::cout << "Get Expected type 2\n";
// kt.place_ = ctx.Input<framework::LoDTensor>("pre_ids")->place();
// std::cout << "Get Expected type 3\n";
return kt; return kt;
} }
/*
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto ids_var = scope.FindVar(Input("ids"));
auto scores_var = scope.FindVar(Input("scores"));
auto pre_ids_var = scope.FindVar(Input("pre_ids"));
PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto& ids = ids_var->Get<framework::LoDTensor>();
auto& scores = scores_var->Get<framework::LoDTensor>();
auto& pre_ids = pre_ids_var->Get<framework::LoDTensor>();
size_t level = Attr<int>("level");
size_t beam_size = Attr<int>("beam_size");
int end_id = Attr<int>("end_id");
BeamSearch alg(ids, scores, level, beam_size, end_id);
auto selected_ids_var = scope.FindVar(Output("selected_ids"));
auto selected_scores_var = scope.FindVar(Output("selected_scores"));
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
auto& selected_ids_tensor =
*selected_ids_var->GetMutable<framework::LoDTensor>();
auto& selected_scores_tensor =
*selected_scores_var->GetMutable<framework::LoDTensor>();
alg(pre_ids, &selected_ids_tensor, &selected_scores_tensor);
}
*/
}; };
/*
class BeamSearchInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
for (const std::string &arg :
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
PADDLE_ENFORCE(context->HasInput(arg),
"BeamSearch need input argument '%s'", arg);
}
for (const std::string &arg :
std::vector<std::string>({"selected_ids", "selected_scores"})) {
PADDLE_ENFORCE(context->HasOutput(arg),
"BeamSearch need output argument '%s'", arg);
}
}
};
*/
class BeamSearchInferVarType : public framework::VarTypeInference { class BeamSearchInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
...@@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference { ...@@ -334,18 +270,15 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/*
REGISTER_OPERATOR(beam_search, paddle::operators::BeamSearchOp,
paddle::operators::BeamSearchProtoAndCheckerMaker,
paddle::operators::BeamSearchInferShape,
paddle::operators::BeamSearchInferVarType,
paddle::framework::EmptyGradOpMaker);
*/
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(beam_search, ops::BeamSearchOp,
ops::BeamSearchOpMaker, REGISTER_OPERATOR(beam_search, ops::BeamSearchOp, ops::BeamSearchOpMaker,
ops::BeamSearchInferVarType); ops::BeamSearchInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
beam_search, beam_search,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>, ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>); ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -195,7 +195,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item); ...@@ -195,7 +195,7 @@ std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
std::string ItemToString(const BeamSearch::Item& item); std::string ItemToString(const BeamSearch::Item& item);
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T>{ class BeamSearchOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
std::cout << "Compute 1\n"; std::cout << "Compute 1\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册