未验证 提交 c34b24ed 编写于 作者: L lujun 提交者: GitHub

Merge pull request #16425 from junjun315/checkpoint-hotfix

Checkpoint hotfix
...@@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL(
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>, load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -19,11 +19,27 @@ limitations under the License. */ ...@@ -19,11 +19,27 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class SaveCombineOp : public framework::OperatorWithKernel { class SaveCombineOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
// TODO(lujun): The override here is just to bypass transform
// in operator impl, which is not elegant enough.
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
}; };
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -54,7 +70,7 @@ to a file on disk. ...@@ -54,7 +70,7 @@ to a file on disk.
"(string)" "(string)"
"The \"file_path\" where the LoDTensor variables will be saved.") "The \"file_path\" where the LoDTensor variables will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string& path) { return !path.empty(); });
} }
}; };
...@@ -70,5 +86,4 @@ REGISTER_OP_CPU_KERNEL( ...@@ -70,5 +86,4 @@ REGISTER_OP_CPU_KERNEL(
save_combine, save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>);
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL(
save_combine, save_combine,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>);
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册