Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a7fa2051
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a7fa2051
编写于
12月 14, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/paddle
into add-multiBatch-chunkEval-dev
上级
181db326
9956d5f7
变更
22
展开全部
隐藏空白更改
内联
并排
Showing
22 changed file
with
1163 addition
and
1038 deletion
+1163
-1038
README.md
README.md
+2
-2
paddle/capi/error.cpp
paddle/capi/error.cpp
+1
-1
paddle/capi/error.h
paddle/capi/error.h
+8
-0
paddle/framework/backward.cc
paddle/framework/backward.cc
+2
-2
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+4
-4
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+7
-0
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+1
-1
paddle/operators/while_op.cc
paddle/operators/while_op.cc
+1
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+16
-0
paddle/platform/device_context.h
paddle/platform/device_context.h
+16
-0
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+16
-0
paddle/platform/place.h
paddle/platform/place.h
+6
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+17
-0
python/paddle/v2/fluid/layers/__init__.py
python/paddle/v2/fluid/layers/__init__.py
+17
-0
python/paddle/v2/fluid/layers/control_flow.py
python/paddle/v2/fluid/layers/control_flow.py
+13
-1022
python/paddle/v2/fluid/layers/io.py
python/paddle/v2/fluid/layers/io.py
+57
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+791
-0
python/paddle/v2/fluid/layers/ops.py
python/paddle/v2/fluid/layers/ops.py
+9
-0
python/paddle/v2/fluid/layers/tensor.py
python/paddle/v2/fluid/layers/tensor.py
+130
-0
python/paddle/v2/fluid/tests/book/test_image_classification_train.py
...le/v2/fluid/tests/book/test_image_classification_train.py
+2
-2
python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py
...dle/v2/fluid/tests/book/test_understand_sentiment_lstm.py
+46
-2
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
README.md
浏览文件 @
a7fa2051
...
...
@@ -2,8 +2,8 @@
[
![Build Status
](
https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop
)
](https://travis-ci.org/PaddlePaddle/Paddle)
[
![Documentation Status
](
https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat
)
](http://
doc.paddlepaddle.org/develop/doc/
)
[
![Documentation Status
](
https://img.shields.io/badge/中文文档-最新-brightgreen.svg
)
](http://
doc.paddlepaddle.org/develop/doc_cn/
)
[
![Documentation Status
](
https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat
)
](http://
www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html
)
[
![Documentation Status
](
https://img.shields.io/badge/中文文档-最新-brightgreen.svg
)
](http://
www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html
)
[
![Coverage Status
](
https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop
)
](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop)
[
![Release
](
https://img.shields.io/github/release/PaddlePaddle/Paddle.svg
)
](https://github.com/PaddlePaddle/Paddle/releases)
[
![License
](
https://img.shields.io/badge/license-Apache%202-blue.svg
)
](LICENSE)
...
...
paddle/capi/error.cpp
浏览文件 @
a7fa2051
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include "error.h"
const
char
*
paddle_error_string
(
paddle_error
err
)
{
extern
"C"
const
char
*
paddle_error_string
(
paddle_error
err
)
{
switch
(
err
)
{
case
kPD_NULLPTR
:
return
"nullptr error"
;
...
...
paddle/capi/error.h
浏览文件 @
a7fa2051
...
...
@@ -29,9 +29,17 @@ typedef enum {
kPD_UNDEFINED_ERROR
=
-
1
,
}
paddle_error
;
#ifdef __cplusplus
extern
"C"
{
#endif
/**
* Error string for Paddle API.
*/
PD_API
const
char
*
paddle_error_string
(
paddle_error
err
);
#ifdef __cplusplus
}
#endif
#endif
paddle/framework/backward.cc
浏览文件 @
a7fa2051
...
...
@@ -430,14 +430,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
op_grads
;
if
((
*
it
)
->
Type
()
==
"recurrent"
||
(
*
it
)
->
Type
()
==
"while"
)
{
int
step_block_idx
=
(
*
it
)
->
GetBlockAttr
(
"s
tep
_block"
);
int
step_block_idx
=
(
*
it
)
->
GetBlockAttr
(
"s
ub
_block"
);
BlockDescBind
*
backward_block
=
CreateStepBlock
(
program_desc
,
no_grad_vars
,
grad_to_var
,
step_block_idx
);
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
}
else
if
((
*
it
)
->
Type
()
==
"conditional_block"
)
{
BlockDescBind
*
backward_block
=
CreateStepBlock
(
program_desc
,
no_grad_vars
,
grad_to_var
,
(
*
it
)
->
GetBlockAttr
(
"block"
));
(
*
it
)
->
GetBlockAttr
(
"
sub_
block"
));
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
}
else
{
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
);
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
a7fa2051
...
...
@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
scopes
->
front
()
=
&
scope
.
NewScope
();
auto
&
cur_scope
=
*
scopes
->
front
();
auto
*
block
=
Attr
<
framework
::
BlockDescBind
*>
(
"block"
);
auto
*
block
=
Attr
<
framework
::
BlockDescBind
*>
(
"
sub_
block"
);
framework
::
Executor
exec
(
dev_ctx
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
}
...
...
@@ -88,7 +88,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>"
);
AddAttr
<
framework
::
BlockDescBind
*>
(
"block"
,
"The step block of conditional block operator"
);
"
sub_
block"
,
"The step block of conditional block operator"
);
AddComment
(
R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
...
...
@@ -117,7 +117,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto
&
scopes
=
scope_var
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
framework
::
Scope
&
cur_scope
=
*
scopes
[
0
];
auto
*
block
=
Attr
<
framework
::
BlockDescBind
*>
(
"block"
);
auto
*
block
=
Attr
<
framework
::
BlockDescBind
*>
(
"
sub_
block"
);
framework
::
Executor
exec
(
dev_ctx
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
...
...
@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetInput
(
"Scope"
,
Output
(
"Scope"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
));
grad_op
->
SetBlockAttr
(
"block"
,
*
this
->
grad_block_
[
0
]);
grad_op
->
SetBlockAttr
(
"
sub_
block"
,
*
this
->
grad_block_
[
0
]);
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
grad_op
);
}
};
...
...
paddle/operators/math/math_function.cu
浏览文件 @
a7fa2051
...
...
@@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU
(
context
,
tensor
,
value
));
}
template
<
>
void
set_constant_with_place
<
platform
::
CudnnPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
set_constant_with_place
<
platform
::
GPUPlace
>
(
context
,
tensor
,
value
);
}
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
float
>;
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
a7fa2051
...
...
@@ -25,7 +25,7 @@ constexpr char kOutputs[] = "outputs";
constexpr
char
kStepScopes
[]
=
"step_scopes"
;
constexpr
char
kExStates
[]
=
"ex_states"
;
constexpr
char
kStates
[]
=
"states"
;
constexpr
char
kStepBlock
[]
=
"s
tep
_block"
;
constexpr
char
kStepBlock
[]
=
"s
ub
_block"
;
constexpr
char
kReverse
[]
=
"reverse"
;
constexpr
char
kIsTrain
[]
=
"is_train"
;
#define GRAD_SUFFIX "@GRAD"
...
...
paddle/operators/while_op.cc
浏览文件 @
a7fa2051
...
...
@@ -25,7 +25,7 @@ namespace operators {
using
StepScopeVar
=
std
::
vector
<
framework
::
Scope
*>
;
using
LoDTensor
=
framework
::
LoDTensor
;
constexpr
char
kStepBlock
[]
=
"s
tep
_block"
;
constexpr
char
kStepBlock
[]
=
"s
ub
_block"
;
constexpr
char
kCondition
[]
=
"Condition"
;
constexpr
char
kStepScopes
[]
=
"StepScopes"
;
constexpr
char
kParameters
[]
=
"X"
;
...
...
paddle/platform/device_context.cc
浏览文件 @
a7fa2051
...
...
@@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
CudnnDeviceContext
::
CudnnDeviceContext
(
CudnnPlace
place
)
:
CUDADeviceContext
(
place
),
place_
(
place
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream
()));
}
CudnnDeviceContext
::~
CudnnDeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
Place
CudnnDeviceContext
::
GetPlace
()
const
{
return
CudnnPlace
();
}
cudnnHandle_t
CudnnDeviceContext
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
#endif
}
// namespace platform
...
...
paddle/platform/device_context.h
浏览文件 @
a7fa2051
...
...
@@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t
cublas_handle_
;
};
class
CudnnDeviceContext
:
public
CUDADeviceContext
{
public:
explicit
CudnnDeviceContext
(
CudnnPlace
place
);
virtual
~
CudnnDeviceContext
();
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
final
;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
private:
cudnnHandle_t
cudnn_handle_
;
CudnnPlace
place_
;
};
#endif
}
// namespace platform
...
...
paddle/platform/device_context_test.cc
浏览文件 @
a7fa2051
...
...
@@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
delete
device_context
;
}
}
TEST
(
Device
,
CudnnDeviceContext
)
{
using
paddle
::
platform
::
CudnnDeviceContext
;
using
paddle
::
platform
::
CudnnPlace
;
if
(
paddle
::
platform
::
dynload
::
HasCUDNN
())
{
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
CudnnDeviceContext
*
device_context
=
new
CudnnDeviceContext
(
CudnnPlace
(
i
));
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
}
}
}
paddle/platform/place.h
浏览文件 @
a7fa2051
...
...
@@ -43,6 +43,11 @@ struct GPUPlace {
int
device
;
};
struct
CudnnPlace
:
public
GPUPlace
{
CudnnPlace
()
:
GPUPlace
()
{}
explicit
CudnnPlace
(
int
d
)
:
GPUPlace
(
d
)
{}
};
struct
IsGPUPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
GPUPlace
&
gpu
)
const
{
return
true
;
}
...
...
@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef
boost
::
variant
<
GPUPlace
,
CPUPlace
>
Place
;
typedef
boost
::
variant
<
CudnnPlace
,
GPUPlace
,
CPUPlace
>
Place
;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
...
...
paddle/pybind/pybind.cc
浏览文件 @
a7fa2051
...
...
@@ -282,6 +282,23 @@ All parameter, weight, gradient are variables in Paddle.
}
return
ret_values
;
});
m
.
def
(
"get_grad_op_descs"
,
[](
const
OpDescBind
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
grad_to_var
,
const
std
::
vector
<
BlockDescBind
*>
&
grad_sub_block
)
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
grad_op_descs
=
framework
::
OpInfoMap
::
Instance
()
.
Get
(
op_desc
.
Type
())
.
GradOpMaker
()(
op_desc
,
no_grad_set
,
&
grad_to_var
,
grad_sub_block
);
std
::
vector
<
OpDescBind
*>
grad_op_desc_ptrs
(
grad_op_descs
.
size
());
std
::
transform
(
grad_op_descs
.
begin
(),
grad_op_descs
.
end
(),
grad_op_desc_ptrs
.
begin
(),
[](
std
::
unique_ptr
<
OpDescBind
>
&
p
)
{
return
p
.
release
();
});
return
grad_op_desc_ptrs
;
});
m
.
def
(
"prune"
,
[](
const
ProgramDescBind
&
origin
,
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
ProgramDescBind
prog_with_targets
(
origin
);
...
...
python/paddle/v2/fluid/layers/__init__.py
0 → 100644
浏览文件 @
a7fa2051
import
ops
from
ops
import
*
import
nn
from
nn
import
*
import
io
from
io
import
*
import
tensor
from
tensor
import
*
import
control_flow
from
control_flow
import
*
__all__
=
[]
__all__
+=
nn
.
__all__
__all__
+=
io
.
__all__
__all__
+=
tensor
.
__all__
__all__
+=
control_flow
.
__all__
__all__
+=
ops
.
__all__
python/paddle/v2/fluid/layers.py
→
python/paddle/v2/fluid/layers
/control_flow
.py
浏览文件 @
a7fa2051
此差异已折叠。
点击以展开。
python/paddle/v2/fluid/layers/io.py
0 → 100644
浏览文件 @
a7fa2051
from
..
import
core
from
..layer_helper
import
LayerHelper
__all__
=
[
'data'
]
def
data
(
name
,
shape
,
append_batch_size
=
True
,
dtype
=
'float32'
,
lod_level
=
0
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
main_program
=
None
,
startup_program
=
None
,
stop_gradient
=
True
):
"""
Data Layer.
Args:
name: The name/alias of the function
shape: Tuple declaring the shape.
append_batch_size: Whether or not to append the data as a batch.
dtype: The type of data : float32, float_16, int etc
type: The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program: Name of the main program that calls this
startup_program: Name of the startup program
stop_gradient: A boolean that mentions whether gradient should flow.
This function takes in input and based on whether data has
to be returned back as a minibatch, it creates the global variable using
the helper functions. The global variables can be accessed by all the
following operations and layers in the graph.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper
=
LayerHelper
(
'data'
,
**
locals
())
shape
=
list
(
shape
)
for
i
in
xrange
(
len
(
shape
)):
if
shape
[
i
]
is
None
:
shape
[
i
]
=
-
1
append_batch_size
=
False
elif
shape
[
i
]
<
0
:
append_batch_size
=
False
if
append_batch_size
:
shape
=
[
-
1
]
+
shape
# append batch size as -1
return
helper
.
create_global_variable
(
name
=
name
,
shape
=
shape
,
dtype
=
dtype
,
type
=
type
,
stop_gradient
=
stop_gradient
,
lod_level
=
lod_level
)
python/paddle/v2/fluid/layers/nn.py
0 → 100644
浏览文件 @
a7fa2051
此差异已折叠。
点击以展开。
python/paddle/v2/fluid/layers/ops.py
0 → 100644
浏览文件 @
a7fa2051
from
..registry
import
register_layer
__all__
=
[
'mean'
,
'mul'
,
'dropout'
,
'reshape'
,
'sigmoid'
,
'scale'
,
'transpose'
,
'sigmoid_cross_entropy_with_logits'
,
'elementwise_add'
,
'elementwise_div'
,
'elementwise_sub'
,
'elementwise_mul'
,
'clip'
,
'abs'
]
for
_OP
in
set
(
__all__
):
globals
()[
_OP
]
=
register_layer
(
_OP
)
python/paddle/v2/fluid/layers/tensor.py
0 → 100644
浏览文件 @
a7fa2051
from
..layer_helper
import
LayerHelper
__all__
=
[
'create_tensor'
,
'cast'
,
'concat'
,
'sums'
,
'assign'
,
'fill_constant_batch_size_like'
,
'fill_constant'
,
'ones'
,
'zeros'
]
def
create_tensor
(
dtype
,
name
=
None
,
main_program
=
None
,
startup_program
=
None
):
helper
=
LayerHelper
(
"create_tensor"
,
**
locals
())
return
helper
.
create_variable
(
name
=
helper
.
name
,
dtype
=
dtype
)
def
cast
(
x
,
dtype
,
main_program
=
None
):
"""
This function takes in the input with input_dtype
and casts it to the output_dtype as the output.
"""
helper
=
LayerHelper
(
'cast'
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'in_dtype'
:
x
.
dtype
,
'out_dtype'
:
out
.
dtype
})
return
out
def
concat
(
input
,
axis
,
main_program
=
None
,
startup_program
=
None
):
"""
This function concats the input along the axis mentioned
and returns that as the output.
"""
helper
=
LayerHelper
(
'concat'
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'concat'
,
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'axis'
:
axis
})
return
out
def
sums
(
input
,
out
=
None
,
main_program
=
None
,
startup_program
=
None
):
"""
This function takes in the input and performs the sum operation on it
and returns that as the output.
"""
helper
=
LayerHelper
(
'sum'
,
**
locals
())
if
out
is
None
:
out
=
helper
.
create_tmp_variable
(
dtype
=
helper
.
input_dtype
())
helper
.
append_op
(
type
=
'sum'
,
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
out
})
return
out
def
assign
(
input
,
output
,
main_program
=
None
,
startup_program
=
None
):
helper
=
LayerHelper
(
'assign'
,
**
locals
())
helper
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
[
input
]},
outputs
=
{
'Out'
:
[
output
]},
attrs
=
{
'scale'
:
1.0
})
return
output
def
fill_constant
(
shape
,
dtype
,
value
,
out
=
None
,
main_program
=
None
,
startup_program
=
None
):
"""
This function creates a tensor , with shape as mentioned in the input and
specified dtype and fills this up with a constant value that
comes in the input. It also sets the stop_gradient to be True.
"""
helper
=
LayerHelper
(
"fill_constant"
,
**
locals
())
if
out
is
None
:
out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'fill_constant'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'shape'
:
shape
,
'dtype'
:
out
.
dtype
,
'value'
:
float
(
value
)})
out
.
stop_gradient
=
True
return
out
def
fill_constant_batch_size_like
(
input
,
shape
,
dtype
,
value
,
input_dim_idx
=
0
,
output_dim_idx
=
0
,
main_program
=
None
,
startup_program
=
None
):
helper
=
LayerHelper
(
"fill_constant_batch_size_like"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
'fill_constant_batch_size_like'
,
inputs
=
{
'Input'
:
input
},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'shape'
:
shape
,
'dtype'
:
out
.
dtype
,
'value'
:
float
(
value
),
'input_dim_idx'
:
input_dim_idx
,
'output_dim_idx'
:
output_dim_idx
})
out
.
stop_gradient
=
True
return
out
def
ones
(
shape
,
dtype
,
main_program
=
None
):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 1.0.
"""
return
fill_constant
(
value
=
1.0
,
**
locals
())
def
zeros
(
shape
,
dtype
,
main_program
=
None
):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 0.0.
"""
return
fill_constant
(
value
=
0.0
,
**
locals
())
python/paddle/v2/fluid/tests/book/test_image_classification_train.py
浏览文件 @
a7fa2051
from
__future__
import
print_function
import
numpy
as
np
import
sys
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
import
sys
def
resnet_cifar10
(
input
,
depth
=
32
):
...
...
python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py
浏览文件 @
a7fa2051
import
numpy
as
np
import
paddle.v2
as
paddle
import
paddle.v2.fluid
as
fluid
from
paddle.v2.fluid.layer_helper
import
LayerHelper
def
lstm
(
x
,
c_pre_init
,
hidden_dim
,
forget_bias
=
None
,
main_program
=
None
,
startup_program
=
None
):
"""
This function helps create an operator for the LSTM (Long Short Term
Memory) cell that can be used inside an RNN.
"""
helper
=
LayerHelper
(
'lstm_unit'
,
**
locals
())
rnn
=
fluid
.
layers
.
StaticRNN
()
with
rnn
.
step
():
c_pre
=
rnn
.
memory
(
init
=
c_pre_init
)
x_t
=
rnn
.
step_input
(
x
)
before_fc
=
fluid
.
layers
.
concat
(
input
=
[
x_t
,
c_pre
],
axis
=
1
,
main_program
=
main_program
,
startup_program
=
startup_program
)
after_fc
=
fluid
.
layers
.
fc
(
input
=
before_fc
,
size
=
hidden_dim
*
4
,
main_program
=
main_program
,
startup_program
=
startup_program
)
dtype
=
x
.
dtype
c
=
helper
.
create_tmp_variable
(
dtype
)
h
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
'lstm_unit'
,
inputs
=
{
"X"
:
after_fc
,
"C_prev"
:
c_pre
},
outputs
=
{
"C"
:
c
,
"H"
:
h
},
attrs
=
{
"forget_bias"
:
forget_bias
})
rnn
.
update_memory
(
c_pre
,
c
)
rnn
.
output
(
h
)
return
rnn
()
def
lstm_net
(
dict_dim
,
class_dim
=
2
,
emb_dim
=
32
,
seq_len
=
80
,
batch_size
=
50
):
...
...
@@ -23,8 +68,7 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
c_pre_init
=
fluid
.
layers
.
fill_constant
(
dtype
=
emb
.
dtype
,
shape
=
[
batch_size
,
emb_dim
],
value
=
0.0
)
c_pre_init
.
stop_gradient
=
False
layer_1_out
=
fluid
.
layers
.
lstm
(
emb
,
c_pre_init
=
c_pre_init
,
hidden_dim
=
emb_dim
)
layer_1_out
=
lstm
(
emb
,
c_pre_init
=
c_pre_init
,
hidden_dim
=
emb_dim
)
layer_1_out
=
fluid
.
layers
.
transpose
(
x
=
layer_1_out
,
axis
=
[
1
,
0
,
2
])
prediction
=
fluid
.
layers
.
fc
(
input
=
layer_1_out
,
...
...
python/setup.py.in
浏览文件 @
a7fa2051
...
...
@@ -68,6 +68,7 @@ packages=['paddle',
'paddle.v2.plot',
'paddle.v2.fluid',
'paddle.v2.fluid.proto',
'paddle.v2.fluid.layers',
'py_paddle']
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录