提交 c7a247b7 编写于 作者: L Liu Yiqun

Support to load parameters from buffer in c-api.

上级 fb5cd7f8
...@@ -90,6 +90,18 @@ paddle_error paddle_arguments_set_ids(paddle_arguments args, ...@@ -90,6 +90,18 @@ paddle_error paddle_arguments_set_ids(paddle_arguments args,
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error paddle_arguments_set_frame_shape(paddle_arguments args,
uint64_t ID,
uint64_t frameHeight,
uint64_t frameWidth) {
if (args == nullptr) return kPD_NULLPTR;
auto a = castArg(args);
if (ID >= a->args.size()) return kPD_OUT_OF_RANGE;
a->args[ID].setFrameHeight(frameHeight);
a->args[ID].setFrameWidth(frameWidth);
return kPD_NO_ERROR;
}
paddle_error paddle_arguments_set_sequence_start_pos(paddle_arguments args, paddle_error paddle_arguments_set_sequence_start_pos(paddle_arguments args,
uint64_t ID, uint64_t ID,
uint32_t nestedLevel, uint32_t nestedLevel,
......
...@@ -111,6 +111,19 @@ PD_API paddle_error paddle_arguments_set_ids(paddle_arguments args, ...@@ -111,6 +111,19 @@ PD_API paddle_error paddle_arguments_set_ids(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector ids); paddle_ivector ids);
/**
* @brief paddle_arguments_set_frame_shape Set the fram size of one argument
* in array, which index is `ID`.
* @param [in] args arguments array
* @param [in] ID array index
* @param [out] ids integer vector pointer
* @return paddle_error
*/
PD_API paddle_error paddle_arguments_set_frame_shape(paddle_arguments args,
uint64_t ID,
uint64_t frameHeight,
uint64_t frameWidth);
/** /**
* @brief PDArgsSetSequenceStartPos Set sequence start position vector of one * @brief PDArgsSetSequenceStartPos Set sequence start position vector of one
* argument in array, which index is `ID`. * argument in array, which index is `ID`.
......
...@@ -68,6 +68,15 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk( ...@@ -68,6 +68,15 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk(
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error paddle_gradient_machine_load_parameter_from_buffer(
paddle_gradient_machine machine, const char* buf, uint64_t length) {
auto m = cast(machine);
if (m == nullptr || buf == nullptr || m->machine == nullptr)
return kPD_NULLPTR;
m->machine->loadParameters(buf, length);
return kPD_NO_ERROR;
}
paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine, paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine,
paddle_arguments inArgs, paddle_arguments inArgs,
paddle_arguments outArgs, paddle_arguments outArgs,
......
...@@ -45,6 +45,15 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference( ...@@ -45,6 +45,15 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference(
PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk( PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk(
paddle_gradient_machine machine, const char* path); paddle_gradient_machine machine, const char* path);
/**
* @brief Load parameter from buffer.
* @param machine Gradient Machine.
* @param buffer containing all parameters.
* @return paddle_error
*/
PD_API paddle_error paddle_gradient_machine_load_parameter_from_buffer(
paddle_gradient_machine machine, const char* buf, uint64_t length);
/** /**
* @brief Forward a gradient machine * @brief Forward a gradient machine
* @param machine Gradient machine * @param machine Gradient machine
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "GradientMachine.h" #include "GradientMachine.h"
#include <string.h>
#include <fstream> #include <fstream>
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
...@@ -81,6 +82,48 @@ void GradientMachine::loadParameters(const std::string& dir) { ...@@ -81,6 +82,48 @@ void GradientMachine::loadParameters(const std::string& dir) {
} }
} }
void GradientMachine::loadParameters(const char* buf, uint64_t length) {
LOG(INFO) << "Loading parameter from pre-load buffer";
CHECK_NOTNULL(buf);
CHECK_GE(length, static_cast<uint64_t>(sizeof(uint64_t)));
uint64_t numFiles = 0;
memcpy(&numFiles, buf, sizeof(uint64_t));
uint64_t position = sizeof(uint64_t);
LOG(INFO) << "numFiles: " << numFiles << ", position: " << position;
std::map<std::string, char*> offsets;
std::map<std::string, uint64_t> lengths;
for (uint64_t i = 0; i < numFiles; i++) {
std::string filename(buf + position);
position += filename.size() + 1;
LOG(INFO) << "filename: " << filename << ", position: " << position;
uint64_t size = 0;
memcpy(&size, buf + position, sizeof(uint64_t));
position += sizeof(uint64_t);
offsets[filename] = const_cast<char*>(buf + position);
lengths[filename] = size;
position += size;
CHECK_GE(length, position);
}
CHECK_GE(offsets.size(), parameters_.size());
for (auto& para : parameters_) {
std::string filename = para->getName();
if (para->isFullSize()) {
if (offsets.end() == offsets.find(filename)) {
para->loadMiss(filename);
} else {
std::istringstream stream(
std::string(offsets[filename], lengths[filename]));
para->load(stream);
}
}
}
}
void GradientMachine::randParameters() { void GradientMachine::randParameters() {
LOG(INFO) << "Initing parameters.."; LOG(INFO) << "Initing parameters..";
......
...@@ -221,6 +221,8 @@ public: ...@@ -221,6 +221,8 @@ public:
void loadParameters(const std::string& dir); void loadParameters(const std::string& dir);
void loadParameters(const char* buf, uint64_t length);
void randParameters(); void randParameters();
virtual void getStats(real& cost, int64_t& numProcessed) { virtual void getStats(real& cost, int64_t& numProcessed) {
......
...@@ -24,6 +24,8 @@ limitations under the License. */ ...@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/gserver/layers/AgentLayer.h" #include "paddle/gserver/layers/AgentLayer.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include <iostream>
namespace paddle { namespace paddle {
void parameterInitNN(int paramId, void parameterInitNN(int paramId,
Parameter* para, Parameter* para,
......
...@@ -314,27 +314,31 @@ bool Parameter::save(std::ostream& s) const { ...@@ -314,27 +314,31 @@ bool Parameter::save(std::ostream& s) const {
/** /**
* Load parameter value from a file * Load parameter value from a file
*/ */
bool Parameter::loadMiss(const std::string& filename) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model.";
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
}
if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to random.";
randomize();
return true;
}
if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to zero.";
zeroMem();
return true;
}
LOG(FATAL) << "unsupported load_missing_parameter_strategy: "
<< FLAGS_load_missing_parameter_strategy;
return false;
}
bool Parameter::load(const std::string& filename) { bool Parameter::load(const std::string& filename) {
std::ifstream fs(filename, std::ios_base::binary); std::ifstream fs(filename, std::ios_base::binary);
if (!fs) { if (!fs) {
LOG(INFO) << "missing parameters [" << filename << "] while loading model."; loadMiss(filename);
if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) {
LOG(FATAL) << getName() << " missing, not allowed.";
return false;
}
if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to random.";
randomize();
return true;
}
if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) {
LOG(INFO) << getName() << " missing, set to zero.";
zeroMem();
return true;
}
LOG(FATAL) << "unsupported load_missing_parameter_strategy: "
<< FLAGS_load_missing_parameter_strategy;
return false;
} }
return load(fs); return load(fs);
} }
......
...@@ -201,6 +201,11 @@ public: ...@@ -201,6 +201,11 @@ public:
*/ */
bool save(std::ostream& s) const; bool save(std::ostream& s) const;
/**
* Fill parameter when file is missed
*/
bool loadMiss(const std::string& filename);
/** /**
* Load parameter value from a file * Load parameter value from a file
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册