cartpole_solver_parallel.cc 3.7 KB
Newer Older
Z
zenghsh3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <glog/logging.h>
#include <omp.h>
#include "cartpole.h"
#include "es_agent.h"
#include "paddle_api.h"

using namespace DeepES;
using namespace paddle::lite_api;

const int ITER = 10;

std::shared_ptr<PaddlePredictor> create_paddle_predictor(const std::string& model_dir) {
  // 1. Create CxxConfig
  CxxConfig config;
  config.set_model_dir(model_dir);
  config.set_valid_places({
    Place{TARGET(kX86), PRECISION(kFloat)},
    Place{TARGET(kHost), PRECISION(kFloat)}
  });

  // 2. Create PaddlePredictor by CxxConfig
  std::shared_ptr<PaddlePredictor> predictor = CreatePaddlePredictor<CxxConfig>(config);
  return predictor;
}

// Use PaddlePredictor of CartPole model to predict the action.
std::vector<float> forward(std::shared_ptr<PaddlePredictor> predictor, const float* obs) {
  std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
  input_tensor->Resize({1, 4});
  input_tensor->CopyFromCpu(obs);
  
  predictor->Run();
  
  std::vector<float> probs(2, 0.0);
  std::unique_ptr<const Tensor> output_tensor(
      std::move(predictor->GetOutput(0)));
  output_tensor->CopyToCpu(probs.data());
  return probs;
}

int arg_max(const std::vector<float>& vec) {
  return static_cast<int>(std::distance(vec.begin(), std::max_element(vec.begin(), vec.end())));
}


61
float evaluate(CartPole& env, std::shared_ptr<ESAgent> agent) {
Z
zenghsh3 已提交
62 63 64 65 66
  float total_reward = 0.0;
  env.reset();
  const float* obs = env.getState();

  std::shared_ptr<PaddlePredictor> paddle_predictor;
67
  paddle_predictor = agent->get_predictor();
Z
zenghsh3 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

  while (true) {
    std::vector<float> probs = forward(paddle_predictor, obs); 
    int act = arg_max(probs);
    env.step(act);
    float reward = env.getReward(); 
    bool done = env.isDone();
    total_reward += reward;
    if (done) break;
    obs = env.getState();
  }
  return total_reward;
}


int main(int argc, char* argv[]) {
  std::vector<CartPole> envs;
  for (int i = 0; i < ITER; ++i) {
    envs.push_back(CartPole());
  }

  std::shared_ptr<PaddlePredictor> paddle_predictor = create_paddle_predictor("../demo/paddle/cartpole_init_model");
  std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>(paddle_predictor, "../benchmark/cartpole_config.prototxt");

92 93 94
  // Clone agents to sample (explore).
  std::vector< std::shared_ptr<ESAgent> > sampling_agents;
  for (int i = 0; i < ITER; ++i) {
Z
zenghsh3 已提交
95 96 97
    sampling_agents.push_back(agent->clone());
  }

98
  std::vector<SamplingInfo> noisy_keys;
Z
zenghsh3 已提交
99 100 101 102
  std::vector<float> noisy_rewards(ITER, 0.0f);
  noisy_keys.resize(ITER);

  omp_set_num_threads(10);
103
  for (int epoch = 0; epoch < 100; ++epoch) {
Z
zenghsh3 已提交
104 105 106
#pragma omp parallel for schedule(dynamic, 1)
    for (int i = 0; i < ITER; ++i) {
      std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
107
      SamplingInfo key;
108
      bool success = sampling_agent->add_noise(key);
Z
zenghsh3 已提交
109 110 111 112 113 114 115
      float reward = evaluate(envs[i], sampling_agent);

      noisy_keys[i] = key;
      noisy_rewards[i] = reward;
    }

    // NOTE: all parameters of sampling_agents will be updated
116
    bool success = agent->update(noisy_keys, noisy_rewards);
Z
zenghsh3 已提交
117
  
118
    int reward = evaluate(envs[0], agent);
Z
zenghsh3 已提交
119 120 121
    LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward;
  }
}