diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 5c4be7a7f312bb814d433db402d35a29dd6c9ab6..ac1c2dde3bb8a706001116f05ba5240717f57712 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override {} protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, ctx.GetPlace()); } + // TODO(lujun): The override here is just to bypass transform + // in operator impl, which is not elegant enough. + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } }; class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { @@ -61,7 +68,7 @@ to a file on disk. "(string)" "The \"file_path\" where the LoDTensor variables will be saved.") .AddCustomChecker( - [](const std::string &path) { return !path.empty(); }); + [](const std::string& path) { return !path.empty(); }); } }; @@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cu b/paddle/fluid/operators/save_combine_op.cu index bc4478b51b111518439fe250a70b8dee0df53ad9..78607823a0368d216310bbbb390fd7face002839 100644 --- a/paddle/fluid/operators/save_combine_op.cu +++ b/paddle/fluid/operators/save_combine_op.cu @@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); + ops::SaveCombineOpKernel);