Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ef7e76fc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ef7e76fc
编写于
7月 26, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into make_network_op
上级
0b4880ba
eff17a68
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
262 addition
and
69 deletion
+262
-69
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+1
-1
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+5
-5
paddle/framework/scope.h
paddle/framework/scope.h
+16
-5
paddle/framework/scope_test.cc
paddle/framework/scope_test.cc
+5
-0
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+2
-2
paddle/platform/enforce.h
paddle/platform/enforce.h
+37
-31
paddle/platform/enforce_test.cc
paddle/platform/enforce_test.cc
+1
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+20
-13
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+1
-2
python/paddle/v2/__init__.py
python/paddle/v2/__init__.py
+8
-0
python/paddle/v2/dataset/mq2007.py
python/paddle/v2/dataset/mq2007.py
+2
-2
python/paddle/v2/framework/create_op_creation_methods.py
python/paddle/v2/framework/create_op_creation_methods.py
+3
-0
python/paddle/v2/framework/network.py
python/paddle/v2/framework/network.py
+124
-0
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+3
-2
python/paddle/v2/framework/tests/test_net.py
python/paddle/v2/framework/tests/test_net.py
+2
-2
python/paddle/v2/framework/tests/test_network.py
python/paddle/v2/framework/tests/test_network.py
+32
-0
python/paddle/v2/layer.py
python/paddle/v2/layer.py
+0
-3
未找到文件。
paddle/framework/net_op_test.cc
浏览文件 @
ef7e76fc
...
...
@@ -69,7 +69,7 @@ TEST(OpKernel, all) {
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
ef7e76fc
...
...
@@ -90,7 +90,7 @@ TEST(OpRegistry, IllegalAttr) {
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"larger_than check fail"
;
const
char
*
err_msg
=
err
.
what
();
...
...
@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
bool
caught
=
false
;
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"Attribute 'test_attr' is required!"
;
const
char
*
err_msg
=
err
.
what
();
...
...
@@ -154,7 +154,7 @@ TEST(OpRegistry, CustomChecker) {
caught
=
false
;
try
{
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"'test_attr' must be even!"
;
const
char
*
err_msg
=
err
.
what
();
...
...
@@ -192,7 +192,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestAttrProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
std
::
runtime_error
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
platform
::
EnforceNotMet
);
}
class
TestInOutProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
...
...
@@ -208,5 +208,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestInOutProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
std
::
runtime_error
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
platform
::
EnforceNotMet
);
}
paddle/framework/scope.h
浏览文件 @
ef7e76fc
...
...
@@ -56,7 +56,9 @@ class Scope {
if
(
var
)
{
return
var
;
}
else
{
vars_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
new
Variable
());
auto
ptr
=
new
Variable
();
name_to_var_
[
name
]
=
std
::
unique_ptr
<
Variable
>
(
ptr
);
var_to_name_
[
ptr
]
=
name
;
return
GetVariable
(
name
);
}
}
...
...
@@ -68,8 +70,8 @@ class Scope {
* from it's parent scope. Return nullptr if not found.
*/
Variable
*
GetVariable
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars
_
.
find
(
name
);
if
(
it
!=
vars
_
.
end
())
{
auto
it
=
name_to_var
_
.
find
(
name
);
if
(
it
!=
name_to_var
_
.
end
())
{
return
it
->
second
.
get
();
}
else
if
(
parent_
!=
nullptr
)
{
return
parent_
->
GetVariable
(
name
);
...
...
@@ -84,12 +86,21 @@ class Scope {
* Find if there is a Variable in this scope and it's parent scope
*/
bool
HasVariable
(
const
std
::
string
&
name
)
const
{
return
(
vars_
.
find
(
name
)
!=
vars
_
.
end
()
||
return
(
name_to_var_
.
find
(
name
)
!=
name_to_var
_
.
end
()
||
(
parent_
&&
parent_
->
HasVariable
(
name
)));
}
std
::
string
GetVariableName
(
Variable
*
const
var
)
const
{
try
{
return
var_to_name_
.
at
(
var
);
}
catch
(...)
{
return
""
;
}
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
std
::
unordered_map
<
Variable
*
,
std
::
string
>
var_to_name_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
name_to_var_
;
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
};
...
...
paddle/framework/scope_test.cc
浏览文件 @
ef7e76fc
...
...
@@ -40,6 +40,11 @@ TEST(Scope, Create) {
/// already exist.
Variable
*
var4
=
scope
->
CreateVariable
(
"a"
);
EXPECT_EQ
(
var4
,
var2
);
EXPECT_EQ
(
"a"
,
scope
->
GetVariableName
(
var4
));
Scope
scope2
;
auto
var
=
scope2
.
CreateVariable
(
"tmp"
);
EXPECT_EQ
(
""
,
scope
->
GetVariableName
(
var
));
}
TEST
(
Scope
,
Parent
)
{
...
...
paddle/framework/tensor_test.cc
浏览文件 @
ef7e76fc
...
...
@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool
caught
=
false
;
try
{
src_tensor
.
data
<
double
>
();
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
...
...
@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool
caught
=
false
;
try
{
dst_tensor
.
ShareDataWith
<
float
>
(
src_tensor
);
}
catch
(
std
::
runtime_error
&
err
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
err
)
{
caught
=
true
;
std
::
string
msg
=
"Tenosr holds no memory. Call Tensor::mutable_data first."
;
...
...
paddle/platform/enforce.h
浏览文件 @
ef7e76fc
...
...
@@ -36,6 +36,21 @@ limitations under the License. */
namespace
paddle
{
namespace
platform
{
struct
EnforceNotMet
:
public
std
::
exception
{
std
::
exception_ptr
exp_
;
std
::
string
err_str_
;
EnforceNotMet
(
std
::
exception_ptr
e
,
const
char
*
f
,
int
l
)
:
exp_
(
e
)
{
try
{
std
::
rethrow_exception
(
exp_
);
}
catch
(
const
std
::
exception
&
exp
)
{
err_str_
=
string
::
Sprintf
(
"%s at [%s:%d]"
,
exp
.
what
(),
f
,
l
);
}
}
const
char
*
what
()
const
noexcept
{
return
err_str_
.
c_str
();
}
};
// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
...
...
@@ -52,9 +67,7 @@ template <typename... Args>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
int
stat
,
const
Args
&
...
args
)
{
if
(
UNLIKELY
(
!
(
stat
)))
{
throw
std
::
runtime_error
(
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
throw
std
::
runtime_error
(
string
::
Sprintf
(
args
...));
}
}
...
...
@@ -64,12 +77,8 @@ template <typename... Args>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
cudaError_t
e
,
const
Args
&
...
args
)
{
if
(
UNLIKELY
(
e
))
{
// clang-format off
throw
thrust
::
system_error
(
e
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
throw
thrust
::
system_error
(
e
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...));
}
}
...
...
@@ -77,12 +86,8 @@ template <typename... Args>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
curandStatus_t
stat
,
const
Args
&
...
args
)
{
if
(
stat
!=
CURAND_STATUS_SUCCESS
)
{
// clang-format off
throw
thrust
::
system_error
(
cudaErrorLaunchFailure
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
throw
thrust
::
system_error
(
cudaErrorLaunchFailure
,
thrust
::
cuda_category
(),
string
::
Sprintf
(
args
...));
}
}
...
...
@@ -92,12 +97,8 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
if
(
stat
==
CUDNN_STATUS_SUCCESS
)
{
return
;
}
else
{
// clang-format off
throw
std
::
runtime_error
(
platform
::
dynload
::
cudnnGetErrorString
(
stat
)
+
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
// clang-format on
throw
std
::
runtime_error
(
platform
::
dynload
::
cudnnGetErrorString
(
stat
)
+
string
::
Sprintf
(
args
...));
}
}
...
...
@@ -126,22 +127,27 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
}
else
if
(
stat
==
CUBLAS_STATUS_LICENSE_ERROR
)
{
err
=
"CUBLAS: license error, "
;
}
throw
std
::
runtime_error
(
err
+
string
::
Sprintf
(
args
...)
+
string
::
Sprintf
(
" at [%s:%s];"
,
__FILE__
,
__LINE__
));
throw
std
::
runtime_error
(
err
+
string
::
Sprintf
(
args
...));
}
#endif // PADDLE_ONLY_CPU
#define PADDLE_THROW(...) \
do { \
throw std::runtime_error( \
string::Sprintf(__VA_ARGS__) + \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)
#define PADDLE_ENFORCE(...) \
do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0)
}
// namespace platform
...
...
paddle/platform/enforce_test.cc
浏览文件 @
ef7e76fc
...
...
@@ -23,7 +23,7 @@ TEST(ENFORCE, FAILED) {
bool
in_catch
=
false
;
try
{
PADDLE_ENFORCE
(
false
,
"Enforce is not ok %d at all"
,
123
);
}
catch
(
const
std
::
runtime_error
&
error
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
error
)
{
// your error handling code here
in_catch
=
true
;
std
::
string
msg
=
"Enforce is not ok 123 at all"
;
...
...
paddle/pybind/pybind.cc
浏览文件 @
ef7e76fc
...
...
@@ -48,6 +48,11 @@ void ExposeOperator(ClassType& m) {
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
}
static
size_t
UniqueIntegerGenerator
()
{
static
std
::
atomic
<
size_t
>
generator
;
return
generator
.
fetch_add
(
1
);
}
PYBIND11_PLUGIN
(
core
)
{
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
...
...
@@ -98,7 +103,8 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
)
.
def
(
"create_var"
,
&
pd
::
Scope
::
CreateVariable
,
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
)
.
def
(
"get_var_name"
,
&
pd
::
Scope
::
GetVariableName
);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
...
...
@@ -141,10 +147,9 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator
(
operator_base
);
using
PlainNetPtr
=
std
::
shared_ptr
<
pd
::
PlainNet
>
;
py
::
class_
<
pd
::
PlainNet
,
PlainNetPtr
>
plain_net
(
m
,
"Plain
Net"
);
py
::
class_
<
pd
::
PlainNet
,
PlainNetPtr
>
net
(
m
,
"
Net"
);
plain_net
.
def_static
(
"create"
,
net
.
def_static
(
"create"
,
[]()
->
std
::
shared_ptr
<
pd
::
PlainNet
>
{
auto
retv
=
std
::
make_shared
<
pd
::
PlainNet
>
();
retv
->
type_
=
"plain_net"
;
...
...
@@ -152,12 +157,14 @@ All parameter, weight, gradient are variables in Paddle.
})
.
def
(
"add_op"
,
&
pd
::
PlainNet
::
AddOp
)
.
def
(
"add_op"
,
[](
PlainNetPtr
&
self
,
const
PlainNetPtr
&
plain_
net
)
->
void
{
self
->
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
plain_
net
));
[](
PlainNetPtr
&
self
,
const
PlainNetPtr
&
net
)
->
void
{
self
->
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
net
));
})
.
def
(
"complete_add_op"
,
&
pd
::
PlainNet
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
PlainNetPtr
&
self
)
{
self
->
CompleteAddOp
();
});
ExposeOperator
(
plain_net
);
ExposeOperator
(
net
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
return
m
.
ptr
();
}
python/paddle/trainer/config_parser.py
浏览文件 @
ef7e76fc
...
...
@@ -2055,8 +2055,7 @@ class BatchNormLayer(LayerBase):
# Automatically select cudnn_batch_norm for GPU and batch_norm for CPU.
# Also based on cudnn version.
use_cudnn
=
use_gpu
and
batch_norm_type
!=
"batch_norm"
and
\
((
not
parallel_nn
)
or
self
.
config
.
device
>
-
1
)
and
\
cudnn_version
>=
4007
((
not
parallel_nn
)
or
self
.
config
.
device
>
-
1
)
self
.
layer_type
=
"cudnn_batch_norm"
if
use_cudnn
else
"batch_norm"
super
(
BatchNormLayer
,
self
).
__init__
(
name
,
self
.
layer_type
,
0
,
inputs
=
inputs
,
**
xargs
)
...
...
python/paddle/v2/__init__.py
浏览文件 @
ef7e76fc
...
...
@@ -34,6 +34,7 @@ import minibatch
import
plot
import
image
import
model
import
paddle.trainer.config_parser
as
cp
__all__
=
[
'optimizer'
,
...
...
@@ -58,6 +59,8 @@ __all__ = [
'model'
,
]
cp
.
begin_parse
()
def
init
(
**
kwargs
):
import
py_paddle.swig_paddle
as
api
...
...
@@ -73,6 +76,11 @@ def init(**kwargs):
for
key
in
args_dict
.
keys
():
args
.
append
(
'--%s=%s'
%
(
key
,
str
(
args_dict
[
key
])))
if
'use_gpu'
in
kwargs
:
cp
.
g_command_config_args
[
'use_gpu'
]
=
kwargs
[
'use_gpu'
]
assert
'parallel_nn'
not
in
kwargs
,
(
"currently 'parallel_nn' is not "
"supported in v2 APIs."
)
api
.
initPaddle
(
*
args
)
...
...
python/paddle/v2/dataset/mq2007.py
浏览文件 @
ef7e76fc
...
...
@@ -242,9 +242,9 @@ def gen_list(querylist):
if
not
isinstance
(
querylist
,
QueryList
):
querylist
=
QueryList
(
querylist
)
querylist
.
_correct_ranking_
()
relevance_score_list
=
[
query
.
relevance_score
for
query
in
querylist
]
relevance_score_list
=
[
[
query
.
relevance_score
]
for
query
in
querylist
]
feature_vector_list
=
[
query
.
feature_vector
for
query
in
querylist
]
yield
np
.
array
(
relevance_score_list
)
.
T
,
np
.
array
(
feature_vector_list
)
yield
np
.
array
(
relevance_score_list
),
np
.
array
(
feature_vector_list
)
def
query_filter
(
querylists
):
...
...
python/paddle/v2/framework/create_op_creation_methods.py
浏览文件 @
ef7e76fc
...
...
@@ -220,6 +220,9 @@ def create_op_creation_method(op_proto):
__impl__
.
all_input_args
=
[
var
.
name
for
var
in
op_proto
.
inputs
]
__impl__
.
all_output_args
=
[
var
.
name
for
var
in
op_proto
.
outputs
]
__impl__
.
all_attr_args
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
]
__impl__
.
all_not_temp_output_args
=
[
var
.
name
for
var
in
op_proto
.
outputs
if
not
var
.
temporary
]
return
__impl__
...
...
python/paddle/v2/framework/network.py
0 → 100644
浏览文件 @
ef7e76fc
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.create_op_creation_methods
import
op_creations
from
default_scope_funcs
import
create_var
,
get_var
,
get_cur_scope
__all__
=
[
'Network'
]
# Only expose Network
class
NetworkFunctor
(
object
):
"""
Network Op Creation Function. Used internally in this module.
It convert string input to Variable. If it is not created before, just
create in scope.
It is a functor object. means the instances are callable.
:param func: The op creation function which generated in Python.
:param net: The Network instance.
"""
def
__init__
(
self
,
func
,
net
):
self
.
func
=
func
self
.
net
=
net
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Paddle must use keyword argument"
)
inputs
=
self
.
func
.
all_input_args
for
ipt
in
inputs
:
if
ipt
in
kwargs
:
var
=
kwargs
[
ipt
]
if
isinstance
(
var
,
basestring
):
var
=
create_var
(
var
)
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Input of op creation must be string or variable"
)
kwargs
[
ipt
]
=
get_cur_scope
().
get_var_name
(
var
)
notemp_outputs
=
self
.
func
.
all_not_temp_output_args
for
name
in
notemp_outputs
:
if
name
not
in
kwargs
:
kwargs
[
name
]
=
self
.
func
.
__name__
+
"@OUT@%d"
%
core
.
unique_integer
(
)
outputs
=
self
.
func
.
all_output_args
for
opt
in
outputs
:
if
opt
in
kwargs
:
var
=
kwargs
[
opt
]
if
isinstance
(
var
,
basestring
):
var
=
create_var
(
var
)
if
not
isinstance
(
var
,
core
.
Variable
):
raise
TypeError
(
"Output of op creation must be string or variable"
)
kwargs
[
opt
]
=
get_cur_scope
().
get_var_name
(
var
)
op
=
self
.
func
(
**
kwargs
)
self
.
net
.
net
.
add_op
(
op
)
lst
=
[
get_var
(
kwargs
[
opt
])
for
opt
in
notemp_outputs
]
if
len
(
lst
)
==
1
:
return
lst
[
0
]
elif
len
(
lst
)
==
0
:
return
None
else
:
return
lst
class
Network
(
object
):
"""
The network concept. It avoid user to manually create operator, create
variable, and combine them into a Net. Just use Network.xxx can create the
operator, create variables in default scope, and add them into `self.net`.
For example:
.. code-block: python
net = Network()
out = net.add_two(X="a", Y="b")
fc_out = net.fc(X="out", W="fc.w")
net.run(...)
"""
def
__init__
(
self
):
self
.
net
=
core
.
Net
.
create
()
funcs
=
(
func_name
for
func_name
in
dir
(
op_creations
)
if
not
func_name
.
startswith
(
"__"
))
# TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime
# later.
for
func_name
in
funcs
:
func
=
getattr
(
op_creations
,
func_name
)
impl
=
NetworkFunctor
(
func
,
self
)
setattr
(
self
,
func_name
,
impl
.
__call__
)
self
.
__complete_add_op__
=
False
def
infer_shape
(
self
):
self
.
complete_add_op
()
self
.
net
.
infer_shape
(
get_cur_scope
())
def
run
(
self
,
device_context
):
self
.
complete_add_op
()
self
.
net
.
run
(
get_cur_scope
(),
device_context
)
def
__str__
(
self
):
return
str
(
self
.
net
)
def
complete_add_op
(
self
):
if
not
self
.
__complete_add_op__
:
self
.
net
.
complete_add_op
()
self
.
__complete_add_op__
=
True
if
__name__
==
'__main__'
:
net
=
Network
()
out
=
net
.
add_two
(
X
=
"a"
,
Y
=
"b"
)
fc_out
=
net
.
fc
(
X
=
out
,
W
=
"fc.w"
,
b
=
"fc.b"
,
activation
=
"softmax"
)
net
.
complete_add_op
()
print
net
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
ef7e76fc
...
...
@@ -3,7 +3,7 @@ add_python_test(test_framework
test_scope.py
test_default_scope_funcs.py
test_op_creation_methods.py
test_
plain_
net.py
test_net.py
test_tensor.py
test_fc_op.py
test_add_two_op.py
...
...
@@ -12,4 +12,5 @@ add_python_test(test_framework
test_mul_op.py
test_sigmoid_op.py
test_softmax_op.py
test_rowwise_add_op.py
)
test_rowwise_add_op.py
test_network.py
)
python/paddle/v2/framework/tests/test_
plain_
net.py
→
python/paddle/v2/framework/tests/test_net.py
浏览文件 @
ef7e76fc
...
...
@@ -5,11 +5,11 @@ import unittest
class
TestNet
(
unittest
.
TestCase
):
def
test_net_all
(
self
):
net
=
core
.
Plain
Net
.
create
()
net
=
core
.
Net
.
create
()
op1
=
op_creations
.
add_two
(
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
net
.
add_op
(
op1
)
net2
=
core
.
Plain
Net
.
create
()
net2
=
core
.
Net
.
create
()
net2
.
add_op
(
op_creations
.
fc
(
X
=
"X"
,
W
=
"w"
,
Y
=
"fc.out"
))
net2
.
complete_add_op
(
True
)
net
.
add_op
(
net2
)
...
...
python/paddle/v2/framework/tests/test_network.py
0 → 100644
浏览文件 @
ef7e76fc
from
paddle.v2.framework.network
import
Network
import
paddle.v2.framework.core
as
core
import
unittest
class
TestNet
(
unittest
.
TestCase
):
def
test_net_all
(
self
):
net
=
Network
()
out
=
net
.
add_two
(
X
=
"X"
,
Y
=
"Y"
)
fc_out
=
net
.
fc
(
X
=
out
,
W
=
"w"
)
net
.
complete_add_op
()
self
.
assertTrue
(
isinstance
(
fc_out
,
core
.
Variable
))
self
.
assertEqual
(
'''Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
'''
,
str
(
net
))
net2
=
Network
()
tmp
=
net2
.
add_two
(
X
=
"X"
,
Y
=
"Y"
)
self
.
assertTrue
(
isinstance
(
tmp
,
core
.
Variable
))
net2
.
complete_add_op
()
self
.
assertEqual
(
'''Op(plain_net), inputs:(X, Y), outputs:(add_two@OUT@2).
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@2).
'''
,
str
(
net2
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/layer.py
浏览文件 @
ef7e76fc
...
...
@@ -324,6 +324,3 @@ def parse_network(output_layers, extra_layers=None):
def
get_layer
(
name
):
return
config_base
.
__layer_map__
.
get
(
name
)
cp
.
begin_parse
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录