未验证 提交 171df5b5 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #16303 from junjun315/checkpoint

for Checkpoint save and load
......@@ -11,89 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/load_combine_op.h"
namespace paddle {
namespace operators {
class LoadCombineOp : public framework::OperatorBase {
class LoadCombineOp : public framework::OperatorWithKernel {
public:
LoadCombineOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto model_from_memory = Attr<bool>("model_from_memory");
auto out_var_names = Outputs("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(scope, place, &fin, load_as_fp16, out_var_names);
}
}
void LoadParamsFromBuffer(
const framework::Scope &scope, const platform::Place &place,
std::istream *buffer, bool load_as_fp16,
const std::vector<std::string> &out_var_names) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < out_var_names.size(); i++) {
auto *out_var = scope.FindVar(out_var_names[i]);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_names[i]);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::proto::VarType::FP32, ctx.GetPlace());
return kt;
}
};
......@@ -124,21 +62,30 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
LoadCombine Operator.
LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
LoadCombine operator loads LoDTensor variables from a file, which could be
loaded in memory already. The file should contain one or more LoDTensors
serialized using the SaveCombine operator. The
LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used
LoadCombine operator applies a deserialization strategy to appropriately load
the LodTensors, and this strategy complements the serialization strategy used
in the SaveCombine operator. Hence, the LoadCombine operator is tightly coupled
with the SaveCombine operator, and can only deserialize one or more LoDTensors
with the SaveCombine operator, and can only deserialize one or more LoDTensors
that were saved using the SaveCombine operator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(load_combine, ops::LoadCombineOp,
ops::LoadCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto model_from_memory = ctx.Attr<bool>("model_from_memory");
auto &out_var_names = ctx.Outputs("Out");
PADDLE_ENFORCE_GT(
static_cast<int>(out_var_names.size()), 0,
"The number of output variables should be greater than 0.");
if (!model_from_memory) {
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin),
"Cannot open file %s for load_combine op", filename);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
} else {
PADDLE_ENFORCE(!filename.empty(), "Cannot load file from memory");
std::stringstream fin(filename, std::ios::in | std::ios::binary);
LoadParamsFromBuffer(ctx, place, &fin, load_as_fp16, out_var_names);
}
}
void LoadParamsFromBuffer(
const framework::ExecutionContext &context, const platform::Place &place,
std::istream *buffer, bool load_as_fp16,
const std::vector<std::string> &out_var_names) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto out_vars = context.MultiOutputVar("Out");
for (size_t i = 0; i < out_var_names.size(); i++) {
PADDLE_ENFORCE(out_vars[i] != nullptr,
"Output variable %s cannot be found", out_var_names[i]);
auto *tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
// Error checking
PADDLE_ENFORCE(static_cast<bool>(*buffer), "Cannot read more");
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, dev_ctx);
auto in_dtype = tensor->type();
auto out_dtype =
load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
out_vars[i]->Clear();
tensor = out_vars[i]->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
buffer->peek();
PADDLE_ENFORCE(buffer->eof(),
"You are not allowed to load partial data via "
"load_combine_op, use load_op instead.");
}
};
} // namespace operators
} // namespace paddle
......@@ -11,89 +11,26 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include "paddle/fluid/operators/load_op.h"
namespace paddle {
namespace operators {
class LoadOp : public framework::OperatorBase {
class LoadOp : public framework::OperatorWithKernel {
public:
LoadOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
using framework::OperatorWithKernel::OperatorWithKernel;
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr,
"Output variable %s cannot be found in scope %p",
out_var_name, &scope);
void InferShape(framework::InferShapeContext *ctx) const override {}
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
framework::proto::VarType::FP32, platform::CPUPlace());
return kt;
}
};
......@@ -116,8 +53,15 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"file.");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <string>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
filename);
auto out_var_name = ctx.Outputs("Out").data();
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found ",
out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable cannot be found ");
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var,
const framework::ExecutionContext &ctx) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto in_dtype = tensor->type();
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
};
} // namespace operators
} // namespace paddle
......@@ -12,87 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <sstream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
#include <string>
#include "paddle/fluid/operators/save_combine_op.h"
namespace paddle {
namespace operators {
class SaveCombineOp : public framework::OperatorBase {
class SaveCombineOp : public framework::OperatorWithKernel {
public:
SaveCombineOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16");
bool is_present = FileExists(filename);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
using framework::OperatorWithKernel::OperatorWithKernel;
for (size_t i = 0; i < inp_var_names.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
auto &tensor = var->Get<framework::LoDTensor>();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = tensor.type();
auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
}
fout.close();
}
void InferShape(framework::InferShapeContext *ctx) const override {}
};
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
......@@ -105,7 +36,7 @@ class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
SaveCombine operator
This operator will serialize and write a list of input LoDTensor variables
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");
AddAttr<bool>("overwrite",
......@@ -134,3 +65,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(save_combine, ops::SaveCombineOp,
ops::SaveCombineOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SaveCombineOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
bool is_present = FileExists(filename);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, cannot save_combine to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto &inp_var_names = ctx.Inputs("X");
auto &inp_vars = ctx.MultiInputVar("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE(inp_vars[i] != nullptr,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(inp_vars[i]->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
auto &tensor = inp_vars[i]->Get<framework::LoDTensor>();
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
auto in_dtype = tensor.type();
auto out_dtype =
save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
}
fout.close();
}
};
} // namespace operators
} // namespace paddle
......@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP(save_combine);
USE_NO_KERNEL_OP(load_combine);
USE_CPU_ONLY_OP(save_combine);
USE_CPU_ONLY_OP(load_combine);
template <typename T, typename U>
T* CreateForSaveCombineOp(int x, int y, const std::vector<int>& lod_info,
......
......@@ -16,8 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
USE_NO_KERNEL_OP(save);
USE_NO_KERNEL_OP(load);
USE_CPU_ONLY_OP(save);
USE_CPU_ONLY_OP(load);
TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope;
......
......@@ -15,118 +15,24 @@ limitations under the License. */
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/operators/save_op.h"
namespace paddle {
namespace operators {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
class SaveOp : public framework::OperatorBase {
class SaveOp : public framework::OperatorWithKernel {
public:
SaveOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(scope, place, var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
using framework::OperatorWithKernel::OperatorWithKernel;
void SaveLodTensor(const platform::Place &place,
framework::Variable *var) const {
auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type();
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close();
}
void InferShape(framework::InferShapeContext *ctx) const override {}
void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const {
auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
PADDLE_ENFORCE(
lt_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows");
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
......@@ -154,14 +60,20 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
"The \"file_path\" where the variable will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
AddOutput(LOOKUP_TABLE_PATH,
"(string)"
"for pserver: The \"kLookupTablePath\" where checkpoint notify "
"to save lookup table variables"
" to directory specified.")
.AsDispensable();
}
};
class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output(LOOKUP_TABLE_PATH).front();
ctx->SetType(out_var_name, framework::proto::VarType::RAW);
auto var_type = framework::proto::VarType::RAW;
ctx->SetType(LOOKUP_TABLE_PATH, var_type);
}
};
......@@ -169,11 +81,18 @@ class SaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
ops::SaveOpShapeInference);
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker,
ops::SaveOpVarTypeInference, ops::SaveOpShapeInference);
REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <fstream>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
template <typename DeviceContext, typename T>
class SaveOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
auto *input_var = ctx.InputVar("X");
auto iname = ctx.Inputs("X").data();
PADDLE_ENFORCE(input_var != nullptr, "Cannot find variable %s for save_op",
iname);
if (input_var->IsType<framework::LoDTensor>()) {
SaveLodTensor(ctx, place, input_var);
} else if (input_var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(ctx, place, input_var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
void SaveLodTensor(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
auto filename = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
filename, overwrite);
}
MkDirRecursively(DirName(filename).c_str());
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto save_as_fp16 = ctx.Attr<bool>("save_as_fp16");
auto in_dtype = tensor.type();
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor out;
framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out);
// copy LoD info to the new tensor
out.set_lod(tensor.lod());
framework::SerializeToStream(fout, out, dev_ctx);
} else {
framework::SerializeToStream(fout, tensor, dev_ctx);
}
fout.close();
}
void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
PADDLE_ENFORCE(
out_put_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows");
auto *lt_var = out_put_var->GetMutable<std::string>();
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename, std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
}
};
} // namespace operators
} // namespace paddle
......@@ -644,10 +644,9 @@ class Operator(object):
outputs={"Out": [var1]})
"""
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'save_combine', 'load_combine', 'ncclInit', 'select',
'checkpoint_notify', 'gen_nccl_id'
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id'
}
def __init__(self,
......
......@@ -29,9 +29,13 @@ from .tracer import *
from . import profiler
from .profiler import *
from . import checkpoint
from .checkpoint import *
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
__all__ += nn.__all__
__all__ += tracer.__all__
__all__ += profiler.__all__
__all__ += checkpoint.__all__
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import collections
from .. import core
from ..framework import Variable, default_main_program
__all__ = ['save_persistables', 'load_persistables']
def save_persistables(vardict, dirname, filename=None):
"""
This function filters out all variables in layer.parameters from the
give `layer` and then trys to load these variables from the folder
`dirname` or the file `filename`.
Use the `dirname` to specify the folder where persistable variables were
saved. If variables were saved in separate files, set `filename` None;
if all variables were saved in a single file, use `filename` to specify
the file name.
Args:
vardict(dict of Parameters): The parameters will
be saved. If it is None, nothing
will be deal.
dirname(str): The directory path.
filename(str|None): The file which saved all variables. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
Examples:
.. code-block:: python
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
x = to_variable(x_data)
y = to_variable(y_data)
init_hidden = to_variable(init_hidden_data)
init_cell = to_variable(init_cell_data)
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
param_path = "./my_paddle_model"
fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict(), dirname=param_path,
layer=ptb_model)
"""
if isinstance(vardict, collections.OrderedDict):
_save_var_to_file(vardict, dirname, filename)
def load_persistables(vardict, dirname, filename=None):
"""
This function trys to load persistable variables from the folder
`dirname` or the file `filename`.
Use the `dirname` to specify the folder where persistable variables were
saved. If variables were saved in separate files, set `filename` None;
if all variables were saved in a single file, use `filename` to specify
the file name.
Args:
vardict(dict of Parameters): The parameters will be loaded.
dirname(str): The directory path.
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
saved in differnet files, set it to None.
Default: None
Returns:
dict: The parameter-dict resumed from file
Examples:
.. code-block:: python
my_layer = layer(fluid.imperative.Layer)
param_path = "./my_paddle_model"
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.parameters(), param_path)
param_1 = param_dict['PtbModel_0.w_1']
or:
my_layer = layer(fluid.imperative.Layer)
param_path = "./my_paddle_model"
filename = "model.file"
param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict(), param_path,
filename=filename)
param_1 = param_dict['PtbModel_0.w_1']
"""
if isinstance(vardict, collections.OrderedDict):
return _load_var_from_file(vardict, dirname, filename)
return {}
def _save_var_to_file(stat_dict, file_dir, file_name):
save_block = default_main_program().global_block()
save_var_map = {}
for each_var in stat_dict.items():
save_var_map[each_var.name] = each_var
if file_name is None:
save_block.append_op(
type='save',
inputs={'X': [each_var]},
outputs={},
attrs={'file_path': os.path.join(file_dir, each_var.name)})
if file_name is not None:
save_var_list = []
for name in sorted(save_var_map.keys()):
save_var_list.append(save_var_map[name])
save_block.append_op(
type='save_combine',
inputs={'X': save_var_list},
outputs={},
attrs={'file_path': os.path.join(file_dir, file_name)})
def _load_var_from_file(stat_dict, file_dir, file_name):
load_block = default_main_program().global_block()
load_var_map = {}
for each_var in stat_dict.items():
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if file_name is None:
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [new_var]},
attrs={'file_path': os.path.join(file_dir, each_var.name)})
load_var_map[new_var.name] = new_var
if file_name is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(file_dir, file_name)})
for res_var in load_var_list:
load_var_map[res_var.name] = res_var
return load_var_map
def _clone_var_in_block_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
......@@ -212,6 +212,34 @@ class Layer(core.Layer):
else:
object.__delattr__(self, name)
def state_dict(self, destination=None, prefix='', include_sublayers=True):
if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[prefix + name] = data
if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
if layer_item is not None:
destination_temp = destination.copy()
destination_temp.update(
layer_item.state_dict(destination_temp, prefix +
layer_name + ".",
include_sublayers))
destination = destination_temp
return destination
def load_dict(self, stat_dict, include_sublayers=True):
for name, item in self.__dict__.get('_parameters', None).items():
if item.name in stat_dict:
self.__setattr__(name, stat_dict[item.name])
if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
if layer_item is not None:
layer_item.load_dict(stat_dict)
class PyLayer(core.PyLayer):
"""Layers composed of user-defined python codes."""
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from paddle.fluid.imperative.base import to_variable
class SimpleImgConvPool(fluid.imperative.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.imperative.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x
class TestImperativeCheckpoint(unittest.TestCase):
def save_load_persistables(self):
seed = 90
epoch_num = 1
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
dy_param_init_value = {}
step = 0
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(128, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label._stop_gradient = True
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
dy_out = avg_loss._numpy()
avg_loss._backward()
sgd.minimize(avg_loss)
fluid.imperative.save_persistables(mnist, "save_dir")
mnist.clear_gradients()
for param in mnist.parameters():
dy_param_init_value[param.name] = param._numpy()
mnist.load_dict(
fluid.imperative.load_persistables(mnist, "save_dir"))
restore = mnist.parameters()
self.assertEqual(len(dy_param_init_value), len(restore))
for value in restore:
self.assertTrue(
np.allclose(value, dy_param_init_value[value.name]))
self.assertTrue(np.isfinite(value.all()))
self.assertFalse(np.isnan(value.any()))
step += 1
if step > 20:
break
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册