diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index c6bd2bf3dfca77dc078eb04b1d90c7d90883203f..5a6145d5174ce1e4cdf787e73ed78a81d67a62e1 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase { auto *tensor = out_var->GetMutable(); DeserializeFromStream(fin, tensor, *dev_ctx); + + auto load_as_fp16 = Attr("load_as_fp16"); + auto in_dtype = framework::ToDataType(tensor->type()); + auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + // convert to float16 tensor + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor fp16_tensor; + framework::TransDataType(in_kernel_type, out_kernel_type, *tensor, + &fp16_tensor); + } } }; @@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "(Tensor) The tensor need to be loaded"); + AddAttr( + "load_as_fp16", + "(boolean, default false)" + "If true, the tensor will be first loaded and then " + "converted to float16 data type. Otherwise, the tensor will be " + "directly loaded without data type conversion.") + .SetDefault(false); AddAttr("file_path", "(string) " "Variable will be loaded from \"file_path\".")