未验证 提交 b15c6755 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #7421 from emailweixu/fetch_var

helper functions fetch_var and get_var
......@@ -17,7 +17,9 @@ import contextlib
from framework import Program, default_main_program
from . import core
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
__all__ = [
'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var'
]
g_scope = core.Scope()
......@@ -146,6 +148,35 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
return fetch_count > 0
def fetch_var(name, scope=None, return_numpy=True):
"""
Fetch the value of the variable with the given name from the given scope
Args:
name(str): name of the variable. Typically, only persistable variables
can be found in the scope used for running the program.
scope(core.Scope|None): scope object. It should be the scope where
you pass to Executor.run() when running your program.
If None, global_scope() will be used.
return_numpy(bool): whether convert the tensor to numpy.ndarray
Returns:
LodTensor|numpy.ndarray
"""
assert isinstance(name, str)
if scope is None:
scope = global_scope()
assert isinstance(scope, core.Scope)
var = global_scope().find_var(name)
assert var is not None, (
"Cannot find " + name + " in scope. Perhaps you need to make the"
" variable persistable by using var.persistable = True in your"
" program.")
tensor = var.get_tensor()
if return_numpy:
tensor = as_numpy(tensor)
return tensor
class Executor(object):
def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple):
......
......@@ -31,6 +31,7 @@ __all__ = [
'program_guard',
'switch_startup_program',
'switch_main_program',
'get_var',
]
EMPTY_VAR_NAME = core.kEmptyVarName()
......@@ -1123,3 +1124,22 @@ def program_guard(main_program, startup_program=None):
switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
def get_var(name, program=None):
"""
Get a variable by name from the global block of a program
Args:
name(str): name of the variable
program(Program|None): program object.
If None, default_global_program() will be used.
Returns:
Variable
"""
if program is None:
program = default_main_program()
assert isinstance(name, str)
assert isinstance(name, Program)
return program.global_block().var(name)
......@@ -35,13 +35,15 @@ __all__ = [
]
def create_tensor(dtype, name=None):
def create_tensor(dtype, name=None, persistable=False):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
return helper.create_variable(
name=helper.name, dtype=dtype, persistable=persistable)
def create_parameter(shape,
dtype,
name=None,
attr=None,
is_bias=False,
default_initializer=None):
......@@ -62,7 +64,7 @@ def create_parameter(shape,
"""
helper = LayerHelper("create_parameter", **locals())
if attr is None:
attr = ParamAttr()
attr = ParamAttr(name=name)
return helper.create_parameter(attr, shape, dtype, is_bias,
default_initializer)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
import op_test
import numpy
import unittest
class TestFetchVar(op_test.OpTest):
def test_fetch_var(self):
val = numpy.array([1, 3, 5]).astype(numpy.int32)
x = layers.create_tensor(dtype="int32", persistable=True, name="x")
layers.assign(input=val, output=x)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_main_program(), feed={}, fetch_list=[])
fetched_x = fluid.fetch_var("x")
self.assertTrue(
numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册