提交 4f1f7e90 编写于 作者: L Liu Yiqun

Delete c-api interface, paddle_gradient_machine_load_parameter_from_buffer, and

related codes in Paddle core.
上级 9dccdd77
...@@ -93,15 +93,6 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk( ...@@ -93,15 +93,6 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk(
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error paddle_gradient_machine_load_parameter_from_buffer(
paddle_gradient_machine machine, const char* buf, uint64_t length) {
auto m = cast(machine);
if (m == nullptr || buf == nullptr || m->machine == nullptr)
return kPD_NULLPTR;
m->machine->loadParameters(buf, length);
return kPD_NO_ERROR;
}
paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine, paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine,
paddle_arguments inArgs, paddle_arguments inArgs,
paddle_arguments outArgs, paddle_arguments outArgs,
......
...@@ -57,15 +57,6 @@ paddle_gradient_machine_create_for_inference_with_parameters( ...@@ -57,15 +57,6 @@ paddle_gradient_machine_create_for_inference_with_parameters(
PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk( PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk(
paddle_gradient_machine machine, const char* path); paddle_gradient_machine machine, const char* path);
/**
* @brief Load parameter from buffer.
* @param machine Gradient Machine.
* @param buffer containing all parameters.
* @return paddle_error
*/
PD_API paddle_error paddle_gradient_machine_load_parameter_from_buffer(
paddle_gradient_machine machine, const char* buf, uint64_t length);
/** /**
* @brief Forward a gradient machine * @brief Forward a gradient machine
* @param machine Gradient machine * @param machine Gradient machine
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "GradientMachine.h" #include "GradientMachine.h"
#include <string.h>
#include <fstream> #include <fstream>
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
...@@ -82,48 +81,6 @@ void GradientMachine::loadParameters(const std::string& dir) { ...@@ -82,48 +81,6 @@ void GradientMachine::loadParameters(const std::string& dir) {
} }
} }
void GradientMachine::loadParameters(const char* buf, uint64_t length) {
LOG(INFO) << "Loading parameter from pre-load buffer";
CHECK_NOTNULL(buf);
CHECK_GE(length, static_cast<uint64_t>(sizeof(uint64_t)));
uint64_t numFiles = 0;
memcpy(&numFiles, buf, sizeof(uint64_t));
uint64_t position = sizeof(uint64_t);
LOG(INFO) << "numFiles: " << numFiles << ", position: " << position;
std::map<std::string, char*> offsets;
std::map<std::string, uint64_t> lengths;
for (uint64_t i = 0; i < numFiles; i++) {
std::string filename(buf + position);
position += filename.size() + 1;
LOG(INFO) << "filename: " << filename << ", position: " << position;
uint64_t size = 0;
memcpy(&size, buf + position, sizeof(uint64_t));
position += sizeof(uint64_t);
offsets[filename] = const_cast<char*>(buf + position);
lengths[filename] = size;
position += size;
CHECK_GE(length, position);
}
CHECK_GE(offsets.size(), parameters_.size());
for (auto& para : parameters_) {
std::string filename = para->getName();
if (para->isFullSize()) {
if (offsets.end() == offsets.find(filename)) {
para->loadMiss(filename);
} else {
std::istringstream stream(
std::string(offsets[filename], lengths[filename]));
para->load(stream);
}
}
}
}
void GradientMachine::randParameters() { void GradientMachine::randParameters() {
LOG(INFO) << "Initing parameters.."; LOG(INFO) << "Initing parameters..";
......
...@@ -221,8 +221,6 @@ public: ...@@ -221,8 +221,6 @@ public:
void loadParameters(const std::string& dir); void loadParameters(const std::string& dir);
void loadParameters(const char* buf, uint64_t length);
void randParameters(); void randParameters();
virtual void getStats(real& cost, int64_t& numProcessed) { virtual void getStats(real& cost, int64_t& numProcessed) {
......
...@@ -314,31 +314,27 @@ bool Parameter::save(std::ostream& s) const { ...@@ -314,31 +314,27 @@ bool Parameter::save(std::ostream& s) const {
/** /**
* Load parameter value from a file * Load parameter value from a file
*/ */
bool Parameter::loadMiss(const std::string& filename) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
}
if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to random.";
randomize();
return true;
}
if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to zero.";
zeroMem();
return true;
}
LOG(FATAL) << "unsupported load_missing_parameter_strategy: "
<< FLAGS_load_missing_parameter_strategy;
return false;
}
bool Parameter::load(const std::string& filename) { bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
loadMiss(filename); LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
}
if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to random.";
randomize();
return true;
}
if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to zero.";
zeroMem();
return true;
}
LOG(FATAL) << "unsupported load_missing_parameter_strategy: "
<< FLAGS_load_missing_parameter_strategy;
return false;
} }
return load(fs); return load(fs);
} }
......
...@@ -201,11 +201,6 @@ public: ...@@ -201,11 +201,6 @@ public:
*/ */
bool save(std::ostream& s) const; bool save(std::ostream& s) const;
/**
* Fill parameter when file is missed
*/
bool loadMiss(const std::string& filename);
/** /**
* Load parameter value from a file * Load parameter value from a file
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册