提交 0fc42012 编写于 作者: D dzhwinter

"update interface"

上级 33b4deed
...@@ -10,9 +10,6 @@ public: ...@@ -10,9 +10,6 @@ public:
AdadeltaOptimizer( AdadeltaOptimizer(
Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay) Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay)
: ParameterOptimizer(parameter, lr), : ParameterOptimizer(parameter, lr),
accum_gradient_(nullptr),
accum_delta_(nullptr),
update_delta_(nullptr),
rho_(rho), rho_(rho),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) { decay_(decay) {
......
...@@ -20,6 +20,8 @@ public: ...@@ -20,6 +20,8 @@ public:
if (accum_gradient_) delete accum_gradient_; if (accum_gradient_) delete accum_gradient_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
const char *SerializeState(int *state_len);
void DeSerializeState(const std::string &state);
private: private:
Tensor *accum_gradient_; Tensor *accum_gradient_;
......
...@@ -4,13 +4,6 @@ ...@@ -4,13 +4,6 @@
namespace paddle { namespace paddle {
namespace optimizer { namespace optimizer {
void AdamOptimizer::set_weight(Tensor *p) {
parameter_ = p;
size_t size = p->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
void AdamOptimizer::Update(const Tensor *gradient) { void AdamOptimizer::Update(const Tensor *gradient) {
num_sample_passed_ += 1; num_sample_passed_ += 1;
double learning_rate = lr_policy_->LearningRate(num_sample_passed_); double learning_rate = lr_policy_->LearningRate(num_sample_passed_);
......
...@@ -13,18 +13,19 @@ public: ...@@ -13,18 +13,19 @@ public:
double epsilon, double epsilon,
double decay) double decay)
: ParameterOptimizer(parameter, lr), : ParameterOptimizer(parameter, lr),
momentums_(nullptr),
velocitys_(nullptr),
beta_1_(beta_1), beta_1_(beta_1),
beta_2_(beta_2), beta_2_(beta_2),
epsilon_(epsilon), epsilon_(epsilon),
decay_(decay) {} decay_(decay) {
size_t size = p->size();
momentums_ = new Tensor(size);
velocitys_ = new Tensor(size);
}
~AdamOptimizer() { ~AdamOptimizer() {
if (momentums_) delete momentums_; if (momentums_) delete momentums_;
if (velocitys_) delete velocitys_; if (velocitys_) delete velocitys_;
} }
void Update(const Tensor *gradient); void Update(const Tensor *gradient);
void set_weight(Tensor *p);
private: private:
Tensor *momentums_; Tensor *momentums_;
......
...@@ -17,7 +17,7 @@ unsigned CalStateSize(const HEAD& head, const TAIL&... tail) { ...@@ -17,7 +17,7 @@ unsigned CalStateSize(const HEAD& head, const TAIL&... tail) {
if (std::is_fundamental<HEAD>::value) { if (std::is_fundamental<HEAD>::value) {
return sizeof head + CalStateSize(tail...); return sizeof head + CalStateSize(tail...);
} else { } else {
return sizeof(head[0] * head->size()) + CalStateSize(tail...); return sizeof(head[0]) * head->size() + CalStateSize(tail...);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册