diff --git a/deepes/include/torch/es_agent.h b/deepes/include/torch/es_agent.h index 76a456c3199a0c6fae78a04fad9a1841118ea723..486b74e207ac939c27339dd60ef9eb23ce2285b2 100644 --- a/deepes/include/torch/es_agent.h +++ b/deepes/include/torch/es_agent.h @@ -161,8 +161,17 @@ public: private: - std::shared_ptr _sampling_model; + int64_t _calculate_param_size() { + auto params = _model->named_parameters(); + for (auto& param: params) { + torch::Tensor tensor = param.value().view({-1}); + _param_size += tensor.size(0); + } + return _param_size; + } + std::shared_ptr _model; + std::shared_ptr _sampling_model; bool _is_sampling_agent; std::shared_ptr _sampling_method; std::shared_ptr _optimizer; @@ -171,15 +180,6 @@ private: // malloc memory of noise and neg_gradients in advance. float* _noise; float* _neg_gradients; - - int64_t _calculate_param_size() { - auto params = _model->named_parameters(); - for (auto& param: params) { - torch::Tensor tensor = param.value().view({-1}); - _param_size += tensor.size(0); - } - return _param_size; - } }; }