提交 7346a23d 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

add api set_params/get_params in Model (#56)

* add api set_params/get_params in Model; add Interface of Network and LayerFunc to solve circular imports; refine parameter_names api of Model

* remove licence in third party code; remove interface of Network and LayerFunc; move get_parameter_pairs and get_parameter_names api to Network

* refine comment

* refine commment
上级 d8449b74
......@@ -103,6 +103,7 @@ def main(argv=None):
first_line = fd.readline()
second_line = fd.readline()
if "COPYRIGHT (C)" in first_line.upper(): continue
if "THIRD PARTY" in first_line.upper(): continue
original_contents = io.open(filename, encoding="utf-8").read()
new_contents = generate_copyright(
COPYRIGHT, lang_type(filename)) + original_contents
......
......@@ -55,13 +55,27 @@ function check_style() {
trap : 0
}
function run_test() {
function run_test_with_gpu() {
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
cmake ..
cat <<EOF
========================================
Running unit tests ...
Running unit tests with GPU...
========================================
EOF
ctest --output-on-failure
}
function run_test_with_cpu() {
export CUDA_VISIBLE_DEVICES=""
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
cmake ..
cat <<EOF
========================================
Running unit tests with CPU...
========================================
EOF
ctest --output-on-failure
......@@ -76,7 +90,8 @@ function main() {
check_style
;;
test)
run_test
run_test_with_gpu
run_test_with_cpu
;;
*)
print_usage
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Third party code
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following code are copied or modified from:
# https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py
import cv2
import gym
......@@ -25,10 +16,6 @@ __all__ = ['AtariPlayer']
ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK = threading.Lock()
"""
The following AtariPlayer are copied or modified from tensorpack/tensorpack:
https://github.com/tensorpack/tensorpack/blob/master/examples/DeepQNetwork/atari.py
"""
class AtariPlayer(gym.Env):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Third party code
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following code are copied or modified from:
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
import gym
import numpy as np
......@@ -19,10 +10,6 @@ from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class MapState(gym.ObservationWrapper):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Third party code
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following code are copied or modified from:
# https://github.com/pat-coady/trpo
import numpy as np
import scipy.signal
__all__ = ['calc_discount_sum_rewards', 'calc_gae', 'Scaler']
"""
The following code are copied or modified from:
https://github.com/pat-coady/trpo
Written by Patrick Coady (pat-coady.github.io)
"""
def calc_discount_sum_rewards(rewards, gamma):
......
......@@ -17,19 +17,21 @@ Base class to define an Algorithm.
import hashlib
import paddle.fluid as fluid
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from parl.layers.layer_wrappers import LayerFunc
from parl.plutils import *
__all__ = ['Network', 'Model']
class Network(object):
"""
A Network is an unordered set of LayerFuncs or Networks.
A Network is a collection of LayerFuncs or Networks.
"""
def sync_params_to(self,
target_net,
gpu_id=0,
gpu_id,
decay=0.0,
share_vars_parallel_executor=None):
"""
......@@ -57,13 +59,10 @@ class Network(object):
assert not target_net is self, "cannot copy between identical networks"
assert isinstance(target_net, Network)
assert self.__class__.__name__ == target_net.__class__.__name__, \
"must be the same class for para syncing!"
"must be the same class for params syncing!"
assert (decay >= 0 and decay <= 1)
# Resolve Circular Imports
from parl.plutils import get_parameter_pairs, fetch_framework_var
param_pairs = get_parameter_pairs(self, target_net)
param_pairs = self._get_parameter_pairs(self, target_net)
self._cached_sync_params_program = fluid.Program()
......@@ -106,15 +105,120 @@ class Network(object):
@property
def parameter_names(self):
""" param_attr names of all parameters in Network,
only parameter created by parl.layers included
only parameter created by parl.layers included.
The order of parameter names will be consistent between
different instances of same parl.Network.
Returns:
list of string, param_attr names of all parameters
"""
# Resolve Circular Imports
from parl.plutils import get_parameter_names
return get_parameter_names(self)
try:
return self._parameter_names
except AttributeError:
self._parameter_names = self._get_parameter_names(self)
return self._parameter_names
def get_params(self):
""" Get numpy arrays of parameters in this Network
Returns:
List of numpy array.
"""
params = []
for param_name in self.parameter_names:
param = fetch_value(param_name)
params.append(param)
return params
def set_params(self, params, gpu_id):
""" Set parameters in this Network with params
Args:
params: List of numpy array.
gpu_id: gpu id where this Network in. (if gpu_id < 0, means in cpu.)
"""
assert len(params) == len(self.parameter_names), \
'size of input params should be same as parameters number of current Network'
for (param_name, param) in list(zip(self.parameter_names, params)):
set_value(param_name, param, gpu_id)
def _get_parameter_names(self, obj):
""" Recursively get parameter names in obj,
Args:
obj (parl.Network/parl.LayerFunc/list/tuple/dict): input object
Returns:
parameter_names (list of string): all parameter names in obj
"""
parameter_names = []
for attr in sorted(obj.__dict__.keys()):
val = getattr(obj, attr)
if isinstance(val, Network):
parameter_names.extend(self._get_parameter_names(val))
elif isinstance(val, LayerFunc):
for attr in val.attr_holder.sorted():
if attr:
parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list):
for x in val:
parameter_names.extend(self._get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(self._get_parameter_names(x))
else:
# for any other type, won't be handled. E.g. set
pass
return parameter_names
def _get_parameter_pairs(self, src, target):
""" Recursively gets parameters in source network and
corresponding parameters in target network.
Args:
src (parl.Network/parl.LayerFunc/list/tuple/dict): source object
target (parl.Network/parl.LayerFunc/list/tuple/dict): target object
Returns:
param_pairs (list of tuple): all parameter names in source network
and corresponding parameter names in
target network.
"""
param_pairs = []
if isinstance(src, Network):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
src_attrs = src.attr_holder.sorted()
target_attrs = target.attr_holder.sorted()
assert len(src_attrs) == len(target_attrs), \
"number of ParamAttr between source layer and target layer should be same."
for (src_attr, target_attr) in zip(src_attrs, target_attrs):
if src_attr:
assert target_attr, "ParamAttr between source layer and target layer is inconsistent."
param_pairs.append((src_attr.name, target_attr.name))
elif isinstance(src, tuple) or isinstance(src, list):
for src_var, target_var in zip(src, target):
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(
self._get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled. E.g. set
pass
return param_pairs
class Model(Network):
......
......@@ -73,7 +73,7 @@ class AgentBaseTest(unittest.TestCase):
self.assertIsNotNone(output_np)
def test_agent_with_cpu(self):
agent = TestAgent(self.algorithm, gpu_id=0)
agent = TestAgent(self.algorithm, gpu_id=-1)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
......
......@@ -71,6 +71,19 @@ class TestModel3(Model):
return out
class TestModel4(Model):
def __init__(self):
self.fc1 = layers.fc(size=256)
self.fc2 = layers.fc(size=128)
self.fc3 = layers.fc(size=1)
def predict(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
out = self.fc3(out)
return out
class ModelBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
......@@ -80,8 +93,10 @@ class ModelBaseTest(unittest.TestCase):
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
else:
place = fluid.CPUPlace
place = fluid.CPUPlace()
self.gpu_id = -1
self.executor = fluid.Executor(place)
def test_network_copy(self):
......@@ -121,6 +136,11 @@ class ModelBaseTest(unittest.TestCase):
set(self.model.parameter_names),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
# Second test for cache parameter_names
self.assertSetEqual(
set(self.model.parameter_names),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
def test_sync_params_in_one_program(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
......@@ -139,7 +159,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[model_output, target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
self.model.sync_params_to(self.target_model)
self.model.sync_params_to(self.target_model, self.gpu_id)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
......@@ -177,7 +197,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
self.model.sync_params_to(self.target_model)
self.model.sync_params_to(self.target_model, self.gpu_id)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
......@@ -245,7 +265,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -278,7 +298,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -302,7 +322,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -335,7 +355,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -359,7 +379,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -393,7 +413,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -417,7 +437,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model2, decay)
self.model.sync_params_to(self.target_model2, decay=decay)
self.model.sync_params_to(self.target_model2, self.gpu_id, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -457,7 +477,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertNotEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
model.sync_params_to(target_model)
model.sync_params_to(target_model, self.gpu_id)
random_obs = np.random.random(size=(N, 100)).astype('float32')
for i in range(N):
......@@ -508,7 +528,7 @@ class ModelBaseTest(unittest.TestCase):
x = np.expand_dims(random_obs[i], axis=0)
self.executor.run(program1, feed={'obs': x})
model.sync_params_to(target_model)
model.sync_params_to(target_model, self.gpu_id)
random_obs = np.random.random(size=(N, 32, 128, 128)).astype('float32')
for i in range(N):
......@@ -520,6 +540,124 @@ class ModelBaseTest(unittest.TestCase):
self.assertEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
def test_get_params(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
self.executor.run(fluid.default_startup_program())
expected_params = []
for param_name in [
'fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b'
]:
expected_params.append(fetch_value(param_name))
params = self.model.get_params()
self.assertEqual(len(params), len(expected_params))
for param in params:
flag = False
for expected_param in expected_params:
if np.sum(param) - np.sum(expected_param) < 1e-5:
flag = True
break
self.assertTrue(flag)
def test_set_params(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
new_params = [x + 1.0 for x in params]
self.model.set_params(new_params, self.gpu_id)
for x, y in list(zip(new_params, self.model.get_params())):
self.assertEqual(np.sum(x), np.sum(y))
def test_set_params_between_different_models(self):
model1 = TestModel4()
model2 = TestModel4()
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model1_output = model1.predict(obs)
model2_output = model2.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model1_output, model2_output])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
# pass parameters of self.model to model2
params = model1.get_params()
model2.set_params(params, self.gpu_id)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model1_output, model2_output])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_set_params_with_wrong_params_num(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
try:
self.model.set_params(params[1:], self.gpu_id)
except:
# expected
return
assert False
def test_set_params_with_wrong_params_shape(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
params.reverse()
self.model.set_params(params, self.gpu_id)
x = np.random.random(size=(1, 4)).astype('float32')
try:
outputs = self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
except:
# expected
return
assert False
if __name__ == '__main__':
unittest.main()
......@@ -40,7 +40,6 @@ from paddle.fluid.executor import _fetch_var
from paddle.fluid.framework import Variable
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Network
from parl.layers.attr_holder import AttrHolder
......@@ -128,22 +127,6 @@ class LayerFunc(object):
return params_names
def check_caller_name():
stack = inspect.stack()
## we trace back to the call stack and make sure Network.__init__ is on the path
called_by_init = False
for s in stack:
try:
the_class = s[0].f_locals["self"].__class__
the_method = s[0].f_code.co_name
if issubclass(the_class, Network) and the_method == "__init__":
called_by_init = True
except:
pass
assert called_by_init, "parl.layers can only be called in Network.__init__()!"
def fc(size,
num_flatten_dims=1,
param_attr=None,
......@@ -156,7 +139,6 @@ def fc(size,
default_name = "fc"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class FC_(LayerFunc):
def __init__(self):
......@@ -187,7 +169,6 @@ def embedding(size,
Return a function that creates a paddle.fluid.layers.embedding.
"""
param_attr = update_attr_name(name, "embedding", param_attr, False)
check_caller_name()
class Embedding_(LayerFunc):
def __init__(self):
......@@ -222,7 +203,6 @@ def dynamic_lstm(size,
default_name = "dynamic_lstm"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class DynamicLstm_(LayerFunc):
def __init__(self):
......@@ -265,7 +245,6 @@ def dynamic_lstmp(size,
default_name = "dynamic_lstmp"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class DynamicLstmp_(LayerFunc):
def __init__(self):
......@@ -303,7 +282,6 @@ def dynamic_gru(size,
default_name = "dynamic_gru"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class DynamicGru_(LayerFunc):
def __init__(self):
......@@ -353,7 +331,6 @@ def sequence_conv(num_filters,
default_name = "sequence_conv"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class SequenceConv_(LayerFunc):
def __init__(self):
......@@ -391,7 +368,6 @@ def conv2d(num_filters,
default_name = "conv2d"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class Conv2D_(LayerFunc):
def __init__(self):
......@@ -432,7 +408,6 @@ def conv2d_transpose(num_filters,
default_name = "conv2d_transpose"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class Conv2DTranspose_(LayerFunc):
def __init__(self):
......@@ -463,7 +438,6 @@ def lstm_unit(forget_bias=0.0, param_attr=None, bias_attr=None, name=None):
default_name = "lstm_unit"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
check_caller_name()
class LstmUnit_(LayerFunc):
def __init__(self):
......@@ -491,7 +465,6 @@ def row_conv(future_context_size, param_attr=None, act=None, name=None):
Return a function that creates a paddle.fluid.layers.row_conv.
"""
param_attr = update_attr_name(name, "row_conv", param_attr, False)
check_caller_name()
class RowConv_(LayerFunc):
def __init__(self):
......@@ -534,7 +507,6 @@ def batch_norm(act=None,
None, False)
moving_variance_attr = update_attr_name(
name, default_name + "_moving_variance", None, False)
check_caller_name()
class BatchNorm_(LayerFunc):
def __init__(self):
......@@ -578,7 +550,6 @@ def create_parameter(shape,
"""
param_attr = update_attr_name(name, "create_parameter", attr, False)
check_caller_name()
class CreateParameter_(LayerFunc):
def __init__(self):
......
......@@ -17,13 +17,8 @@ Common functions of PARL framework
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
from parl.layers.layer_wrappers import LayerFunc
from parl.framework.model_base import Network
__all__ = [
'fetch_framework_var', 'fetch_value', 'get_parameter_pairs',
'get_parameter_names'
]
__all__ = ['fetch_framework_var', 'fetch_value', 'set_value']
def fetch_framework_var(attr_name):
......@@ -57,77 +52,15 @@ def fetch_value(attr_name):
return _fetch_var(attr_name, return_numpy=True)
def get_parameter_pairs(src, target):
""" Recursively get pairs of parameter names between src and target
Args:
src: parl.Network/parl.LayerFunc/list/tuple/set/dict
target: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
param_pairs: list of all tuple(src_param_name, target_param_name, is_bias)
between src and target
"""
param_pairs = []
if isinstance(src, Network):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
src_attrs = src.attr_holder.sorted()
target_attrs = target.attr_holder.sorted()
assert len(src_attrs) == len(target_attrs), \
"number of ParamAttr between source layer and target layer should be same."
for (src_attr, target_attr) in zip(src_attrs, target_attrs):
if src_attr:
assert target_attr, "ParamAttr between source layer and target layer is inconsistent."
param_pairs.append((src_attr.name, target_attr.name))
elif isinstance(src, tuple) or isinstance(src, list) or isinstance(
src, set):
for src_var, target_var in zip(src, target):
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled
pass
return param_pairs
def get_parameter_names(obj):
""" Recursively get parameter names in obj,
mainly used to get parameter names of a parl.Network
def set_value(attr_name, value, gpu_id):
""" Given name of ParamAttr, set numpy value to the parameter in global_scope
Args:
obj: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
parameter_names: list of string, all parameter names in obj
attr_name: ParamAttr name of parameter
value: numpy array
gpu_id: gpu id where the parameter in
"""
parameter_names = []
for attr in obj.__dict__:
val = getattr(obj, attr)
if isinstance(val, Network):
parameter_names.extend(get_parameter_names(val))
elif isinstance(val, LayerFunc):
for attr in val.attr_holder.tolist():
if attr:
parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list) or isinstance(
val, set):
for x in val:
parameter_names.extend(get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(get_parameter_names(x))
else:
# for any other type, won't be handled
pass
return parameter_names
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
var = _fetch_var(attr_name, return_numpy=False)
var.set(value, place)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册