Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
23761bea
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
23761bea
编写于
1月 02, 2019
作者:
Q
Qiyang Min
提交者:
GitHub
1月 02, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14971 from velconia/imperative_mnist
Imperative Optimizer
上级
8eb1f262
22956530
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
857 addition
and
266 deletion
+857
-266
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+9
-0
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+27
-0
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+25
-7
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+20
-11
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+6
-8
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+2
-3
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+23
-8
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+41
-8
python/paddle/fluid/imperative/__init__.py
python/paddle/fluid/imperative/__init__.py
+4
-0
python/paddle/fluid/imperative/base.py
python/paddle/fluid/imperative/base.py
+2
-3
python/paddle/fluid/imperative/layers.py
python/paddle/fluid/imperative/layers.py
+14
-19
python/paddle/fluid/imperative/nn.py
python/paddle/fluid/imperative/nn.py
+250
-0
python/paddle/fluid/initializer.py
python/paddle/fluid/initializer.py
+16
-8
python/paddle/fluid/layer_helper.py
python/paddle/fluid/layer_helper.py
+13
-12
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+71
-115
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+25
-19
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+64
-29
python/paddle/fluid/tests/unittests/test_imperative.py
python/paddle/fluid/tests/unittests/test_imperative.py
+9
-16
python/paddle/fluid/tests/unittests/test_imperative_base.py
python/paddle/fluid/tests/unittests/test_imperative_base.py
+30
-0
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
...paddle/fluid/tests/unittests/test_imperative_optimizer.py
+206
-0
未找到文件。
paddle/fluid/framework/operator.h
浏览文件 @
23761bea
...
...
@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) {
return
result
;
}
inline
std
::
string
GradOriginalVarName
(
const
std
::
string
&
grad_var_name
)
{
std
::
size_t
pos
=
grad_var_name
.
rfind
(
kGradVarSuffix
);
if
(
pos
==
std
::
string
::
npos
)
{
return
grad_var_name
;
}
else
{
return
grad_var_name
.
substr
(
0
,
pos
);
}
}
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
);
const
Tensor
*
GetLoDTensorOrSelectedRowsValueFromVar
(
const
Variable
&
var
);
Tensor
*
GetMutableLoDTensorOrSelectedRowsValueFromVar
(
Variable
*
var
);
...
...
paddle/fluid/framework/operator_test.cc
浏览文件 @
23761bea
...
...
@@ -288,3 +288,30 @@ TEST(OpKernel, multi_inputs) {
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_place
);
}
TEST
(
VarNameTest
,
all
)
{
std
::
string
var_name
(
"X"
);
std
::
string
grad_var_name
=
paddle
::
framework
::
GradVarName
(
var_name
);
ASSERT_EQ
(
grad_var_name
,
"X@GRAD"
);
std
::
string
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
grad_var_name
);
ASSERT_EQ
(
original_var_name
,
"X"
);
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
original_var_name
);
ASSERT_EQ
(
original_var_name
,
"X"
);
std
::
string
var_name_2
(
"XYZ"
);
grad_var_name
=
paddle
::
framework
::
GradVarName
(
var_name_2
);
ASSERT_EQ
(
grad_var_name
,
"XYZ@GRAD"
);
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
grad_var_name
);
ASSERT_EQ
(
original_var_name
,
"XYZ"
);
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
original_var_name
);
ASSERT_EQ
(
original_var_name
,
"XYZ"
);
std
::
string
var_name_3
(
""
);
grad_var_name
=
paddle
::
framework
::
GradVarName
(
var_name_3
);
ASSERT_EQ
(
grad_var_name
,
"@GRAD"
);
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
grad_var_name
);
ASSERT_EQ
(
original_var_name
,
""
);
original_var_name
=
paddle
::
framework
::
GradOriginalVarName
(
original_var_name
);
ASSERT_EQ
(
original_var_name
,
""
);
}
paddle/fluid/imperative/layer.cc
浏览文件 @
23761bea
...
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
...
...
@@ -31,8 +32,14 @@ using framework::Variable;
void
AddTo
(
Variable
*
src
,
Variable
*
dst
)
{
framework
::
LoDTensor
*
dst_tensor
=
dst
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
LoDTensor
*
src_tensor
=
src
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
dst_tensor
->
numel
()
==
src_tensor
->
numel
(),
"%lld vs %lld"
,
dst_tensor
->
numel
(),
src_tensor
->
numel
());
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if
(
src_tensor
->
numel
()
==
0
)
{
return
;
}
PADDLE_ENFORCE
(
dst_tensor
->
numel
()
==
src_tensor
->
numel
(),
"dst_numel %lld vs. src_numel %lld"
,
dst_tensor
->
numel
(),
src_tensor
->
numel
());
float
*
dst_data
=
dst_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
const
float
*
src_data
=
src_tensor
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
src_tensor
->
numel
();
++
i
)
{
...
...
@@ -45,6 +52,10 @@ class Autograd {
Autograd
()
{}
void
RunBackward
(
VarBase
*
var
)
{
if
(
var
->
stop_gradient_
)
{
return
;
}
std
::
deque
<
OpBase
*>
ready
;
ready
.
push_back
(
var
->
pre_op_
);
...
...
@@ -60,6 +71,9 @@ class Autograd {
const
std
::
vector
<
VarBase
*>&
ingrads
=
it
.
second
;
for
(
size_t
i
=
0
;
i
<
ingrads
.
size
();
++
i
)
{
if
(
!
ingrads
[
i
])
continue
;
if
(
ready_op
->
input_vars_
[
it
.
first
][
i
]
->
stop_gradient_
)
{
continue
;
}
OpBase
*
pre_op
=
ready_op
->
pre_ops_
[
it
.
first
][
i
];
if
(
!
pre_op
)
continue
;
...
...
@@ -107,7 +121,7 @@ framework::LoDTensor& VarBase::Grad() {
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
OpBase
::
ApplyGrad
()
{
if
(
!
grad_op_desc_
)
{
VLOG
(
3
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
LOG
(
WARNING
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
return
{};
}
VLOG
(
3
)
<<
"op grad "
<<
grad_op_desc_
->
Type
();
...
...
@@ -117,15 +131,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for
(
auto
it
:
grad_output_vars_
)
{
auto
&
outputs
=
grad_outputs
[
it
.
first
];
for
(
size_t
i
=
0
;
i
<
it
.
second
.
size
();
++
i
)
{
tmp_vars
.
emplace_back
(
new
framework
::
Variable
());
outputs
.
push_back
(
tmp_vars
.
back
().
get
());
outputs
.
back
()
->
GetMutable
<
framework
::
LoDTensor
>
();
// Allocate a new variable
Variable
*
tmp_var
=
new
framework
::
Variable
();
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tmp_vars
.
emplace_back
(
tmp_var
);
outputs
.
push_back
(
tmp_var
);
}
}
framework
::
RuntimeContext
ctx
(
grad_input_vars_
,
grad_outputs
);
// No need to do
static
infer shape here.
// No need to do
compile time
infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc_
->
InferVarType
(
block_
);
...
...
@@ -144,6 +161,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for
(
auto
it
:
grad_output_vars_
)
{
auto
&
outputs
=
grad_outputs
[
it
.
first
];
auto
&
origin_outputs
=
it
.
second
;
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
framework
::
Variable
*
orig_grad
=
origin_outputs
[
i
];
AddTo
(
outputs
[
i
],
orig_grad
);
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
23761bea
...
...
@@ -86,23 +86,30 @@ class VarBase {
pre_op_out_idx_
(
-
1
),
var_desc_
(
nullptr
),
var_
(
new
framework
::
Variable
()),
grads_
(
new
framework
::
Variable
())
{}
grads_
(
new
framework
::
Variable
()),
stop_gradient_
(
false
)
{}
virtual
~
VarBase
()
{
if
(
var_
)
{
delete
var_
;
var_
=
nullptr
;
}
if
(
grads_
)
{
delete
grads_
;
grads_
=
nullptr
;
}
}
explicit
VarBase
(
bool
stop_gradient
)
:
pre_op_
(
nullptr
),
pre_op_out_idx_
(
-
1
),
var_desc_
(
nullptr
),
var_
(
new
framework
::
Variable
()),
grads_
(
new
framework
::
Variable
()),
stop_gradient_
(
stop_gradient
)
{}
virtual
~
VarBase
()
{}
void
RunBackward
();
framework
::
LoDTensor
&
Grad
();
inline
std
::
string
GradName
()
const
{
PADDLE_ENFORCE
(
var_desc_
,
"Couldn't get gradient variable's name, please call backward() first"
);
return
string
::
Sprintf
(
"%s@IGrad"
,
var_desc_
->
Name
());
}
OpBase
*
pre_op_
;
std
::
string
pre_op_out_name_
;
int
pre_op_out_idx_
;
...
...
@@ -110,6 +117,8 @@ class VarBase {
framework
::
VarDesc
*
var_desc_
;
framework
::
Variable
*
var_
;
framework
::
Variable
*
grads_
;
bool
stop_gradient_
;
};
class
OpBase
{
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
23761bea
...
...
@@ -50,16 +50,14 @@ void InitVar(framework::Variable* var, framework::Variable* grad_var) {
class
Tracer
{
public:
explicit
Tracer
(
framework
::
BlockDesc
*
root_block
,
framework
::
BlockDesc
*
startup_block
)
:
root_block_
(
root_block
),
startup_block_
(
startup_block
)
{}
explicit
Tracer
(
framework
::
BlockDesc
*
root_block
)
:
root_block_
(
root_block
)
{}
virtual
~
Tracer
()
{}
void
Trace
(
OpBase
*
op
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>&
inputs
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>&
outputs
,
framework
::
BlockDesc
*
block
)
{
framework
::
BlockDesc
*
block
,
const
bool
stop_gradient
=
false
)
{
std
::
map
<
std
::
string
,
VarBase
*>
vars
;
framework
::
OpDesc
*
op_desc
=
op
->
op_desc_
;
...
...
@@ -107,6 +105,7 @@ class Tracer {
}
else
{
LOG
(
ERROR
)
<<
"tracer doesn't support yet"
;
}
out
->
stop_gradient_
=
stop_gradient
;
out
->
pre_op_
=
op
;
out
->
pre_op_out_name_
=
it
.
first
;
out
->
pre_op_out_idx_
=
i
;
...
...
@@ -130,9 +129,7 @@ class Tracer {
p
.
op
.
RuntimeInferShape
(
scope
,
place
,
ctx
);
p
.
func
(
framework
::
ExecutionContext
(
p
.
op
,
scope
,
*
p
.
dev_ctx
,
p
.
ctx
));
if
(
block
==
startup_block_
)
{
op
->
grad_op_desc_
=
nullptr
;
}
else
{
if
(
!
stop_gradient
)
{
framework
::
OpDesc
*
grad_op_desc
;
auto
grad_to_var
=
new
std
::
unordered_map
<
std
::
string
,
std
::
string
>
();
CreateGradOp
(
*
op_desc
,
{},
{
block
},
&
grad_op_desc
,
grad_to_var
);
...
...
@@ -156,6 +153,7 @@ class Tracer {
}
}
}
for
(
auto
it
:
grad_op_desc
->
Outputs
())
{
auto
&
grad_out_vars
=
op
->
grad_output_vars_
[
it
.
first
];
for
(
const
std
::
string
&
grad_outvar
:
it
.
second
)
{
...
...
@@ -170,12 +168,12 @@ class Tracer {
}
}
}
op
->
block_
=
block
;
}
private:
framework
::
BlockDesc
*
root_block_
;
framework
::
BlockDesc
*
startup_block_
;
};
}
// namespace imperative
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
23761bea
...
...
@@ -23,9 +23,8 @@ namespace pybind {
void
BindTracer
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
imperative
::
Tracer
>
(
*
m
,
"Tracer"
,
""
)
.
def
(
"__init__"
,
[](
imperative
::
Tracer
&
self
,
framework
::
BlockDesc
*
root_block
,
framework
::
BlockDesc
*
startup_block
)
{
new
(
&
self
)
imperative
::
Tracer
(
root_block
,
startup_block
);
[](
imperative
::
Tracer
&
self
,
framework
::
BlockDesc
*
root_block
)
{
new
(
&
self
)
imperative
::
Tracer
(
root_block
);
})
.
def
(
"trace"
,
&
imperative
::
Tracer
::
Trace
);
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
23761bea
...
...
@@ -125,11 +125,26 @@ PYBIND11_MODULE(core, m) {
m
.
add_object
(
"_cleanup"
,
py
::
capsule
([]()
{
ScopePool
::
Instance
().
Clear
();
}));
py
::
class_
<
imperative
::
VarBase
,
PyVarBase
>
(
m
,
"VarBase"
,
R"DOC()DOC"
)
.
def
(
py
::
init
<>
())
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
m
,
"VarBase"
,
R"DOC()DOC"
)
// .def(py::init<>())
.
def
(
py
::
init
<
bool
>
(),
py
::
arg
(
"stop_gradient"
)
=
false
)
.
def
(
"_run_backward"
,
[](
imperative
::
VarBase
&
self
)
{
self
.
RunBackward
();
})
.
def
(
"_grad_name"
,
&
imperative
::
VarBase
::
GradName
)
.
def
(
"_grad"
,
&
imperative
::
VarBase
::
Grad
)
.
def_property
(
"grad_value"
,
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
grads_
;
},
[](
imperative
::
VarBase
&
self
,
framework
::
Variable
*
grad
)
{
self
.
grads_
=
grad
;
},
py
::
return_value_policy
::
reference
)
.
def_property
(
"value"
,
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
var_
;
},
[](
imperative
::
VarBase
&
self
,
framework
::
Variable
*
var
)
{
self
.
var_
=
var
;
},
py
::
return_value_policy
::
reference
)
.
def_property
(
"desc"
,
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
var_desc_
;
},
...
...
@@ -137,12 +152,12 @@ PYBIND11_MODULE(core, m) {
self
.
var_desc_
=
var_desc
;
},
py
::
return_value_policy
::
reference
)
.
def_property
(
"var"
,
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
var_
;
}
,
[](
imperative
::
VarBase
&
self
,
framework
::
Variable
*
var
)
{
self
.
var_
=
var
;
},
py
::
return_value_policy
::
reference
);
.
def_property
(
"stop_gradient"
,
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
stop_gradient_
;
},
[](
imperative
::
VarBase
&
self
,
bool
stop_gradient
)
{
self
.
stop_gradient_
=
stop_gradient
;
}
);
py
::
class_
<
imperative
::
OpBase
,
PyOpBase
>
(
m
,
"OpBase"
,
R"DOC()DOC"
)
.
def
(
py
::
init
<>
())
...
...
python/paddle/fluid/framework.py
浏览文件 @
23761bea
...
...
@@ -20,7 +20,6 @@ import contextlib
import
os
import
re
import
six
import
sys
import
numpy
as
np
...
...
@@ -368,9 +367,10 @@ class Variable(object):
if
_in_imperative_mode
():
self
.
_ivar
=
core
.
VarBase
()
self
.
_ivar
.
desc
=
self
.
desc
self
.
_ivar
.
stop_gradient
=
stop_gradient
def
_numpy
(
self
):
tensor
=
self
.
_ivar
.
va
r
.
get_tensor
()
tensor
=
self
.
_ivar
.
va
lue
.
get_tensor
()
return
np
.
array
(
tensor
)
def
_backward
(
self
):
...
...
@@ -379,6 +379,14 @@ class Variable(object):
def
_gradient
(
self
):
return
np
.
array
(
self
.
_ivar
.
_grad
())
@
property
def
_value
(
self
):
return
self
.
_ivar
.
value
@
_value
.
setter
def
_value
(
self
,
v
):
self
.
_ivar
.
value
=
v
def
__str__
(
self
):
return
self
.
to_string
(
True
)
...
...
@@ -422,6 +430,14 @@ class Variable(object):
"""
self
.
desc
=
input
@
property
def
_stop_gradient
(
self
):
return
self
.
_ivar
.
stop_gradient
@
_stop_gradient
.
setter
def
_stop_gradient
(
self
,
s
):
self
.
_ivar
.
stop_gradient
=
s
@
property
def
persistable
(
self
):
return
self
.
desc
.
persistable
()
...
...
@@ -681,9 +697,11 @@ class Operator(object):
self
.
_update_desc_attr
(
attr_name
,
attr_val
)
self
.
desc
.
check_attrs
()
if
self
.
_has_kernel
(
type
):
self
.
desc
.
infer_var_type
(
self
.
block
.
desc
)
self
.
desc
.
infer_shape
(
self
.
block
.
desc
)
if
_in_imperative_mode
():
self
.
iop
=
core
.
OpBase
()
self
.
iop
.
desc
=
self
.
desc
...
...
@@ -1266,12 +1284,22 @@ class Block(object):
Operator: the append Operator.
"""
op_desc
=
self
.
desc
.
append_op
()
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
*
args
,
**
kwargs
)
if
_in_imperative_mode
():
_imperative_tracer
().
trace
(
op
.
iop
,
op
.
inputs
,
op
.
outputs
,
self
.
desc
)
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
type
=
kwargs
.
get
(
"type"
,
None
),
inputs
=
kwargs
.
get
(
"inputs"
,
None
),
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
self
.
ops
.
append
(
op
)
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
return
op
def
_trace_op
(
self
,
op
,
stop_gradient
=
False
):
if
_in_imperative_mode
():
_imperative_tracer
().
trace
(
op
.
iop
,
op
.
inputs
,
op
.
outputs
,
self
.
desc
,
stop_gradient
)
def
_insert_op
(
self
,
index
,
*
args
,
**
kwargs
):
"""
Insert a Operator according to the giving arguments.
...
...
@@ -1317,10 +1345,15 @@ class Block(object):
def
_prepend_op
(
self
,
*
args
,
**
kwargs
):
op_desc
=
self
.
desc
.
_prepend_op
()
op
=
Operator
(
self
,
op_desc
,
*
args
,
**
kwargs
)
if
_in_imperative_mode
():
_imperative_tracer
().
trace
(
op
.
iop
,
op
.
inputs
,
op
.
outputs
,
self
.
desc
)
op
=
Operator
(
self
,
op_desc
,
type
=
kwargs
.
get
(
"type"
,
None
),
inputs
=
kwargs
.
get
(
"inputs"
,
None
),
outputs
=
kwargs
.
get
(
"outputs"
,
None
),
attrs
=
kwargs
.
get
(
"attrs"
,
None
))
self
.
ops
.
insert
(
0
,
op
)
self
.
_trace_op
(
op
,
kwargs
.
get
(
"stop_gradient"
,
False
))
return
op
def
_sync_with_cpp
(
self
):
...
...
python/paddle/fluid/imperative/__init__.py
浏览文件 @
23761bea
...
...
@@ -20,6 +20,10 @@ from .base import *
from
.
import
layers
from
.layers
import
*
from
.
import
nn
from
.nn
import
*
__all__
=
[]
__all__
+=
layers
.
__all__
__all__
+=
base
.
__all__
__all__
+=
nn
.
__all__
python/paddle/fluid/imperative/base.py
浏览文件 @
23761bea
...
...
@@ -28,8 +28,7 @@ def enabled():
def
guard
():
train
=
framework
.
Program
()
startup
=
framework
.
Program
()
tracer
=
core
.
Tracer
(
train
.
current_block
().
desc
,
startup
.
current_block
().
desc
)
tracer
=
core
.
Tracer
(
train
.
current_block
().
desc
)
with
framework
.
program_guard
(
train
,
startup
):
with
framework
.
unique_name
.
guard
():
with
framework
.
_imperative_guard
(
tracer
):
...
...
@@ -46,7 +45,7 @@ def to_variable(value, block=None):
name
=
None
,
shape
=
value
.
shape
,
dtype
=
value
.
dtype
)
var
=
py_var
.
_ivar
.
va
r
var
=
py_var
.
_ivar
.
va
lue
tensor
=
var
.
get_tensor
()
tensor
.
set
(
value
,
core
.
CPUPlace
())
return
py_var
...
...
python/paddle/fluid/imperative/layers.py
浏览文件 @
23761bea
...
...
@@ -24,26 +24,21 @@ __all__ = ['PyLayer']
class
PyLayer
(
core
.
Layer
):
def
__init__
(
self
):
self
.
_built
=
False
def
__call__
(
self
,
inputs
):
if
not
isinstance
(
inputs
,
list
)
and
not
isinstance
(
inputs
,
tuple
):
inputs
=
[
inputs
]
var_inputs
=
[]
for
x
in
inputs
:
py_var
=
base
.
to_variable
(
x
)
var_inputs
.
append
(
py_var
)
if
not
self
.
_built
:
self
.
_build_once
(
inputs
)
self
.
_built
=
True
outputs
=
self
.
forward
(
var_inputs
)
return
outputs
def
__init__
(
self
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
name
=
None
):
self
.
_once_built
=
False
self
.
_dtype
=
dtype
def
_build_once
(
self
,
inputs
):
pass
def
forward
(
self
,
inputs
):
return
[]
def
__call__
(
self
,
*
inputs
):
if
not
self
.
_once_built
:
self
.
_build_once
(
*
inputs
)
self
.
_once_built
=
True
outputs
=
self
.
forward
(
*
inputs
)
return
outputs
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
python/paddle/fluid/imperative/nn.py
0 → 100644
浏览文件 @
23761bea
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
six.moves
import
reduce
from
..
import
core
from
..layers
import
utils
from
.
import
layers
from
..framework
import
Variable
,
OpProtoHolder
from
..param_attr
import
ParamAttr
from
..initializer
import
Normal
,
Constant
__all__
=
[
'Conv2D'
,
'Pool2D'
,
'FC'
,
]
class
Conv2D
(
layers
.
PyLayer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
None
,
use_cudnn
=
True
,
act
=
None
,
param_attr
=
None
,
bias_attr
=
None
,
name
=
None
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
):
assert
param_attr
is
not
False
,
"param_attr should not be False here."
super
(
Conv2D
,
self
).
__init__
(
name
=
name
,
dtype
=
dtype
)
from
..layer_helper
import
LayerHelper
self
.
_helper
=
LayerHelper
(
type
(
self
).
__name__
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
dtype
=
dtype
,
name
=
name
)
self
.
_groups
=
groups
self
.
_stride
=
utils
.
convert_to_list
(
stride
,
2
,
'stride'
)
self
.
_padding
=
utils
.
convert_to_list
(
padding
,
2
,
'padding'
)
self
.
_dilation
=
utils
.
convert_to_list
(
dilation
,
2
,
'dilation'
)
if
not
isinstance
(
use_cudnn
,
bool
):
raise
ValueError
(
"use_cudnn should be True or False"
)
self
.
_use_cudnn
=
use_cudnn
self
.
_num_channels
=
num_channels
if
(
self
.
_num_channels
==
self
.
_groups
and
num_filters
%
self
.
_num_channels
==
0
and
not
self
.
_use_cudnn
):
self
.
_l_type
=
'depthwise_conv2d'
else
:
self
.
_l_type
=
'conv2d'
if
groups
is
None
:
num_filter_channels
=
num_channels
else
:
if
num_channels
%
groups
!=
0
:
raise
ValueError
(
"num_channels must be divisible by groups."
)
num_filter_channels
=
num_channels
//
groups
filter_size
=
utils
.
convert_to_list
(
filter_size
,
2
,
'filter_size'
)
filter_shape
=
[
num_filters
,
int
(
num_filter_channels
)]
+
filter_size
def
_get_default_param_initializer
():
filter_elem_num
=
filter_size
[
0
]
*
filter_size
[
1
]
*
num_channels
std
=
(
2.0
/
filter_elem_num
)
**
0.5
return
Normal
(
0.0
,
std
,
0
)
self
.
_filter_param
=
self
.
_helper
.
create_parameter
(
attr
=
self
.
_helper
.
param_attr
,
shape
=
filter_shape
,
dtype
=
self
.
_dtype
,
default_initializer
=
_get_default_param_initializer
())
if
self
.
_use_cudnn
:
self
.
_helper
.
create_variable
(
name
=
"kCUDNNFwdAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
_helper
.
create_variable
(
name
=
"kCUDNNBwdDataAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
_helper
.
create_variable
(
name
=
"kCUDNNBwdFilterAlgoCache"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
_bias_param
=
self
.
_helper
.
create_parameter
(
attr
=
self
.
_helper
.
bias_attr
,
shape
=
[
num_filters
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
):
pre_bias
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
self
.
_l_type
,
inputs
=
{
'Input'
:
input
,
'Filter'
:
self
.
_filter_param
,
},
outputs
=
{
"Output"
:
pre_bias
},
attrs
=
{
'strides'
:
self
.
_stride
,
'paddings'
:
self
.
_padding
,
'dilations'
:
self
.
_dilation
,
'groups'
:
self
.
_groups
,
'use_cudnn'
:
self
.
_use_cudnn
,
'use_mkldnn'
:
False
,
})
pre_act
=
self
.
_helper
.
create_variable_for_type_inference
(
dtype
=
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
'elementwise_add'
,
inputs
=
{
'X'
:
[
pre_bias
],
'Y'
:
[
self
.
_bias_param
]},
outputs
=
{
'Out'
:
[
pre_act
]},
attrs
=
{
'axis'
:
1
})
return
self
.
_helper
.
append_activation
(
pre_act
)
class
Pool2D
(
layers
.
PyLayer
):
def
__init__
(
self
,
pool_size
=-
1
,
pool_type
=
"max"
,
pool_stride
=
1
,
pool_padding
=
0
,
global_pooling
=
False
,
use_cudnn
=
True
,
ceil_mode
=
False
,
exclusive
=
True
,
name
=
None
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
):
if
pool_type
not
in
[
"max"
,
"avg"
]:
raise
ValueError
(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'."
,
str
(
pool_type
))
if
global_pooling
is
False
and
pool_size
==
-
1
:
raise
ValueError
(
"When the global_pooling is False, pool_size must be passed "
"and be a valid value. Received pool_size: "
+
str
(
pool_size
))
if
not
isinstance
(
use_cudnn
,
bool
):
raise
ValueError
(
"use_cudnn should be True or False"
)
super
(
Pool2D
,
self
).
__init__
(
name
=
name
,
dtype
=
dtype
)
from
..layer_helper
import
LayerHelper
self
.
_helper
=
LayerHelper
(
type
(
self
).
__name__
,
dtype
=
dtype
,
name
=
name
)
self
.
_pool_type
=
pool_type
self
.
_pool_size
=
utils
.
convert_to_list
(
pool_size
,
2
,
'pool_size'
)
self
.
_pool_padding
=
utils
.
convert_to_list
(
pool_padding
,
2
,
'pool_padding'
)
self
.
_pool_stride
=
utils
.
convert_to_list
(
pool_stride
,
2
,
'pool_stride'
)
self
.
_global_pooling
=
global_pooling
self
.
_use_cudnn
=
use_cudnn
self
.
_ceil_mode
=
ceil_mode
self
.
_exclusive
=
exclusive
self
.
_l_type
=
'pool2d'
def
forward
(
self
,
input
):
pool_out
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
self
.
_l_type
,
inputs
=
{
"X"
:
input
},
outputs
=
{
"Out"
:
pool_out
},
attrs
=
{
"pooling_type"
:
self
.
_pool_type
,
"ksize"
:
self
.
_pool_size
,
"global_pooling"
:
self
.
_global_pooling
,
"strides"
:
self
.
_pool_stride
,
"paddings"
:
self
.
_pool_padding
,
"use_cudnn"
:
self
.
_use_cudnn
,
"ceil_mode"
:
self
.
_ceil_mode
,
"use_mkldnn"
:
False
,
"exclusive"
:
self
.
_exclusive
,
})
return
pool_out
class
FC
(
layers
.
PyLayer
):
def
__init__
(
self
,
size
,
param_attr
=
None
,
num_flatten_dims
=
1
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
):
super
(
FC
,
self
).
__init__
()
self
.
_size
=
size
self
.
_num_flatten_dims
=
num_flatten_dims
self
.
_dtype
=
dtype
from
..layer_helper
import
LayerHelper
self
.
_helper
=
LayerHelper
(
'FC'
,
param_attr
=
param_attr
)
def
_build_once
(
self
,
input
):
input_shape
=
input
.
shape
param_shape
=
[
reduce
(
lambda
a
,
b
:
a
*
b
,
input_shape
[
self
.
_num_flatten_dims
:],
1
)
]
+
[
self
.
_size
]
self
.
_w
=
self
.
_helper
.
create_parameter
(
attr
=
self
.
_helper
.
param_attr
,
shape
=
param_shape
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
def
forward
(
self
,
input
):
tmp
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"mul"
,
inputs
=
{
"X"
:
input
,
"Y"
:
self
.
_w
},
outputs
=
{
"Out"
:
tmp
},
attrs
=
{
"x_num_col_dims"
:
self
.
_num_flatten_dims
,
"y_num_col_dims"
:
1
})
out
=
self
.
_helper
.
create_variable_for_type_inference
(
self
.
_dtype
)
self
.
_helper
.
append_op
(
type
=
"sum"
,
inputs
=
{
"X"
:
[
tmp
]},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"use_mkldnn"
:
False
})
return
out
python/paddle/fluid/initializer.py
浏览文件 @
23761bea
...
...
@@ -162,7 +162,8 @@ class ConstantInitializer(Initializer):
"dtype"
:
int
(
var
.
dtype
),
"value"
:
float
(
self
.
_value
),
'force_cpu'
:
self
.
_force_cpu
or
force_init_on_cpu
()
})
},
stop_gradient
=
True
)
var
.
op
=
op
return
op
...
...
@@ -231,7 +232,8 @@ class UniformInitializer(Initializer):
"min"
:
self
.
_low
,
"max"
:
self
.
_high
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
if
var
.
dtype
==
VarDesc
.
VarType
.
FP16
:
block
.
append_op
(
...
...
@@ -309,7 +311,8 @@ class NormalInitializer(Initializer):
"std"
:
self
.
_std_dev
,
"seed"
:
self
.
_seed
,
"use_mkldnn"
:
False
})
},
stop_gradient
=
True
)
if
var
.
dtype
==
VarDesc
.
VarType
.
FP16
:
block
.
append_op
(
...
...
@@ -371,7 +374,8 @@ class TruncatedNormalInitializer(Initializer):
"mean"
:
self
.
_mean
,
"std"
:
self
.
_std_dev
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
var
.
op
=
op
return
op
...
...
@@ -461,7 +465,8 @@ class XavierInitializer(Initializer):
"min"
:
-
limit
,
"max"
:
limit
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
...
...
@@ -474,7 +479,8 @@ class XavierInitializer(Initializer):
"mean"
:
0.0
,
"std"
:
std
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
var
.
op
=
op
return
op
...
...
@@ -559,7 +565,8 @@ class MSRAInitializer(Initializer):
"min"
:
-
limit
,
"max"
:
limit
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
else
:
std
=
np
.
sqrt
(
2.0
/
float
(
fan_in
))
...
...
@@ -572,7 +579,8 @@ class MSRAInitializer(Initializer):
"mean"
:
0.0
,
"std"
:
std
,
"seed"
:
self
.
_seed
})
},
stop_gradient
=
True
)
var
.
op
=
op
return
op
...
...
python/paddle/fluid/layer_helper.py
浏览文件 @
23761bea
...
...
@@ -22,8 +22,8 @@ import numpy as np
from
.framework
import
Variable
,
Parameter
,
default_main_program
,
default_startup_program
,
dtype_is_floating
,
_in_imperative_mode
from
.
import
unique_name
from
paddle.fluid.imperative
import
base
as
imperative_base
from
paddle.fluid.initializer
import
Constant
,
Xavier
from
paddle.fluid.imperative
import
base
from
.param_attr
import
ParamAttr
,
WeightNormParamAttr
from
.
import
core
from
six.moves
import
zip
...
...
@@ -50,7 +50,7 @@ class LayerHelper(object):
return
default_startup_program
()
def
to_variable
(
self
,
x
):
return
base
.
to_variable
(
x
,
self
.
main_program
.
current_block
())
return
imperative_
base
.
to_variable
(
x
,
self
.
main_program
.
current_block
())
def
append_op
(
self
,
*
args
,
**
kwargs
):
return
self
.
main_program
.
current_block
().
append_op
(
*
args
,
**
kwargs
)
...
...
@@ -314,11 +314,9 @@ class LayerHelper(object):
WeightNormParamAttr
.
params_with_weight_norm
.
append
(
param
)
return
param
if
_in_imperative_mode
():
self
.
main_program
.
global_block
().
create_parameter
(
dtype
=
dtype
,
shape
=
shape
,
**
attr
.
_to_kwargs
())
# In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
return
self
.
startup
_program
.
global_block
().
create_parameter
(
return
self
.
main
_program
.
global_block
().
create_parameter
(
dtype
=
dtype
,
shape
=
shape
,
**
attr
.
_to_kwargs
(
with_initializer
=
True
))
...
...
@@ -380,13 +378,16 @@ class LayerHelper(object):
def
set_variable_initializer
(
self
,
var
,
initializer
):
assert
isinstance
(
var
,
Variable
)
self
.
startup_program
.
global_block
().
create_var
(
name
=
var
.
name
,
type
=
var
.
type
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
,
persistable
=
True
,
initializer
=
initializer
)
if
imperative_base
.
enabled
():
initializer
(
var
,
var
.
block
)
else
:
self
.
startup_program
.
global_block
().
create_var
(
name
=
var
.
name
,
type
=
var
.
type
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
,
persistable
=
True
,
initializer
=
initializer
)
def
append_bias_op
(
self
,
input_var
,
dim_start
=
1
,
dim_end
=
None
):
"""
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
23761bea
此差异已折叠。
点击以展开。
python/paddle/fluid/layers/tensor.py
浏览文件 @
23761bea
...
...
@@ -20,6 +20,7 @@ from ..framework import convert_np_dtype_to_dtype_
from
..framework
import
Variable
from
..initializer
import
Constant
,
force_init_on_cpu
from
..core
import
VarDesc
from
..imperative
import
base
as
imperative_base
from
.layer_function_generator
import
templatedoc
import
numpy
...
...
@@ -104,15 +105,15 @@ def create_global_var(shape,
Args:
shape(list[int]): shape of the variable
value(float): the value of the variable. The new created
value(float): the value of the variable. The new created
variable will be filled with it.
dtype(string): data type of the variable
persistable(bool): if this variable is persistable.
persistable(bool): if this variable is persistable.
Default: False
force_cpu(bool): force this variable to be on CPU.
force_cpu(bool): force this variable to be on CPU.
Default: False
name(str|None): The name of the variable. If set to None the variable
name will be generated automatically.
name(str|None): The name of the variable. If set to None the variable
name will be generated automatically.
Default: None
Returns:
...
...
@@ -121,21 +122,26 @@ def create_global_var(shape,
Examples:
.. code-block:: python
var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32',
var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32',
persistable=True, force_cpu=True, name='new_var')
"""
helper
=
LayerHelper
(
"global_var"
,
**
locals
())
var
=
helper
.
create_global_variable
(
dtype
=
dtype
,
shape
=
shape
,
persistable
=
persistable
,
name
=
name
)
dtype
=
dtype
,
shape
=
shape
,
persistable
=
persistable
,
name
=
name
,
stop_gradient
=
True
)
helper
.
set_variable_initializer
(
var
,
initializer
=
Constant
(
value
=
float
(
value
),
force_cpu
=
force_cpu
))
return
var
def
cast
(
x
,
dtype
):
"""
This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts
This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts
it to the output with :attr:`dtype`.
Args:
...
...
@@ -199,9 +205,9 @@ def tensor_array_to_tensor(input, axis=1, name=None):
and returns that as the output.
A simple example as below:
.. code-block:: text
Given:
input.data = {[[0.6, 0.1, 0.3],
...
...
@@ -210,9 +216,9 @@ def tensor_array_to_tensor(input, axis=1, name=None):
[1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1
Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
...
...
@@ -498,12 +504,12 @@ def argmax(x, axis=0):
def
argsort
(
input
,
axis
=-
1
,
name
=
None
):
"""
Performs sorting on the input Variable along the given axis, and outputs
sorted data Varibale and its corresponding index Variable with the same
Performs sorting on the input Variable along the given axis, and outputs
sorted data Varibale and its corresponding index Variable with the same
shape as :attr:`input`.
.. code-block:: text
For example, the given axis is -1 and the input Variable
input = [[0.15849551, 0.45865775, 0.8563702 ],
...
...
@@ -516,15 +522,15 @@ def argsort(input, axis=-1, name=None):
and the sorted indices along the given axis turn outs to be
indices = [[0, 1, 2],
indices = [[0, 1, 2],
[0, 2, 1]]
Args:
input(Variable): The input Variable for sorting.
axis(int): The axis along which to sort the input Variable. When
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
axis(int): The axis along which to sort the input Variable. When
:attr:`axis` < 0, the actual axis will be :attr:`axis` +
rank(:attr:`input`). Default -1, the last dimension.
name(str|None): (optional) A name for this layer. If set None, the
name(str|None): (optional) A name for this layer. If set None, the
layer will be named automatically.
Returns:
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
23761bea
...
...
@@ -30,6 +30,7 @@ from .initializer import Constant
from
.layer_helper
import
LayerHelper
from
.layers
import
ops
from
.regularizer
import
append_regularization_ops
from
.imperative
import
base
as
imperative_base
__all__
=
[
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
...
...
@@ -301,25 +302,45 @@ class Optimizer(object):
This method combines interface `append_backward()` and
`create_optimization_pass()` into one.
"""
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
[
error_clip_callback
])
if
imperative_base
.
enabled
():
if
parameter_list
is
not
None
:
params_grads
=
parameter_list
else
:
program
=
loss
.
block
.
program
parameters
=
program
.
global_block
().
all_parameters
()
params_grads
=
[]
for
param
in
parameters
:
# create gradient variable
grad_var
=
Variable
(
block
=
loss
.
block
,
name
=
param
.
_ivar
.
_grad_name
(),
stop_gradient
=
True
)
grad_var
.
_value
=
param
.
_ivar
.
grad_value
params_grads
.
append
((
param
,
grad_var
))
optimize_ops
=
self
.
_create_optimization_pass
(
params_grads
,
loss
,
startup_program
)
else
:
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
[
error_clip_callback
])
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
self
.
_process_distribute_lookuptable
(
params_grads
,
loss
,
startup_program
)
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
self
.
_process_distribute_lookuptable
(
params_grads
,
loss
,
startup_program
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
# Add regularization if any
params_grads
=
append_regularization_ops
(
params_grads
,
self
.
regularization
)
# Add regularization if any
params_grads
=
append_regularization_ops
(
params_grads
,
self
.
regularization
)
optimize_ops
=
self
.
_create_optimization_pass
(
params_grads
,
loss
,
startup_program
)
if
table_optimize_op
is
not
None
:
optimize_ops
.
append
(
table_optimize_op
)
params_grads
.
append
(
table_param_and_grad
)
optimize_ops
=
self
.
_create_optimization_pass
(
params_grads
,
loss
,
startup_program
)
if
table_optimize_op
is
not
None
:
optimize_ops
.
append
(
table_optimize_op
)
params_grads
.
append
(
table_param_and_grad
)
return
optimize_ops
,
params_grads
...
...
@@ -364,7 +385,8 @@ class SGDOptimizer(Optimizer):
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
self
.
_create_param_lr
(
param_and_grad
)
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]})
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]},
stop_gradient
=
True
)
return
sgd_op
...
...
@@ -448,7 +470,8 @@ class MomentumOptimizer(Optimizer):
"VelocityOut"
:
velocity_acc
},
attrs
=
{
"mu"
:
self
.
_momentum
,
"use_nesterov"
:
self
.
_use_nesterov
})
"use_nesterov"
:
self
.
_use_nesterov
},
stop_gradient
=
True
)
return
momentum_op
...
...
@@ -477,7 +500,7 @@ class LarsMomentumOptimizer(Optimizer):
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
...
...
@@ -533,7 +556,8 @@ class LarsMomentumOptimizer(Optimizer):
"mu"
:
self
.
_momentum
,
"lars_coeff"
:
self
.
_lars_coeff
,
"lars_weight_decay"
:
self
.
_lars_weight_decay
})
},
stop_gradient
=
True
)
return
momentum_op
...
...
@@ -608,7 +632,8 @@ class AdagradOptimizer(Optimizer):
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment_acc
},
attrs
=
{
"epsilon"
:
self
.
_epsilon
})
attrs
=
{
"epsilon"
:
self
.
_epsilon
},
stop_gradient
=
True
)
return
adagrad_op
...
...
@@ -738,7 +763,8 @@ class AdamOptimizer(Optimizer):
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
,
"lazy_mode"
:
self
.
_lazy_mode
})
},
stop_gradient
=
True
)
return
adam_op
...
...
@@ -760,13 +786,15 @@ class AdamOptimizer(Optimizer):
type
=
"scale"
,
inputs
=
{
"X"
:
beta1_pow_acc
},
outputs
=
{
"Out"
:
beta1_pow_acc
},
attrs
=
{
"scale"
:
self
.
_beta1
})
attrs
=
{
"scale"
:
self
.
_beta1
},
stop_gradient
=
True
)
main_block
.
append_op
(
type
=
"scale"
,
inputs
=
{
"X"
:
beta2_pow_acc
},
outputs
=
{
"Out"
:
beta2_pow_acc
},
attrs
=
{
"scale"
:
self
.
_beta2
})
attrs
=
{
"scale"
:
self
.
_beta2
},
stop_gradient
=
True
)
class
AdamaxOptimizer
(
Optimizer
):
...
...
@@ -877,7 +905,8 @@ class AdamaxOptimizer(Optimizer):
"beta1"
:
self
.
_beta1
,
"beta2"
:
self
.
_beta2
,
"epsilon"
:
self
.
_epsilon
})
},
stop_gradient
=
True
)
return
adamax_op
...
...
@@ -897,7 +926,8 @@ class AdamaxOptimizer(Optimizer):
type
=
"scale"
,
inputs
=
{
"X"
:
beta1_pow_acc
},
outputs
=
{
"Out"
:
beta1_pow_acc
},
attrs
=
{
"scale"
:
self
.
_beta1
})
attrs
=
{
"scale"
:
self
.
_beta1
},
stop_gradient
=
True
)
class
DecayedAdagradOptimizer
(
Optimizer
):
...
...
@@ -979,7 +1009,8 @@ class DecayedAdagradOptimizer(Optimizer):
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
],
"MomentOut"
:
moment_acc
},
attrs
=
{
"epsilon"
:
self
.
_epsilon
})
attrs
=
{
"epsilon"
:
self
.
_epsilon
},
stop_gradient
=
True
)
return
decayed_adagrad_op
...
...
@@ -1075,7 +1106,8 @@ class AdadeltaOptimizer(Optimizer):
"AvgSquaredUpdateOut"
:
avg_squared_update_acc
},
attrs
=
{
"epsilon"
:
self
.
_epsilon
,
"rho"
:
self
.
_rho
})
"rho"
:
self
.
_rho
},
stop_gradient
=
True
)
return
adadelta_op
...
...
@@ -1224,7 +1256,8 @@ class RMSPropOptimizer(Optimizer):
"decay"
:
self
.
_rho
,
"momentum"
:
self
.
_momentum
,
"centered"
:
self
.
_centered
})
},
stop_gradient
=
True
)
return
rmsprop_op
...
...
@@ -1345,7 +1378,8 @@ class FtrlOptimizer(Optimizer):
},
attrs
=
{
"l1"
:
self
.
_l1
,
"l2"
:
self
.
_l1
,
"lr_power"
:
self
.
_lr_power
})
"lr_power"
:
self
.
_lr_power
},
stop_gradient
=
True
)
return
ftrl_op
...
...
@@ -1509,7 +1543,8 @@ class ModelAverage(Optimizer):
"average_window"
:
self
.
average_window
,
"min_average_window"
:
self
.
min_average_window
,
"max_average_window"
:
self
.
max_average_window
,
})
},
stop_gradient
=
True
)
@
contextmanager
def
apply
(
self
,
executor
,
need_restore
=
True
):
...
...
python/paddle/fluid/tests/unittests/test_imperative.py
浏览文件 @
23761bea
...
...
@@ -18,17 +18,8 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.layers.nn
import
FC
@
contextlib
.
contextmanager
def
new_program_scope
():
prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
yield
from
paddle.fluid.imperative.nn
import
FC
from
test_imperative_base
import
new_program_scope
class
MyLayer
(
fluid
.
imperative
.
PyLayer
):
...
...
@@ -36,7 +27,7 @@ class MyLayer(fluid.imperative.PyLayer):
super
(
MyLayer
,
self
).
__init__
()
def
forward
(
self
,
inputs
):
x
=
fluid
.
layers
.
relu
(
inputs
[
0
]
)
x
=
fluid
.
layers
.
relu
(
inputs
)
self
.
_x_for_debug
=
x
x
=
fluid
.
layers
.
elementwise_mul
(
x
,
x
)
x
=
fluid
.
layers
.
reduce_sum
(
x
)
...
...
@@ -54,7 +45,7 @@ class MLP(fluid.imperative.PyLayer):
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.1
)))
def
forward
(
self
,
inputs
):
x
=
self
.
_fc1
(
inputs
[
0
]
)
x
=
self
.
_fc1
(
inputs
)
x
=
self
.
_fc2
(
x
)
x
=
fluid
.
layers
.
reduce_sum
(
x
)
return
x
...
...
@@ -66,13 +57,14 @@ class TestImperative(unittest.TestCase):
cl
=
core
.
Layer
()
cl
.
forward
([])
l
=
fluid
.
imperative
.
PyLayer
()
l
.
forward
(
[])
self
.
assertRaises
(
NotImplementedError
,
l
.
forward
,
[])
def
test_layer_in_out
(
self
):
np_inp
=
np
.
array
([
1.0
,
2.0
,
-
1.0
],
dtype
=
np
.
float32
)
with
fluid
.
imperative
.
guard
():
var_inp
=
fluid
.
imperative
.
base
.
to_variable
(
np_inp
)
l
=
MyLayer
()
x
=
l
(
np
_inp
)[
0
]
x
=
l
(
var
_inp
)[
0
]
self
.
assertIsNotNone
(
x
)
dy_out
=
x
.
_numpy
()
x
.
_backward
()
...
...
@@ -97,8 +89,9 @@ class TestImperative(unittest.TestCase):
def
test_mlp
(
self
):
np_inp
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
np
.
float32
)
with
fluid
.
imperative
.
guard
():
var_inp
=
fluid
.
imperative
.
base
.
to_variable
(
np_inp
)
mlp
=
MLP
()
out
=
mlp
(
np
_inp
)
out
=
mlp
(
var
_inp
)
dy_out
=
out
.
_numpy
()
out
.
_backward
()
dy_grad
=
mlp
.
_fc1
.
_w
.
_gradient
()
...
...
python/paddle/fluid/tests/unittests/test_imperative_base.py
0 → 100644
浏览文件 @
23761bea
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
@
contextlib
.
contextmanager
def
new_program_scope
():
prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
yield
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
0 → 100644
浏览文件 @
23761bea
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
unittest
import
numpy
as
np
import
six
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.imperative.nn
import
Conv2D
,
Pool2D
,
FC
from
paddle.fluid.imperative.base
import
to_variable
from
test_imperative_base
import
new_program_scope
class
SimpleImgConvPool
(
fluid
.
imperative
.
PyLayer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
()
self
.
_conv2d
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
fluid
.
imperative
.
PyLayer
):
def
__init__
(
self
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MNIST
,
self
).
__init__
()
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
8
*
8
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)))
def
forward
(
self
,
inputs
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
class
TestImperativeMnist
(
unittest
.
TestCase
):
def
test_mnist_cpu_float32
(
self
):
seed
=
90
with
fluid
.
imperative
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
# mnist = Conv2D(1, 20, 5)
mnist
=
MNIST
()
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
)
dy_param_init_value
=
{}
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>=
2
:
break
x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
128
,
1
)
img
=
to_variable
(
x_data
)
label
=
to_variable
(
y_data
)
label
.
_stop_gradient
=
True
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
dy_out
=
loss
.
_numpy
()
if
batch_id
==
0
:
for
param
in
fluid
.
default_main_program
().
global_block
(
).
all_parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
_numpy
()
loss
.
_backward
()
sgd
.
minimize
(
loss
)
dy_param_value
=
{}
for
param
in
fluid
.
default_main_program
().
global_block
(
).
all_parameters
():
dy_param_value
[
param
.
name
]
=
param
.
_numpy
()
with
new_program_scope
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
# mnist = Conv2D(1, 20, 5)
mnist
=
MNIST
()
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
)
img
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
sgd
.
minimize
(
loss
)
# initialize params and fetch them
static_param_init_value
=
{}
static_param_name_list
=
[]
for
param
in
fluid
.
default_startup_program
().
global_block
(
).
all_parameters
():
static_param_name_list
.
append
(
param
.
name
)
out
=
exe
.
run
(
fluid
.
default_startup_program
(),
fetch_list
=
static_param_name_list
)
for
i
in
range
(
len
(
static_param_name_list
)):
static_param_init_value
[
static_param_name_list
[
i
]]
=
out
[
i
]
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>=
2
:
break
x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
[
128
,
1
])
fetch_list
=
[
loss
.
name
]
fetch_list
.
extend
(
static_param_name_list
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"pixel"
:
x_data
,
"label"
:
y_data
},
fetch_list
=
fetch_list
)
static_param_value
=
{}
static_out
=
out
[
0
]
for
i
in
range
(
1
,
len
(
out
)):
static_param_value
[
static_param_name_list
[
i
-
1
]]
=
out
[
i
]
for
key
,
value
in
six
.
iteritems
(
static_param_init_value
):
self
.
assertTrue
(
np
.
allclose
(
value
.
all
(),
dy_param_init_value
[
key
].
all
()))
self
.
assertTrue
(
np
.
allclose
(
static_out
.
all
(),
dy_out
.
all
()))
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
self
.
assertTrue
(
np
.
allclose
(
value
.
all
(),
dy_param_value
[
key
].
all
()))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录