提交 eb95417e 编写于 作者: K Kexin Zhao

initial commit

上级 0446220e
...@@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase { ...@@ -46,6 +46,19 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx); DeserializeFromStream(fin, tensor, *dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type());
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
}
} }
}; };
...@@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,6 +67,13 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The tensor need to be loaded"); AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<bool>(
"load_as_fp16",
"(boolean, default false)"
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion.")
.SetDefault(false);
AddAttr<std::string>("file_path", AddAttr<std::string>("file_path",
"(string) " "(string) "
"Variable will be loaded from \"file_path\".") "Variable will be loaded from \"file_path\".")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册