diff --git a/paddle/capi/Arguments.cpp b/paddle/capi/Arguments.cpp index 8b81ec69e60399af86f055d2258276ac06e0b13a..1ec403077e7ea0bc8299e6266167b50ed81c3b08 100644 --- a/paddle/capi/Arguments.cpp +++ b/paddle/capi/Arguments.cpp @@ -90,6 +90,18 @@ paddle_error paddle_arguments_set_ids(paddle_arguments args, return kPD_NO_ERROR; } +paddle_error paddle_arguments_set_frame_shape(paddle_arguments args, + uint64_t ID, + uint64_t frameHeight, + uint64_t frameWidth) { + if (args == nullptr) return kPD_NULLPTR; + auto a = castArg(args); + if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; + a->args[ID].setFrameHeight(frameHeight); + a->args[ID].setFrameWidth(frameWidth); + return kPD_NO_ERROR; +} + paddle_error paddle_arguments_set_sequence_start_pos(paddle_arguments args, uint64_t ID, uint32_t nestedLevel, diff --git a/paddle/capi/arguments.h b/paddle/capi/arguments.h index d71ea26a5d1aff130d974541532fda3b09bf6fe5..ba49d692ad19accdcc84ba8e4c44ab23a1d05ac1 100644 --- a/paddle/capi/arguments.h +++ b/paddle/capi/arguments.h @@ -111,6 +111,19 @@ PD_API paddle_error paddle_arguments_set_ids(paddle_arguments args, uint64_t ID, paddle_ivector ids); +/** + * @brief paddle_arguments_set_frame_shape Set the fram size of one argument + * in array, which index is `ID`. + * @param [in] args arguments array + * @param [in] ID array index + * @param [out] ids integer vector pointer + * @return paddle_error + */ +PD_API paddle_error paddle_arguments_set_frame_shape(paddle_arguments args, + uint64_t ID, + uint64_t frameHeight, + uint64_t frameWidth); + /** * @brief PDArgsSetSequenceStartPos Set sequence start position vector of one * argument in array, which index is `ID`. diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp index 00f76e0152366834eafc22df710cf3d6c7b8471f..e2d2d30ddcd205c26e64243fcf1c13642f19277e 100644 --- a/paddle/capi/gradient_machine.cpp +++ b/paddle/capi/gradient_machine.cpp @@ -68,6 +68,15 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk( return kPD_NO_ERROR; } +paddle_error paddle_gradient_machine_load_parameter_from_buffer( + paddle_gradient_machine machine, const char* buf, uint64_t length) { + auto m = cast(machine); + if (m == nullptr || buf == nullptr || m->machine == nullptr) + return kPD_NULLPTR; + m->machine->loadParameters(buf, length); + return kPD_NO_ERROR; +} + paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine, paddle_arguments inArgs, paddle_arguments outArgs, diff --git a/paddle/capi/gradient_machine.h b/paddle/capi/gradient_machine.h index d7e2dd9bf8037ed474971624d4518160604abe4d..242683905080326c44890cc411895c6951c89a0c 100644 --- a/paddle/capi/gradient_machine.h +++ b/paddle/capi/gradient_machine.h @@ -45,6 +45,15 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference( PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk( paddle_gradient_machine machine, const char* path); +/** + * @brief Load parameter from buffer. + * @param machine Gradient Machine. + * @param buffer containing all parameters. + * @return paddle_error + */ +PD_API paddle_error paddle_gradient_machine_load_parameter_from_buffer( + paddle_gradient_machine machine, const char* buf, uint64_t length); + /** * @brief Forward a gradient machine * @param machine Gradient machine diff --git a/paddle/gserver/gradientmachines/GradientMachine.cpp b/paddle/gserver/gradientmachines/GradientMachine.cpp index b44e4dc202f01956ed21c175aa897ced8e92546b..b7678d9b2f006082037a01654064675865ffc808 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.cpp +++ b/paddle/gserver/gradientmachines/GradientMachine.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "GradientMachine.h" +#include #include #include "paddle/utils/Logging.h" @@ -81,6 +82,48 @@ void GradientMachine::loadParameters(const std::string& dir) { } } +void GradientMachine::loadParameters(const char* buf, uint64_t length) { + LOG(INFO) << "Loading parameter from pre-load buffer"; + + CHECK_NOTNULL(buf); + CHECK_GE(length, static_cast(sizeof(uint64_t))); + + uint64_t numFiles = 0; + memcpy(&numFiles, buf, sizeof(uint64_t)); + uint64_t position = sizeof(uint64_t); + LOG(INFO) << "numFiles: " << numFiles << ", position: " << position; + + std::map offsets; + std::map lengths; + for (uint64_t i = 0; i < numFiles; i++) { + std::string filename(buf + position); + position += filename.size() + 1; + LOG(INFO) << "filename: " << filename << ", position: " << position; + uint64_t size = 0; + memcpy(&size, buf + position, sizeof(uint64_t)); + position += sizeof(uint64_t); + offsets[filename] = const_cast(buf + position); + lengths[filename] = size; + position += size; + CHECK_GE(length, position); + } + + CHECK_GE(offsets.size(), parameters_.size()); + + for (auto& para : parameters_) { + std::string filename = para->getName(); + if (para->isFullSize()) { + if (offsets.end() == offsets.find(filename)) { + para->loadMiss(filename); + } else { + std::istringstream stream( + std::string(offsets[filename], lengths[filename])); + para->load(stream); + } + } + } +} + void GradientMachine::randParameters() { LOG(INFO) << "Initing parameters.."; diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index f9c82a2bef82b4e6bcbf0c73583505d2692f3926..081518a9d2bbdff120b1c01b4d306fc054281595 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -221,6 +221,8 @@ public: void loadParameters(const std::string& dir); + void loadParameters(const char* buf, uint64_t length); + void randParameters(); virtual void getStats(real& cost, int64_t& numProcessed) { diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index cfa80a89365af5111746eec9599d16e37532a9f7..148296d20bdaeece2194727a2111d7fc6cb5ed55 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -24,6 +24,8 @@ limitations under the License. */ #include "paddle/gserver/layers/AgentLayer.h" #include "paddle/utils/Stat.h" +#include + namespace paddle { void parameterInitNN(int paramId, Parameter* para, diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index ebe36d49376882fe4c1013e19dcf71f452b3e501..80dbb73a7dfe7c43cc7d0b63100e7c93539ea56b 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -314,27 +314,31 @@ bool Parameter::save(std::ostream& s) const { /** * Load parameter value from a file */ +bool Parameter::loadMiss(const std::string& filename) { + LOG(INFO) << "missing parameters [" << filename << "] while loading model."; + if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { + LOG(FATAL) << getName() << " missing, not allowed."; + return false; + } + if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) { + LOG(INFO) << getName() << " missing, set to random."; + randomize(); + return true; + } + if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) { + LOG(INFO) << getName() << " missing, set to zero."; + zeroMem(); + return true; + } + LOG(FATAL) << "unsupported load_missing_parameter_strategy: " + << FLAGS_load_missing_parameter_strategy; + return false; +} + bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { - LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { - LOG(FATAL) << getName() << " missing, not allowed."; - return false; - } - if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) { - LOG(INFO) << getName() << " missing, set to random."; - randomize(); - return true; - } - if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) { - LOG(INFO) << getName() << " missing, set to zero."; - zeroMem(); - return true; - } - LOG(FATAL) << "unsupported load_missing_parameter_strategy: " - << FLAGS_load_missing_parameter_strategy; - return false; + loadMiss(filename); } return load(fs); } diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 0bac76f068ec22bec52766b43e331fe109a34188..21932f6b6eeeb32377e2f25a49c44858c2958da2 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -201,6 +201,11 @@ public: */ bool save(std::ostream& s) const; + /** + * Fill parameter when file is missed + */ + bool loadMiss(const std::string& filename); + /** * Load parameter value from a file */