提交 b4aa0eca 编写于 作者: D dzhwinter

"modify update interface"

上级 8610ba1c
......@@ -27,3 +27,4 @@ add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest(optimizer_test)
add_simple_unittest(optimizer_factory_test)
add_simple_unittest(Tensor_test)
......@@ -5,34 +5,42 @@
*/
#include <string.h>
#include "paddle/math/BaseMatrix.h"
#include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h"
namespace paddle {
namespace optimizer {
template <class T>
using TensorBase = BaseMatrixT<T>;
template <class T>
class TensorT : public TensorBase<T> {
class TensorT {
public:
TensorT(T* data, int size) : TensorBase<T>(1, size, 0, data, false, false) {}
TensorT(size_t h, size_t w, T* data) : height_(h), width_(w), data_(data_) {}
TensorT(T* data, int size) : height_(1), width_(size), data_(data) {}
TensorT(const TensorT& t)
: TensorBase<T>(1, t.size(), 0, t.get_buffer(), false, false) {}
: TensorT(1, t.size(), 0, t.get_buffer(), false, false) {}
TensorT& operator=(const TensorT& t) {
this->size_ = t.size();
this->width_ = t.size();
this->data_ = t.get_buffer();
}
T* get_buffer() { return this->data_; }
T& operator[](const int idx) {
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
return this->data_[idx];
return data_[idx];
}
T& operator[](const int idx) const {
CHECK(idx >= 0 && idx < this->width_) << "out of index range";
return data_[idx];
}
// TODO: replace with tensorshape
size_t size() const { return this->width_; }
protected:
size_t height_;
size_t width_;
T* data_;
};
// TODO(zhihong): design problem of dynamic datatype, need to fix
// TODO(zhihong): design problem of dynamic datatype, need to fix it
typedef TensorT<real> Tensor;
} // namespace optimizer
......
#include "Tensor.h"
#include <iostream>
#include "gtest/gtest.h"
using namespace paddle;
using namespace paddle::optimizer;
TEST(Tensor, indexer) {
real* ptr = new real[3];
Tensor t(ptr, 3);
for (auto i = 0; i < t.size(); ++i) {
t[i] = i;
}
ASSERT_EQ(t[2], 2);
ASSERT_EQ(t[1], 1);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -2,6 +2,7 @@
#include <string>
#include "parameter_optimizer.h"
using namespace paddle;
using namespace paddle::optimizer;
template <paddle_element_type VALUE>
......@@ -50,8 +51,8 @@ int paddle_update_parameter(paddle_optimizer* o,
const void* grad_buffer,
int num_bytes) {
// TOOD(zhihong): datatype not work. need to add the runtime datatype
auto grad = reinterpret_cast<const real*>(grad_buffer);
Tensor gradient(const_cast<real*>(grad), num_bytes);
auto grad_type = reinterpret_cast<const real*>(grad_buffer);
Tensor* gradient = new Tensor(const_cast<real*>(grad_type), num_bytes);
o->impl->update(gradient);
return PADDLE_SUCCESS;
}
......
#include <glog/logging.h>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
// #include "adadelta_optimizer.h"
// #include "adagrad_optimizer.h"
// #include "adam_optimizer.h"
#include "lr_policy.h"
#include "sgd_optimizer.h"
......@@ -36,20 +36,20 @@ ParameterOptimizer *ParameterOptimizer::create(
config.sgd().nesterov(),
lr);
}
if (s == "Adadelta") {
return new AdagradOptimizer(
config.adagrad().epsilon(), config.adagrad().decay(), lr);
}
if (s == "Adagrad") {
return new AdagradOptimizer(
config.adagrad().epsilon(), config.adagrad().decay(), lr);
}
if (s == "Adam") {
return new AdadeltaOptimizer(config.adadelta().rho(),
config.adadelta().epsilon(),
config.adadelta().decay(),
lr);
}
// if (s == "Adadelta") {
// return new AdagradOptimizer(
// config.adagrad().epsilon(), config.adagrad().decay(), lr);
// }
// if (s == "Adagrad") {
// return new AdagradOptimizer(
// config.adagrad().epsilon(), config.adagrad().decay(), lr);
// }
// if (s == "Adam") {
// return new AdadeltaOptimizer(config.adadelta().rho(),
// config.adadelta().epsilon(),
// config.adadelta().decay(),
// lr);
// }
// default
return new SGDOptimizer(config.sgd().momentum(),
config.sgd().decay(),
......
......@@ -16,7 +16,8 @@ void SGDOptimizer::set_weight(Tensor *p) {
void SGDOptimizer::update(const Tensor &gradient) {
num_sample_passed += 1;
double learning_rate = lr_policy->get_learning_rate(num_sample_passed);
double velocity = 0.0;
real velocity = 0.0;
Tensor &param = *parameter_;
for (size_t i = 0; i < parameter_->size(); ++i) {
if (momentum == 0.0) {
velocity =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册