未验证 提交 6b3c48c1 编写于 作者: H Huang Jiyi 提交者: GitHub

[phi decoupling] move serialization from phi to fluid (#50608)

* move save_op to fluid

* fix namespace

* move_load_kernel

* fix kernel_register

* move serialization to fluid

* fix test

* fix bugs
上级 7fc9f433
...@@ -166,13 +166,7 @@ cc_test( ...@@ -166,13 +166,7 @@ cc_test(
cc_library( cc_library(
lod_tensor lod_tensor
SRCS lod_tensor.cc SRCS lod_tensor.cc
DEPS ddim DEPS ddim mixed_vector place tensor framework_proto version)
mixed_vector
place
tensor
framework_proto
version
serialization)
cc_test( cc_test(
lod_tensor_test lod_tensor_test
...@@ -1103,7 +1097,7 @@ cc_test( ...@@ -1103,7 +1097,7 @@ cc_test(
cc_library( cc_library(
selected_rows_utils selected_rows_utils
SRCS selected_rows_utils.cc SRCS selected_rows_utils.cc
DEPS selected_rows serialization device_context) DEPS selected_rows device_context)
cc_test( cc_test(
selected_rows_utils_test selected_rows_utils_test
SRCS selected_rows_utils_test.cc SRCS selected_rows_utils_test.cc
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include "paddle/phi/core/serialization.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -207,7 +206,31 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, ...@@ -207,7 +206,31 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod,
void SerializeToStream(std::ostream &os, void SerializeToStream(std::ostream &os,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
phi::SerializeToStream(os, tensor, dev_ctx); { // the 1st field, uint32_t version for DenseTensor
os.write(
reinterpret_cast<const char *>(&paddle::framework::kCurTensorVersion),
sizeof(paddle::framework::kCurTensorVersion));
}
{
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(framework::LoD::value_type::value_type);
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
// the 3st field, Tensor
paddle::framework::TensorToStream(
os, static_cast<phi::DenseTensor>(tensor), dev_ctx);
} }
void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) { void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) {
...@@ -215,14 +238,14 @@ void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) { ...@@ -215,14 +238,14 @@ void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) {
const platform::DeviceContext *dev_ctx; const platform::DeviceContext *dev_ctx;
auto place = tensor.place(); auto place = tensor.place();
dev_ctx = pool.Get(place); dev_ctx = pool.Get(place);
phi::SerializeToStream(os, tensor, *dev_ctx); SerializeToStream(os, tensor, *dev_ctx);
} }
void DeserializeFromStream(std::istream &os, phi::DenseTensor *tensor) { void DeserializeFromStream(std::istream &os, phi::DenseTensor *tensor) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext *dev_ctx; const platform::DeviceContext *dev_ctx;
dev_ctx = pool.Get(platform::CPUPlace()); dev_ctx = pool.Get(platform::CPUPlace());
phi::DeserializeFromStream(os, tensor, *dev_ctx); DeserializeFromStream(os, tensor, *dev_ctx);
} }
void DeserializeFromStream(std::istream &is, void DeserializeFromStream(std::istream &is,
...@@ -230,13 +253,71 @@ void DeserializeFromStream(std::istream &is, ...@@ -230,13 +253,71 @@ void DeserializeFromStream(std::istream &is,
const platform::DeviceContext &dev_ctx, const platform::DeviceContext &dev_ctx,
const size_t &seek, const size_t &seek,
const std::vector<int64_t> &shape) { const std::vector<int64_t> &shape) {
phi::DeserializeFromStream(is, tensor, dev_ctx, seek, shape); {
// the 1st field, unit32_t version for DenseTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
phi::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version,
0U,
phi::errors::InvalidArgument(
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
}
// the 3st filed, Tensor
paddle::framework::TensorFromStream(
is, static_cast<phi::DenseTensor *>(tensor), dev_ctx, seek, shape);
} }
void DeserializeFromStream(std::istream &is, void DeserializeFromStream(std::istream &is,
phi::DenseTensor *tensor, phi::DenseTensor *tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
phi::DeserializeFromStream(is, tensor, dev_ctx); {
// the 1st field, unit32_t version for DenseTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
phi::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version,
0U,
phi::errors::InvalidArgument(
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
// the 3st filed, Tensor
paddle::framework::TensorFromStream(
is, static_cast<phi::DenseTensor *>(tensor), dev_ctx);
} }
LoD ConvertToOffsetBasedLoD(const LoD &length_lod) { LoD ConvertToOffsetBasedLoD(const LoD &length_lod) {
......
...@@ -14,15 +14,32 @@ limitations under the License. */ ...@@ -14,15 +14,32 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/phi/core/serialization.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void SerializeToStream(std::ostream& os, void SerializeToStream(std::ostream& os,
const phi::SelectedRows& selected_rows, const phi::SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
phi::SerializeToStream(os, selected_rows, dev_ctx); { // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
}
{
// the 2st field, rows information
auto& rows = selected_rows.rows();
uint64_t size = rows.size();
os.write(reinterpret_cast<const char*>(&size), sizeof(size));
for (uint64_t i = 0; i < size; ++i) {
os.write(reinterpret_cast<const char*>(&rows[i]), sizeof(rows[i]));
}
}
{
// the 3st field, the height of SelectedRows
int64_t height = selected_rows.height();
os.write(reinterpret_cast<const char*>(&height), sizeof(height));
}
// the 4st field, Tensor data
paddle::framework::TensorToStream(os, selected_rows.value(), dev_ctx);
} }
void SerializeToStream(std::ostream& os, void SerializeToStream(std::ostream& os,
...@@ -31,20 +48,51 @@ void SerializeToStream(std::ostream& os, ...@@ -31,20 +48,51 @@ void SerializeToStream(std::ostream& os,
const platform::DeviceContext* dev_ctx; const platform::DeviceContext* dev_ctx;
auto place = selected_rows.place(); auto place = selected_rows.place();
dev_ctx = pool.Get(place); dev_ctx = pool.Get(place);
phi::SerializeToStream(os, selected_rows, *dev_ctx); SerializeToStream(os, selected_rows, *dev_ctx);
} }
void DeserializeFromStream(std::istream& is, phi::SelectedRows* selected_rows) { void DeserializeFromStream(std::istream& is, phi::SelectedRows* selected_rows) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx; const platform::DeviceContext* dev_ctx;
dev_ctx = pool.Get(platform::CPUPlace()); dev_ctx = pool.Get(platform::CPUPlace());
phi::DeserializeFromStream(is, selected_rows, *dev_ctx); DeserializeFromStream(is, selected_rows, *dev_ctx);
} }
void DeserializeFromStream(std::istream& is, void DeserializeFromStream(std::istream& is,
phi::SelectedRows* selected_rows, phi::SelectedRows* selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
phi::DeserializeFromStream(is, selected_rows, dev_ctx); {
// the 1st field, unit32_t version for SelectedRows
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version,
0U,
phi::errors::InvalidArgument(
"Only version 0 SelectedRows is supported."));
}
{
// the 2st field, rows information
uint64_t size = 0;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
PADDLE_ENFORCE_EQ(
is.good(),
true,
phi::errors::Unavailable("Cannot read the number of rows."));
auto& rows = *selected_rows->mutable_rows();
rows.resize(size);
for (uint64_t i = 0; i < size; ++i) {
is.read(reinterpret_cast<char*>(&rows[i]), sizeof(int64_t));
}
}
{
// the 3st field, the height of the SelectedRows
int64_t height;
is.read(reinterpret_cast<char*>(&height), sizeof(int64_t));
selected_rows->set_height(height);
}
// the 4st field, tensor which contains the data
paddle::framework::TensorFromStream(
is, selected_rows->mutable_value(), dev_ctx);
} }
} // namespace framework } // namespace framework
......
...@@ -14,20 +14,84 @@ limitations under the License. */ ...@@ -14,20 +14,84 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
phi::DenseTensor* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
if (seek != -1) {
PADDLE_ENFORCE_GE(seek,
0,
phi::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
framework::DeserializeFromStream(fin, out, dev_ctx, seek, shape);
} else {
framework::DeserializeFromStream(fin, out, dev_ctx);
}
auto in_dtype = out->dtype();
auto out_dtype = load_as_fp16 ? phi::DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
phi::CastKernel<T>(dev_ctx, *out, out_dtype, out);
}
}
template <typename T, typename Context>
void LoadSelectedRowsKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
phi::SelectedRows* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
framework::DeserializeFromStream(fin, out, dev_ctx);
}
class LoadOp : public framework::OperatorWithKernel { class LoadOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
} }
}; };
...@@ -45,7 +109,7 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,7 +109,7 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>("file_path", AddAttr<std::string>("file_path",
R"(Variable will be loaded from "file_path")") R"(Variable will be loaded from "file_path")")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string& path) { return !path.empty(); });
AddAttr<int64_t>("seek", "(int64_t) Starting for load tensor from seek pos") AddAttr<int64_t>("seek", "(int64_t) Starting for load tensor from seek pos")
.SetDefault(-1); .SetDefault(-1);
AddAttr<std::vector<int64_t>>("shape", AddAttr<std::vector<int64_t>>("shape",
...@@ -64,3 +128,17 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,3 +128,17 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker); REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
PD_REGISTER_KERNEL(load, CPU, ALL_LAYOUT, ops::LoadKernel, float) {}
PD_REGISTER_KERNEL(
load_sr, CPU, ALL_LAYOUT, ops::LoadSelectedRowsKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(load, GPU, ALL_LAYOUT, ops::LoadKernel, float) {}
PD_REGISTER_KERNEL(
load_sr, GPU, ALL_LAYOUT, ops::LoadSelectedRowsKernel, float) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(load, XPU, ALL_LAYOUT, ops::LoadKernel, float) {}
#endif
...@@ -32,7 +32,6 @@ limitations under the License. */ ...@@ -32,7 +32,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/serialization.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -111,9 +110,9 @@ void SaveCombineTensorKernel(const Context& dev_ctx, ...@@ -111,9 +110,9 @@ void SaveCombineTensorKernel(const Context& dev_ctx,
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor // copy LoD info to the new tensor
out.set_lod(tensor.lod()); out.set_lod(tensor.lod());
phi::SerializeToStream(ss, out, dev_ctx); framework::SerializeToStream(ss, out, dev_ctx);
} else { } else {
phi::SerializeToStream(ss, tensor, dev_ctx); framework::SerializeToStream(ss, tensor, dev_ctx);
} }
} }
......
...@@ -88,3 +88,61 @@ REGISTER_OPERATOR(save, ...@@ -88,3 +88,61 @@ REGISTER_OPERATOR(save,
ops::SaveOp, ops::SaveOp,
ops::SaveOpProtoMaker, ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference); ops::SaveOpVarTypeInference);
PD_REGISTER_KERNEL(save,
CPU,
ALL_LAYOUT,
ops::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(save_sr,
CPU,
ALL_LAYOUT,
ops::SaveSelectedRowsKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(save,
GPU,
ALL_LAYOUT,
ops::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(save_sr,
GPU,
ALL_LAYOUT,
ops::SaveSelectedRowsKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
...@@ -24,111 +24,118 @@ limitations under the License. */ ...@@ -24,111 +24,118 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename Context>
void SaveKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16) {
PADDLE_ENFORCE_EQ(
FileExists(file_path) && !overwrite,
false,
phi::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
file_path,
overwrite));
MkDirRecursively(DirName(file_path).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save variables.", file_path));
auto in_dtype = x.dtype();
auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
auto out = phi::Cast<T>(dev_ctx, x, out_dtype);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, x, dev_ctx);
}
fout.close();
}
template <typename T, typename Context>
void SaveSelectedRowsKernel(const Context& dev_ctx,
const phi::SelectedRows& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16) {
PADDLE_ENFORCE_EQ(
FileExists(file_path) && !overwrite,
false,
phi::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
file_path,
overwrite));
PADDLE_ENFORCE_EQ(save_as_fp16,
false,
phi::errors::Unimplemented(
"SelectedRows is not supported to save as float16."));
MkDirRecursively(DirName(file_path).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save variables.", file_path));
framework::SerializeToStream(fout, x, dev_ctx);
fout.close();
}
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables // define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SaveOpKernel : public framework::OpKernel<T> { class SaveOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto *input_var = ctx.InputVar("X"); auto* input_var = ctx.InputVar("X");
auto iname = ctx.InputNames("X").data(); auto iname = ctx.InputNames("X").data();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
input_var, input_var,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The variable %s to be saved cannot be found.", iname)); "The variable %s to be saved cannot be found.", iname));
auto filename = ctx.Attr<std::string>("file_path"); auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite"); auto overwrite = ctx.Attr<bool>("overwrite");
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
VLOG(4) << "save output file_path: " << filename; VLOG(4) << "save output file_path: " << filename;
PADDLE_ENFORCE_EQ( // get device context from pool
FileExists(filename) && !overwrite, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
false, auto& dev_ctx = *pool.Get(place);
platform::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
filename,
overwrite));
MkDirRecursively(DirName(filename).c_str());
if (input_var->IsType<phi::DenseTensor>()) { if (input_var->IsType<phi::DenseTensor>()) {
SaveLodTensor(ctx, place, input_var, filename); auto& tensor = input_var->Get<phi::DenseTensor>();
SaveKernel<T>(dev_ctx, tensor, filename, save_as_fp16);
} else if (input_var->IsType<phi::SelectedRows>()) { } else if (input_var->IsType<phi::SelectedRows>()) {
SaveSelectedRows(ctx, place, input_var, filename); auto& selectedRows = input_var->Get<phi::SelectedRows>();
SaveSelectedRowsKernel<T>(dev_ctx, selectedRows, filename);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Save operator only supports saving phi::DenseTensor and " "Save operator only supports saving phi::DenseTensor and "
"SelectedRows " "SelectedRows "
"variable, %s has wrong type", "variable, %s has wrong type",
iname)); iname));
} }
} }
void SaveLodTensor(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var,
const std::string &filename) const {
auto &tensor = var->Get<phi::DenseTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto in_dtype = tensor.dtype();
auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type =
phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, in_dtype);
auto out_kernel_type =
phi::KernelKey(place, phi::DataLayout::ALL_LAYOUT, out_dtype);
phi::DenseTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close();
}
void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var,
const std::string &filename) const {
auto &selectedRows = var->Get<phi::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", filename));
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
}
}; };
} // namespace operators } // namespace operators
......
...@@ -109,10 +109,6 @@ cc_library( ...@@ -109,10 +109,6 @@ cc_library(
phi_device_context phi_device_context
SRCS device_context.cc SRCS device_context.cc
DEPS dense_tensor selected_rows) DEPS dense_tensor selected_rows)
cc_library(
serialization
SRCS serialization.cc
DEPS version tensor phi_device_context)
cc_library( cc_library(
custom_kernel custom_kernel
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/core/enforce.h"
// Note: The TensorToStream depends on framework.proto,
// it is difficult to move into phi
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/version.h"
namespace phi {
void SerializeToStream(std::ostream &os,
const DenseTensor &tensor,
const DeviceContext &dev_ctx) {
{ // the 1st field, uint32_t version for DenseTensor
os.write(
reinterpret_cast<const char *>(&paddle::framework::kCurTensorVersion),
sizeof(paddle::framework::kCurTensorVersion));
}
{
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(phi::LoD::value_type::value_type);
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
// the 3st field, Tensor
paddle::framework::TensorToStream(
os, static_cast<DenseTensor>(tensor), dev_ctx);
}
void DeserializeFromStream(std::istream &is,
DenseTensor *tensor,
const DeviceContext &dev_ctx,
const size_t &seek,
const std::vector<int64_t> &shape) {
{
// the 1st field, unit32_t version for DenseTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
phi::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version,
0U,
phi::errors::InvalidArgument(
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
}
// the 3st filed, Tensor
paddle::framework::TensorFromStream(
is, static_cast<DenseTensor *>(tensor), dev_ctx, seek, shape);
}
void DeserializeFromStream(std::istream &is,
DenseTensor *tensor,
const DeviceContext &dev_ctx) {
{
// the 1st field, unit32_t version for DenseTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version),
true,
phi::errors::InvalidArgument(
"Tensor version %u is not supported.", version));
PADDLE_ENFORCE_EQ(
version,
0U,
phi::errors::InvalidArgument(
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
// the 3st filed, Tensor
paddle::framework::TensorFromStream(
is, static_cast<DenseTensor *>(tensor), dev_ctx);
}
void SerializeToStream(std::ostream &os,
const SelectedRows &selected_rows,
const DeviceContext &dev_ctx) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{
// the 2st field, rows information
auto &rows = selected_rows.rows();
uint64_t size = rows.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (uint64_t i = 0; i < size; ++i) {
os.write(reinterpret_cast<const char *>(&rows[i]), sizeof(rows[i]));
}
}
{
// the 3st field, the height of SelectedRows
int64_t height = selected_rows.height();
os.write(reinterpret_cast<const char *>(&height), sizeof(height));
}
// the 4st field, Tensor data
paddle::framework::TensorToStream(os, selected_rows.value(), dev_ctx);
}
void DeserializeFromStream(std::istream &is,
SelectedRows *selected_rows,
const DeviceContext &dev_ctx) {
{
// the 1st field, unit32_t version for SelectedRows
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version,
0U,
phi::errors::InvalidArgument(
"Only version 0 SelectedRows is supported."));
}
{
// the 2st field, rows information
uint64_t size = 0;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
PADDLE_ENFORCE_EQ(
is.good(),
true,
phi::errors::Unavailable("Cannot read the number of rows."));
auto &rows = *selected_rows->mutable_rows();
rows.resize(size);
for (uint64_t i = 0; i < size; ++i) {
is.read(reinterpret_cast<char *>(&rows[i]), sizeof(int64_t));
}
}
{
// the 3st field, the height of the SelectedRows
int64_t height;
is.read(reinterpret_cast<char *>(&height), sizeof(int64_t));
selected_rows->set_height(height);
}
// the 4st field, tensor which contains the data
paddle::framework::TensorFromStream(
is, selected_rows->mutable_value(), dev_ctx);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
/*
* Serialize/Desiralize DenseTensor to std::ostream
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
*/
void SerializeToStream(std::ostream& os,
const DenseTensor& tensor,
const DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is,
DenseTensor* tensor,
const DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is,
DenseTensor* tensor,
const DeviceContext& dev_ctx,
const size_t& seek,
const std::vector<int64_t>& shape);
/*
* Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
*/
void SerializeToStream(std::ostream& os,
const SelectedRows& selected_rows,
const DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is,
SelectedRows* selected_rows,
const DeviceContext& dev_ctx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/load_kernel.h"
#include <fstream>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
DenseTensor* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
if (seek != -1) {
PADDLE_ENFORCE_GE(seek,
0,
phi::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
DeserializeFromStream(fin, out, dev_ctx, seek, shape);
} else {
DeserializeFromStream(fin, out, dev_ctx);
}
auto in_dtype = out->dtype();
auto out_dtype = load_as_fp16 ? DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
CastKernel<T>(dev_ctx, *out, out_dtype, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(load, CPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(load, GPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(load, XPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/save_kernel.h"
#include <fstream>
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
template <typename T, typename Context>
void SaveKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16) {
PADDLE_ENFORCE_EQ(
FileExists(file_path) && !overwrite,
false,
phi::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
file_path,
overwrite));
MkDirRecursively(DirName(file_path).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save variables.", file_path));
auto in_dtype = x.dtype();
auto out_dtype = save_as_fp16 ? DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
auto out = Cast<T>(dev_ctx, x, out_dtype);
SerializeToStream(fout, out, dev_ctx);
} else {
SerializeToStream(fout, x, dev_ctx);
}
fout.close();
}
} // namespace phi
PD_REGISTER_KERNEL(save,
CPU,
ALL_LAYOUT,
phi::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(save,
GPU,
ALL_LAYOUT,
phi::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SaveKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/load_kernel.h"
#include <fstream>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
SelectedRows* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
DeserializeFromStream(fin, out, dev_ctx);
}
} // namespace sr
} // namespace phi
PD_REGISTER_KERNEL(load_sr, CPU, ALL_LAYOUT, phi::sr::LoadKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(load_sr, GPU, ALL_LAYOUT, phi::sr::LoadKernel, float) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
SelectedRows* out);
} // namespace sr
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/selected_rows/save_kernel.h"
#include <fstream>
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void SaveKernel(const Context& dev_ctx,
const SelectedRows& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16) {
PADDLE_ENFORCE_EQ(
FileExists(file_path) && !overwrite,
false,
phi::errors::PreconditionNotMet(
"%s exists!, cannot save to it when overwrite is set to false.",
file_path,
overwrite));
PADDLE_ENFORCE_EQ(save_as_fp16,
false,
phi::errors::Unimplemented(
"SelectedRows is not supported to save as float16."));
MkDirRecursively(DirName(file_path).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save variables.", file_path));
SerializeToStream(fout, x, dev_ctx);
fout.close();
}
} // namespace sr
} // namespace phi
PD_REGISTER_KERNEL(save_sr,
CPU,
ALL_LAYOUT,
phi::sr::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(save_sr,
GPU,
ALL_LAYOUT,
phi::sr::SaveKernel,
float,
double,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void SaveKernel(const Context& dev_ctx,
const SelectedRows& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16);
} // namespace sr
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册