提交 e11b40c5 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

sync paras in program, fix deepcopy bug, python3 compatibility (#28)

* sync paras in program, fix deepcopy bug, python3 compatibility

* refactor code, add plutil directory, clean import order

* remove old comment

* refine comment

* fix codestyle

* cache sync program, add gputils module, refine model_base unittest

* fix codestyle

* refine sync params cache

* add fetch_value module
上级 942c3c5c
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
import paddle.fluid as fluid import paddle.fluid as fluid
import parl.layers as layers import parl.layers as layers
from parl.framework.model_base import Network, Model
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from parl.framework.model_base import Network, Model
__all__ = ['Algorithm'] __all__ = ['Algorithm']
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
Base class to define an Algorithm. Base class to define an Algorithm.
""" """
import hashlib
import paddle.fluid as fluid
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from parl.utils.utils import has_func
__all__ = ['Network', 'Model'] __all__ = ['Network', 'Model']
...@@ -26,36 +27,64 @@ class Network(object): ...@@ -26,36 +27,64 @@ class Network(object):
A Network is an unordered set of LayerFuncs or Networks. A Network is an unordered set of LayerFuncs or Networks.
""" """
def sync_paras_to(self, target_net): def sync_params_to(self, target_net, gpu_id=0, decay=0.0):
assert not target_net is self, "cannot copy between identical networks" """
assert isinstance(target_net, Network) Args:
assert self.__class__.__name__ == target_net.__class__.__name__, \ target_net: Network object deepcopy from source network
"must be the same class for para syncing!" gpu_id: gpu id of target_net
decay: Float. The decay to use.
for attr in self.__dict__: target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
if not attr in target_net.__dict__: """
continue args_hash_id = hashlib.md5('{}_{}_{}'.format(
val = getattr(self, attr) id(target_net), gpu_id, decay).encode('utf-8')).hexdigest()
target_val = getattr(target_net, attr) has_cached = False
try:
assert type(val) == type(target_val), \ if self._cached_id == args_hash_id:
"[Error]sync_paras_to failed, \ has_cached = True
ensure that the destination model is generated by deep copied from source model" except AttributeError:
has_cached = False
### TODO: sync paras recursively
if has_func(val, 'sync_paras_to'): if not has_cached:
val.sync_paras_to(target_val) # Can not run _cached program, need create a new program
elif isinstance(val, tuple) or isinstance(val, list) or isinstance( self._cached_id = args_hash_id
val, set):
for v, tv in zip(val, target_val): assert not target_net is self, "cannot copy between identical networks"
v.sync_paras_to(tv) assert isinstance(target_net, Network)
elif isinstance(val, dict): assert self.__class__.__name__ == target_net.__class__.__name__, \
for k in val.keys(): "must be the same class for para syncing!"
assert k in target_val assert (decay >= 0 and decay <= 1)
val[k].sync_paras_to(target_val[k])
else: # Resolve Circular Imports
# for any other type, we do not copy from parl.plutils import get_parameter_pairs, fetch_framework_var
pass
param_pairs = get_parameter_pairs(self, target_net)
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
self._cached_fluid_executor = fluid.Executor(place)
self._cached_sync_params_program = fluid.Program()
with fluid.program_guard(self._cached_sync_params_program):
for (src_var_name, target_var_name, is_bias) in param_pairs:
src_var = fetch_framework_var(src_var_name, is_bias)
target_var = fetch_framework_var(target_var_name, is_bias)
fluid.layers.assign(
decay * target_var + (1 - decay) * src_var, target_var)
self._cached_fluid_executor.run(self._cached_sync_params_program)
@property
def parameter_names(self):
""" param_attr names of all parameters in Network,
only parameter created by parl.layers included
Returns:
list of string, param_attr names of all parameters
"""
# Resolve Circular Imports
from parl.plutils import get_parameter_names
return get_parameter_names(self)
class Model(Network): class Model(Network):
...@@ -80,7 +109,7 @@ class Model(Network): ...@@ -80,7 +109,7 @@ class Model(Network):
Note that it's the model structure that is copied from initial actor, Note that it's the model structure that is copied from initial actor,
parameters in initial model havn't been copied to target model. parameters in initial model havn't been copied to target model.
To copy parameters, you must explicitly use sync_paras_to function after the program is initialized. To copy parameters, you must explicitly use sync_params_to function after the program is initialized.
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
import parl.layers as layers import parl.layers as layers
from abc import ABCMeta, abstractmethod
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from parl.layers import common_functions as comf from parl.layers import common_functions as comf
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from abc import ABCMeta, abstractmethod
class PolicyDistribution(object): class PolicyDistribution(object):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from parl.framework.algorithm_base import Algorithm
from copy import deepcopy
import numpy as np
import unittest
import sys
class Value(Model):
def __init__(self, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
self.fc1 = layers.fc(size=256, act='relu')
self.fc2 = layers.fc(size=128, act='relu')
self.fc3 = layers.fc(size=self.act_dim)
def value(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
value = self.fc3(out)
return value
class QLearning(Algorithm):
def __init__(self, critic_model):
self.critic_model = critic_model
self.target_model = deepcopy(critic_model)
def define_predict(self, obs):
self.q_value = self.critic_model.value(obs)
self.q_target_value = self.target_model.value(obs)
class AlgorithmBaseTest(unittest.TestCase):
def test_sync_paras_in_one_program(self):
critic_model = Value(obs_dim=4, act_dim=1)
dqn = QLearning(critic_model)
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
dqn.define_predict(obs)
place = fluid.CUDAPlace(0)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program,
feed={'obs': x},
fetch_list=[dqn.q_value, dqn.q_target_value])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
critic_model.sync_paras_to(dqn.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program,
feed={'obs': x},
fetch_list=[dqn.q_value, dqn.q_target_value])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_sync_paras_among_programs(self):
critic_model = Value(obs_dim=4, act_dim=1)
dqn = QLearning(critic_model)
dqn_2 = deepcopy(dqn)
pred_program = fluid.Program()
pred_program_2 = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
dqn.define_predict(obs)
# algorithm #2
with fluid.program_guard(pred_program_2):
obs_2 = layers.data(name='obs_2', shape=[4], dtype='float32')
dqn_2.define_predict(obs_2)
place = fluid.CUDAPlace(0)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program, feed={'obs': x}, fetch_list=[dqn.q_value])
outputs_2 = executor.run(
pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
dqn.critic_model.sync_paras_to(dqn_2.critic_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program, feed={'obs': x}, fetch_list=[dqn.q_value])
outputs_2 = executor.run(
pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value])
self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten())
if __name__ == '__main__':
unittest.main()
...@@ -12,31 +12,405 @@ ...@@ -12,31 +12,405 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import parl.layers as layers import parl.layers as layers
from parl.framework.model_base import Model
from copy import deepcopy
import unittest import unittest
from copy import deepcopy
from paddle.fluid import ParamAttr
from parl.framework.model_base import Model
from parl.utils import get_gpu_count
from parl.plutils import fetch_value
class Value(Model): class TestModel(Model):
def __init__(self, obs_dim, act_dim): def __init__(self):
self.obs_dim = obs_dim self.fc1 = layers.fc(
self.act_dim = act_dim size=256,
act=None,
param_attr=ParamAttr(name='fc1.w'),
bias_attr=ParamAttr(name='fc1.b'))
self.fc2 = layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc2.w'),
bias_attr=ParamAttr(name='fc2.b'))
self.fc3 = layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc3.w'),
bias_attr=ParamAttr(name='fc3.b'))
self.fc1 = layers.fc(size=256, act='relu') def predict(self, obs):
self.fc2 = layers.fc(size=128, act='relu') out = self.fc1(obs)
out = self.fc2(out)
out = self.fc3(out)
return out
class ModelBaseTest(unittest.TestCase): class ModelBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
self.target_model = deepcopy(self.model)
self.target_model2 = deepcopy(self.model)
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace
self.executor = fluid.Executor(place)
def test_network_copy(self): def test_network_copy(self):
value = Value(obs_dim=2, act_dim=1) self.assertNotEqual(self.model.fc1.param_name,
target_value = deepcopy(value) self.target_model.fc1.param_name)
self.assertNotEqual(value.fc1.param_name, target_value.fc1.param_name) self.assertNotEqual(self.model.fc1.bias_name,
self.assertNotEqual(value.fc1.bias_name, target_value.fc1.bias_name) self.target_model.fc1.bias_name)
self.assertNotEqual(self.model.fc2.param_name,
self.target_model.fc2.param_name)
self.assertNotEqual(self.model.fc2.bias_name,
self.target_model.fc2.bias_name)
self.assertNotEqual(self.model.fc3.param_name,
self.target_model.fc3.param_name)
self.assertNotEqual(self.model.fc3.bias_name,
self.target_model.fc3.bias_name)
def test_network_copy_with_multi_copy(self):
self.assertNotEqual(self.target_model.fc1.param_name,
self.target_model2.fc1.param_name)
self.assertNotEqual(self.target_model.fc1.bias_name,
self.target_model2.fc1.bias_name)
self.assertNotEqual(self.target_model.fc2.param_name,
self.target_model2.fc2.param_name)
self.assertNotEqual(self.target_model.fc2.bias_name,
self.target_model2.fc2.bias_name)
self.assertNotEqual(self.target_model.fc3.param_name,
self.target_model2.fc3.param_name)
self.assertNotEqual(self.target_model.fc3.bias_name,
self.target_model2.fc3.bias_name)
def test_network_parameter_names(self):
self.assertSetEqual(
set(self.model.parameter_names),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
def test_sync_params_in_one_program(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model_output, target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
self.model.sync_params_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model_output, target_model_output])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_sync_params_among_programs(self):
pred_program = fluid.Program()
pred_program_2 = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
# program 2
with fluid.program_guard(pred_program_2):
obs = layers.data(name='obs', shape=[4], dtype='float32')
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
outputs_2 = self.executor.run(
pred_program_2,
feed={'obs': x},
fetch_list=[target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
self.model.sync_params_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
outputs_2 = self.executor.run(
pred_program_2,
feed={'obs': x},
fetch_list=[target_model_output])
self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten())
def _numpy_update(self, target_model, decay):
model_fc1_w = fetch_value('fc1.w')
model_fc1_b = fetch_value('fc1.b')
model_fc2_w = fetch_value('fc2.w')
model_fc2_b = fetch_value('fc2.b')
model_fc3_w = fetch_value('fc3.w')
model_fc3_b = fetch_value('fc3.b')
unique_id = target_model.parameter_names[0].split('_')[-1]
target_model_fc1_w = fetch_value(
'PARL_target_fc1.w_{}'.format(unique_id))
target_model_fc1_b = fetch_value(
'PARL_target_fc1.b_{}'.format(unique_id))
target_model_fc2_w = fetch_value(
'PARL_target_fc2.w_{}'.format(unique_id))
target_model_fc2_b = fetch_value(
'PARL_target_fc2.b_{}'.format(unique_id))
target_model_fc3_w = fetch_value(
'PARL_target_fc3.w_{}'.format(unique_id))
target_model_fc3_b = fetch_value(
'PARL_target_fc3.b_{}'.format(unique_id))
# updated self.target_model parameters value in numpy way
target_model_fc1_w = decay * target_model_fc1_w + (
1 - decay) * model_fc1_w
target_model_fc1_b = decay * target_model_fc1_b + (
1 - decay) * model_fc1_b
target_model_fc2_w = decay * target_model_fc2_w + (
1 - decay) * model_fc2_w
target_model_fc2_b = decay * target_model_fc2_b + (
1 - decay) * model_fc2_b
target_model_fc3_w = decay * target_model_fc3_w + (
1 - decay) * model_fc3_w
target_model_fc3_b = decay * target_model_fc3_b + (
1 - decay) * model_fc3_b
return (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w, target_model_fc3_b)
def test_sync_params_with_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_decay_with_multi_sync(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_different_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.8
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_multi_target_model(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
target_model_output2 = self.target_model2.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.8
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model2, decay)
self.model.sync_params_to(self.target_model2, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output2])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name) self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,15 +15,16 @@ ...@@ -15,15 +15,16 @@
Wrappers for fluid.layers so that the layers can share parameters conveniently. Wrappers for fluid.layers so that the layers can share parameters conveniently.
""" """
from paddle.fluid.executor import _fetch_var import inspect
import paddle.fluid as fluid
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.unique_name as unique_name import paddle.fluid.unique_name as unique_name
import paddle.fluid as fluid
import six
from copy import deepcopy from copy import deepcopy
import inspect from paddle.fluid.executor import _fetch_var
from paddle.fluid.framework import Variable
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Network from parl.framework.model_base import Network
...@@ -62,28 +63,6 @@ class LayerFunc(object): ...@@ -62,28 +63,6 @@ class LayerFunc(object):
self.param_attr = param_attr self.param_attr = param_attr
self.bias_attr = bias_attr self.bias_attr = bias_attr
def sync_paras_to(self, target_layer, gpu_id=0):
"""
Copy the paras from self to a target layer
"""
## isinstance can handle subclass
assert isinstance(target_layer, LayerFunc)
src_attrs = [self.param_attr, self.bias_attr]
target_attrs = [target_layer.param_attr, target_layer.bias_attr]
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
for i, attrs in enumerate(zip(src_attrs, target_attrs)):
src_attr, target_attr = attrs
assert (src_attr and target_attr) \
or (not src_attr and not target_attr)
if not src_attr:
continue
src_var = _fetch_var(src_attr.name)
target_var = _fetch_var(target_attr.name, return_numpy=False)
target_var.set(src_var, place)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__ cls = self.__class__
## __new__ won't init the class, we need to do that ourselves ## __new__ won't init the class, we need to do that ourselves
...@@ -92,15 +71,14 @@ class LayerFunc(object): ...@@ -92,15 +71,14 @@ class LayerFunc(object):
memo[id(self)] = copied memo[id(self)] = copied
## first copy all content ## first copy all content
for k, v in self.__dict__.iteritems(): for k, v in six.iteritems(self.__dict__):
setattr(copied, k, deepcopy(v, memo)) setattr(copied, k, deepcopy(v, memo))
## then we need to create new para names for self.param_attr and self.bias_attr ## then we need to create new para names for self.param_attr and self.bias_attr
def create_new_para_name(attr): def create_new_para_name(attr):
if attr: if attr:
assert attr.name, "attr should have a name already!" assert attr.name, "attr should have a name already!"
## remove the last number id but keep the name key name_key = 'PARL_target_' + attr.name
name_key = "_".join(attr.name.split("_")[:-1])
attr.name = unique_name.generate(name_key) attr.name = unique_name.generate(name_key)
create_new_para_name(copied.param_attr) create_new_para_name(copied.param_attr)
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import parl.layers as layers import parl.layers as layers
import unittest
from parl.framework.model_base import Network from parl.framework.model_base import Network
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import numpy as np
import paddle.fluid as fluid
import parl.layers as layers import parl.layers as layers
import unittest
from parl.framework.model_base import Network from parl.framework.model_base import Network
import paddle.fluid as fluid
import numpy as np
class MyNetWork(Network): class MyNetWork(Network):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common functions of PARL framework
"""
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
from parl.layers.layer_wrappers import LayerFunc
from parl.framework.model_base import Network
__all__ = [
'fetch_framework_var', 'fetch_value', 'get_parameter_pairs',
'get_parameter_names'
]
def fetch_framework_var(attr_name, is_bias):
""" Fetch framework variable according given attr_name.
Return a new reusing variable through create_parameter way
Args:
attr_name: string, attr name of parameter
is_bias: bool, decide whether the parameter is bias
Returns:
framework_var: framework.Varialbe
"""
scope = fluid.executor.global_scope()
core_var = scope.find_var(attr_name)
shape = core_var.get_tensor().shape()
framework_var = fluid.layers.create_parameter(
shape=shape,
dtype='float32',
attr=fluid.ParamAttr(name=attr_name),
is_bias=is_bias)
return framework_var
def fetch_value(attr_name):
""" Given name of ParamAttr, fetch numpy value of the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
Returns:
numpy.ndarray
"""
return _fetch_var(attr_name, return_numpy=True)
def get_parameter_pairs(src, target):
""" Recursively get pairs of parameter names between src and target
Args:
src: parl.Network/parl.LayerFunc/list/tuple/set/dict
target: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
param_pairs: list of all tuple(src_param_name, target_param_name, is_bias)
between src and target
"""
param_pairs = []
if isinstance(src, Network):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
param_pairs.append((src.param_attr.name, target.param_attr.name,
False))
if src.bias_attr:
param_pairs.append((src.bias_attr.name, target.bias_attr.name,
True))
elif isinstance(src, tuple) or isinstance(src, list) or isinstance(
src, set):
for src_var, target_var in zip(src, target):
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled
pass
return param_pairs
def get_parameter_names(obj):
""" Recursively get parameter names in obj,
mainly used to get parameter names of a parl.Network
Args:
obj: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
parameter_names: list of string, all parameter names in obj
"""
parameter_names = []
for attr in obj.__dict__:
val = getattr(obj, attr)
if isinstance(val, Network):
parameter_names.extend(get_parameter_names(val))
elif isinstance(val, LayerFunc):
parameter_names.append(val.param_name)
if val.bias_name is not None:
parameter_names.append(val.bias_name)
elif isinstance(val, tuple) or isinstance(val, list) or isinstance(
val, set):
for x in val:
parameter_names.extend(get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(get_parameter_names(x))
else:
# for any other type, won't be handled
pass
return parameter_names
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from parl.utils.utils import * from parl.utils.utils import *
from parl.utils.gputils import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from parl.utils import logger
__all__ = ['get_gpu_count']
def get_gpu_count():
""" get avaliable gpu count
Returns:
gpu_count: int
"""
gpu_count = 0
env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env_cuda_devices is not None:
assert isinstance(env_cuda_devices, str)
gpu_count = len(env_cuda_devices.split(','))
logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
else:
try:
gpu_count = str(subprocess.check_output(["nvidia-smi",
"-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.warn(e.message)
gpu_count = 0
return gpu_count
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import errno
import logging import logging
import os import os
import errno
import os.path import os.path
from termcolor import colored
import sys import sys
from termcolor import colored
__all__ = ['set_dir', 'get_dir', 'set_level'] __all__ = ['set_dir', 'get_dir', 'set_level']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册