未验证 提交 8c9bf1fa 编写于 作者: B Bo Zhou 提交者: GitHub

add optimizer factory (#230)

* add optimizer factory

* add file:optimizer_factory

* copyright

* remove dependence on algorithm.h
上级 03468ce6
......@@ -16,7 +16,6 @@
#include <glog/logging.h>
#include <omp.h>
#include "cartpole.h"
#include "gaussian_sampling.h"
#include "es_agent.h"
#include "paddle_api.h"
......@@ -101,7 +100,7 @@ int main(int argc, char* argv[]) {
noisy_keys.resize(ITER);
omp_set_num_threads(10);
for (int epoch = 0; epoch < 10000; ++epoch) {
for (int epoch = 0; epoch < 1000; ++epoch) {
#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OPTIMIZER_FACTORY_H
#define OPTIMIZER_FACTORY_H
#include <algorithm>
#include <memory>
#include "optimizer.h"
#include "sgd_optimizer.h"
#include "adam_optimizer.h"
#include "deepes.pb.h"
namespace DeepES{
/* @brief: create an optimizer according to the configuration"
* @args:
* config: configuration for the optimizer
*
*/
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config);
}//namespace
#endif
......@@ -16,7 +16,7 @@
#define DEEPES_PADDLE_ES_AGENT_H_
#include "paddle_api.h"
#include "optimizer.h"
#include "optimizer_factory.h"
#include "utils.h"
#include "gaussian_sampling.h"
#include "deepes.pb.h"
......
......@@ -17,7 +17,7 @@
#include <memory>
#include <string>
#include "optimizer.h"
#include "optimizer_factory.h"
#include "utils.h"
#include "gaussian_sampling.h"
#include "deepes.pb.h"
......@@ -49,7 +49,7 @@ public:
load_proto_conf(config_path, *_config);
_sampling_method = std::make_shared<GaussianSampling>();
_sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr());
_optimizer = create_optimizer(_config->optimizer());
// Origin agent can't be used to sample, so keep it same with _model for evaluating.
_sampling_model = model;
_param_size = _calculate_param_size();
......@@ -125,7 +125,7 @@ public:
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
auto tensor_a = tensor.accessor<float,1>();
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0));
_optimizer->update(tensor_a, _neg_gradients+counter, tensor.size(0), param.key());
counter += tensor.size(0);
}
......
......@@ -14,7 +14,7 @@ if [ $1 = "paddle" ]; then
fi
# Initialization model
if [ ! -d ./demo/paddle/cartpole_init_model]; then
if [ ! -d ./demo/paddle/cartpole_init_model ]; then
unzip ./demo/paddle/cartpole_init_model.zip -d ./demo/paddle/
fi
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "optimizer_factory.h"
namespace DeepES{
std::shared_ptr<Optimizer> create_optimizer(const OptimizerConfig& optimizer_config) {
std::shared_ptr<Optimizer> optimizer;
std::string opt_type = optimizer_config.type();
std::transform(opt_type.begin(), opt_type.end(), opt_type.begin(), ::tolower);
if (opt_type == "sgd") {
optimizer = std::make_shared<SGDOptimizer>(optimizer_config.base_lr(), \
optimizer_config.momentum());
}else if (opt_type == "adam") {
optimizer = std::make_shared<AdamOptimizer>(optimizer_config.base_lr(), \
optimizer_config.beta1(), \
optimizer_config.beta2(), \
optimizer_config.epsilon());
}else {
// TODO: NotImplementedError
}
return optimizer;
}
}//namespace
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "es_agent.h"
namespace DeepES {
......@@ -50,7 +49,7 @@ ESAgent::ESAgent(
_sampling_method = std::make_shared<GaussianSampling>();
_sampling_method->load_config(*_config);
_optimizer = std::make_shared<SGDOptimizer>(_config->optimizer().base_lr());
_optimizer = create_optimizer(_config->optimizer());
_param_names = _predictor->GetParamNames();
_param_size = _calculate_param_size();
......@@ -108,7 +107,7 @@ bool ESAgent::update(
std::unique_ptr<Tensor> tensor = _predictor->GetMutableTensor(param_name);
float* tensor_data = tensor->mutable_data<float>();
int64_t tensor_size = ShapeProduction(tensor->shape());
_optimizer->update(tensor_data, _neg_gradients + counter, tensor_size);
_optimizer->update(tensor_data, _neg_gradients + counter, tensor_size, param_name);
counter += tensor_size;
}
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册