Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b4aa0eca
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b4aa0eca
编写于
6月 05, 2017
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"modify update interface"
上级
8610ba1c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
62 addition
and
30 deletion
+62
-30
paddle/optimizer/CMakeLists.txt
paddle/optimizer/CMakeLists.txt
+1
-0
paddle/optimizer/Tensor.h
paddle/optimizer/Tensor.h
+18
-10
paddle/optimizer/Tensor_test.cpp
paddle/optimizer/Tensor_test.cpp
+21
-0
paddle/optimizer/optimizer.cc
paddle/optimizer/optimizer.cc
+3
-2
paddle/optimizer/parameter_optimizer.cc
paddle/optimizer/parameter_optimizer.cc
+17
-17
paddle/optimizer/sgd_optmizer.cc
paddle/optimizer/sgd_optmizer.cc
+2
-1
未找到文件。
paddle/optimizer/CMakeLists.txt
浏览文件 @
b4aa0eca
...
...
@@ -27,3 +27,4 @@ add_dependencies(optimizer gen_proto_cpp)
add_simple_unittest
(
optimizer_test
)
add_simple_unittest
(
optimizer_factory_test
)
add_simple_unittest
(
Tensor_test
)
paddle/optimizer/Tensor.h
浏览文件 @
b4aa0eca
...
...
@@ -5,34 +5,42 @@
*/
#include <string.h>
#include "paddle/math/BaseMatrix.h"
#include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h"
namespace
paddle
{
namespace
optimizer
{
template
<
class
T
>
using
TensorBase
=
BaseMatrixT
<
T
>
;
template
<
class
T
>
class
TensorT
:
public
TensorBase
<
T
>
{
class
TensorT
{
public:
TensorT
(
T
*
data
,
int
size
)
:
TensorBase
<
T
>
(
1
,
size
,
0
,
data
,
false
,
false
)
{}
TensorT
(
size_t
h
,
size_t
w
,
T
*
data
)
:
height_
(
h
),
width_
(
w
),
data_
(
data_
)
{}
TensorT
(
T
*
data
,
int
size
)
:
height_
(
1
),
width_
(
size
),
data_
(
data
)
{}
TensorT
(
const
TensorT
&
t
)
:
Tensor
Base
<
T
>
(
1
,
t
.
size
(),
0
,
t
.
get_buffer
(),
false
,
false
)
{}
:
Tensor
T
(
1
,
t
.
size
(),
0
,
t
.
get_buffer
(),
false
,
false
)
{}
TensorT
&
operator
=
(
const
TensorT
&
t
)
{
this
->
size
_
=
t
.
size
();
this
->
width
_
=
t
.
size
();
this
->
data_
=
t
.
get_buffer
();
}
T
*
get_buffer
()
{
return
this
->
data_
;
}
T
&
operator
[](
const
int
idx
)
{
CHECK
(
idx
>=
0
&&
idx
<
this
->
width_
)
<<
"out of index range"
;
return
this
->
data_
[
idx
];
return
data_
[
idx
];
}
T
&
operator
[](
const
int
idx
)
const
{
CHECK
(
idx
>=
0
&&
idx
<
this
->
width_
)
<<
"out of index range"
;
return
data_
[
idx
];
}
// TODO: replace with tensorshape
size_t
size
()
const
{
return
this
->
width_
;
}
protected:
size_t
height_
;
size_t
width_
;
T
*
data_
;
};
// TODO(zhihong): design problem of dynamic datatype, need to fix
// TODO(zhihong): design problem of dynamic datatype, need to fix
it
typedef
TensorT
<
real
>
Tensor
;
}
// namespace optimizer
...
...
paddle/optimizer/Tensor_test.cpp
0 → 100644
浏览文件 @
b4aa0eca
#include "Tensor.h"
#include <iostream>
#include "gtest/gtest.h"
using
namespace
paddle
;
using
namespace
paddle
::
optimizer
;
TEST
(
Tensor
,
indexer
)
{
real
*
ptr
=
new
real
[
3
];
Tensor
t
(
ptr
,
3
);
for
(
auto
i
=
0
;
i
<
t
.
size
();
++
i
)
{
t
[
i
]
=
i
;
}
ASSERT_EQ
(
t
[
2
],
2
);
ASSERT_EQ
(
t
[
1
],
1
);
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
return
RUN_ALL_TESTS
();
}
paddle/optimizer/optimizer.cc
浏览文件 @
b4aa0eca
...
...
@@ -2,6 +2,7 @@
#include <string>
#include "parameter_optimizer.h"
using
namespace
paddle
;
using
namespace
paddle
::
optimizer
;
template
<
paddle_element_type
VALUE
>
...
...
@@ -50,8 +51,8 @@ int paddle_update_parameter(paddle_optimizer* o,
const
void
*
grad_buffer
,
int
num_bytes
)
{
// TOOD(zhihong): datatype not work. need to add the runtime datatype
auto
grad
=
reinterpret_cast
<
const
real
*>
(
grad_buffer
);
Tensor
gradient
(
const_cast
<
real
*>
(
grad
),
num_bytes
);
auto
grad
_type
=
reinterpret_cast
<
const
real
*>
(
grad_buffer
);
Tensor
*
gradient
=
new
Tensor
(
const_cast
<
real
*>
(
grad_type
),
num_bytes
);
o
->
impl
->
update
(
gradient
);
return
PADDLE_SUCCESS
;
}
...
...
paddle/optimizer/parameter_optimizer.cc
浏览文件 @
b4aa0eca
#include <glog/logging.h>
#include "adadelta_optimizer.h"
#include "adagrad_optimizer.h"
#include "adam_optimizer.h"
//
#include "adadelta_optimizer.h"
//
#include "adagrad_optimizer.h"
//
#include "adam_optimizer.h"
#include "lr_policy.h"
#include "sgd_optimizer.h"
...
...
@@ -36,20 +36,20 @@ ParameterOptimizer *ParameterOptimizer::create(
config
.
sgd
().
nesterov
(),
lr
);
}
if
(
s
==
"Adadelta"
)
{
return
new
AdagradOptimizer
(
config
.
adagrad
().
epsilon
(),
config
.
adagrad
().
decay
(),
lr
);
}
if
(
s
==
"Adagrad"
)
{
return
new
AdagradOptimizer
(
config
.
adagrad
().
epsilon
(),
config
.
adagrad
().
decay
(),
lr
);
}
if
(
s
==
"Adam"
)
{
return
new
AdadeltaOptimizer
(
config
.
adadelta
().
rho
(),
config
.
adadelta
().
epsilon
(),
config
.
adadelta
().
decay
(),
lr
);
}
//
if (s == "Adadelta") {
//
return new AdagradOptimizer(
//
config.adagrad().epsilon(), config.adagrad().decay(), lr);
//
}
//
if (s == "Adagrad") {
//
return new AdagradOptimizer(
//
config.adagrad().epsilon(), config.adagrad().decay(), lr);
//
}
//
if (s == "Adam") {
//
return new AdadeltaOptimizer(config.adadelta().rho(),
//
config.adadelta().epsilon(),
//
config.adadelta().decay(),
//
lr);
//
}
// default
return
new
SGDOptimizer
(
config
.
sgd
().
momentum
(),
config
.
sgd
().
decay
(),
...
...
paddle/optimizer/sgd_optmizer.cc
浏览文件 @
b4aa0eca
...
...
@@ -16,7 +16,8 @@ void SGDOptimizer::set_weight(Tensor *p) {
void
SGDOptimizer
::
update
(
const
Tensor
&
gradient
)
{
num_sample_passed
+=
1
;
double
learning_rate
=
lr_policy
->
get_learning_rate
(
num_sample_passed
);
double
velocity
=
0.0
;
real
velocity
=
0.0
;
Tensor
&
param
=
*
parameter_
;
for
(
size_t
i
=
0
;
i
<
parameter_
->
size
();
++
i
)
{
if
(
momentum
==
0.0
)
{
velocity
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录