From eb95417e0528cb625341d081aef4a20d5fcb2b32 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 3 May 2018 15:45:00 -0700 Subject: [PATCH] initial commit --- paddle/fluid/operators/load_op.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index c6bd2bf3d..5a6145d51 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); + + auto load_as_fp16 = Attr("load_as_fp16"); + auto in_dtype = framework::ToDataType(tensor->type()); + auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + // convert to float16 tensor + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor fp16_tensor; + framework::TransDataType(in_kernel_type, out_kernel_type, *tensor, + &fp16_tensor); + } } }; @@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "(Tensor) The tensor need to be loaded"); + AddAttr( + "load_as_fp16", + "(boolean, default false)" + "If true, the tensor will be first loaded and then " + "converted to float16 data type. Otherwise, the tensor will be " + "directly loaded without data type conversion.") + .SetDefault(false); AddAttr("file_path", "(string) " "Variable will be loaded from \"file_path\".") -- GitLab