Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
01425309
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
01425309
编写于
11月 07, 2017
作者:
Y
Yang Yu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename shrink_state -> shrink_rnn_memory
Follow comments
上级
b4dddb29
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
33 addition
and
41 deletion
+33
-41
paddle/operators/shrink_rnn_memory_op.cc
paddle/operators/shrink_rnn_memory_op.cc
+30
-37
paddle/operators/tensor_array_read_write_op.cc
paddle/operators/tensor_array_read_write_op.cc
+0
-1
python/paddle/v2/framework/layers.py
python/paddle/v2/framework/layers.py
+1
-1
python/paddle/v2/framework/tests/test_shrink_rnn_memory.py
python/paddle/v2/framework/tests/test_shrink_rnn_memory.py
+2
-2
未找到文件。
paddle/operators/shrink_
state
_op.cc
→
paddle/operators/shrink_
rnn_memory
_op.cc
浏览文件 @
01425309
...
@@ -18,9 +18,9 @@
...
@@ -18,9 +18,9 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
Shrink
State
Op
:
public
ArrayOp
{
class
Shrink
RNNMemory
Op
:
public
ArrayOp
{
public:
public:
Shrink
State
Op
(
const
std
::
string
&
type
,
Shrink
RNNMemory
Op
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
...
@@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp {
...
@@ -36,18 +36,12 @@ class ShrinkStateOp : public ArrayOp {
PADDLE_ENFORCE
(
rank_table_var
!=
nullptr
,
"RankTable must be set"
);
PADDLE_ENFORCE
(
rank_table_var
!=
nullptr
,
"RankTable must be set"
);
auto
&
rank_table
=
rank_table_var
->
Get
<
framework
::
LoDRankTable
>
();
auto
&
rank_table
=
rank_table_var
->
Get
<
framework
::
LoDRankTable
>
();
int
dst_num_rows
=
0
;
{
auto
&
rank_items
=
rank_table
.
items
();
auto
&
rank_items
=
rank_table
.
items
();
for
(
auto
&
rank_item
:
rank_items
)
{
int
dst_num_rows
=
if
(
offset
<
rank_item
.
length
)
{
std
::
lower_bound
(
rank_items
.
begin
(),
rank_items
.
end
(),
offset
,
++
dst_num_rows
;
[](
const
framework
::
LoDRankTable
::
TableItem
&
a
,
}
else
{
size_t
b
)
{
return
a
.
length
>
b
;
})
-
break
;
rank_items
.
begin
();
}
}
}
auto
*
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
*
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output Out must be set"
);
PADDLE_ENFORCE
(
out_var
!=
nullptr
,
"Output Out must be set"
);
...
@@ -58,9 +52,9 @@ class ShrinkStateOp : public ArrayOp {
...
@@ -58,9 +52,9 @@ class ShrinkStateOp : public ArrayOp {
}
}
};
};
class
Shrink
State
OpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Shrink
RNNMemory
OpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
Shrink
State
OpProtoMaker
(
framework
::
OpProto
*
proto
,
Shrink
RNNMemory
OpProtoMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
""
);
AddInput
(
"X"
,
""
);
...
@@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -71,7 +65,7 @@ class ShrinkStateOpProtoMaker : public framework::OpProtoAndCheckerMaker {
}
}
};
};
class
Shrink
StateOp
InferShape
:
public
framework
::
InferShapeBase
{
class
Shrink
RNNMemory
InferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
...
@@ -81,9 +75,9 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase {
...
@@ -81,9 +75,9 @@ class ShrinkStateOpInferShape : public framework::InferShapeBase {
}
}
};
};
class
Shrink
State
GradOp
:
public
ArrayOp
{
class
Shrink
RNNMemory
GradOp
:
public
ArrayOp
{
public:
public:
Shrink
State
GradOp
(
const
std
::
string
&
type
,
Shrink
RNNMemory
GradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
...
@@ -92,8 +86,7 @@ class ShrinkStateGradOp : public ArrayOp {
...
@@ -92,8 +86,7 @@ class ShrinkStateGradOp : public ArrayOp {
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
*
dout_var
=
scope
.
FindVar
(
Input
(
framework
::
GradVarName
(
"Out"
)));
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
*
dx_var
=
scope
.
FindVar
(
Output
(
framework
::
GradVarName
(
"X"
)));
auto
*
dx_var
=
scope
.
FindVar
(
dx_name
);
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
PADDLE_ENFORCE
(
dx_var
!=
nullptr
,
"Input Gradient should not be nullptr"
);
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
*
x_var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE
(
x_var
!=
nullptr
);
PADDLE_ENFORCE
(
x_var
!=
nullptr
);
...
@@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp {
...
@@ -110,7 +103,7 @@ class ShrinkStateGradOp : public ArrayOp {
auto
height
=
dout_tensor
.
dims
()[
0
];
auto
height
=
dout_tensor
.
dims
()[
0
];
dx_tensor
.
Slice
(
0
,
static_cast
<
int
>
(
height
))
dx_tensor
.
Slice
(
0
,
static_cast
<
int
>
(
height
))
.
CopyFrom
(
dout_tensor
,
dout_tensor
.
place
(),
dev_ctx
);
.
CopyFrom
(
dout_tensor
,
dout_tensor
.
place
(),
dev_ctx
);
if
(
height
<
dout_tensor
.
dims
()[
0
]
)
{
if
(
dx_tensor
.
dims
()[
0
]
<
height
)
{
auto
rest_tensor
=
dx_tensor
.
Slice
(
auto
rest_tensor
=
dx_tensor
.
Slice
(
static_cast
<
int
>
(
height
),
static_cast
<
int
>
(
dout_tensor
.
dims
()[
0
]));
static_cast
<
int
>
(
height
),
static_cast
<
int
>
(
dout_tensor
.
dims
()[
0
]));
math
::
set_constant
(
dev_ctx
,
&
rest_tensor
,
0.0
f
);
math
::
set_constant
(
dev_ctx
,
&
rest_tensor
,
0.0
f
);
...
@@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp {
...
@@ -119,7 +112,7 @@ class ShrinkStateGradOp : public ArrayOp {
}
}
};
};
class
Shri
kState
GradInferShape
:
public
framework
::
InferShapeBase
{
class
Shri
nkRNNMemory
GradInferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
context
->
HasInput
(
"X"
));
...
@@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase {
...
@@ -129,14 +122,14 @@ class ShrikStateGradInferShape : public framework::InferShapeBase {
}
}
};
};
class
Shrink
State
GradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
class
Shrink
RNN
GradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDescBind
();
auto
*
op
=
new
framework
::
OpDescBind
();
op
->
SetType
(
"shrink_
state
_grad"
);
op
->
SetType
(
"shrink_
rnn_memory
_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
...
@@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
...
@@ -149,8 +142,8 @@ class ShrinkStateGradOpMaker : public framework::SingleGradOpDescMaker {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
shrink_
state
,
ops
::
ShrinkState
Op
,
REGISTER_OPERATOR
(
shrink_
rnn_memory
,
ops
::
ShrinkRNNMemory
Op
,
ops
::
Shrink
StateOpInferShape
,
ops
::
ShrinkStateOpProtoMaker
,
ops
::
Shrink
RNNMemoryInferShape
,
ops
::
Shrink
State
GradOpMaker
);
ops
::
Shrink
RNNMemoryOpProtoMaker
,
ops
::
ShrinkRNN
GradOpMaker
);
REGISTER_OPERATOR
(
shrink_
state_grad
,
ops
::
ShrinkState
GradOp
,
REGISTER_OPERATOR
(
shrink_
rnn_memory_grad
,
ops
::
ShrinkRNNMemory
GradOp
,
ops
::
Shri
kState
GradInferShape
);
ops
::
Shri
nkRNNMemory
GradInferShape
);
paddle/operators/tensor_array_read_write_op.cc
浏览文件 @
01425309
...
@@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
...
@@ -85,7 +85,6 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
public:
void
operator
()(
const
framework
::
OpDescBind
&
op_desc
,
void
operator
()(
const
framework
::
OpDescBind
&
op_desc
,
framework
::
BlockDescBind
*
block
)
const
override
{
framework
::
BlockDescBind
*
block
)
const
override
{
VLOG
(
10
)
<<
"I am here?"
;
for
(
auto
&
out_var
:
op_desc
.
OutputArgumentNames
())
{
for
(
auto
&
out_var
:
op_desc
.
OutputArgumentNames
())
{
VLOG
(
10
)
<<
"Set Variable "
<<
out_var
<<
" as LOD_TENSOR_ARRAY"
;
VLOG
(
10
)
<<
"Set Variable "
<<
out_var
<<
" as LOD_TENSOR_ARRAY"
;
block
->
Var
(
out_var
)
->
SetType
(
framework
::
VarDesc
::
LOD_TENSOR_ARRAY
);
block
->
Var
(
out_var
)
->
SetType
(
framework
::
VarDesc
::
LOD_TENSOR_ARRAY
);
...
...
python/paddle/v2/framework/layers.py
浏览文件 @
01425309
...
@@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None):
...
@@ -844,7 +844,7 @@ def shrink_memory(x, i, table, main_program=None):
helper
=
LayerHelper
(
'shrink_memory'
,
**
locals
())
helper
=
LayerHelper
(
'shrink_memory'
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
data_type
)
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
data_type
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'shrink_
state
'
,
type
=
'shrink_
rnn_memory
'
,
inputs
=
{
'X'
:
[
x
],
inputs
=
{
'X'
:
[
x
],
'I'
:
[
i
],
'I'
:
[
i
],
'RankTable'
:
[
table
]},
'RankTable'
:
[
table
]},
...
...
python/paddle/v2/framework/tests/test_shrink_
state
.py
→
python/paddle/v2/framework/tests/test_shrink_
rnn_memory
.py
浏览文件 @
01425309
...
@@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program
...
@@ -7,8 +7,8 @@ from paddle.v2.framework.framework import g_main_program
import
numpy
import
numpy
class
TestShrink
State
(
unittest
.
TestCase
):
class
TestShrink
RNNMemory
(
unittest
.
TestCase
):
def
test_shrink_
state
(
self
):
def
test_shrink_
rnn_memory
(
self
):
x
=
layers
.
data
(
'x'
,
shape
=
[
100
],
data_type
=
'float32'
)
x
=
layers
.
data
(
'x'
,
shape
=
[
100
],
data_type
=
'float32'
)
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
table
=
layers
.
lod_rank_table
(
x
=
x
)
table
=
layers
.
lod_rank_table
(
x
=
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录