es_agent.h 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Z
zenghsh3 已提交
15 16 17
#ifndef TORCH_ESAGENT_H
#define TORCH_ESAGENT_H

18 19
#include <memory>
#include <string>
B
Bo Zhou 已提交
20
#include "optimizer_factory.h"
21 22 23 24 25 26
#include "utils.h"
#include "gaussian_sampling.h"
#include "deepes.pb.h"

namespace DeepES{

Z
zenghsh3 已提交
27 28 29
/**
 * @brief DeepES agent for Torch.
 *
30
 * Our implemtation is flexible to support any model that subclass torch::nn::Module.
31 32
 * That is, we can instantiate an agent by: es_agent = ESAgent<Model>(model);
 * After that, users can clone an agent for multi-thread processing, add parametric noise for exploration,
33 34 35
 * and update the parameteres, according to the evaluation resutls of noisy parameters.
 */
template <class T>
Z
zenghsh3 已提交
36
class ESAgent{
37
public:
38
  ESAgent() {}
Z
zenghsh3 已提交
39 40 41

  ~ESAgent() {
    delete[] _noise;
42 43
    if (!_is_sampling_agent)
      delete[] _neg_gradients;
Z
zenghsh3 已提交
44
  }
45

Z
zenghsh3 已提交
46
  ESAgent(std::shared_ptr<T> model, std::string config_path): _model(model) {
47
    _is_sampling_agent = false;
48 49 50 51
    _config = std::make_shared<DeepESConfig>();
    load_proto_conf(config_path, *_config);
    _sampling_method = std::make_shared<GaussianSampling>();
    _sampling_method->load_config(*_config);
B
Bo Zhou 已提交
52
    _optimizer = create_optimizer(_config->optimizer());
53
    // Origin agent can't be used to sample, so keep it same with _model for evaluating.
Z
zenghsh3 已提交
54
    _sampling_model = model;
55
    _param_size = _calculate_param_size();
Z
zenghsh3 已提交
56 57 58

    _noise = new float [_param_size];
    _neg_gradients = new float [_param_size];
59 60
  }

Z
zenghsh3 已提交
61 62 63 64 65 66 67
  /** 
   * @breif Clone a sampling agent
   *
   * Only cloned ESAgent can call `add_noise` function.
   * Each cloned ESAgent will have a copy of original parameters.
   * (support sampling in multi-thread way)
   */
Z
zenghsh3 已提交
68 69
  std::shared_ptr<ESAgent> clone() {
    std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>();
70

71 72
    new_agent->_model = _model;
    std::shared_ptr<T> new_model = _model->clone();
Z
zenghsh3 已提交
73
    new_agent->_sampling_model = new_model;
74
  
75 76 77
    new_agent->_is_sampling_agent = true;
    new_agent->_sampling_method = _sampling_method;
    new_agent->_param_size = _param_size;
78

79 80
    float* new_noise = new float [_param_size];
    new_agent->_noise = new_noise;
81

82
    return new_agent;
Z
zenghsh3 已提交
83 84
  }

Z
zenghsh3 已提交
85 86 87 88 89 90
  /**
   * @brief Use the model to predict. 
   *
   * if _is_sampling_agent is true, will use the sampling model with added noise;
   * if _is_sampling_agent is false, will use the original model without added noise.
   */
91
  torch::Tensor predict(const torch::Tensor& x) {
Z
zenghsh3 已提交
92
    return _sampling_model->forward(x);
Z
zenghsh3 已提交
93 94
  }

Z
zenghsh3 已提交
95 96 97 98 99 100
  /**
   * @brief Update parameters of model based on ES algorithm.
   *
   * Only not cloned ESAgent can call `update` function.
   * Parameters of cloned agents will also be updated.
   */
Z
zhoubo01 已提交
101
  bool update(std::vector<SamplingInfo>& noisy_info, std::vector<float>& noisy_rewards) {
102 103 104
    if (_is_sampling_agent) {
      LOG(ERROR) << "[DeepES] Cloned ESAgent cannot call update function, please use original ESAgent.";
      return false;
Z
zenghsh3 已提交
105
    }
106 107

    compute_centered_ranks(noisy_rewards);
Z
zenghsh3 已提交
108 109

    memset(_neg_gradients, 0, _param_size * sizeof(float));
Z
zhoubo01 已提交
110 111
    for (int i = 0; i < noisy_info.size(); ++i) {
      int key = noisy_info[i].key(0);
112
      float reward = noisy_rewards[i];
Z
zenghsh3 已提交
113 114 115
      bool success = _sampling_method->resampling(key, _noise, _param_size);
      for (int64_t j = 0; j < _param_size; ++j) {
        _neg_gradients[j] += _noise[j] * reward;
116 117
      }
    }
Z
zenghsh3 已提交
118
    for (int64_t j = 0; j < _param_size; ++j) {
Z
zhoubo01 已提交
119
      _neg_gradients[j] /= -1.0 * noisy_info.size();
120 121 122 123
    }

    //update
    auto params = _model->named_parameters();
Z
zenghsh3 已提交
124
    int64_t counter = 0;
125 126 127
    for (auto& param: params) {
      torch::Tensor tensor = param.value().view({-1});
      auto tensor_a = tensor.accessor<float,1>();
Z
zhoubo01 已提交
128
      _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key());
129 130
      counter += tensor.size(0);
    }
131 132

