提交 1f185137 编写于 作者: Z zenghsh3

refine naming

上级 7fa307a6
......@@ -78,8 +78,10 @@ class ESAgent {
std::shared_ptr<PaddlePredictor> get_predictor();
private:
int64_t _calculate_param_size();
std::shared_ptr<PaddlePredictor> _predictor;
std::shared_ptr<PaddlePredictor> _sample_predictor;
std::shared_ptr<PaddlePredictor> _sampling_predictor;
bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer;
......@@ -89,8 +91,6 @@ class ESAgent {
// malloc memory of noise and neg_gradients in advance.
float* _noise;
float* _neg_gradients;
int64_t _calculate_param_size();
};
}
......
......@@ -51,7 +51,7 @@ public:
_sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr());
// Origin agent can't be used to sample, so keep it same with _model for evaluating.
_sampled_model = model;
_sampling_model = model;
_param_size = _calculate_param_size();
_noise = new float [_param_size];
......@@ -70,7 +70,7 @@ public:
new_agent->_model = _model;
std::shared_ptr<T> new_model = _model->clone();
new_agent->_sampled_model = new_model;
new_agent->_sampling_model = new_model;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
......@@ -89,7 +89,7 @@ public:
* if _is_sampling_agent is false, will use the original model without added noise.
*/
torch::Tensor predict(const torch::Tensor& x) {
return _sampled_model->forward(x);
return _sampling_model->forward(x);
}
/**
......@@ -139,19 +139,19 @@ public:
return false;
}
auto sampled_params = _sampled_model->named_parameters();
auto sampling_params = _sampling_model->named_parameters();
auto params = _model->named_parameters();
int key = _sampling_method->sampling(_noise, _param_size);
sampling_key.add_key(key);
int64_t counter = 0;
for (auto& param: sampled_params) {
torch::Tensor sampled_tensor = param.value().view({-1});
for (auto& param: sampling_params) {
torch::Tensor sampling_tensor = param.value().view({-1});
std::string param_name = param.key();
torch::Tensor tensor = params.find(param_name)->view({-1});
auto sampled_tensor_a = sampled_tensor.accessor<float,1>();
auto sampling_tensor_a = sampling_tensor.accessor<float,1>();
auto tensor_a = tensor.accessor<float,1>();
for (int64_t j = 0; j < tensor.size(0); ++j) {
sampled_tensor_a[j] = tensor_a[j] + _noise[counter + j];
sampling_tensor_a[j] = tensor_a[j] + _noise[counter + j];
}
counter += tensor.size(0);
}
......@@ -161,7 +161,7 @@ public:
private:
std::shared_ptr<T> _sampled_model;
std::shared_ptr<T> _sampling_model;
std::shared_ptr<T> _model;
bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method;
......
......@@ -14,7 +14,9 @@ if [ $1 = "paddle" ]; then
fi
# Initialization model
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/
if [ ! -d ./demo/paddle/cartpole_init_model]; then
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/
fi
FLAGS=" -DWITH_PADDLE=ON"
elif [ $1 = "torch" ]; then
......
......@@ -42,7 +42,7 @@ ESAgent::ESAgent(
_is_sampling_agent = false;
_predictor = predictor;
// Original agent can't be used to sample, so keep it same with _predictor for evaluating.
_sample_predictor = predictor;
_sampling_predictor = predictor;
_config = std::make_shared<DeepESConfig>();
load_proto_conf(config_path, *_config);
......@@ -60,20 +60,20 @@ ESAgent::ESAgent(
}
std::shared_ptr<ESAgent> ESAgent::clone() {
std::shared_ptr<PaddlePredictor> new_sample_predictor = _predictor->Clone();
std::shared_ptr<PaddlePredictor> new_sampling_predictor = _predictor->Clone();
std::shared_ptr<ESAgent> new_agent = std::make_shared<ESAgent>();
float* new_noise = new float [_param_size];
float* noise = new float [_param_size];
new_agent->_predictor = _predictor;
new_agent->_sample_predictor = new_sample_predictor;
new_agent->_sampling_predictor = new_sampling_predictor;
new_agent->_is_sampling_agent = true;
new_agent->_sampling_method = _sampling_method;
new_agent->_param_names = _param_names;
new_agent->_param_size = _param_size;
new_agent->_noise = new_noise;
new_agent->_noise = noise;
return new_agent;
}
......@@ -126,7 +126,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) {
int64_t counter = 0;
for (std::string param_name: _param_names) {
std::unique_ptr<Tensor> sample_tensor = _sample_predictor->GetMutableTensor(param_name);
std::unique_ptr<Tensor> sample_tensor = _sampling_predictor->GetMutableTensor(param_name);
std::unique_ptr<const Tensor> tensor = _predictor->GetTensor(param_name);
int64_t tensor_size = ShapeProduction(tensor->shape());
for (int64_t j = 0; j < tensor_size; ++j) {
......@@ -140,7 +140,7 @@ bool ESAgent::add_noise(SamplingKey& sampling_key) {
std::shared_ptr<PaddlePredictor> ESAgent::get_predictor() {
return _sample_predictor;
return _sampling_predictor;
}
int64_t ESAgent::_calculate_param_size() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册