提交 1f185137 编写于 作者: Z zenghsh3

refine naming

上级 7fa307a6
...@@ -78,8 +78,10 @@ class ESAgent { ...@@ -78,8 +78,10 @@ class ESAgent {
std::shared_ptr<PaddlePredictor> get_predictor(); std::shared_ptr<PaddlePredictor> get_predictor();
private: private:
int64_t _calculate_param_size();
std::shared_ptr<PaddlePredictor> _predictor; std::shared_ptr<PaddlePredictor> _predictor;
std::shared_ptr<PaddlePredictor> _sample_predictor; std::shared_ptr<PaddlePredictor> _sampling_predictor;
bool _is_sampling_agent; bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method; std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer; std::shared_ptr<Optimizer> _optimizer;
...@@ -89,8 +91,6 @@ class ESAgent { ...@@ -89,8 +91,6 @@ class ESAgent {
// malloc memory of noise and neg_gradients in advance. // malloc memory of noise and neg_gradients in advance.
float* _noise; float* _noise;
float* _neg_gradients; float* _neg_gradients;
int64_t _calculate_param_size();
}; };
} }
......
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
_sampling_method->load_config(*_config); _sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr()); _optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr());
// Origin agent can't be used to sample, so keep it same with _model for evaluating. // Origin agent can't be used to sample, so keep it same with _model for evaluating.
_sampled_model = model; _sampling_model = model;
_param_size = _calculate_param_size(); _param_size = _calculate_param_size();
_noise = new float [_param_size]; _noise = new float [_param_size];
...@@ -70,7 +70,7 @@ public: ...@@ -70,7 +70,7 @@ public:
new_agent->_model = _model; new_agent->_model = _model;
std::shared_ptr<T> new_model = _model->clone(); std::shared_ptr<T> new_model = _model->clone();
new_agent->_sampled_model = new_model; new_agent->_sampling_model = new_model;
new_agent->_is_sampling_agent = true; new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method; new_agent->_sampling_method = _sampling_method;
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
* if _is_sampling_agent is false, will use the original model without added noise. * if _is_sampling_agent is false, will use the original model without added noise.
*/ */
torch::Tensor predict(const torch::Tensor& x) { torch::Tensor predict(const torch::Tensor& x) {
return _sampled_model->forward(x); return _sampling_model->forward(x);
} }
/** /**
...@@ -139,19 +139,19 @@ public: ...@@ -139,19 +139,19 @@ public:
return false; return false;
} }
auto sampled_params = _sampled_model->named_parameters(); auto sampling_params = _sampling_model->named_parameters();
auto params = _model->named_parameters(); auto params = _model->named_parameters();
int key = _sampling_method->sampling(_noise, _param_size); int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key); sampling_key.add_key(key);
int64_t counter = 0; int64_t counter = 0;
for (auto& param: sampled_params) { for (auto& param: sampling_params) {
torch::Tensor sampled_tensor = param.value().view({-1}); torch::Tensor sampling_tensor = param.value().view({-1});
std::string param_name = param.key(); std::string param_name = param.key();
torch::Tensor tensor = params.find(param_name)->view({-1}); torch::Tensor tensor = params.find(param_name)->view({-1});
auto sampled_tensor_a = sampled_tensor.accessor<float,1>(); auto sampling_tensor_a = sampling_tensor.accessor<float,1>();
auto tensor_a = tensor.accessor<float,1>(); auto tensor_a = tensor.accessor<float,1>();
for (int64_t j = 0; j < tensor.size(0); ++j) { for (int64_t j = 0; j < tensor.size(0); ++j) {
sampled_tensor_a[j] = tensor_a[j] + _noise[counter + j]; sampling_tensor_a[j] = tensor_a[j] + _noise[counter + j];
} }
counter += tensor.size(0); counter += tensor.size(0);
} }
...@@ -161,7 +161,7 @@ public: ...@@ -161,7 +161,7 @@ public:
private: private:
std::shared_ptr<T> _sampled_model; std::shared_ptr<T> _sampling_model;
std::shared_ptr<T> _model; std::shared_ptr<T> _model;
bool _is_sampling_agent; bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method; std::shared_ptr<SamplingMethod> _sampling_method;
......
...@@ -14,7 +14,9 @@ if [ $1 = "paddle" ]; then ...@@ -14,7 +14,9 @@ if [ $1 = "paddle" ]; then
fi fi
# Initialization model # Initialization model
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/ if [ ! -d ./demo/paddle/cartpole_init_model]; then
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/
fi
FLAGS=" -DWITH_PADDLE=ON" FLAGS=" -DWITH_PADDLE=ON"
elif [ $1 = "torch" ]; then elif [ $1 = "torch" ]; then
......
...@@ -42,7 +42,7 @@ ESAgent::ESAgent( ...@@ -42,7 +42,7 @@ ESAgent::ESAgent(
_is_sampling_agent = false; _is_sampling_agent = false;
_predictor = predictor; _predictor = predictor;
// Original agent can't be used to sample, so keep it same with _predictor for evaluating. // Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sample_predictor = predictor; _sampling_predictor = predictor;
_config = std::make_shared<DeepESConfig>(); _config = std::make_shared<DeepESConfig>();
load_proto_conf(config_path, *_config); load_proto_conf(config_path, *_config);
...@@ -60,20 +60,20 @@ ESAgent::ESAgent( ...@@ -60,20 +60,20 @@ ESAgent::ESAgent(
} }
std::shared_ptr<ESAgent> ESAgent::clone() { std::shared_ptr<ESAgent> ESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sample_predictor = _predictor->Clone(); std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>(); std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>();
float* new_noise = new float [_param_size]; float* noise = new float [_param_size];
new_agent->_predictor = _predictor; new_agent->_predictor = _predictor;
new_agent->_sample_predictor = new_sample_predictor; new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true; new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method; new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names; new_agent->_param_names = _param_names;
new_agent->_param_size = _param_size; new_agent->_param_size = _param_size;
new_agent->_noise = new_noise; new_agent->_noise = noise;
return new_agent; return new_agent;
} }
...@@ -126,7 +126,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) { ...@@ -126,7 +126,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) {
int64_t counter = 0; int64_t counter = 0;
for (std::string param_name: _param_names) { for (std::string param_name: _param_names) {
std::unique_ptr<Tensor> sample_tensor = _sample_predictor->GetMutableTensor(param_name); std::unique_ptr<Tensor> sample_tensor = _sampling_predictor->GetMutableTensor(param_name);
std::unique_ptr<const Tensor> tensor = _predictor->GetTensor(param_name); std::unique_ptr<const Tensor> tensor = _predictor->GetTensor(param_name);
int64_t tensor_size = ShapeProduction(tensor->shape()); int64_t tensor_size = ShapeProduction(tensor->shape());
for (int64_t j = 0; j < tensor_size; ++j) { for (int64_t j = 0; j < tensor_size; ++j) {
...@@ -140,7 +140,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) { ...@@ -140,7 +140,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) {
std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() { std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() {
return _sample_predictor; return _sampling_predictor;
} }
int64_t ESAgent::_calculate_param_size() { int64_t ESAgent::_calculate_param_size() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册