From 6c88f1ae6e2a517492280a3e355dba6f0fdb10b6 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 27 Apr 2018 10:14:49 -0700 Subject: [PATCH] add save op float16 support --- paddle/fluid/operators/save_load_op_test.cc | 32 +++++++++++++++++++++ paddle/fluid/operators/save_op.cc | 21 +++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index a7ba1e0ae1..0cfb7fb730 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -61,3 +61,35 @@ TEST(SaveLoadOp, CPU) { } } } + +TEST(SaveLoadFP16Op, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({3, 10}); + + float* expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(i); + } + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", std::string("tensor.save")}); + attrs.insert({"save_as_fp16", true}); + + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", {{"X", {"test_var"}}}, {}, attrs); + save_op->Run(scope, place); + + auto load_var = scope.Var("out_var"); + auto target = load_var->GetMutable(); + auto load_op = paddle::framework::OpRegistry::CreateOp( + "load", {}, {{"Out", {"out_var"}}}, attrs); + load_op->Run(scope, place); + paddle::platform::float16* actual = target->data(); + for (int64_t i = 0; i < tensor->numel(); ++i) { + EXPECT_EQ(expect[i], static_cast(actual[i])); + } +} diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 4a715c4baa..f45d07ed90 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); + auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -96,7 +98,18 @@ class SaveOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::SerializeToStream(fout, tensor, dev_ctx); + auto in_dtype = framework::ToDataType(tensor.type()); + auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor out; + framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); + framework::SerializeToStream(fout, out, dev_ctx); + } else { + framework::SerializeToStream(fout, tensor, dev_ctx); + } } }; @@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk. "(boolean, default true)" "Overwrite the output file if exist") .SetDefault(true); + AddAttr("save_as_fp16", + "(boolean, default false)" + "If true, the tensor will be converted to float16 data " + "type and then saved. Otherwise, the tensor will be " + "directly saved without data type conversion.") + .SetDefault(false); AddAttr("file_path", "(string)" "The \"file_path\" where the variable will be saved.") -- GitLab