    return true;
133 134
  }

Z
zenghsh3 已提交
135
  // copied parameters = original parameters + noise
Z
zhoubo01 已提交
136
  bool add_noise(SamplingInfo& sampling_info) {
137 138 139 140 141
    if (!_is_sampling_agent) {
      LOG(ERROR) << "[DeepES] Original ESAgent cannot call add_noise function, please use cloned ESAgent.";
      return false;
    }

Z
zenghsh3 已提交
142
    auto sampling_params = _sampling_model->named_parameters();
143
    auto params = _model->named_parameters();
Z
zenghsh3 已提交
144
    int key = _sampling_method->sampling(_noise, _param_size);
Z
zhoubo01 已提交
145
    sampling_info.add_key(key);
Z
zenghsh3 已提交
146
    int64_t counter = 0;
Z
zenghsh3 已提交
147 148
    for (auto& param: sampling_params) {
      torch::Tensor sampling_tensor = param.value().view({-1});
Z
zhoubo01 已提交
149
      std::string param_name = param.key();
150
      torch::Tensor tensor = params.find(param_name)->view({-1});
Z
zenghsh3 已提交
151
      auto sampling_tensor_a = sampling_tensor.accessor<float,1>();
152
      auto tensor_a = tensor.accessor<float,1>();
Z
zenghsh3 已提交
153
      for (int64_t j = 0; j < tensor.size(0); ++j) {
Z
zenghsh3 已提交
154
        sampling_tensor_a[j] = tensor_a[j] + _noise[counter + j];
155 156 157
      }
      counter += tensor.size(0);
    }
158
    return true;
159 160
  }

161
  
162 163

private:
Z
zenghsh3 已提交
164
  int64_t _calculate_param_size() {
Z
zhoubo01 已提交
165
    _param_size = 0;
Z
zenghsh3 已提交
166 167 168 169 170 171 172 173
    auto params = _model->named_parameters();
    for (auto& param: params) {
      torch::Tensor tensor = param.value().view({-1});
      _param_size += tensor.size(0);
    }
    return _param_size;
  }

174
  std::shared_ptr<T> _model;
Z
zenghsh3 已提交
175
  std::shared_ptr<T> _sampling_model;
176
  bool _is_sampling_agent;
177 178 179
  std::shared_ptr<SamplingMethod> _sampling_method;
  std::shared_ptr<Optimizer> _optimizer;
  std::shared_ptr<DeepESConfig> _config;
Z
zenghsh3 已提交
180 181 182 183
  int64_t _param_size;
  // malloc memory of noise and neg_gradients in advance.
  float* _noise;
  float* _neg_gradients;
184 185 186
};

}
Z
zenghsh3 已提交
187 188

#endif /* TORCH_ESAGENT_H */