未验证 提交 a1ac2da3 编写于 作者: R rical730 提交者: GitHub

add es torch effect test (#236)

* add es torch effect test

* test the performance using the gtest framework

* update

* delete last blank line of CMakeLists.txt

* optimizer sin demo and use vector for data storage

* optimize sin demo test
上级 d18740e5
......@@ -162,6 +162,7 @@ public:
private:
int64_t _calculate_param_size() {
int _param_size = 0;
auto params = _model->named_parameters();
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
......
......@@ -2,6 +2,7 @@ cmake_minimum_required (VERSION 2.6)
project (DeepES)
set(TARGET unit_test_main)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
......@@ -14,13 +15,20 @@ if (OPENMP_FOUND)
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
file(GLOB core_src "../src/*.cc" "../src/*.cpp" "../benchmark/*.cc")
file(GLOB test_src "../test/src/*.cc")
# Torch lib
list(APPEND CMAKE_PREFIX_PATH "../libtorch")
find_package(Torch REQUIRED ON)
# include and source
file(GLOB test_src "../test/src/*.cc")
file(GLOB core_src "../src/*.cc")
file(GLOB agent_src "../src/torch/*.cc")
include_directories("../include/torch")
include_directories("../include")
include_directories("../benchmark")
include_directories("../test/include")
add_executable(${TARGET} "unit_test.cc" ${core_src} ${test_src} ${lib_src}) # ${demo_src}
target_link_libraries(${TARGET} gflags protobuf pthread glog gtest) # "${TORCH_LIBRARIES}"
add_executable(${TARGET} "unit_test.cc" ${core_src} ${agent_src} ${test_src})
target_link_libraries(${TARGET} gflags protobuf pthread glog gtest "${TORCH_LIBRARIES}")
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _TORCH_DEMO_MODEL_H
#define _TORCH_DEMO_MODEL_H
#include <torch/torch.h>
struct Model : public torch::nn::Module{
Model() = delete;
Model(const int obs_dim, const int act_dim) {
_obs_dim = obs_dim;
_act_dim = act_dim;
int hid1_size = 30;
int hid2_size = 15;
fc1 = register_module("fc1", torch::nn::Linear(obs_dim, hid1_size));
fc2 = register_module("fc2", torch::nn::Linear(hid1_size, hid2_size));
fc3 = register_module("fc3", torch::nn::Linear(hid2_size, act_dim));
}
torch::Tensor forward(torch::Tensor x) {
x = x.reshape({-1, _obs_dim});
x = torch::tanh(fc1->forward(x));
x = torch::tanh(fc2->forward(x));
x = torch::tanh(fc3->forward(x));
return x;
}
std::shared_ptr<Model> clone() {
std::shared_ptr<Model> model = std::make_shared<Model>(_obs_dim, _act_dim);
std::vector<torch::Tensor> parameters1 = parameters();
std::vector<torch::Tensor> parameters2 = model->parameters();
for (int i = 0; i < parameters1.size(); ++i) {
torch::Tensor src = parameters1[i].view({-1});
torch::Tensor des = parameters2[i].view({-1});
auto src_a = src.accessor<float, 1>();
auto des_a = des.accessor<float, 1>();
for (int j = 0; j < src.size(0); ++j) {
des_a[j] = src_a[j];
}
}
return model;
}
int _act_dim;
int _obs_dim;
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};
#endif
#!/bin/bash
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
#---------------libtorch-------------#
if [ ! -d "./libtorch" ];then
echo "Cannot find the torch library: ../libtorch"
echo "Downloading Torch library"
wget -q https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip
unzip -q libtorch-cxx11-abi-shared-with-deps-1.4.0+cpu.zip
rm -rf libtorch-cxx11-abi-shared-with-deps-1.4.0+cpu.zip
echo "Torch library Downloaded"
fi
#----------------protobuf-------------#
cp ./src/proto/deepes.proto ./
protoc deepes.proto --cpp_out ./
mv deepes.pb.h ./include
mv deepes.pb.cc ./src
#----------------build---------------#
rm -rf build
mkdir build
cd build
cmake ../test # -DWITH_TORCH=ON
cmake ../test
make -j10
#-----------------run----------------#
./unit_test_main
......@@ -19,6 +19,7 @@
namespace DeepES {
TEST(SGDOptimizersTest, Method_update) {
std::shared_ptr<DeepESConfig> config = std::make_shared<DeepESConfig>();
auto optimizer_config = config->mutable_optimizer();
......@@ -55,5 +56,5 @@ TEST(AdamOptimizersTest, Method_update) {
EXPECT_FALSE(optimizer->update(adam_wei, adam_grad, 9, "test"));
}
}
} // namespace
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include <torch/torch.h>
#include <glog/logging.h>
#include <omp.h>
#include "gaussian_sampling.h"
#include "torch_demo_model.h"
#include "es_agent.h"
#include <memory>
#include <vector>
#include <random>
#include <math.h>
namespace DeepES {
// The fixture for testing class Foo.
class TorchDemoTest : public ::testing::Test {
protected:
float evaluate(std::vector<float>& x_list, std::vector<float>& y_list, int size, std::shared_ptr<ESAgent<Model>> agent) {
float total_loss = 0.0;
for (int i = 0; i < size; ++i) {
torch::Tensor x_input = torch::tensor(x_list[i], torch::dtype(torch::kFloat32));
torch::Tensor predict_y = agent->predict(x_input);
auto pred_y = predict_y.accessor<float,2>();
float loss = pow((pred_y[0][0] - y_list[i]), 2);
total_loss += loss;
}
return -total_loss / float(size);
}
float train_loss() {
return -1.0 * evaluate(x_list, y_list, train_data_size, agent);
}
float test_loss() {
return -1.0 * evaluate(test_x_list, test_y_list, test_data_size, agent);
}
float train_test_gap() {
float train_lo = train_loss();
float test_lo = test_loss();
if ( train_lo > test_lo) {
return train_lo - test_lo;
}
else {
return test_lo - train_lo;
}
}
void SetUp() override {
std::default_random_engine generator(0); // fix seed
std::uniform_real_distribution<float> uniform(-3.0, 9.0);
std::normal_distribution<float> norm;
for (int i = 0; i < train_data_size; ++i) {
float x_i = uniform(generator); // generate data between [-3, 9]
float y_i = sin(x_i) + norm(generator)*0.05; // noise std 0.05
x_list.push_back(x_i);
y_list.push_back(y_i);
}
for (int i= 0; i < test_data_size; ++i) {
float x_i = uniform(generator);
float y_i = sin(x_i);
test_x_list.push_back(x_i);
test_y_list.push_back(y_i);
}
std::shared_ptr<Model> model = std::make_shared<Model>(1, 1);
agent = std::make_shared<ESAgent<Model>>(model, "../test/torch_sin_config.prototxt");
// Clone agents to sample (explore).
std::vector<std::shared_ptr<ESAgent<Model>>> sampling_agents;
for (int i = 0; i < iter; ++i) {
sampling_agents.push_back(agent->clone());
}
std::vector<SamplingKey> noisy_keys;
std::vector<float> noisy_rewards(iter, 0.0f);
noisy_keys.resize(iter);
LOG(INFO) << "start training...";
for (int epoch = 0; epoch < 1001; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < iter; ++i) {
auto sampling_agent = sampling_agents[i];
SamplingKey key;
bool success = sampling_agent->add_noise(key);
float reward = evaluate(x_list, y_list, train_data_size, sampling_agent);
noisy_keys[i] = key;
noisy_rewards[i] = reward;
}
bool success = agent->update(noisy_keys, noisy_rewards);
if (epoch % 100 == 0) {
float reward = evaluate(test_x_list, test_y_list, test_data_size, agent);
float train_reward = evaluate(x_list, y_list, train_data_size, agent);
LOG(INFO) << "Epoch:" << epoch << " Loss: " << -reward << ", Train loss" << -train_reward;
}
}
}
// Class members declared here can be used by all tests in the test suite
int train_data_size = 300;
int test_data_size = 100;
int iter = 10;
std::vector<float> x_list;
std::vector<float> y_list;
std::vector<float> test_x_list;
std::vector<float> test_y_list;
std::shared_ptr<ESAgent<Model>> agent;
};
TEST_F(TorchDemoTest, TrainingEffectTest) {
EXPECT_LT(train_loss(), 0.05);
EXPECT_LT(test_loss(), 0.05);
EXPECT_LT(train_test_gap(), 0.03);
}
} // namespace
......@@ -26,5 +26,5 @@ TEST(UtilsTest, Method_compute_centered_ranks) {
}
}
} // namespace
seed : 1024
gaussian_sampling {
std: 0.005
}
optimizer {
type: "Adam",
base_lr: 0.005,
momentum: 0.9,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册