未验证 提交 db412585 编写于 作者: S Shibo Tao 提交者: GitHub

add API serialize_program, serialize_persistables, save_to_file,...

add API serialize_program, serialize_persistables, save_to_file, deserialize_program, deserialize_persistables, load_from_file. (#29034)
上级 14013a2e
......@@ -281,7 +281,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.",
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
......@@ -307,7 +308,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
PADDLE_ENFORCE_EQ(
version, 0U,
platform::errors::InvalidArgument(
"Tensor version %u is not supported, only version 0 is supported.",
"Deserialize to tensor failed, maybe the loaded file is "
"not a paddle model(expected file format: 0, but %u found).",
version));
}
{
......
......@@ -226,8 +226,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
'y': tensor_y},
fetch_list=[avg_cost])
self.assertRaises(ValueError, paddle.static.save_inference_model,
None, ['x', 'y'], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model, None,
['x', 'y'], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model,
MODEL_DIR + "/", [x, y], [avg_cost], exe)
self.assertRaises(ValueError, paddle.static.save_inference_model,
......@@ -251,7 +251,8 @@ class TestSaveInferenceModelNew(unittest.TestCase):
MODEL_DIR + "_isdir", [x, y], [avg_cost], exe)
os.rmdir(params_path)
paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost], exe)
paddle.static.io.save_inference_model(MODEL_DIR, [x, y], [avg_cost],
exe)
self.assertTrue(os.path.exists(MODEL_DIR + ".pdmodel"))
self.assertTrue(os.path.exists(MODEL_DIR + ".pdiparams"))
......@@ -263,20 +264,34 @@ class TestSaveInferenceModelNew(unittest.TestCase):
six.moves.reload_module(executor) # reload to build a new scope
self.assertRaises(ValueError, paddle.static.load_inference_model,
None, exe)
self.assertRaises(ValueError, paddle.static.load_inference_model, None,
exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR + "/", exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
[MODEL_DIR], exe)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR, exe, pserver_endpoints=None)
self.assertRaises(ValueError, paddle.static.load_inference_model,
MODEL_DIR, exe, unsupported_param=None)
self.assertRaises((TypeError, ValueError), paddle.static.load_inference_model,
None, exe, model_filename="illegal", params_filename="illegal")
model = InferModel(paddle.static.io.load_inference_model(MODEL_DIR, exe))
self.assertRaises(
ValueError,
paddle.static.load_inference_model,
MODEL_DIR,
exe,
pserver_endpoints=None)
self.assertRaises(
ValueError,
paddle.static.load_inference_model,
MODEL_DIR,
exe,
unsupported_param=None)
self.assertRaises(
(TypeError, ValueError),
paddle.static.load_inference_model,
None,
exe,
model_filename="illegal",
params_filename="illegal")
model = InferModel(
paddle.static.io.load_inference_model(MODEL_DIR, exe))
outs = exe.run(model.program,
feed={
......@@ -289,7 +304,57 @@ class TestSaveInferenceModelNew(unittest.TestCase):
self.assertEqual(model.feed_var_names, ["x", "y"])
self.assertEqual(len(model.fetch_vars), 1)
self.assertEqual(expected, actual)
# test save_to_file content type should be bytes
self.assertRaises(ValueError, paddle.static.io.save_to_file, '', 123)
# test _get_valid_program
self.assertRaises(TypeError, paddle.static.io._get_valid_program, 0)
p = Program()
cp = CompiledProgram(p)
paddle.static.io._get_valid_program(cp)
self.assertTrue(paddle.static.io._get_valid_program(cp) is p)
cp._program = None
self.assertRaises(TypeError, paddle.static.io._get_valid_program, cp)
def test_serialize_program_and_persistables(self):
init_program = fluid.default_startup_program()
program = fluid.default_main_program()
# fake program without feed/fetch
with program_guard(program, init_program):
x = layers.data(name='x', shape=[2], dtype='float32')
y = layers.data(name='y', shape=[1], dtype='float32')
y_predict = layers.fc(input=x, size=1, act=None)
cost = layers.square_error_cost(input=y_predict, label=y)
avg_cost = layers.mean(cost)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost, init_program)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
tensor_x = np.array([[1, 1], [1, 2], [5, 2]]).astype("float32")
tensor_y = np.array([[-2], [-3], [-7]]).astype("float32")
for i in six.moves.xrange(3):
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
# test if return type of serialize_program is bytes
res1 = paddle.static.io.serialize_program([x, y], [avg_cost])
self.assertTrue(isinstance(res1, bytes))
# test if return type of serialize_persistables is bytes
res2 = paddle.static.io.serialize_persistables([x, y], [avg_cost], exe)
self.assertTrue(isinstance(res2, bytes))
# test if variables in program is empty
res = paddle.static.io._serialize_persistables(Program(), None)
self.assertEqual(res, None)
self.assertRaises(TypeError, paddle.static.io.deserialize_persistables,
None, None, None)
class TestLoadInferenceModelError(unittest.TestCase):
......
......@@ -18,28 +18,43 @@ import errno
import inspect
import logging
import os
import warnings
import six
import numpy as np
import paddle
from paddle.fluid import core, Variable, CompiledProgram, program_guard, default_main_program, Program
from paddle.fluid.framework import static_only
from paddle.fluid import layers
from paddle.fluid.io import _get_valid_program, save_vars, _save_distributed_persistables
from paddle.fluid.io import prepend_feed_ops, append_fetch_ops, save_persistables
from paddle.fluid.io import load_persistables, _endpoints_replacement
from paddle.fluid import (
core,
Variable,
CompiledProgram,
default_main_program,
Program,
layers,
unique_name,
program_guard, )
from paddle.fluid.io import prepend_feed_ops, append_fetch_ops
from paddle.fluid.framework import static_only, Parameter
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.log_helper import get_logger
__all__ = [
'save_inference_model',
'load_inference_model',
'serialize_program',
'serialize_persistables',
'save_to_file',
'deserialize_program',
'deserialize_persistables',
'load_from_file',
]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _check_args(caller, args, supported_args=[], deprecated_args=[]):
def _check_args(caller, args, supported_args=None, deprecated_args=None):
supported_args = [] if supported_args is None else supported_args
deprecated_args = [] if deprecated_args is None else deprecated_args
for arg in args:
if arg in deprecated_args:
raise ValueError(
......@@ -51,6 +66,319 @@ def _check_args(caller, args, supported_args=[], deprecated_args=[]):
format(caller, arg, supported_args))
def _check_vars(name, var_list):
if not isinstance(var_list, list):
var_list = [var_list]
if not var_list or not all([isinstance(var, Variable) for var in var_list]):
raise ValueError(
"'{}' should be a Variable or a list of Variable.".format(name))
def _normalize_path_prefix(path_prefix):
"""
convert path_prefix to absolute path.
"""
if not isinstance(path_prefix, six.string_types):
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
return path_prefix
def _get_valid_program(program=None):
"""
return default main program if program is None.
"""
if program is None:
program = default_main_program()
elif isinstance(program, CompiledProgram):
program = program._program
if program is None:
raise TypeError(
"The type of input program is invalid, expected tyep is Program, but received None"
)
warnings.warn(
"The input is a CompiledProgram, this is not recommended.")
if not isinstance(program, Program):
raise TypeError(
"The type of input program is invalid, expected type is fluid.Program, but received %s"
% type(program))
return program
def _clone_var_in_block(block, var):
assert isinstance(var, Variable)
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def _normalize_program(program, feed_vars, fetch_vars):
"""
optimize program according feed_vars and fetch_vars.
"""
# remind users to set auc_states to 0 if auc op were found.
for op in program.global_block().ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn("Be sure that you have set auc states to 0 "
"before saving inference model.")
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(program):
uniq_fetch_vars = []
for i, var in enumerate(fetch_vars):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_fetch_vars.append(var)
fetch_vars = uniq_fetch_vars
# serialize program
copy_program = program.clone()
global_block = copy_program.global_block()
remove_op_idx = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
remove_op_idx.append(i)
for idx in remove_op_idx[::-1]:
global_block._remove_op(idx)
copy_program.desc.flush()
feed_var_names = [var.name for var in feed_vars]
copy_program = copy_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars)
copy_program = copy_program._inference_optimize(prune_read_op=True)
fetch_var_names = [var.name for var in fetch_vars]
prepend_feed_ops(copy_program, feed_var_names)
append_fetch_ops(copy_program, fetch_var_names)
copy_program.desc._set_version()
return copy_program
def is_persistable(var):
"""
Check whether the given variable is persistable.
Args:
var(Variable): The variable to be checked.
Returns:
bool: True if the given `var` is persistable
False if not.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
param = fluid.default_main_program().global_block().var('fc.b')
res = fluid.io.is_persistable(param)
"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
@static_only
def serialize_program(feed_vars, fetch_vars):
"""
:api_attr: Static Graph
Serialize default main program according to feed_vars and fetch_vars.
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
bytes: serialized program.
Raises:
ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize the default main program to bytes.
serialized_program = paddle.static.serialize_program([image], [predict])
# deserialize bytes to program
deserialized_program = paddle.static.deserialize_program(serialized_program)
"""
# verify feed_vars
_check_vars('feed_vars', feed_vars)
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)
def _serialize_program(program):
"""
serialize given program to bytes.
"""
return program.desc.serialize_to_string()
@static_only
def serialize_persistables(feed_vars, fetch_vars, executor):
"""
:api_attr: Static Graph
Serialize parameters using given executor and default main program according to feed_vars and fetch_vars.
Args:
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
bytes: serialized program.
Raises:
ValueError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
ValueError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# serialize parameters to bytes.
serialized_params = paddle.static.serialize_persistables([image], [predict], exe)
# deserialize bytes to parameters.
main_program = paddle.static.default_main_program()
deserialized_params = paddle.static.deserialize_persistables(main_program, serialized_params, exe)
"""
# verify feed_vars
_check_vars('feed_vars', feed_vars)
# verify fetch_vars
_check_vars('fetch_vars', fetch_vars)
program = _get_valid_program()
program = _normalize_program(program, feed_vars, fetch_vars)
return _serialize_persistables(program, executor)
def _serialize_persistables(program, executor):
"""
Serialize parameters using given program and executor.
"""
vars_ = list(filter(is_persistable, program.list_vars()))
# warn if no variable found in model
if len(vars_) == 0:
warnings.warn("no variable in your model, please ensure there are any "
"variables in your model to save")
return None
# create a new program and clone persitable vars to it
save_program = Program()
save_block = save_program.global_block()
save_var_map = {}
for var in vars_:
if var.type != core.VarDesc.VarType.RAW:
var_copy = _clone_var_in_block(save_block, var)
save_var_map[var_copy.name] = var
# create in_vars and out_var, then append a save_combine op to save_program
in_vars = []
for name in sorted(save_var_map.keys()):
in_vars.append(save_var_map[name])
out_var_name = unique_name.generate("out_var")
out_var = save_block.create_var(
type=core.VarDesc.VarType.RAW, name=out_var_name)
out_var.desc.set_persistable(True)
save_block.append_op(
type='save_combine',
inputs={'X': in_vars},
outputs={'Y': out_var},
attrs={'file_path': '',
'save_to_memory': True})
# run save_program to save vars
# NOTE(zhiqiu): save op will add variable kLookupTablePath to save_program.desc,
# which leads to diff between save_program and its desc. Call _sync_with_cpp
# to keep consistency.
save_program._sync_with_cpp()
executor.run(save_program)
# return serialized bytes in out_var
return global_scope().find_var(out_var_name).get_bytes()
def save_to_file(path, content):
"""
Save content to given path.
Args:
path(str): Path to write content to.
content(bytes): Content to write.
Returns:
None
"""
if not isinstance(content, bytes):
raise ValueError("'content' type should be bytes.")
with open(path, "wb") as f:
f.write(content)
@static_only
def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
"""
......@@ -106,13 +434,9 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
# and parameters are going to be saved in file "./infer_model.pdiparams".
"""
# check path_prefix, set model_path and params_path
if not isinstance(path_prefix, six.string_types):
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
path_prefix = _normalize_path_prefix(path_prefix)
try:
# mkdir may conflict if pserver and trainer are running on the same machine
dirname = os.path.dirname(path_prefix)
......@@ -128,74 +452,118 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor):
raise ValueError("'{}' is an existing directory.".format(params_path))
# verify feed_vars
if not isinstance(feed_vars, list):
feed_vars = [feed_vars]
if not feed_vars or not all(
[isinstance(var, Variable) for var in feed_vars]):
raise ValueError(
"'feed_vars' should be a Variable or a list of Variable.")
_check_vars('feed_vars', feed_vars)
# verify fetch_vars
if not isinstance(fetch_vars, list):
fetch_vars = [fetch_vars]
if not fetch_vars or not all(
[isinstance(var, Variable) for var in fetch_vars]):
raise ValueError(
"'fetch_vars' should be a Variable or a list of Variable.")
_check_vars('fetch_vars', fetch_vars)
main_program = _get_valid_program()
# remind users to set auc_states to 0 if auc op were found.
for op in main_program.global_block().ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"Be sure that you have set auc states to 0 before saving inference model."
)
break
program = _get_valid_program()
program = _normalize_program(program, feed_vars, fetch_vars)
# serialize and save program
program_bytes = _serialize_program(program)
save_to_file(model_path, program_bytes)
# serialize and save params
params_bytes = _serialize_persistables(program, executor)
save_to_file(params_path, params_bytes)
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with program_guard(main_program):
uniq_fetch_vars = []
for i, var in enumerate(fetch_vars):
var = layers.scale(
var, 1., name="save_infer_model/scale_{}".format(i))
uniq_fetch_vars.append(var)
fetch_vars = uniq_fetch_vars
# save model
origin_program = main_program.clone()
main_program = main_program.clone()
global_block = main_program.global_block()
remove_op_idx = []
for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch":
remove_op_idx.append(i)
for idx in remove_op_idx[::-1]:
global_block._remove_op(idx)
main_program.desc.flush()
@static_only
def deserialize_program(data):
"""
:api_attr: Static Graph
feed_var_names = [var.name for var in feed_vars]
main_program = main_program._prune_with_input(
feeded_var_names=feed_var_names, targets=fetch_vars)
main_program = main_program._inference_optimize(prune_read_op=True)
fetch_var_names = [var.name for var in fetch_vars]
prepend_feed_ops(main_program, feed_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_path, "wb") as f:
f.write(main_program.desc.serialize_to_string())
main_program._copy_dist_param_info_from(origin_program)
Deserialize given data to a program.
# save params
dirname = os.path.dirname(params_path)
basename = os.path.basename(params_path)
save_persistables(executor, dirname, main_program, basename)
Args:
data(bytes): serialized program.
Returns:
Program: deserialized program.
"""
program = Program.parse_from_string(data)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
return program
@static_only
def deserialize_persistables(program, data, executor):
"""
:api_attr: Static Graph
Deserialize given data to parameters according to given program and executor.
Args:
program(Program): program that contains parameter names (to deserialize).
data(bytes): serialized parameters.
executor(Executor): executor used to run load op.
Returns:
Program: deserialized program.
"""
if not isinstance(program, Program):
raise TypeError(
"program type must be `fluid.Program`, but received `%s`" %
type(program))
# load params to a tmp program
load_program = Program()
load_block = load_program.global_block()
vars_ = list(filter(is_persistable, program.list_vars()))
origin_shape_map = {}
load_var_map = {}
check_vars = []
sparse_vars = []
for var in vars_:
assert isinstance(var, Variable)
if var.type == core.VarDesc.VarType.RAW:
continue
if isinstance(var, Parameter):
origin_shape_map[var.name] = tuple(var.desc.get_shape())
if var.type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_vars.append(var)
continue
var_copy = _clone_var_in_block(load_block, var)
check_vars.append(var)
load_var_map[var_copy.name] = var_copy
# append load_combine op to load parameters,
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
load_block.append_op(
type='load_combine',
inputs={},
outputs={"Out": load_var_list},
# if load from memory, file_path is data
attrs={'file_path': data,
'model_from_memory': True})
executor.run(load_program)
# check var shape
for var in check_vars:
if not isinstance(var, Parameter):
continue
var_tmp = paddle.fluid.global_scope().find_var(var.name)
assert var_tmp != None, "can't not find var: " + var.name
new_shape = (np.array(var_tmp.get_tensor())).shape
assert var.name in origin_shape_map, var.name + " MUST in var list."
origin_shape = origin_shape_map.get(var.name)
if new_shape != origin_shape:
raise RuntimeError(
"Shape mismatch, program needs a parameter with shape ({}), "
"but the loaded parameter ('{}') has a shape of ({}).".format(
origin_shape, var.name, new_shape))
def load_from_file(path):
"""
Load file in binary mode.
Args:
path(str): Path of an existed file.
Returns:
bytes: Content of file.
"""
with open(path, 'rb') as f:
data = f.read()
return data
@static_only
......@@ -277,18 +645,13 @@ def load_inference_model(path_prefix, executor, **configs):
if params_filename is None:
raise ValueError(
"params_filename cannot be None when path_prefix is None.")
load_dirname = path_prefix
program_desc_str = model_filename
load_dirname = ''
program_bytes = model_filename
params_filename = params_filename
# load from file
else:
# check and norm path_prefix
if not isinstance(path_prefix, six.string_types):
raise ValueError("'path_prefix' should be a string.")
if path_prefix.endswith("/"):
raise ValueError("'path_prefix' should not be a directory")
path_prefix = os.path.normpath(path_prefix)
path_prefix = os.path.abspath(path_prefix)
path_prefix = _normalize_path_prefix(path_prefix)
# set model_path and params_path in new way,
# path_prefix represents a file path without suffix in this case.
......@@ -319,17 +682,17 @@ def load_inference_model(path_prefix, executor, **configs):
_logger.warning("The old way to load inference model is deprecated."
" model path: {}, params path: {}".format(
model_path, params_path))
with open(model_path, "rb") as f:
program_desc_str = f.read()
program_bytes = load_from_file(model_path)
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)
program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
# Binary data also need versioning.
load_persistables(executor, load_dirname, program, params_filename)
# deserialize bytes to program
program = deserialize_program(program_bytes)
# load params data
params_path = os.path.join(load_dirname, params_filename)
params_bytes = load_from_file(params_path)
# deserialize bytes to params
deserialize_persistables(program, params_bytes, executor)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册