提交 a4fd3756 编写于 作者: T tangwei12

bug fix

上级 f9d4b9da
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <streambuf>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
......@@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) {
file_path.append(file_path);
file_path.append("/");
file_path.append(file);
return full_path;
return file_path;
}
static bool IsNumber(const std::string &s) {
std::string::const_iterator it = s.begin();
while (it != s.end() && std::isdigit(*it)) ++it;
return !s.empty() && it == s.end();
}
static void LoadInputVars(const framework::Scope &scope,
......@@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope,
"Cannot find variable %s for save_combine_op",
inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
"LoadCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
std::string var_file = GenePath(dir, inp_var_names[i]);
......@@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope,
static void LoadStringArgv(const framework::Scope &scope,
const platform::Place &place,
const std::string &argv, const std::string &dir) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
const std::vector<std::string> &argv,
const std::string &dir) {
for (size_t i = 0; i < argv.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
auto *var = scope.FindVar(argv[i]);
std::string *var_str = var->GetMutable<std::string>();
std::string var_file = GenePath(dir, argv);
std::string var_file = GenePath(dir, argv[i]);
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
std::getline(fin, var_str);
std::getline(fin, *var_str);
fin.close();
VLOG(3) << " load String argv: " << argv << " value is: " << var_str;
VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str;
}
}
......@@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
std::string serial_num = Attr<std::string>("Serial");
std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
auto *serial_tmp = serial_var->GetMutable<std::string>();
serial_tmp->append("0");
}
PADDLE_ENFORCE(serial_var != nullptr,
"Cannot find variable %s for checkpoint_load_op",
serial_var_name);
auto *serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
serial_num = serial_num_attr;
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num;
std::string success = GenePath(dir, serial_num);
std::string success = GenePath(dir, serial_num->c_str());
VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS);
bool is_present = FileExists(success);
......@@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase {
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
LoadInputVars(scope, place, &inp_var_names);
LoadInputVars(scope, place, inp_var_names, dir);
VLOG(3) << "Ready to load string argv to scope";
auto argv = Inputs("Argv");
LoadStringArgv(scope, place, &argv, &dir);
// VLOG(3) << "Ready to load string argv to scope";
// auto argv = Output("Argv");
// LoadStringArgv(scope, place, argv, dir);
}
};
......@@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddInput(
AddOutput(
"Argv",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
"(vector) Input LoDTensors that need to be saved together in a file.");
AddComment(R"DOC(
CheckpointLoad operator
This operator will serialize and write a list of input LoDTensor variables
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");
......@@ -177,10 +182,32 @@ to a file on disk.
}
};
class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Argv").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class CheckpointLoadOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
ops::CheckpointLoadOpProtoMaker);
paddle::framework::EmptyGradOpMaker,
ops::CheckpointLoadOpProtoMaker,
ops::CheckpointLoadOpVarTypeInference,
ops::CheckpointLoadOpShapeInference);
// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp,
// ops::CheckpointLoadOpProtoMaker);
......@@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, attrs);
"checkpoint_save", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place);
}
......@@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
scope.Var("test_var");
auto var = scope.Var("test_var");
auto tensor = var->GetMutable<paddle::framework::LoDTensor>();
tensor->Resize({3, 10});
paddle::framework::LoD expect_lod;
expect_lod.resize(1);
expect_lod[0].push_back(0);
expect_lod[0].push_back(1);
expect_lod[0].push_back(2);
expect_lod[0].push_back(3);
tensor->set_lod(expect_lod);
float* expect = tensor->mutable_data<float>(place);
for (int64_t i = 0; i < tensor->numel(); ++i) {
expect[i] = static_cast<float>(paddle::platform::float16(i));
}
scope.Var("SERIAL_NUMBER");
paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});
attrs.insert({"Serial", std::string("SERIAL_NUMBER")});
auto load_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
attrs);
"checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs);
load_op->Run(scope, place);
}
......@@ -33,12 +33,18 @@ constexpr char kSEP = '/';
const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";
static bool IsNumber(const std::string &s) {
std::string::const_iterator it = s.begin();
while (it != s.end() && std::isdigit(*it)) ++it;
return !s.empty() && it == s.end();
}
static std::string GenePath(const std::string &dir, const std::string &file) {
std::string file_path;
file_path.append(file_path);
file_path.append(dir);
file_path.append("/");
file_path.append(file);
return full_path;
return file_path;
}
static bool FileExists(const std::string &filepath) {
......@@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto dir = Attr<std::string>("dir");
auto ck_dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite");
std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
*serial_tmp = serial_var->GetMutable<std::string>();
serial_tmp->append("0");
}
auto *serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER
auto *serial_num =
scope.FindVar(serial_var_name)->GetMutable<std::string>();
VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
<< " value: " << serial_num;
auto *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("1");
if (!IsNumber(serial_num)) {
serial_num = "0";
}
dir = GenePath(dir, serial_num);
std::string dir = GenePath(ck_dir, serial_num->c_str());
VLOG(1) << "CheckpointSaveOp current dir: " << dir;
bool is_present = FileExists(dir);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
overwrite);
}
MkDirRecursively(dir.c_str());
......@@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
CheckpointSave operator
This operator will serialize and write a list of input LoDTensor variables
This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");
AddAttr<bool>("overwrite",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册