提交 ab95661b 编写于 作者: J jingqinghe

add program_saver api

上级 b6506624
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
from .fl_job import FLCompileTimeJob
......@@ -198,6 +199,27 @@ class JobGenerator(object):
local_job.set_strategy(fl_strategy)
local_job.save(output)
def save_program(self, program_path, loss):
if not os.path.exists(program_path):
os.makedirs(program_path)
main_program_str = fluid.default_main_program(
).desc.serialize_to_string()
startup_program_str = fluid.default_startup_program(
).desc.serialize_to_string()
params = fluid.default_main_program().global_block().all_parameters()
para_info = []
for pa in params:
para_info.append(pa.name)
with open(program_path + '/para_info', 'w') as fout:
for item in para_info:
fout.write("%s\n" % item)
with open(program_path + '/startup_program', "wb") as fout:
fout.write(startup_program_str)
with open(program_path + '/main_program', "wb") as fout:
fout.write(main_program_str)
with open(program_path + '/loss_name', 'w') as fout:
fout.write(loss.name)
def generate_fl_job_from_program(self, strategy, endpoints, worker_num,
program_input, output):
local_job = FLCompileTimeJob()
......
......@@ -15,6 +15,7 @@
import os
import json
import paddle.fluid as fluid
from paddle_fl.core.master.job_generator import JobGenerator
input = fluid.layers.data(name='input', shape=[1, 28, 28], dtype="float32")
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
......@@ -28,27 +29,7 @@ place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
def save_program(program_path):
if not os.path.exists(program_path):
os.makedirs(program_path)
main_program_str = fluid.default_main_program().desc.serialize_to_string()
startup_program_str = fluid.default_startup_program(
).desc.serialize_to_string()
params = fluid.default_main_program().global_block().all_parameters()
para_info = []
for pa in params:
para_info.append(pa.name)
with open(program_path + '/para_info', 'w') as fout:
for item in para_info:
fout.write("%s\n" % item)
with open(program_path + '/startup_program', "wb") as fout:
fout.write(startup_program_str)
with open(program_path + '/main_program', "wb") as fout:
fout.write(main_program_str)
with open(program_path + '/loss_name', 'w') as fout:
fout.write(avg_cost.name)
job_generator = JobGenerator()
program_path = './load_file'
save_program(program_path)
job_generator.save_program(program_path, avg_cost)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册