From 2e5d9cb0e9f610a49cc7c2e75deaa17ad4fcf7ac Mon Sep 17 00:00:00 2001 From: Hongsheng Zeng Date: Mon, 18 May 2020 12:19:31 +0800 Subject: [PATCH] add unittest of get_weights set_weights with create_parameter (#262) --- parl/core/fluid/tests/model_base_test.py | 37 ++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/parl/core/fluid/tests/model_base_test.py b/parl/core/fluid/tests/model_base_test.py index 1656366..faa1368 100644 --- a/parl/core/fluid/tests/model_base_test.py +++ b/parl/core/fluid/tests/model_base_test.py @@ -690,6 +690,43 @@ class ModelBaseTest(unittest.TestCase): self.executor.run( pred_program, feed={'obs': x}, fetch_list=[model_output]) + def test_get_weights_set_weights_with_create_parameter(self): + model1 = TestModel2() + model2 = TestModel2() + + pred_program = fluid.Program() + with fluid.program_guard(pred_program): + obs = layers.data(name='obs', shape=[100], dtype='float32') + model1_output = model1.predict(obs) + model2_output = model2.predict(obs) + + self.executor.run(fluid.default_startup_program()) + + N = 10 + random_obs = np.random.random(size=(N, 100)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[model1_output, model2_output]) + self.assertNotEqual( + np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten())) + + # pass parameters of self.model to model2 + params = model1.get_weights() + model2.set_weights(params) + + random_obs = np.random.random(size=(N, 100)).astype('float32') + for i in range(N): + x = np.expand_dims(random_obs[i], axis=0) + outputs = self.executor.run( + pred_program, + feed={'obs': x}, + fetch_list=[model1_output, model2_output]) + self.assertEqual( + np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten())) + if __name__ == '__main__': unittest.main() -- GitLab