From acfdc312f903e5cfb02843ee82487443ec5e0a92 Mon Sep 17 00:00:00 2001 From: xzl Date: Wed, 25 Oct 2017 17:34:20 +0800 Subject: [PATCH] support trainconfig and modelconfig of MergedModel --- paddle/capi/gradient_machine.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp index 629449bbd49..482b51e8a84 100644 --- a/paddle/capi/gradient_machine.cpp +++ b/paddle/capi/gradient_machine.cpp @@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters( modelConfigProtobuf.resize(modelConfigSize); is.read(&modelConfigProtobuf[0], modelConfigSize); paddle::TrainerConfig config; + paddle::ModelConfig modelConfig; if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) { - return kPD_PROTOBUF_ERROR; + if (!modelConfig.ParseFromString(modelConfigProtobuf) || + !modelConfig.IsInitialized()) { + return kPD_PROTOBUF_ERROR; + } + } else { + modelConfig = config.model_config(); } auto ptr = new paddle::capi::CGradientMachine(); ptr->machine.reset(paddle::GradientMachine::create( - config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE})); + modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE})); std::vector& parameters = ptr->machine->getParameters(); for (auto& para : parameters) { para->load(is); -- GitLab