提交 a4fd3756 编写于 作者: T tangwei12

bug fix

上级 f9d4b9da
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <streambuf>
#include <string> #include <string>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
...@@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) { ...@@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) {
file_path.append(file_path); file_path.append(file_path);
file_path.append("/"); file_path.append("/");
file_path.append(file); file_path.append(file);
return full_path; return file_path;
}
static bool IsNumber(const std::string &s) {
std::string::const_iterator it = s.begin();
while (it != s.end() && std::isdigit(*it)) ++it;
return !s.empty() && it == s.end();
} }
static void LoadInputVars(const framework::Scope &scope, static void LoadInputVars(const framework::Scope &scope,
...@@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope, ...@@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope,
"Cannot find variable %s for save_combine_op", "Cannot find variable %s for save_combine_op",
inp_var_names[i]); inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type", "LoadCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]); inp_var_names[i]);
std::string var_file = GenePath(dir, inp_var_names[i]); std::string var_file = GenePath(dir, inp_var_names[i]);
...@@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope, ...@@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope,
static void LoadStringArgv(const framework::Scope &scope, static void LoadStringArgv(const framework::Scope &scope,
const platform::Place &place, const platform::Place &place,
const std::string &argv, const std::string &dir) { const std::vector<std::string> &argv,
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); const std::string &dir) {
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < argv.size(); i++) { for (size_t i = 0; i < argv.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]); auto *var = scope.FindVar(argv[i]);
std::string *var_str = var->GetMutable<std::string>(); std::string *var_str = var->GetMutable<std::string>();
std::string var_file = GenePath(dir, argv[i]);
std::string var_file = GenePath(dir, argv);
std::ifstream fin(var_file); std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file); var_file);
std::getline(fin, var_str); std::getline(fin, *var_str);
fin.close(); fin.close();
VLOG(3) << " load String argv: " << argv << " value is: " << var_str; VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str;
} }
} }
...@@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string serial_num = Attr<std::string>("Serial"); std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name); auto *serial_var = scope.FindVar(serial_var_name);
PADDLE_ENFORCE(serial_var != nullptr,
if (serial_var == nullptr) { "Cannot find variable %s for checkpoint_load_op",
*serial_var = scope.Var(serial_var_name); serial_var_name);
auto *serial_tmp = serial_var->GetMutable<std::string>();
serial_tmp->append("0");
}
auto *serial_num = serial_var->GetMutable<std::string>(); auto *serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER serial_num = serial_num_attr;
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num; << " value: " << serial_num;
std::string success = GenePath(dir, serial_num); std::string success = GenePath(dir, serial_num->c_str());
VLOG(3) << "Load checkpoint from dir: " << success; VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS); success = GenePath(success, SUCCESS);
bool is_present = FileExists(success); bool is_present = FileExists(success);
...@@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase {
auto inp_var_names = Inputs("X"); auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0, PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0"); "The number of input variables should be greater than 0");
LoadInputVars(scope, place, &inp_var_names); LoadInputVars(scope, place, inp_var_names, dir);
VLOG(3) << "Ready to load string argv to scope"; // VLOG(3) << "Ready to load string argv to scope";
auto argv = Inputs("Argv"); // auto argv = Output("Argv");
LoadStringArgv(scope, place, &argv, &dir); // LoadStringArgv(scope, place, argv, dir);
} }
}; };
...@@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X", "X",
"(vector) Input LoDTensors that need to be saved together in a file.") "(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable(); .AsDuplicable();
AddInput( AddOutput(
"Argv", "Argv",
"(vector) Input LoDTensors that need to be saved together in a file.") "(vector) Input LoDTensors that need to be saved together in a file.");
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
CheckpointLoad operator CheckpointLoad operator
This operator will serialize and write a list of input LoDTensor variables This operator will serialize and write a list of input LoDTensor variables
to a file on disk. to a file on disk.
)DOC"); )DOC");
...@@ -177,10 +182,32 @@ to a file on disk. ...@@ -177,10 +182,32 @@ to a file on disk.
} }
}; };
class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Argv").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class CheckpointLoadOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
ops::CheckpointLoadOpProtoMaker); paddle::framework::EmptyGradOpMaker,
ops::CheckpointLoadOpProtoMaker,
ops::CheckpointLoadOpVarTypeInference,
ops::CheckpointLoadOpShapeInference);
// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
// ops::CheckpointLoadOpProtoMaker);
...@@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) { ...@@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs.insert({"dir", std::string("ckpt")}); attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp( auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, attrs); "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place); save_op->Run(scope, place);
} }
...@@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) { ...@@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUPlace place; paddle::platform::CPUPlace place;
scope.Var("test_var"); auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({3, 10});
paddle::framework::LoD expect_lod;
expect_lod.resize(1);
expect_lod[0].push_back(0);
expect_lod[0].push_back(1);
expect_lod[0].push_back(2);
expect_lod[0].push_back(3);
tensor->set_lod(expect_lod);
float* expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(paddle::platform::float16(i));
}
scope.Var("SERIAL_NUMBER");
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")}); attrs.insert({"dir", std::string("ckpt")});
attrs.insert({"Serial", std::string("SERIAL_NUMBER")});
auto load_op = paddle::framework::OpRegistry::CreateOp( auto load_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, "checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs);
attrs);
load_op->Run(scope, place); load_op->Run(scope, place);
} }
...@@ -33,12 +33,18 @@ constexpr char kSEP = '/'; ...@@ -33,12 +33,18 @@ constexpr char kSEP = '/';
const char SUCCESS[] = "_SUCCESS"; const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER"; const char SERIAL_VAR[] = "SERIAL_NUMBER";
static bool IsNumber(const std::string &s) {
std::string::const_iterator it = s.begin();
while (it != s.end() && std::isdigit(*it)) ++it;
return !s.empty() && it == s.end();
}
static std::string GenePath(const std::string &dir, const std::string &file) { static std::string GenePath(const std::string &dir, const std::string &file) {
std::string file_path; std::string file_path;
file_path.append(file_path); file_path.append(dir);
file_path.append("/"); file_path.append("/");
file_path.append(file); file_path.append(file);
return full_path; return file_path;
} }
static bool FileExists(const std::string &filepath) { static bool FileExists(const std::string &filepath) {
...@@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto dir = Attr<std::string>("dir"); auto ck_dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
std::string serial_var_name = std::string(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name); auto *serial_num =
scope.FindVar(serial_var_name)->GetMutable<std::string>();
if (serial_var == nullptr) { VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
*serial_var = scope.Var(serial_var_name);
*serial_tmp = serial_var->GetMutable<std::string>();
serial_tmp->append("0");
}
auto *serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER
<< " value: " << serial_num; << " value: " << serial_num;
auto *serial_num = serial_var->GetMutable<std::string>(); if (!IsNumber(serial_num)) {
serial_num->append("1"); serial_num = "0";
}
dir = GenePath(dir, serial_num); std::string dir = GenePath(ck_dir, serial_num->c_str());
VLOG(1) << "CheckpointSaveOp current dir: " << dir;
bool is_present = FileExists(dir); bool is_present = FileExists(dir);
if (is_present && !overwrite) { if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
overwrite); overwrite);
} }
MkDirRecursively(dir.c_str()); MkDirRecursively(dir.c_str());
...@@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
CheckpointSave operator CheckpointSave operator
This operator will serialize and write a list of input LoDTensor variables This operator will serialize and write a list of input LoDTensor variables
to a file on disk. to a file on disk.
)DOC"); )DOC");
AddAttr<bool>("overwrite", AddAttr<bool>("overwrite",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册