diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 48e0a1993d07f801e65dfa54a991995c593fe475..e7a0895533dd8902df9a012ab230df2a67256483 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -24,8 +24,9 @@ add_custom_target(paddle_python ALL DEPENDS ${OUTPUT_DIR}/.timestamp) add_subdirectory(paddle/trainer_config_helpers/tests) -add_subdirectory(paddle/v2/reader/tests) add_subdirectory(paddle/v2/tests) +add_subdirectory(paddle/v2/reader/tests) +add_subdirectory(paddle/v2/plot/tests) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/dist/ DESTINATION opt/paddle/share/wheels diff --git a/python/paddle/v2/plot/__init__.py b/python/paddle/v2/plot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..acd3013db4e6a57cd1b269266bea82a31e928397 --- /dev/null +++ b/python/paddle/v2/plot/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from plot import Ploter + +__all__ = ['Ploter'] diff --git a/python/paddle/v2/plot/plot.py b/python/paddle/v2/plot/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7bd039b07db4832295c2374293bffa588eb4ef --- /dev/null +++ b/python/paddle/v2/plot/plot.py @@ -0,0 +1,79 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + + +class PlotData(object): + def __init__(self): + self.step = [] + self.value = [] + + def append(self, step, value): + self.step.append(step) + self.value.append(value) + + def reset(self): + self.step = [] + self.value = [] + + +class Ploter(object): + def __init__(self, *args): + self.__args__ = args + self.__plot_data__ = {} + for title in args: + self.__plot_data__[title] = PlotData() + # demo in notebooks will use Ploter to plot figure, but when we convert + # the ipydb to py file for testing, the import of matplotlib will make the + # script crash. So we can use `export DISABLE_PLOT=True` to disable import + # these libs + self.__disable_plot__ = os.environ.get("DISABLE_PLOT") + if not self.__plot_is_disabled__(): + import matplotlib.pyplot as plt + from IPython import display + self.plt = plt + self.display = display + + def __plot_is_disabled__(self): + return self.__disable_plot__ == "True" + + def append(self, title, step, value): + assert isinstance(title, basestring) + assert self.__plot_data__.has_key(title) + data = self.__plot_data__[title] + assert isinstance(data, PlotData) + data.append(step, value) + + def plot(self): + if self.__plot_is_disabled__(): + return + + titles = [] + for title in self.__args__: + data = self.__plot_data__[title] + assert isinstance(data, PlotData) + if len(data.step) > 0: + titles.append(title) + self.plt.plot(data.step, data.value) + self.plt.legend(titles, loc='upper left') + self.display.clear_output(wait=True) + self.display.display(self.plt.gcf()) + self.plt.gcf().clear() + + def reset(self): + for key in self.__plot_data__: + data = self.__plot_data__[key] + assert isinstance(data, PlotData) + data.reset() diff --git a/python/paddle/v2/plot/plot_curve.py b/python/paddle/v2/plot/plot_curve.py deleted file mode 100644 index 0f62674cb2baad9e4ecd9f6655f7e2dc00173dc6..0000000000000000000000000000000000000000 --- a/python/paddle/v2/plot/plot_curve.py +++ /dev/null @@ -1,48 +0,0 @@ -from IPython import display -import os - - -class PlotCost(object): - """ - append train and test cost in event_handle and then call plot. - """ - - def __init__(self): - self.train_costs = ([], []) - self.test_costs = ([], []) - - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - self.plt = plt - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def plot(self): - if self.__plot_is_disabled__(): - return - - self.plt.plot(*self.train_costs) - self.plt.plot(*self.test_costs) - title = [] - if len(self.train_costs[0]) > 0: - title.append('Train Cost') - if len(self.test_costs[0]) > 0: - title.append('Test Cost') - self.plt.legend(title, loc='upper left') - display.clear_output(wait=True) - display.display(self.plt.gcf()) - self.plt.gcf().clear() - - def append_train_cost(self, step, cost): - self.train_costs[0].append(step) - self.train_costs[1].append(cost) - - def append_test_cost(self, step, cost): - self.test_costs[0].append(step) - self.test_costs[1].append(cost) - - def reset(self): - self.train_costs = ([], []) - self.test_costs = ([], []) diff --git a/python/paddle/v2/plot/tests/CMakeLists.txt b/python/paddle/v2/plot/tests/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..da550a178ce0fe4832b640e0c23505279dedd27a --- /dev/null +++ b/python/paddle/v2/plot/tests/CMakeLists.txt @@ -0,0 +1,3 @@ +add_test(NAME test_ploter + COMMAND bash ${PROJ_ROOT}/python/paddle/v2/plot/tests/run_tests.sh + ${PYTHON_EXECUTABLE}) diff --git a/python/paddle/v2/plot/tests/__init__.py b/python/paddle/v2/plot/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1abfc08f19505a9010e924e34074e5bc3cc0571 --- /dev/null +++ b/python/paddle/v2/plot/tests/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import test_ploter + +__all__ = ['test_ploter.py'] diff --git a/python/paddle/v2/plot/tests/run_tests.sh b/python/paddle/v2/plot/tests/run_tests.sh new file mode 100755 index 0000000000000000000000000000000000000000..9c1a4a71ce43f285c4f970eddf6af46a2821a40a --- /dev/null +++ b/python/paddle/v2/plot/tests/run_tests.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +pushd `dirname $0` > /dev/null +SCRIPTPATH=$PWD +popd > /dev/null + +cd $SCRIPTPATH +$1 -m pip install ../../../../../paddle/dist/*.whl + +export DISABLE_PLOT="True" +test_list="test_ploter.py" + +export PYTHONPATH=$PWD/../../../../../python/ + +for fn in $test_list +do + echo "test $fn" + $1 $fn + if [ $? -ne 0 ]; then + exit 1 + fi +done diff --git a/python/paddle/v2/plot/tests/test_ploter.py b/python/paddle/v2/plot/tests/test_ploter.py new file mode 100644 index 0000000000000000000000000000000000000000..a75f853ed933dfce651faf758f71feca7cd8d328 --- /dev/null +++ b/python/paddle/v2/plot/tests/test_ploter.py @@ -0,0 +1,40 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.v2.plot import Ploter + + +class TestCommon(unittest.TestCase): + def test_append(self): + title1 = "title1" + title2 = "title2" + plot_test = Ploter(title1, title2) + plot_test.append(title1, 1, 2) + plot_test.append(title1, 2, 5) + plot_test.append(title2, 3, 4) + self.assertEqual(plot_test.__plot_data__[title1].step, [1, 2]) + self.assertEqual(plot_test.__plot_data__[title1].value, [2, 5]) + self.assertEqual(plot_test.__plot_data__[title2].step, [3]) + self.assertEqual(plot_test.__plot_data__[title2].value, [4]) + plot_test.reset() + self.assertEqual(plot_test.__plot_data__[title1].step, []) + self.assertEqual(plot_test.__plot_data__[title1].value, []) + self.assertEqual(plot_test.__plot_data__[title2].step, []) + self.assertEqual(plot_test.__plot_data__[title2].value, []) + + +if __name__ == '__main__': + unittest.main()