提交 2ddf4c11 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

torch test env (#156)

* torch test env

* Update build.sh

* update torch unit test
上级 757cc391
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# A dev image based on paddle production image # A dev image based on paddle production image
FROM parl/parl-test:cuda9.0-cudnn7-v1 FROM parl/parl-test:cuda9.0-cudnn7-v2
COPY ./requirements.txt /root/ COPY ./requirements.txt /root/
...@@ -82,7 +82,7 @@ function run_test_with_cpu() { ...@@ -82,7 +82,7 @@ function run_test_with_cpu() {
if [ $# -eq 1 ];then if [ $# -eq 1 ];then
cmake .. cmake ..
else else
cmake .. -DIS_TESTING_SERIALLY=ON cmake .. -$2=ON
fi fi
cat <<EOF cat <<EOF
===================================================== =====================================================
...@@ -145,20 +145,30 @@ function main() { ...@@ -145,20 +145,30 @@ function main() {
;; ;;
test) test)
# test code compability in environments with various python versions # test code compability in environments with various python versions
declare -a envs=("py27" "py36" "py37") declare -a envs=("py36_torch" "py37_torch" "py27" "py36" "py37")
for env in "${envs[@]}";do for env in "${envs[@]}";do
cd /work cd /work
source ~/.bashrc source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH" export PATH="/root/miniconda3/bin:$PATH"
source activate $env source activate $env
python -m pip install --upgrade pip
echo ======================================== echo ========================================
echo Running tests in $env .. echo Running tests in $env ..
echo `which pip` echo `which pip`
echo ======================================== echo ========================================
pip install . pip install .
pip install -r .teamcity/requirements.txt if [ \( $env == "py27" -o $env == "py36" -o $env == "py37" \) ]
run_test_with_cpu $env then
run_test_with_cpu $env "DIS_TESTING_SERIALLY" pip install -r .teamcity/requirements.txt
run_test_with_cpu $env
run_test_with_cpu $env "DIS_TESTING_SERIALLY"
else
echo ========================================
echo "in torch environment"
echo ========================================
pip install -r .teamcity/requirements_torch.txt
run_test_with_cpu $env "DIS_TESTING_TORCH"
fi
done done
run_test_with_gpu run_test_with_gpu
......
...@@ -4,4 +4,3 @@ gym ...@@ -4,4 +4,3 @@ gym
details details
parameterized parameterized
timeout_decorator timeout_decorator
torch==1.2.0
# requirements for torch unittest
gym
details
parameterized
timeout_decorator
...@@ -21,6 +21,7 @@ option(IS_TESTING_SERIALLY "testing scripts that cannot run in parallel" OFF) ...@@ -21,6 +21,7 @@ option(IS_TESTING_SERIALLY "testing scripts that cannot run in parallel" OFF)
option(IS_TESTING_IMPORT "testing import parl" OFF) option(IS_TESTING_IMPORT "testing import parl" OFF)
option(IS_TESTING_DOCS "testing compling the docs" OFF) option(IS_TESTING_DOCS "testing compling the docs" OFF)
option(IS_TESTING_GPU "testing GPU environment" OFF) option(IS_TESTING_GPU "testing GPU environment" OFF)
option(IS_TESTING_TORCH "testing torch parts" OFF)
set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid") set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid")
...@@ -64,6 +65,12 @@ if (WITH_TESTING) ...@@ -64,6 +65,12 @@ if (WITH_TESTING)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH}) py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endforeach() endforeach()
elseif (IS_TESTING_TORCH)
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test_torch.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endforeach()
else () else ()
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py") file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
import os
import torch
import torch.nn as nn
import torch.optim as optim
from parl.core.torch.model import Model
from parl.core.torch.algorithm import Algorithm
from parl.core.torch.agent import Agent
class TestModel(Model):
def __init__(self):
super(TestModel, self).__init__()
self.fc1 = nn.Linear(10, 256)
self.fc2 = nn.Linear(256, 1)
def forward(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
return out
class TestAlgorithm(Algorithm):
def __init__(self, model):
self.model = model
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
def predict(self, obs):
return self.model(obs)
def learn(self, obs, label):
pred_output = self.model(obs)
cost = (pre_output - obs).pow(2)
self.optimizer.zero_grad()
cost.backward()
self.optimizer.step()
return cost.item()
class TestAgent(Agent):
def __init__(self, algorithm):
self.alg = algorithm
def learn(self, obs, label):
cost = self.alg.lean(obs, label)
def predict(self, obs):
return self.alg.predict(obs)
class AgentBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
self.alg = TestAlgorithm(self.model)
def test_agent(self):
agent = TestAgent(self.alg)
obs = torch.randn(3, 10)
output = agent.predict(obs)
self.assertIsNotNone(output)
def test_save(self):
agent = TestAgent(self.alg)
obs = torch.randn(3, 10)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
self.assertTrue(os.path.exists(save_path2))
def test_restore(self):
agent = TestAgent(self.alg)
obs = torch.randn(3, 10)
output = agent.predict(obs)
save_path1 = './model.ckpt'
previous_output = agent.predict(obs).detach().cpu().numpy()
agent.save(save_path1)
agent.restore(save_path1)
current_output = agent.predict(obs).detach().cpu().numpy()
np.testing.assert_equal(current_output, previous_output)
if __name__ == '__main__':
unittest.main()
...@@ -20,12 +20,10 @@ import torch ...@@ -20,12 +20,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from parl.core.torch.model import Model import parl
from parl.core.torch.algorithm import Algorithm
from parl.core.torch.agent import Agent
class TestModel(Model): class TestModel(parl.Model):
def __init__(self): def __init__(self):
super(TestModel, self).__init__() super(TestModel, self).__init__()
self.fc1 = nn.Linear(10, 256) self.fc1 = nn.Linear(10, 256)
...@@ -37,7 +35,7 @@ class TestModel(Model): ...@@ -37,7 +35,7 @@ class TestModel(Model):
return out return out
class TestAlgorithm(Algorithm): class TestAlgorithm(parl.Algorithm):
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
...@@ -47,19 +45,19 @@ class TestAlgorithm(Algorithm): ...@@ -47,19 +45,19 @@ class TestAlgorithm(Algorithm):
def learn(self, obs, label): def learn(self, obs, label):
pred_output = self.model(obs) pred_output = self.model(obs)
cost = (pre_output - obs).pow(2) cost = (pred_output - obs).pow(2)
self.optimizer.zero_grad() self.optimizer.zero_grad()
cost.backward() cost.backward()
self.optimizer.step() self.optimizer.step()
return cost.item() return cost.item()
class TestAgent(Agent): class TestAgent(parl.Agent):
def __init__(self, algorithm): def __init__(self, algorithm):
self.alg = algorithm self.alg = algorithm
def learn(self, obs, label): def learn(self, obs, label):
cost = self.alg.lean(obs, label) cost = self.alg.learn(obs, label)
def predict(self, obs): def predict(self, obs):
return self.alg.predict(obs) return self.alg.predict(obs)
......
...@@ -22,12 +22,10 @@ import torch.nn as nn ...@@ -22,12 +22,10 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from parl.utils import get_gpu_count from parl.utils import get_gpu_count
from parl.core.torch.model import Model import parl
from parl.core.torch.algorithm import Algorithm
from parl.core.torch.agent import Agent
class TestModel(Model): class TestModel(parl.Model):
def __init__(self): def __init__(self):
super(TestModel, self).__init__() super(TestModel, self).__init__()
self.fc1 = nn.Linear(4, 256) self.fc1 = nn.Linear(4, 256)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册