diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 2948cf71a911b296f8cee7ff9a2fb75f644dbe71..63d3f809f263588bc1fbcd9ee4305e2ce9321e38 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL( ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, ops::LoadCombineOpKernel); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 2d8e6ca854b55e01dacd1e0e7898ba59ea6078dc..656728c609eb19f90390d9dec72d9e30fd3040fd 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL( load, ops::LoadOpKernel, ops::LoadOpKernel, ops::LoadOpKernel, + ops::LoadOpKernel, ops::LoadOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 62b1e09737a4af4d0fe08eafcb3b2999d97032c1..5c4be7a7f312bb814d433db402d35a29dd6c9ab6 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -24,6 +24,13 @@ class SaveCombineOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + ctx.GetPlace()); + } }; class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -71,4 +78,5 @@ REGISTER_OP_CPU_KERNEL( ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel, ops::SaveCombineOpKernel);