提交 e11b40c5 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

sync paras in program, fix deepcopy bug, python3 compatibility (#28)

* sync paras in program, fix deepcopy bug, python3 compatibility

* refactor code, add plutil directory, clean import order

* remove old comment

* refine comment

* fix codestyle

* cache sync program, add gputils module, refine model_base unittest

* fix codestyle

* refine sync params cache

* add fetch_value module
上级 942c3c5c
......@@ -14,8 +14,8 @@
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Network, Model
from abc import ABCMeta, abstractmethod
from parl.framework.model_base import Network, Model
__all__ = ['Algorithm']
......
......@@ -15,8 +15,9 @@
Base class to define an Algorithm.
"""
import hashlib
import paddle.fluid as fluid
from abc import ABCMeta, abstractmethod
from parl.utils.utils import has_func
__all__ = ['Network', 'Model']
......@@ -26,36 +27,64 @@ class Network(object):
A Network is an unordered set of LayerFuncs or Networks.
"""
def sync_paras_to(self, target_net):
assert not target_net is self, "cannot copy between identical networks"
assert isinstance(target_net, Network)
assert self.__class__.__name__ == target_net.__class__.__name__, \
"must be the same class for para syncing!"
for attr in self.__dict__:
if not attr in target_net.__dict__:
continue
val = getattr(self, attr)
target_val = getattr(target_net, attr)
assert type(val) == type(target_val), \
"[Error]sync_paras_to failed, \
ensure that the destination model is generated by deep copied from source model"
### TODO: sync paras recursively
if has_func(val, 'sync_paras_to'):
val.sync_paras_to(target_val)
elif isinstance(val, tuple) or isinstance(val, list) or isinstance(
val, set):
for v, tv in zip(val, target_val):
v.sync_paras_to(tv)
elif isinstance(val, dict):
for k in val.keys():
assert k in target_val
val[k].sync_paras_to(target_val[k])
else:
# for any other type, we do not copy
pass
def sync_params_to(self, target_net, gpu_id=0, decay=0.0):
"""
Args:
target_net: Network object deepcopy from source network
gpu_id: gpu id of target_net
decay: Float. The decay to use.
target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
"""
args_hash_id = hashlib.md5('{}_{}_{}'.format(
id(target_net), gpu_id, decay).encode('utf-8')).hexdigest()
has_cached = False
try:
if self._cached_id == args_hash_id:
has_cached = True
except AttributeError:
has_cached = False
if not has_cached:
# Can not run _cached program, need create a new program
self._cached_id = args_hash_id
assert not target_net is self, "cannot copy between identical networks"
assert isinstance(target_net, Network)
assert self.__class__.__name__ == target_net.__class__.__name__, \
"must be the same class for para syncing!"
assert (decay >= 0 and decay <= 1)
# Resolve Circular Imports
from parl.plutils import get_parameter_pairs, fetch_framework_var
param_pairs = get_parameter_pairs(self, target_net)
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
self._cached_fluid_executor = fluid.Executor(place)
self._cached_sync_params_program = fluid.Program()
with fluid.program_guard(self._cached_sync_params_program):
for (src_var_name, target_var_name, is_bias) in param_pairs:
src_var = fetch_framework_var(src_var_name, is_bias)
target_var = fetch_framework_var(target_var_name, is_bias)
fluid.layers.assign(
decay * target_var + (1 - decay) * src_var, target_var)
self._cached_fluid_executor.run(self._cached_sync_params_program)
@property
def parameter_names(self):
""" param_attr names of all parameters in Network,
only parameter created by parl.layers included
Returns:
list of string, param_attr names of all parameters
"""
# Resolve Circular Imports
from parl.plutils import get_parameter_names
return get_parameter_names(self)
class Model(Network):
......@@ -80,7 +109,7 @@ class Model(Network):
Note that it's the model structure that is copied from initial actor,
parameters in initial model havn't been copied to target model.
To copy parameters, you must explicitly use sync_paras_to function after the program is initialized.
To copy parameters, you must explicitly use sync_params_to function after the program is initialized.
"""
__metaclass__ = ABCMeta
......
......@@ -13,10 +13,10 @@
# limitations under the License.
import parl.layers as layers
from abc import ABCMeta, abstractmethod
from paddle.fluid.framework import Variable
from parl.layers import common_functions as comf
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from abc import ABCMeta, abstractmethod
class PolicyDistribution(object):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from parl.framework.algorithm_base import Algorithm
from copy import deepcopy
import numpy as np
import unittest
import sys
class Value(Model):
def __init__(self, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
self.fc1 = layers.fc(size=256, act='relu')
self.fc2 = layers.fc(size=128, act='relu')
self.fc3 = layers.fc(size=self.act_dim)
def value(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
value = self.fc3(out)
return value
class QLearning(Algorithm):
def __init__(self, critic_model):
self.critic_model = critic_model
self.target_model = deepcopy(critic_model)
def define_predict(self, obs):
self.q_value = self.critic_model.value(obs)
self.q_target_value = self.target_model.value(obs)
class AlgorithmBaseTest(unittest.TestCase):
def test_sync_paras_in_one_program(self):
critic_model = Value(obs_dim=4, act_dim=1)
dqn = QLearning(critic_model)
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
dqn.define_predict(obs)
place = fluid.CUDAPlace(0)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program,
feed={'obs': x},
fetch_list=[dqn.q_value, dqn.q_target_value])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
critic_model.sync_paras_to(dqn.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program,
feed={'obs': x},
fetch_list=[dqn.q_value, dqn.q_target_value])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_sync_paras_among_programs(self):
critic_model = Value(obs_dim=4, act_dim=1)
dqn = QLearning(critic_model)
dqn_2 = deepcopy(dqn)
pred_program = fluid.Program()
pred_program_2 = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
dqn.define_predict(obs)
# algorithm #2
with fluid.program_guard(pred_program_2):
obs_2 = layers.data(name='obs_2', shape=[4], dtype='float32')
dqn_2.define_predict(obs_2)
place = fluid.CUDAPlace(0)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program, feed={'obs': x}, fetch_list=[dqn.q_value])
outputs_2 = executor.run(
pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
dqn.critic_model.sync_paras_to(dqn_2.critic_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = executor.run(
pred_program, feed={'obs': x}, fetch_list=[dqn.q_value])
outputs_2 = executor.run(
pred_program_2, feed={'obs_2': x}, fetch_list=[dqn_2.q_value])
self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten())
if __name__ == '__main__':
unittest.main()
......@@ -12,31 +12,405 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from copy import deepcopy
import unittest
from copy import deepcopy
from paddle.fluid import ParamAttr
from parl.framework.model_base import Model
from parl.utils import get_gpu_count
from parl.plutils import fetch_value
class Value(Model):
def __init__(self, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
class TestModel(Model):
def __init__(self):
self.fc1 = layers.fc(
size=256,
act=None,
param_attr=ParamAttr(name='fc1.w'),
bias_attr=ParamAttr(name='fc1.b'))
self.fc2 = layers.fc(
size=128,
act=None,
param_attr=ParamAttr(name='fc2.w'),
bias_attr=ParamAttr(name='fc2.b'))
self.fc3 = layers.fc(
size=1,
act=None,
param_attr=ParamAttr(name='fc3.w'),
bias_attr=ParamAttr(name='fc3.b'))
self.fc1 = layers.fc(size=256, act='relu')
self.fc2 = layers.fc(size=128, act='relu')
def predict(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
out = self.fc3(out)
return out
class ModelBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
self.target_model = deepcopy(self.model)
self.target_model2 = deepcopy(self.model)
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace
self.executor = fluid.Executor(place)
def test_network_copy(self):
value = Value(obs_dim=2, act_dim=1)
target_value = deepcopy(value)
self.assertNotEqual(value.fc1.param_name, target_value.fc1.param_name)
self.assertNotEqual(value.fc1.bias_name, target_value.fc1.bias_name)
self.assertNotEqual(self.model.fc1.param_name,
self.target_model.fc1.param_name)
self.assertNotEqual(self.model.fc1.bias_name,
self.target_model.fc1.bias_name)
self.assertNotEqual(self.model.fc2.param_name,
self.target_model.fc2.param_name)
self.assertNotEqual(self.model.fc2.bias_name,
self.target_model.fc2.bias_name)
self.assertNotEqual(self.model.fc3.param_name,
self.target_model.fc3.param_name)
self.assertNotEqual(self.model.fc3.bias_name,
self.target_model.fc3.bias_name)
def test_network_copy_with_multi_copy(self):
self.assertNotEqual(self.target_model.fc1.param_name,
self.target_model2.fc1.param_name)
self.assertNotEqual(self.target_model.fc1.bias_name,
self.target_model2.fc1.bias_name)
self.assertNotEqual(self.target_model.fc2.param_name,
self.target_model2.fc2.param_name)
self.assertNotEqual(self.target_model.fc2.bias_name,
self.target_model2.fc2.bias_name)
self.assertNotEqual(self.target_model.fc3.param_name,
self.target_model2.fc3.param_name)
self.assertNotEqual(self.target_model.fc3.bias_name,
self.target_model2.fc3.bias_name)
def test_network_parameter_names(self):
self.assertSetEqual(
set(self.model.parameter_names),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
def test_sync_params_in_one_program(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model_output, target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
self.model.sync_params_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[model_output, target_model_output])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_sync_params_among_programs(self):
pred_program = fluid.Program()
pred_program_2 = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
# program 2
with fluid.program_guard(pred_program_2):
obs = layers.data(name='obs', shape=[4], dtype='float32')
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
outputs_2 = self.executor.run(
pred_program_2,
feed={'obs': x},
fetch_list=[target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
self.model.sync_params_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
outputs = self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output])
outputs_2 = self.executor.run(
pred_program_2,
feed={'obs': x},
fetch_list=[target_model_output])
self.assertEqual(outputs[0].flatten(), outputs_2[0].flatten())
def _numpy_update(self, target_model, decay):
model_fc1_w = fetch_value('fc1.w')
model_fc1_b = fetch_value('fc1.b')
model_fc2_w = fetch_value('fc2.w')
model_fc2_b = fetch_value('fc2.b')
model_fc3_w = fetch_value('fc3.w')
model_fc3_b = fetch_value('fc3.b')
unique_id = target_model.parameter_names[0].split('_')[-1]
target_model_fc1_w = fetch_value(
'PARL_target_fc1.w_{}'.format(unique_id))
target_model_fc1_b = fetch_value(
'PARL_target_fc1.b_{}'.format(unique_id))
target_model_fc2_w = fetch_value(
'PARL_target_fc2.w_{}'.format(unique_id))
target_model_fc2_b = fetch_value(
'PARL_target_fc2.b_{}'.format(unique_id))
target_model_fc3_w = fetch_value(
'PARL_target_fc3.w_{}'.format(unique_id))
target_model_fc3_b = fetch_value(
'PARL_target_fc3.b_{}'.format(unique_id))
# updated self.target_model parameters value in numpy way
target_model_fc1_w = decay * target_model_fc1_w + (
1 - decay) * model_fc1_w
target_model_fc1_b = decay * target_model_fc1_b + (
1 - decay) * model_fc1_b
target_model_fc2_w = decay * target_model_fc2_w + (
1 - decay) * model_fc2_w
target_model_fc2_b = decay * target_model_fc2_b + (
1 - decay) * model_fc2_b
target_model_fc3_w = decay * target_model_fc3_w + (
1 - decay) * model_fc3_w
target_model_fc3_b = decay * target_model_fc3_b + (
1 - decay) * model_fc3_b
return (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w, target_model_fc3_b)
def test_sync_params_with_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_decay_with_multi_sync(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_different_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.8
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_multi_target_model(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
model_output = self.model.predict(obs)
target_model_output = self.target_model.predict(obs)
target_model_output2 = self.target_model2.predict(obs)
self.executor.run(fluid.default_startup_program())
decay = 0.9
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
decay = 0.8
# update in numpy way
(target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model2, decay)
self.model.sync_params_to(self.target_model2, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
x = np.expand_dims(random_obs[i], axis=0)
real_target_outputs = self.executor.run(
pred_program,
feed={'obs': x},
fetch_list=[target_model_output2])[0]
# Ideal target output
out_np = np.dot(x, target_model_fc1_w) + target_model_fc1_b
out_np = np.dot(out_np, target_model_fc2_w) + target_model_fc2_b
out_np = np.dot(out_np, target_model_fc3_w) + target_model_fc3_b
self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name)
self.assertNotEqual(value.fc2.param_name, target_value.fc2.param_name)
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
if __name__ == '__main__':
......
......@@ -15,15 +15,16 @@
Wrappers for fluid.layers so that the layers can share parameters conveniently.
"""
from paddle.fluid.executor import _fetch_var
import paddle.fluid as fluid
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
import inspect
import paddle.fluid.layers as layers
import paddle.fluid.unique_name as unique_name
import paddle.fluid as fluid
import six
from copy import deepcopy
import inspect
from paddle.fluid.executor import _fetch_var
from paddle.fluid.framework import Variable
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Network
......@@ -62,28 +63,6 @@ class LayerFunc(object):
self.param_attr = param_attr
self.bias_attr = bias_attr
def sync_paras_to(self, target_layer, gpu_id=0):
"""
Copy the paras from self to a target layer
"""
## isinstance can handle subclass
assert isinstance(target_layer, LayerFunc)
src_attrs = [self.param_attr, self.bias_attr]
target_attrs = [target_layer.param_attr, target_layer.bias_attr]
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
for i, attrs in enumerate(zip(src_attrs, target_attrs)):
src_attr, target_attr = attrs
assert (src_attr and target_attr) \
or (not src_attr and not target_attr)
if not src_attr:
continue
src_var = _fetch_var(src_attr.name)
target_var = _fetch_var(target_attr.name, return_numpy=False)
target_var.set(src_var, place)
def __deepcopy__(self, memo):
cls = self.__class__
## __new__ won't init the class, we need to do that ourselves
......@@ -92,15 +71,14 @@ class LayerFunc(object):
memo[id(self)] = copied
## first copy all content
for k, v in self.__dict__.iteritems():
for k, v in six.iteritems(self.__dict__):
setattr(copied, k, deepcopy(v, memo))
## then we need to create new para names for self.param_attr and self.bias_attr
def create_new_para_name(attr):
if attr:
assert attr.name, "attr should have a name already!"
## remove the last number id but keep the name key
name_key = "_".join(attr.name.split("_")[:-1])
name_key = 'PARL_target_' + attr.name
attr.name = unique_name.generate(name_key)
create_new_para_name(copied.param_attr)
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl.layers as layers
import unittest
from parl.framework.model_base import Network
......
......@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
import unittest
from parl.framework.model_base import Network
import paddle.fluid as fluid
import numpy as np
class MyNetWork(Network):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common functions of PARL framework
"""
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
from parl.layers.layer_wrappers import LayerFunc
from parl.framework.model_base import Network
__all__ = [
'fetch_framework_var', 'fetch_value', 'get_parameter_pairs',
'get_parameter_names'
]
def fetch_framework_var(attr_name, is_bias):
""" Fetch framework variable according given attr_name.
Return a new reusing variable through create_parameter way
Args:
attr_name: string, attr name of parameter
is_bias: bool, decide whether the parameter is bias
Returns:
framework_var: framework.Varialbe
"""
scope = fluid.executor.global_scope()
core_var = scope.find_var(attr_name)
shape = core_var.get_tensor().shape()
framework_var = fluid.layers.create_parameter(
shape=shape,
dtype='float32',
attr=fluid.ParamAttr(name=attr_name),
is_bias=is_bias)
return framework_var
def fetch_value(attr_name):
""" Given name of ParamAttr, fetch numpy value of the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
Returns:
numpy.ndarray
"""
return _fetch_var(attr_name, return_numpy=True)
def get_parameter_pairs(src, target):
""" Recursively get pairs of parameter names between src and target
Args:
src: parl.Network/parl.LayerFunc/list/tuple/set/dict
target: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
param_pairs: list of all tuple(src_param_name, target_param_name, is_bias)
between src and target
"""
param_pairs = []
if isinstance(src, Network):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
param_pairs.append((src.param_attr.name, target.param_attr.name,
False))
if src.bias_attr:
param_pairs.append((src.bias_attr.name, target.bias_attr.name,
True))
elif isinstance(src, tuple) or isinstance(src, list) or isinstance(
src, set):
for src_var, target_var in zip(src, target):
param_pairs.extend(get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled
pass
return param_pairs
def get_parameter_names(obj):
""" Recursively get parameter names in obj,
mainly used to get parameter names of a parl.Network
Args:
obj: parl.Network/parl.LayerFunc/list/tuple/set/dict
Returns:
parameter_names: list of string, all parameter names in obj
"""
parameter_names = []
for attr in obj.__dict__:
val = getattr(obj, attr)
if isinstance(val, Network):
parameter_names.extend(get_parameter_names(val))
elif isinstance(val, LayerFunc):
parameter_names.append(val.param_name)
if val.bias_name is not None:
parameter_names.append(val.bias_name)
elif isinstance(val, tuple) or isinstance(val, list) or isinstance(
val, set):
for x in val:
parameter_names.extend(get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(get_parameter_names(x))
else:
# for any other type, won't be handled
pass
return parameter_names
......@@ -13,3 +13,4 @@
# limitations under the License.
from parl.utils.utils import *
from parl.utils.gputils import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from parl.utils import logger
__all__ = ['get_gpu_count']
def get_gpu_count():
""" get avaliable gpu count
Returns:
gpu_count: int
"""
gpu_count = 0
env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env_cuda_devices is not None:
assert isinstance(env_cuda_devices, str)
gpu_count = len(env_cuda_devices.split(','))
logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
else:
try:
gpu_count = str(subprocess.check_output(["nvidia-smi",
"-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.warn(e.message)
gpu_count = 0
return gpu_count
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import os
import errno
import os.path
from termcolor import colored
import sys
from termcolor import colored
__all__ = ['set_dir', 'get_dir', 'set_level']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册