提交 a85ae577 编写于 作者: Z zenghsh3

refine naming

上级 1f185137
......@@ -161,8 +161,17 @@ public:
private:
std::shared_ptr<T> _sampling_model;
int64_t _calculate_param_size() {
auto params = _model->named_parameters();
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
_param_size += tensor.size(0);
}
return _param_size;
}
std::shared_ptr<T> _model;
std::shared_ptr<T> _sampling_model;
bool _is_sampling_agent;
std::shared_ptr<SamplingMethod> _sampling_method;
std::shared_ptr<Optimizer> _optimizer;
......@@ -171,15 +180,6 @@ private:
// malloc memory of noise and neg_gradients in advance.
float* _noise;
float* _neg_gradients;
int64_t _calculate_param_size() {
auto params = _model->named_parameters();
for (auto& param: params) {
torch::Tensor tensor = param.value().view({-1});
_param_size += tensor.size(0);
}
return _param_size;
}
};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册