未验证 提交 e7985729 编写于 作者: C Chen Weihang 提交者: GitHub

Add all_parameters api for Program (#22180), test=release/1.7 (#22894)

上级 1ccdaa4d
......@@ -4524,6 +4524,65 @@ class Program(object):
for each_var in list(each_block.vars.values()):
yield each_var
@dygraph_not_support
def all_parameters(self):
"""
Get all :ref:`api_guide_parameter_en` from this Program. A list object is returned.
Returns:
list[ :ref:`api_guide_parameter_en` ]: The list contians all parameters in this program.
Examples:
.. code-block:: python
import paddle.fluid as fluid
program = fluid.default_main_program()
data = fluid.data(name='x', shape=[None, 13], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
for param in program.all_parameters():
print(param)
# Here will print all parameters in current program, in this example,
# the result is like:
#
# name: "fc_0.w_0"
# type {
# type: LOD_TENSOR
# lod_tensor {
# tensor {
# data_type: FP32
# dims: 13
# dims: 10
# }
# }
# }
# persistable: true
#
# name: "fc_0.b_0"
# type {
# type: LOD_TENSOR
# lod_tensor {
# tensor {
# data_type: FP32
# dims: 10
# }
# }
# }
# persistable: true
#
# Here print(param) will print out all the properties of a parameter,
# including name, type and persistable, you can access to specific
# property of a parameter, such as param.name, param.type
"""
parameters = []
for each_block in self.blocks:
parameters.extend(each_block.all_parameters())
return parameters
@six.add_metaclass(ParameterMetaClass)
class Parameter(Variable):
......
......@@ -132,6 +132,19 @@ class TestProgram(unittest.TestCase):
for i in range(len(no_read_ops)):
self.assertEqual(no_read_ops[i].type, keep_read_ops[i + 2].type)
def test_program_all_parameters(self):
program = fluid.default_main_program()
data = fluid.data(name='x', shape=[None, 13], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
# NOTE: here the parameters are fc_0.w_0 and fc_0.b_0
param_list = program.all_parameters()
self.assertEqual(len(param_list), 2)
self.assertEqual(param_list[0].name, "fc_0.w_0")
self.assertEqual(param_list[1].name, "fc_0.b_0")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册