diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp index f7ad30f3bf6e2ba83c099c0129068d1d77c0d628..b3287552db87d25edbf6e7f3d5e68121df49e9d6 100644 --- a/paddle/capi/gradient_machine.cpp +++ b/paddle/capi/gradient_machine.cpp @@ -93,15 +93,6 @@ paddle_error paddle_gradient_machine_load_parameter_from_disk( return kPD_NO_ERROR; } -paddle_error paddle_gradient_machine_load_parameter_from_buffer( - paddle_gradient_machine machine, const char* buf, uint64_t length) { - auto m = cast(machine); - if (m == nullptr || buf == nullptr || m->machine == nullptr) - return kPD_NULLPTR; - m->machine->loadParameters(buf, length); - return kPD_NO_ERROR; -} - paddle_error paddle_gradient_machine_forward(paddle_gradient_machine machine, paddle_arguments inArgs, paddle_arguments outArgs, diff --git a/paddle/capi/gradient_machine.h b/paddle/capi/gradient_machine.h index 2205e0e23aaeb097021a882570f345a7fa0e5ffa..c613ade5b24efbbf52f21c7ee86dd3189981c5ef 100644 --- a/paddle/capi/gradient_machine.h +++ b/paddle/capi/gradient_machine.h @@ -57,15 +57,6 @@ paddle_gradient_machine_create_for_inference_with_parameters( PD_API paddle_error paddle_gradient_machine_load_parameter_from_disk( paddle_gradient_machine machine, const char* path); -/** - * @brief Load parameter from buffer. - * @param machine Gradient Machine. - * @param buffer containing all parameters. - * @return paddle_error - */ -PD_API paddle_error paddle_gradient_machine_load_parameter_from_buffer( - paddle_gradient_machine machine, const char* buf, uint64_t length); - /** * @brief Forward a gradient machine * @param machine Gradient machine diff --git a/paddle/gserver/gradientmachines/GradientMachine.cpp b/paddle/gserver/gradientmachines/GradientMachine.cpp index b7678d9b2f006082037a01654064675865ffc808..b44e4dc202f01956ed21c175aa897ced8e92546b 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.cpp +++ b/paddle/gserver/gradientmachines/GradientMachine.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "GradientMachine.h" -#include #include #include "paddle/utils/Logging.h" @@ -82,48 +81,6 @@ void GradientMachine::loadParameters(const std::string& dir) { } } -void GradientMachine::loadParameters(const char* buf, uint64_t length) { - LOG(INFO) << "Loading parameter from pre-load buffer"; - - CHECK_NOTNULL(buf); - CHECK_GE(length, static_cast(sizeof(uint64_t))); - - uint64_t numFiles = 0; - memcpy(&numFiles, buf, sizeof(uint64_t)); - uint64_t position = sizeof(uint64_t); - LOG(INFO) << "numFiles: " << numFiles << ", position: " << position; - - std::map offsets; - std::map lengths; - for (uint64_t i = 0; i < numFiles; i++) { - std::string filename(buf + position); - position += filename.size() + 1; - LOG(INFO) << "filename: " << filename << ", position: " << position; - uint64_t size = 0; - memcpy(&size, buf + position, sizeof(uint64_t)); - position += sizeof(uint64_t); - offsets[filename] = const_cast(buf + position); - lengths[filename] = size; - position += size; - CHECK_GE(length, position); - } - - CHECK_GE(offsets.size(), parameters_.size()); - - for (auto& para : parameters_) { - std::string filename = para->getName(); - if (para->isFullSize()) { - if (offsets.end() == offsets.find(filename)) { - para->loadMiss(filename); - } else { - std::istringstream stream( - std::string(offsets[filename], lengths[filename])); - para->load(stream); - } - } - } -} - void GradientMachine::randParameters() { LOG(INFO) << "Initing parameters.."; diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index 081518a9d2bbdff120b1c01b4d306fc054281595..f9c82a2bef82b4e6bcbf0c73583505d2692f3926 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -221,8 +221,6 @@ public: void loadParameters(const std::string& dir); - void loadParameters(const char* buf, uint64_t length); - void randParameters(); virtual void getStats(real& cost, int64_t& numProcessed) { diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index 80dbb73a7dfe7c43cc7d0b63100e7c93539ea56b..ebe36d49376882fe4c1013e19dcf71f452b3e501 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -314,31 +314,27 @@ bool Parameter::save(std::ostream& s) const { /** * Load parameter value from a file */ -bool Parameter::loadMiss(const std::string& filename) { - LOG(INFO) << "missing parameters [" << filename << "] while loading model."; - if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { - LOG(FATAL) << getName() << " missing, not allowed."; - return false; - } - if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) { - LOG(INFO) << getName() << " missing, set to random."; - randomize(); - return true; - } - if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) { - LOG(INFO) << getName() << " missing, set to zero."; - zeroMem(); - return true; - } - LOG(FATAL) << "unsupported load_missing_parameter_strategy: " - << FLAGS_load_missing_parameter_strategy; - return false; -} - bool Parameter::load(const std::string& filename) { std::ifstream fs(filename, std::ios_base::binary); if (!fs) { - loadMiss(filename); + LOG(INFO) << "missing parameters [" << filename << "] while loading model."; + if (kMissParameterFail == FLAGS_load_missing_parameter_strategy) { + LOG(FATAL) << getName() << " missing, not allowed."; + return false; + } + if (kMissParameterRand == FLAGS_load_missing_parameter_strategy) { + LOG(INFO) << getName() << " missing, set to random."; + randomize(); + return true; + } + if (kMissParameterZero == FLAGS_load_missing_parameter_strategy) { + LOG(INFO) << getName() << " missing, set to zero."; + zeroMem(); + return true; + } + LOG(FATAL) << "unsupported load_missing_parameter_strategy: " + << FLAGS_load_missing_parameter_strategy; + return false; } return load(fs); } diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 21932f6b6eeeb32377e2f25a49c44858c2958da2..0bac76f068ec22bec52766b43e331fe109a34188 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -201,11 +201,6 @@ public: */ bool save(std::ostream& s) const; - /** - * Fill parameter when file is missed - */ - bool loadMiss(const std::string& filename); - /** * Load parameter value from a file */