提交 0fc42012 编写于 作者: D dzhwinter

"update interface"

上级 33b4deed
......@@ -10,9 +10,6 @@ public:
AdadeltaOptimizer(
Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay)
: ParameterOptimizer(parameter, lr),
accum_gradient_(nullptr),
accum_delta_(nullptr),
update_delta_(nullptr),
rho_(rho),
epsilon_(epsilon),
decay_(decay) {
......
......@@ -20,6 +20,8 @@ public:
if (accum_gradient_) delete accum_gradient_;
}
void Update(const Tensor *gradient);
const char *SerializeState(int *state_len);
void DeSerializeState(const std::string &state);
private:
Tensor *accum_gradient_;
......
......@@ -4,13 +4,6 @@
namespace paddle {
namespace optimizer {
void AdamOptimizer::set_weight(Tensor *p) {
parameter_ = p;
size_t size = p->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
void AdamOptimizer::Update(const Tensor *gradient) {
num_sample_passed_ += 1;
double learning_rate = lr_policy_->LearningRate(num_sample_passed_);
......
......@@ -13,18 +13,19 @@ public:
double epsilon,
double decay)
: ParameterOptimizer(parameter, lr),
momentums_(nullptr),
velocitys_(nullptr),
beta_1_(beta_1),
beta_2_(beta_2),
epsilon_(epsilon),
decay_(decay) {}
decay_(decay) {
size_t size = p->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
~AdamOptimizer() {
if (momentums_) delete momentums_;
if (velocitys_) delete velocitys_;
}
void Update(const Tensor *gradient);
void set_weight(Tensor *p);
private:
Tensor *momentums_;
......
......@@ -17,7 +17,7 @@ unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
if (std::is_fundamental<HEAD>::value) {
return sizeof head + CalStateSize(tail...);
} else {
return sizeof(head[0] * head->size()) + CalStateSize(tail...);
return sizeof(head[0]) * head->size() + CalStateSize(tail...);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册