提交 2f6d1e10 编写于 作者: Z zhoubo01

fix comments

上级 752974cb
...@@ -11,5 +11,5 @@ optimizer { ...@@ -11,5 +11,5 @@ optimizer {
epsilon: 1e-08 epsilon: 1e-08
} }
async_es { async_es {
model_iter_id: 0 model_iter_id: 99
} }
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef _ASYNC_ES_AGENT_H #ifndef ASYNC_ES_AGENT_H
#define _ASYNC_ES_AGENT_H #define ASYNC_ES_AGENT_H
#include "es_agent.h" #include "es_agent.h"
#include <map> #include <map>
...@@ -49,7 +49,10 @@ class AsyncESAgent: public ESAgent { ...@@ -49,7 +49,10 @@ class AsyncESAgent: public ESAgent {
std::shared_ptr<AsyncESAgent> clone(); std::shared_ptr<AsyncESAgent> clone();
/** /**
* @brief: Clone an agent for sampling. * @brief: update parameters given data collected during evaluation.
* @args:
* noisy_info: sampling information returned by add_noise function.
* noisy_reward: evaluation rewards.
*/ */
bool update( bool update(
std::vector<SamplingInfo>& noisy_info, std::vector<SamplingInfo>& noisy_info,
......
...@@ -21,15 +21,13 @@ ...@@ -21,15 +21,13 @@
#include "gaussian_sampling.h" #include "gaussian_sampling.h"
#include "deepes.pb.h" #include "deepes.pb.h"
#include <vector> #include <vector>
using namespace paddle::lite_api;
using namespace paddle::lite_api;
namespace DeepES { namespace DeepES {
int64_t ShapeProduction(const shape_t& shape); int64_t ShapeProduction(const shape_t& shape);
typedef paddle::lite_api::PaddlePredictor PaddlePredictor;
/** /**
* @brief DeepES agent with PaddleLite as backend. * @brief DeepES agent with PaddleLite as backend.
* Users mainly focus on the following functions: * Users mainly focus on the following functions:
......
...@@ -40,7 +40,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { ...@@ -40,7 +40,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
bool success = true; bool success = true;
std::ifstream fin(config_file); std::ifstream fin(config_file);
if (!fin || fin.fail()) { if (!fin || fin.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file; LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false; success = false;
} else { } else {
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
...@@ -52,7 +52,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) { ...@@ -52,7 +52,7 @@ bool load_proto_conf(const std::string& config_file, T& proto_config) {
std::string proto_str(file_content_buffer, file_size); std::string proto_str(file_content_buffer, file_size);
if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) { if (!google::protobuf::TextFormat::ParseFromString(proto_str, &proto_config)) {
LOG(FATAL) << "Failed to load config: " << config_file; LOG(ERROR) << "Failed to load config: " << config_file;
success = false; success = false;
} }
delete[] file_content_buffer; delete[] file_content_buffer;
...@@ -66,7 +66,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) { ...@@ -66,7 +66,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
bool success = true; bool success = true;
std::ofstream ofs(config_file, std::ofstream::out); std::ofstream ofs(config_file, std::ofstream::out);
if (!ofs || ofs.fail()) { if (!ofs || ofs.fail()) {
LOG(FATAL) << "open prototxt config failed: " << config_file; LOG(ERROR) << "open prototxt config failed: " << config_file;
success = false; success = false;
} else { } else {
std::string config_str; std::string config_str;
...@@ -76,6 +76,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) { ...@@ -76,6 +76,7 @@ bool save_proto_conf(const std::string& config_file, T&proto_config) {
} }
ofs << config_str; ofs << config_str;
} }
return success;
} }
std::vector<std::string> list_all_model_dirs(std::string path); std::vector<std::string> list_all_model_dirs(std::string path);
......
...@@ -32,8 +32,6 @@ else ...@@ -32,8 +32,6 @@ else
exit 0 exit 0
fi fi
#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
#----------------protobuf-------------# #----------------protobuf-------------#
cp ./src/proto/deepes.proto ./ cp ./src/proto/deepes.proto ./
protoc deepes.proto --cpp_out ./ protoc deepes.proto --cpp_out ./
......
...@@ -30,7 +30,7 @@ AsyncESAgent::~AsyncESAgent() { ...@@ -30,7 +30,7 @@ AsyncESAgent::~AsyncESAgent() {
bool AsyncESAgent::_save() { bool AsyncESAgent::_save() {
bool success = true; bool success = true;
if (_is_sampling_agent) { if (_is_sampling_agent) {
LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call add_noise function, please use cloned AsyncESAgent."; LOG(ERROR) << "[DeepES] Original AsyncESAgent cannot call `save`.Please use cloned AsyncESAgent.";
success = false; success = false;
return success; return success;
} }
...@@ -49,7 +49,7 @@ bool AsyncESAgent::_save() { ...@@ -49,7 +49,7 @@ bool AsyncESAgent::_save() {
model_name = "model_iter_id-"+ std::to_string(model_iter_id); model_name = "model_iter_id-"+ std::to_string(model_iter_id);
std::string model_path = _config->async_es().model_warehouse() + "/" + model_name; std::string model_path = _config->async_es().model_warehouse() + "/" + model_name;
LOG(INFO) << "[save]model_path: " << model_path; LOG(INFO) << "[save]model_path: " << model_path;
_predictor->SaveOptimizedModel(model_path, LiteModelType::kProtobuf); _predictor->SaveOptimizedModel(model_path, paddle::lite_api::LiteModelType::kProtobuf);
// save config // save config
auto async_es = _config->mutable_async_es(); auto async_es = _config->mutable_async_es();
async_es->set_model_iter_id(model_iter_id); async_es->set_model_iter_id(model_iter_id);
...@@ -93,15 +93,17 @@ bool AsyncESAgent::_compute_model_diff() { ...@@ -93,15 +93,17 @@ bool AsyncESAgent::_compute_model_diff() {
std::shared_ptr<PaddlePredictor> old_predictor = kv.second; std::shared_ptr<PaddlePredictor> old_predictor = kv.second;
float* diff = new float[_param_size]; float* diff = new float[_param_size];
memset(diff, 0, _param_size * sizeof(float)); memset(diff, 0, _param_size * sizeof(float));
for (std::string param_name: _param_names) { int offset = 0;
for (const std::string& param_name: _param_names) {
auto des_tensor = old_predictor->GetTensor(param_name); auto des_tensor = old_predictor->GetTensor(param_name);
auto src_tensor = _predictor->GetTensor(param_name); auto src_tensor = _predictor->GetTensor(param_name);
const float* des_data = des_tensor->data<float>(); const float* des_data = des_tensor->data<float>();
const float* src_data = src_tensor->data<float>(); const float* src_data = src_tensor->data<float>();
int64_t tensor_size = ShapeProduction(src_tensor->shape()); int64_t tensor_size = ShapeProduction(src_tensor->shape());
for (int i = 0; i < tensor_size; ++i) { for (int i = 0; i < tensor_size; ++i) {
diff[i] = des_data[i] - src_data[i]; diff[i + offset] = des_data[i] - src_data[i];
} }
offset += tensor_size;
} }
_param_delta[model_iter_id] = diff; _param_delta[model_iter_id] = diff;
} }
...@@ -206,6 +208,7 @@ bool AsyncESAgent::update( ...@@ -206,6 +208,7 @@ bool AsyncESAgent::update(
float reward = noisy_rewards[i]; float reward = noisy_rewards[i];
int model_iter_id = noisy_info[i].model_iter_id(); int model_iter_id = noisy_info[i].model_iter_id();
bool success = _sampling_method->resampling(key, _noise, _param_size); bool success = _sampling_method->resampling(key, _noise, _param_size);
CHECK(success) << "[DeepES] resampling error occurs at sample: " << i;
float* delta = _param_delta[model_iter_id]; float* delta = _param_delta[model_iter_id];
// compute neg_gradients // compute neg_gradients
if (model_iter_id == current_model_iter_id) { if (model_iter_id == current_model_iter_id) {
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
namespace DeepES { namespace DeepES {
typedef paddle::lite_api::Tensor Tensor;
typedef paddle::lite_api::shape_t shape_t;
int64_t ShapeProduction(const shape_t& shape) { int64_t ShapeProduction(const shape_t& shape) {
int64_t res = 1; int64_t res = 1;
for (auto i : shape) res *= i; for (auto i : shape) res *= i;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册