提交 4273b351 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #4473 from NHZlX/fix_merge_model

refine paddle_merge_model
......@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR;
}
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/utils/PythonUtil.h"
DEFINE_string(model_dir, "", "Directory for separated model files");
DEFINE_string(config_file, "", "Config file for the model");
DEFINE_string(model_file, "", "File for merged model file");
using namespace paddle; // NOLINT
......@@ -28,7 +29,8 @@ using namespace std; // NOLINT
int main(int argc, char** argv) {
initMain(argc, argv);
initPython(argc, argv);
string confFile = TrainerConfigHelper::getConfigNameFromPath(FLAGS_model_dir);
string confFile = FLAGS_config_file;
#ifndef PADDLE_WITH_CUDA
FLAGS_use_gpu = false;
#endif
......
......@@ -19,7 +19,7 @@ import "ModelConfig.proto";
package paddle;
message OptimizationConfig {
required int32 batch_size = 3;
optional int32 batch_size = 3 [ default = 1 ];
required string algorithm = 4 [ default = "async_sgd" ];
optional int32 num_batches_per_send_parameter = 5 [ default = 1 ];
optional int32 num_batches_per_get_parameter = 6 [ default = 1 ];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册