“8f156b5e2e222167172f1f2dae32ead2adcf9001”上不存在“develop/api_doc/v2/run_logic.html”
未验证 提交 8c9bf1fa 编写于 作者: B Bo Zhou 提交者: GitHub

add optimizer factory (#230)

* add optimizer factory

* add file:optimizer_factory

* copyright

* remove dependence on algorithm.h
上级 03468ce6
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <omp.h> #include <omp.h>
#include "cartpole.h" #include "cartpole.h"
#include "gaussian_sampling.h"
#include "es_agent.h" #include "es_agent.h"
#include "paddle_api.h" #include "paddle_api.h"
...@@ -101,7 +100,7 @@ int main(int argc, char* argv[]) { ...@@ -101,7 +100,7 @@ int main(int argc, char* argv[]) {
noisy_keys.resize(ITER); noisy_keys.resize(ITER);
omp_set_num_threads(10); omp_set_num_threads(10);
for (int epoch = 0; epoch < 10000; ++epoch) { for (int epoch = 0; epoch < 1000; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1) #pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) { for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i]; std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OPTIMIZER_FACTORY_H
#define OPTIMIZER_FACTORY_H
#include <algorithm>
#include <memory>
#include "optimizer.h"
#include "sgd_optimizer.h"
#include "adam_optimizer.h"
#include "deepes.pb.h"
namespace DeepES{
/* @brief: create an optimizer according to the configuration"
* @args:
* config: configuration for the optimizer
*
*/
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config);
}//namespace
#endif
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#define DEEPES_PADDLE_ES_AGENT_H_ #define DEEPES_PADDLE_ES_AGENT_H_
#include "paddle_api.h" #include "paddle_api.h"
#include "optimizer.h" #include "optimizer_factory.h"
#include "utils.h" #include "utils.h"
#include "gaussian_sampling.h" #include "gaussian_sampling.h"
#include "deepes.pb.h" #include "deepes.pb.h"
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "optimizer.h" #include "optimizer_factory.h"
#include "utils.h" #include "utils.h"
#include "gaussian_sampling.h" #include "gaussian_sampling.h"
#include "deepes.pb.h" #include "deepes.pb.h"
...@@ -49,7 +49,7 @@ public: ...@@ -49,7 +49,7 @@ public:
load_proto_conf(config_path, *_config); load_proto_conf(config_path, *_config);
_sampling_method = std::make_shared<GaussianSampling>(); _sampling_method = std::make_shared<GaussianSampling>();
_sampling_method->load_config(*_config); _sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr()); _optimizer = create_optimizer(_config->optimizer());
// Origin agent can't be used to sample, so keep it same with _model for evaluating. // Origin agent can't be used to sample, so keep it same with _model for evaluating.
_sampling_model = model; _sampling_model = model;
_param_size = _calculate_param_size(); _param_size = _calculate_param_size();
...@@ -125,7 +125,7 @@ public: ...@@ -125,7 +125,7 @@ public:
for (auto& param: params) { for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1}); torch::Tensor tensor = param.value().view({-1});
auto tensor_a = tensor.accessor<float,1>(); auto tensor_a = tensor.accessor<float,1>();
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0)); _optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key());
counter += tensor.size(0); counter += tensor.size(0);
} }
......
...@@ -14,7 +14,7 @@ if [ $1 = "paddle" ]; then ...@@ -14,7 +14,7 @@ if [ $1 = "paddle" ]; then
fi fi
# Initialization model # Initialization model
if [ ! -d ./demo/paddle/cartpole_init_model]; then if [ ! -d ./demo/paddle/cartpole_init_model ]; then
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/ unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/
fi fi
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "optimizer_factory.h"
namespace DeepES{
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config) {
std::shared_ptr<Optimizer> optimizer;
std::string opt_type = optimizer_config.type();
std::transform(opt_type.begin(), opt_type.end(), opt_type.begin(), ::tolower);
if (opt_type == "sgd") {
optimizer = std::make_shared<SGDOptimizer>(optimizer_config.base_lr(), \
optimizer_config.momentum());
}else if (opt_type == "adam") {
optimizer = std::make_shared<AdamOptimizer>(optimizer_config.base_lr(), \
optimizer_config.beta1(), \
optimizer_config.beta2(), \
optimizer_config.epsilon());
}else {
// TODO: NotImplementedError
}
return optimizer;
}
}//namespace
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <vector>
#include "es_agent.h" #include "es_agent.h"
namespace DeepES { namespace DeepES {
...@@ -50,7 +49,7 @@ ESAgent::ESAgent( ...@@ -50,7 +49,7 @@ ESAgent::ESAgent(
_sampling_method = std::make_shared<GaussianSampling>(); _sampling_method = std::make_shared<GaussianSampling>();
_sampling_method->load_config(*_config); _sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr()); _optimizer = create_optimizer(_config->optimizer());
_param_names = _predictor->GetParamNames(); _param_names = _predictor->GetParamNames();
_param_size = _calculate_param_size(); _param_size = _calculate_param_size();
...@@ -108,7 +107,7 @@ bool ESAgent::update( ...@@ -108,7 +107,7 @@ bool ESAgent::update(
std::unique_ptr<Tensor> tensor = _predictor->GetMutableTensor(param_name); std::unique_ptr<Tensor> tensor = _predictor->GetMutableTensor(param_name);
float* tensor_data = tensor->mutable_data<float>(); float* tensor_data = tensor->mutable_data<float>();
int64_t tensor_size = ShapeProduction(tensor->shape()); int64_t tensor_size = ShapeProduction(tensor->shape());
_optimizer->update(tensor_data, _neg_gradients + counter, tensor_size); _optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name);
counter += tensor_size; counter += tensor_size;
} }
return true; return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册