diff --git a/paddle/fluid/operators/save_load_op_test.cc b/paddle/fluid/operators/save_load_op_test.cc index a7ba1e0ae1d22a22cf2943c9aaf0c394ef4ae326..74385ee47543e3f4887081c2225212996d3df3f1 100644 --- a/paddle/fluid/operators/save_load_op_test.cc +++ b/paddle/fluid/operators/save_load_op_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" USE_NO_KERNEL_OP(save); USE_NO_KERNEL_OP(load); @@ -61,3 +62,35 @@ TEST(SaveLoadOp, CPU) { } } } + +TEST(SaveLoadFP16Op, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({3, 10}); + + float* expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(paddle::platform::float16(i)); + } + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", std::string("tensor.save")}); + attrs.insert({"save_as_fp16", true}); + + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", {{"X", {"test_var"}}}, {}, attrs); + save_op->Run(scope, place); + + auto load_var = scope.Var("out_var"); + auto target = load_var->GetMutable(); + auto load_op = paddle::framework::OpRegistry::CreateOp( + "load", {}, {{"Out", {"out_var"}}}, attrs); + load_op->Run(scope, place); + paddle::platform::float16* actual = target->data(); + for (int64_t i = 0; i < tensor->numel(); ++i) { + EXPECT_EQ(expect[i], static_cast(actual[i])); + } +} diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 4a715c4baab2da7b7af86ada22ee88a16b05a814..f45d07ed90d52d204e9a3a5c2efe6df6b27ebfe6 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase { const platform::Place &place) const override { auto filename = Attr("file_path"); auto overwrite = Attr("overwrite"); + auto save_as_fp16 = Attr("save_as_fp16"); if (FileExists(filename) && !overwrite) { PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", @@ -96,7 +98,18 @@ class SaveOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); - framework::SerializeToStream(fout, tensor, dev_ctx); + auto in_dtype = framework::ToDataType(tensor.type()); + auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + + if (in_dtype != out_dtype) { + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + framework::LoDTensor out; + framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); + framework::SerializeToStream(fout, out, dev_ctx); + } else { + framework::SerializeToStream(fout, tensor, dev_ctx); + } } }; @@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk. "(boolean, default true)" "Overwrite the output file if exist") .SetDefault(true); + AddAttr("save_as_fp16", + "(boolean, default false)" + "If true, the tensor will be converted to float16 data " + "type and then saved. Otherwise, the tensor will be " + "directly saved without data type conversion.") + .SetDefault(false); AddAttr("file_path", "(string)" "The \"file_path\" where the variable will be saved.")