提交 acfdc312 编写于 作者: X xzl

support trainconfig and modelconfig of MergedModel

上级 0dc4b298
...@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters( ...@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize); modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize); is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config; paddle::TrainerConfig config;
paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) { if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
!modelConfig.IsInitialized()) {
return kPD_PROTOBUF_ERROR; return kPD_PROTOBUF_ERROR;
} }
} else {
modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine(); auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create( ptr->machine.reset(paddle::GradientMachine::create(
config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE})); modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters(); std::vector<paddle::ParameterPtr>& parameters = ptr->machine->getParameters();
for (auto& para : parameters) { for (auto& para : parameters) {
para->load(is); para->load(is);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册