diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 2948cf71a911b296f8cee7ff9a2fb75f644dbe71..63d3f809f263588bc1fbcd9ee4305e2ce9321e38 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL( ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, ops::LoadCombineOpKernel); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 2d8e6ca854b55e01dacd1e0e7898ba59ea6078dc..656728c609eb19f90390d9dec72d9e30fd3040fd 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL( load, ops::LoadOpKernel, ops::LoadOpKernel, ops::LoadOpKernel, + ops::LoadOpKernel, ops::LoadOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 62b1e09737a4af4d0fe08eafcb3b2999d97032c1..953e2655d13328b986a67398dca54f8a5e3aedcf 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -19,11 +19,27 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + class SaveCombineOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } + // TODO(lujun): The override here is just to bypass transform + // in operator impl, which is not elegant enough. + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } }; class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -54,7 +70,7 @@ to a file on disk. "(string)" "The \"file_path\" where the LoDTensor variables will be saved.") .AddCustomChecker( - [](const std::string &path) { return !path.empty(); }); + [](const std::string& path) { return !path.empty(); }); } }; @@ -70,5 +86,4 @@ REGISTER_OP_CPU_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cu b/paddle/fluid/operators/save_combine_op.cu index bc4478b51b111518439fe250a70b8dee0df53ad9..78607823a0368d216310bbbb390fd7face002839 100644 --- a/paddle/fluid/operators/save_combine_op.cu +++ b/paddle/fluid/operators/save_combine_op.cu @@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel);