未验证 提交 4613aeba 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #10272 from kexinzhao/save_fp16

Add float16 support to save op
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP(save);
USE_NO_KERNEL_OP(load);
......@@ -61,3 +62,35 @@ TEST(SaveLoadOp, CPU) {
}
}
}
TEST(SaveLoadFP16Op, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({3, 10});
float* expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(paddle::platform::float16(i));
}
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", std::string("tensor.save")});
attrs.insert({"save_as_fp16", true});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place);
auto load_var = scope.Var("out_var");
auto target = load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, place);
paddle::platform::float16* actual = target->data<paddle::platform::float16>();
for (int64_t i = 0; i < tensor->numel(); ++i) {
EXPECT_EQ(expect[i], static_cast<float>(actual[i]));
}
}
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <numeric>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -68,6 +69,7 @@ class SaveOp : public framework::OperatorBase {
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
......@@ -96,7 +98,18 @@ class SaveOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::SerializeToStream(fout, tensor, dev_ctx);
auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
}
};
......@@ -114,6 +127,12 @@ This operator will serialize and write a tensor variable to file on disk.
"(boolean, default true)"
"Overwrite the output file if exist")
.SetDefault(true);
AddAttr<bool>("save_as_fp16",
"(boolean, default false)"
"If true, the tensor will be converted to float16 data "
"type and then saved. Otherwise, the tensor will be "
"directly saved without data type conversion.")
.SetDefault(false);
AddAttr<std::string>("file_path",
"(string)"
"The \"file_path\" where the variable will be saved.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册