Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d97a2b42
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d97a2b42
编写于
8月 08, 2017
作者:
Y
Yi Wang
提交者:
GitHub
8月 08, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3 from reyoung/feature/refactorize_framework_proto
Step 1: Make code compile well.
上级
72e3ba50
dba618c0
变更
35
隐藏空白更改
内联
并排
Showing
35 changed file
with
927 addition
and
960 deletion
+927
-960
.gitignore
.gitignore
+2
-1
paddle/framework/attribute.cc
paddle/framework/attribute.cc
+1
-1
paddle/framework/attribute.h
paddle/framework/attribute.h
+2
-3
paddle/framework/backward.cc
paddle/framework/backward.cc
+41
-24
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+224
-213
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+6
-0
paddle/framework/ddim.h
paddle/framework/ddim.h
+2
-0
paddle/framework/framework.proto
paddle/framework/framework.proto
+3
-3
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+4
-3
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+10
-6
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+33
-87
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+28
-8
paddle/framework/operator.cc
paddle/framework/operator.cc
+44
-55
paddle/framework/operator.h
paddle/framework/operator.h
+10
-35
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+33
-33
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+4
-3
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+4
-9
paddle/operators/add_op.h
paddle/operators/add_op.h
+3
-3
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+7
-13
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+1
-1
paddle/operators/fc_op.cc
paddle/operators/fc_op.cc
+8
-8
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+2
-10
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+3
-5
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+3
-5
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+22
-18
paddle/operators/net_op.h
paddle/operators/net_op.h
+1
-2
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+10
-9
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+8
-3
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+379
-370
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+4
-6
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+2
-2
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+4
-8
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+1
-3
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+0
-8
paddle/platform/enforce.h
paddle/platform/enforce.h
+18
-2
未找到文件。
.gitignore
浏览文件 @
d97a2b42
...
@@ -24,4 +24,5 @@ cmake-build-*
...
@@ -24,4 +24,5 @@ cmake-build-*
python/paddle/v2/framework/core.so
python/paddle/v2/framework/core.so
CMakeFiles
CMakeFiles
cmake_install.cmake
cmake_install.cmake
paddle/.timestamp
python/paddlepaddle.egg-info/
paddle/framework/attribute.cc
浏览文件 @
d97a2b42
...
@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
...
@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return
STRINGS
;
return
STRINGS
;
}
}
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
)
{
Attribute
GetAttrValue
(
const
OpDesc
::
Attr
&
attr_desc
)
{
switch
(
attr_desc
.
type
())
{
switch
(
attr_desc
.
type
())
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
return
attr_desc
.
i
();
return
attr_desc
.
i
();
...
...
paddle/framework/attribute.h
浏览文件 @
d97a2b42
...
@@ -21,8 +21,7 @@ limitations under the License. */
...
@@ -21,8 +21,7 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/framework/attribute.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
...
@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template
<
typename
T
>
template
<
typename
T
>
AttrType
AttrTypeID
();
AttrType
AttrTypeID
();
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
);
Attribute
GetAttrValue
(
const
OpDesc
::
Attr
&
attr_desc
);
// check whether a value(attribute) fit a certain limit
// check whether a value(attribute) fit a certain limit
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/framework/backward.cc
浏览文件 @
d97a2b42
...
@@ -20,15 +20,24 @@
...
@@ -20,15 +20,24 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
static
bool
AllInSet
(
const
std
::
vector
<
std
::
string
>&
names
,
template
<
typename
Map
,
typename
T
>
const
std
::
string
&
suffix
,
static
void
ForEachVarName
(
Map
&
names
,
T
callback
)
{
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
for
(
auto
&
name
:
names
)
{
for
(
auto
&
name
:
names
)
{
if
(
set
.
find
(
name
+
suffix
)
==
set
.
end
()
)
{
for
(
auto
&
n
:
name
.
second
)
{
return
false
;
if
(
callback
(
n
))
break
;
}
}
}
}
return
true
;
}
static
bool
AllInSet
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
bool
ret_val
=
true
;
ForEachVarName
(
names
,
[
&
ret_val
,
&
set
,
&
suffix
](
const
std
::
string
&
n
)
{
ret_val
=
set
.
find
(
n
+
suffix
)
==
set
.
end
();
return
!
ret_val
;
});
return
ret_val
;
}
}
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
...
@@ -67,10 +76,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -67,10 +76,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Then all input gradients cannot be computed at all, and we put them into
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
// `no_grad_names` set. Return an NOP.
if
(
AllInSet
(
forwardOp
.
outputs_
,
kGradVarSuffix
,
no_grad_names
))
{
if
(
AllInSet
(
forwardOp
.
outputs_
,
kGradVarSuffix
,
no_grad_names
))
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
ForEachVarName
(
forwardOp
.
inputs_
,
// Mark all input is not need
[
&
no_grad_names
](
const
std
::
string
&
name
)
->
bool
{
no_grad_names
.
insert
(
name
+
kGradVarSuffix
);
no_grad_names
.
insert
(
GradVarName
(
name
));
}
return
false
;
});
return
NOP
();
return
NOP
();
}
}
...
@@ -92,9 +102,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -92,9 +102,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto
fwd
=
*
it
;
auto
fwd
=
*
it
;
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
net
->
AddOp
(
bwd
);
for
(
auto
&
out
:
bwd
->
outputs_
)
{
ForEachVarName
(
bwd
->
outputs_
,
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
[
&
dup_output_ops
,
local_op_id
](
const
std
::
string
&
out
)
{
}
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
return
false
;
});
}
}
// Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
auto
uid
=
uniq_id
++
;
...
@@ -116,7 +128,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -116,7 +128,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position
.
push_back
(
insert_position
.
push_back
(
{
dup_op
.
back
(),
{
dup_op
.
back
(),
OpRegistry
::
CreateOp
(
OpRegistry
::
CreateOp
(
"add"
,
{
dup_outputs
},
{
name
},
"add"
,
{
{
"X"
,
{
dup_outputs
}}},
{{
"Out"
,
{
name
}}
},
{{
"input_format"
,
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
static_cast
<
int
>
(
dup_outputs
.
size
())}}})});
std
::
vector
<
int
>
{
0
,
static_cast
<
int
>
(
dup_outputs
.
size
())}}})});
}
}
...
@@ -130,7 +142,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -130,7 +142,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}
else
{
}
else
{
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
ForEachVarName
(
grad_op
->
inputs_
,
[
&
no_grad_names
,
&
net
](
std
::
string
&
grad_input
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
if
(
no_grad_names
.
count
(
grad_input
))
{
std
::
string
prefix
=
std
::
string
prefix
=
grad_input
.
substr
(
0
,
grad_input
.
size
()
-
kGradVarSuffix
.
size
());
grad_input
.
substr
(
0
,
grad_input
.
size
()
-
kGradVarSuffix
.
size
());
...
@@ -138,16 +152,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -138,16 +152,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
{
"Src"
,
{
prefix
}}
},
{
grad_input
},
{}));
{
{
"Dst"
,
{
grad_input
}}
},
{}));
}
}
}
return
false
;
});
for
(
std
::
string
&
grad_output
:
grad_op
->
outputs_
)
{
if
(
no_grad_names
.
count
(
grad_output
))
{
ForEachVarName
(
grad_op
->
outputs_
,
grad_output
=
kEmptyVarName
;
[
&
no_grad_names
](
std
::
string
&
grad_output
)
{
}
if
(
no_grad_names
.
count
(
grad_output
))
{
}
grad_output
=
kEmptyVarName
;
}
return
false
;
});
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
return
grad_op
;
...
...
paddle/framework/backward_test.cc
浏览文件 @
d97a2b42
...
@@ -44,8 +44,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
...
@@ -44,8 +44,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
public:
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"
A
"
,
"A"
);
AddInput
(
"
X
"
,
"A"
);
AddInput
(
"
B
"
,
"B"
);
AddInput
(
"
Y
"
,
"B"
);
AddOutput
(
"Out"
,
"Out"
);
AddOutput
(
"Out"
,
"Out"
);
AddComment
(
"Mul"
);
AddComment
(
"Mul"
);
}
}
...
@@ -56,7 +56,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
...
@@ -56,7 +56,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X"
);
AddInput
(
"X"
,
"X"
);
AddOutput
(
"
Y
"
,
"Y"
);
AddOutput
(
"
Out
"
,
"Y"
);
AddComment
(
"Sigmoid"
);
AddComment
(
"Sigmoid"
);
}
}
};
};
...
@@ -66,7 +66,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
...
@@ -66,7 +66,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
NoGradOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
NoGradOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X input"
);
AddInput
(
"X"
,
"X input"
);
AddOutput
(
"
Y
"
,
"Y output"
);
AddOutput
(
"
Out
"
,
"Y output"
);
AddComment
(
"NoGradOp, same input output. no Grad"
);
AddComment
(
"NoGradOp, same input output. no Grad"
);
}
}
};
};
...
@@ -74,13 +74,15 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
...
@@ -74,13 +74,15 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class
FcOp
:
public
ops
::
NetOp
{
class
FcOp
:
public
ops
::
NetOp
{
public:
public:
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Input
(
"X"
),
Input
(
"W"
)},
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Output
(
"mul_result"
)},
{}));
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
auto
b_name
=
Input
(
"b"
);
auto
b_name
=
Input
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
std
::
string
before_act
=
"mul_result"
;
if
(
b_name
!=
kEmptyVarName
)
{
if
(
b_name
!=
kEmptyVarName
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"mul_result"
),
b_name
},
AddOp
(
OpRegistry
::
CreateOp
(
{
Output
(
"add_result"
)},
{}));
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
b_name
}}},
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
before_act
=
"add_result"
;
before_act
=
"add_result"
;
}
else
{
}
else
{
auto
out_varname
=
Output
(
"add_result"
);
auto
out_varname
=
Output
(
"add_result"
);
...
@@ -89,8 +91,8 @@ class FcOp : public ops::NetOp {
...
@@ -89,8 +91,8 @@ class FcOp : public ops::NetOp {
}
}
}
}
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
Output
(
before_act
)},
{
Output
(
"Out"
)
},
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
{
"X"
,
{
Output
(
before_act
)}}
},
{}));
{
{
"Out"
,
{
Output
(
"Out"
)}}},
{
}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
};
};
...
@@ -158,206 +160,215 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
...
@@ -158,206 +160,215 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP
(
many_output_op
,
f
::
EmptyOp
,
f
::
ManyOutputOpMaker
);
REGISTER_OP
(
many_output_op
,
f
::
EmptyOp
,
f
::
ManyOutputOpMaker
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
TEST
(
Backward
,
simple_op_grad
)
{
//
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
// TEST(Backward, simple_op_grad) {
ASSERT_NE
(
fwd
,
nullptr
);
// auto fwd = f::OpRegistry::CreateOp(
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
// "rowwise_add", {{"X", {"X"}}, {"b", {"b"}}}, {{"Out", {"Out"}}}, {});
ASSERT_EQ
(
4UL
,
gop
->
inputs_
.
size
());
// ASSERT_NE(fwd, nullptr);
ASSERT_EQ
(
f
::
kEmptyVarName
,
gop
->
inputs_
[
0
]);
// auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
type_
);
// ASSERT_EQ(4UL, gop->inputs_.size());
ASSERT_EQ
(
"X"
+
f
::
kGradVarSuffix
,
gop
->
outputs_
[
0
]);
// ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
ASSERT_EQ
(
"b"
+
f
::
kGradVarSuffix
,
gop
->
outputs_
[
1
]);
// ASSERT_EQ("rowwise_add_grad", gop->type_);
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->outputs_[0]);
ASSERT_EQ
(
"X"
+
f
::
kGradVarSuffix
,
gop
->
Output
(
"X"
+
f
::
kGradVarSuffix
));
// ASSERT_EQ("b" + f::kGradVarSuffix, gop->outputs_[1]);
}
//
// ASSERT_EQ("X" + f::kGradVarSuffix, gop->Output("X" + f::kGradVarSuffix));
TEST
(
Backward
,
simple_op_not_need_grad
)
{
//}
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
//
ASSERT_NE
(
fwd
,
nullptr
);
// TEST(Backward, simple_op_not_need_grad) {
auto
gop
=
f
::
Backward
(
*
fwd
,
{
"X"
});
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_EQ
(
std
::
find
(
gop
->
outputs_
.
begin
(),
gop
->
outputs_
.
end
(),
// ASSERT_NE(fwd, nullptr);
"X"
+
f
::
kGradVarSuffix
),
// auto gop = f::Backward(*fwd, {"X"});
gop
->
outputs_
.
end
());
// ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
// "X" + f::kGradVarSuffix),
auto
no_input_gop
=
f
::
Backward
(
*
fwd
,
{
"X"
,
"b"
});
// gop->outputs_.end());
ASSERT_NE
(
no_input_gop
,
nullptr
);
//
ASSERT_TRUE
(
no_input_gop
->
IsNetOp
());
// auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_EQ
(
0UL
,
// ASSERT_NE(no_input_gop, nullptr);
std
::
static_pointer_cast
<
ops
::
NetOp
>
(
no_input_gop
)
->
ops_
.
size
());
// ASSERT_TRUE(no_input_gop->IsNetOp());
}
// ASSERT_EQ(0UL,
// std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
TEST
(
Backward
,
net_fc_backward_normal
)
{
//}
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
f
::
OpRegistry
::
CreateOp
(
//
"fc"
,
{
"X"
,
"w"
,
"b"
},
{
"mul_result"
,
"add_result"
,
"out"
},
{});
// TEST(Backward, net_fc_backward_normal) {
ASSERT_NE
(
fwd
,
nullptr
);
// std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
// "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
ASSERT_TRUE
(
gop
->
IsNetOp
());
// ASSERT_NE(fwd, nullptr);
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
gop
.
get
());
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
// ASSERT_TRUE(gop->IsNetOp());
ASSERT_NO_THROW
(
net
->
DebugString
());
// auto net = static_cast<ops::NetOp *>(gop.get());
//
ASSERT_EQ
(
3UL
,
net
->
ops_
.
size
());
// ASSERT_NO_THROW(net->DebugString());
//
f
::
OperatorBase
&
d_sigmoid
=
*
net
->
ops_
[
0
];
// ASSERT_EQ(3UL, net->ops_.size());
ASSERT_EQ
(
"sigmoid_grad"
,
d_sigmoid
.
type_
);
//
// f::OperatorBase &d_sigmoid = *net->ops_[0];
f
::
OperatorBase
&
d_add
=
*
net
->
ops_
[
1
];
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
ASSERT_EQ
(
"rowwise_add_grad"
,
d_add
.
type_
);
//
// f::OperatorBase &d_add = *net->ops_[1];
f
::
OperatorBase
&
d_mul
=
*
net
->
ops_
[
2
];
// ASSERT_EQ("rowwise_add_grad", d_add.type_);
ASSERT_EQ
(
"mul_grad"
,
d_mul
.
type_
);
//
}
// f::OperatorBase &d_mul = *net->ops_[2];
// ASSERT_EQ("mul_grad", d_mul.type_);
TEST
(
Backward
,
net_fc_backward_not_have_b
)
{
//}
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
//
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"w"
,
f
::
kEmptyVarName
},
// TEST(Backward, net_fc_backward_not_have_b) {
{
"mul_result"
,
"add_result"
,
"tmp"
},
{});
// std::shared_ptr<f::OperatorBase> fwd =
ASSERT_NE
(
fwd
,
nullptr
);
// f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName},
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
// {"mul_result", "add_result", "tmp"}, {});
ASSERT_TRUE
(
gop
->
IsNetOp
());
// ASSERT_NE(fwd, nullptr);
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
gop
.
get
());
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
// ASSERT_TRUE(gop->IsNetOp());
ASSERT_NO_THROW
(
net
->
DebugString
());
// auto net = static_cast<ops::NetOp *>(gop.get());
//
ASSERT_EQ
(
2UL
,
net
->
ops_
.
size
());
// ASSERT_NO_THROW(net->DebugString());
//
f
::
OperatorBase
&
d_sigmoid
=
*
net
->
ops_
[
0
];
// ASSERT_EQ(2UL, net->ops_.size());
ASSERT_EQ
(
"sigmoid_grad"
,
d_sigmoid
.
type_
);
//
// f::OperatorBase &d_sigmoid = *net->ops_[0];
f
::
OperatorBase
&
d_mul
=
*
net
->
ops_
[
1
];
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
ASSERT_EQ
(
"mul_grad"
,
d_mul
.
type_
);
//
}
// f::OperatorBase &d_mul = *net->ops_[1];
// ASSERT_EQ("mul_grad", d_mul.type_);
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
//}
ops
::
NetOp
net
;
//
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"W1"
,
"b1"
},
// TEST(Backward, net_input_of_network_not_need_grad) {
{
"mul_tmp_0"
,
"add_tmp_0"
,
"hidden0"
},
{}));
// ops::NetOp net;
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"hidden0"
,
"W2"
,
"b2"
},
// net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{
"mul_tmp_1"
,
"add_tmp_1"
,
"hidden1"
},
{}));
// {"mul_tmp_0", "add_tmp_0", "hidden0"},
net
.
CompleteAddOp
();
// {}));
auto
bwd
=
Backward
(
net
,
{
"X"
});
// X@GRAD is not need.
// net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
ASSERT_TRUE
(
bwd
->
IsNetOp
());
// {"mul_tmp_1", "add_tmp_1", "hidden1"},
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
// {}));
// net.CompleteAddOp();
std
::
unordered_set
<
std
::
string
>
all_output
=
std
::
unordered_set
<
std
::
string
>
(
// auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
bwd_net
->
outputs_
.
begin
(),
bwd_net
->
outputs_
.
end
());
// ASSERT_TRUE(bwd->IsNetOp());
all_output
.
erase
(
f
::
kEmptyVarName
);
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
//
for
(
auto
&
out
:
{
"W1"
,
"b1"
,
"hidden0"
,
"W2"
,
"b2"
})
{
// std::unordered_set<std::string> all_output =
ASSERT_NE
(
all_output
.
find
(
out
+
f
::
kGradVarSuffix
),
all_output
.
end
());
// std::unordered_set<std::string>(
}
// bwd_net->outputs_.begin(), bwd_net->outputs_.end());
// all_output.erase(f::kEmptyVarName);
// Not Generated X
//
ASSERT_EQ
(
all_output
.
find
(
"X"
+
f
::
kGradVarSuffix
),
all_output
.
end
());
// for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
// ASSERT_NE(all_output.find(out + f::kGradVarSuffix), all_output.end());
ASSERT_EQ
(
2UL
,
bwd_net
->
ops_
.
size
());
// }
ASSERT_TRUE
(
bwd_net
->
ops_
[
1
]
->
IsNetOp
());
//
auto
first_fc_grad
=
static_cast
<
ops
::
NetOp
*>
(
bwd_net
->
ops_
[
1
].
get
());
// // Not Generated X
ASSERT_EQ
(
3UL
,
first_fc_grad
->
ops_
.
size
());
// ASSERT_EQ(all_output.find("X" + f::kGradVarSuffix), all_output.end());
ASSERT_EQ
(
f
::
kEmptyVarName
,
//
first_fc_grad
->
ops_
[
2
]
->
Output
(
"A"
+
f
::
kGradVarSuffix
));
// ASSERT_EQ(2UL, bwd_net->ops_.size());
}
// ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
// auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
TEST
(
Backward
,
net_shared_weight
)
{
// ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ops
::
NetOp
net
;
// ASSERT_EQ(f::kEmptyVarName,
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"W"
},
{
"Out"
},
{}));
// first_fc_grad->ops_[2]->Output("A" + f::kGradVarSuffix));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"Out"
,
"W"
},
{
"FinalOut"
},
{}));
//}
net
.
CompleteAddOp
();
//
// TEST(Backward, net_shared_weight) {
auto
bwd
=
f
::
Backward
(
net
,
{});
// ops::NetOp net;
ASSERT_TRUE
(
bwd
->
IsNetOp
());
// net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
// net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
ASSERT_EQ
(
3UL
,
bwd_net
->
ops_
.
size
());
// net.CompleteAddOp();
ASSERT_EQ
(
"add"
,
bwd_net
->
ops_
[
2
]
->
type_
);
//
}
// auto bwd = f::Backward(net, {});
// ASSERT_TRUE(bwd->IsNetOp());
TEST
(
Backward
,
op_register_grad_not_for_network
)
{
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
// ASSERT_EQ(3UL, bwd_net->ops_.size());
"fc"
,
{
"X"
,
"W"
,
"b"
},
{
"mul_out"
,
"add_out"
,
"out1"
},
// ASSERT_EQ("add", bwd_net->ops_[2]->type_);
{{
"temporary_index"
,
std
::
vector
<
int
>
{
0
,
1
}}});
//}
//
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
// TEST(Backward, op_register_grad_not_for_network) {
}
// auto fwd = f::OpRegistry::CreateOp(
// "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"},
TEST
(
Backward
,
op_all_input_are_not_need
)
{
// {{"temporary_index", std::vector<int>{0, 1}}});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
//
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"X"
,
"b"
});
// ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
ASSERT_TRUE
(
backward
->
IsNetOp
());
//}
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
//
ASSERT_TRUE
(
net
->
ops_
.
empty
());
// TEST(Backward, op_all_input_are_not_need) {
}
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
// auto backward = f::Backward(*fwd, {"X", "b"});
TEST
(
Backward
,
op_all_output_are_not_need
)
{
// ASSERT_TRUE(backward->IsNetOp());
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
// auto net = static_cast<ops::NetOp *>(backward.get());
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Out"
});
// ASSERT_TRUE(net->ops_.empty());
ASSERT_TRUE
(
backward
->
IsNetOp
());
//}
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
//
ASSERT_TRUE
(
net
->
ops_
.
empty
());
// TEST(Backward, op_all_output_are_not_need) {
}
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
// auto backward = f::Backward(*fwd, {"Out"});
TEST
(
Backward
,
op_part_of_output_are_not_need
)
{
// ASSERT_TRUE(backward->IsNetOp());
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"many_output_op"
,
{
"X"
},
{
"Y"
,
"Z"
},
{});
// auto net = static_cast<ops::NetOp *>(backward.get());
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Z"
});
// ASSERT_TRUE(net->ops_.empty());
ASSERT_TRUE
(
backward
->
IsNetOp
());
//}
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
//
ASSERT_EQ
(
net
->
ops_
.
size
(),
2UL
);
// TEST(Backward, op_part_of_output_are_not_need) {
// auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto
&
fill_zero
=
*
net
->
ops_
[
0
];
// auto backward = f::Backward(*fwd, {"Z"});
ASSERT_EQ
(
"fill_zeros_like"
,
fill_zero
.
type_
);
// ASSERT_TRUE(backward->IsNetOp());
ASSERT_EQ
(
1UL
,
fill_zero
.
inputs_
.
size
());
// auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ
(
"Z"
,
fill_zero
.
inputs_
[
0
]);
// ASSERT_EQ(net->ops_.size(), 2UL);
ASSERT_EQ
(
1UL
,
fill_zero
.
outputs_
.
size
());
//
ASSERT_EQ
(
"Z"
+
f
::
kZeroVarSuffix
,
fill_zero
.
outputs_
[
0
]);
// auto &fill_zero = *net->ops_[0];
// ASSERT_EQ("fill_zeros_like", fill_zero.type_);
auto
&
d_many_out
=
*
net
->
ops_
[
1
];
// ASSERT_EQ(1UL, fill_zero.inputs_.size());
ASSERT_EQ
(
"many_output_op_grad"
,
d_many_out
.
type_
);
// ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ
(
1UL
+
2UL
+
2UL
,
d_many_out
.
inputs_
.
size
());
// I/O/OG
// ASSERT_EQ(1UL, fill_zero.outputs_.size());
ASSERT_EQ
(
"Z"
+
f
::
kZeroVarSuffix
,
d_many_out
.
Input
(
"z"
+
f
::
kGradVarSuffix
));
// ASSERT_EQ("Z" + f::kZeroVarSuffix, fill_zero.outputs_[0]);
ASSERT_EQ
(
"Y"
+
f
::
kGradVarSuffix
,
d_many_out
.
Input
(
"y"
+
f
::
kGradVarSuffix
));
//
ASSERT_EQ
(
"X"
+
f
::
kGradVarSuffix
,
// auto &d_many_out = *net->ops_[1];
d_many_out
.
Output
(
"x"
+
f
::
kGradVarSuffix
));
// ASSERT_EQ("many_output_op_grad", d_many_out.type_);
}
// ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ("Z" + f::kZeroVarSuffix, d_many_out.Input("z" +
TEST
(
Backward
,
op_part_of_input_are_not_need
)
{
// f::kGradVarSuffix));
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"a"
,
"b"
},
{
"out"
},
{});
// ASSERT_EQ("Y" + f::kGradVarSuffix, d_many_out.Input("y" +
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"a"
});
// f::kGradVarSuffix));
auto
&
grad_mul
=
*
backward
;
// ASSERT_EQ("X" + f::kGradVarSuffix,
ASSERT_EQ
(
grad_mul
.
type_
,
"mul_grad"
);
// d_many_out.Output("x" + f::kGradVarSuffix));
ASSERT_EQ
(
grad_mul
.
inputs_
.
size
(),
2UL
+
1UL
+
1UL
);
//}
ASSERT_EQ
(
grad_mul
.
outputs_
.
size
(),
2UL
);
//
ASSERT_EQ
(
grad_mul
.
Output
(
"A"
+
f
::
kGradVarSuffix
),
f
::
kEmptyVarName
);
// TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ
(
grad_mul
.
Output
(
"B"
+
f
::
kGradVarSuffix
),
"b"
+
f
::
kGradVarSuffix
);
// auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
ASSERT_EQ
(
grad_mul
.
Input
(
"Out"
+
f
::
kGradVarSuffix
),
// auto backward = f::Backward(*fwd, {"a"});
"out"
+
f
::
kGradVarSuffix
);
// auto &grad_mul = *backward;
ASSERT_EQ
(
grad_mul
.
Input
(
"A"
),
"a"
);
// ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ
(
grad_mul
.
Input
(
"B"
),
"b"
);
// ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ
(
grad_mul
.
Input
(
"Out"
),
"out"
);
// ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
}
// ASSERT_EQ(grad_mul.Output("A" + f::kGradVarSuffix), f::kEmptyVarName);
// ASSERT_EQ(grad_mul.Output("B" + f::kGradVarSuffix), "b" +
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
// f::kGradVarSuffix);
ops
::
NetOp
net
;
// ASSERT_EQ(grad_mul.Input("Out" + f::kGradVarSuffix),
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"x1"
,
"w1"
,
"b1"
},
// "out" + f::kGradVarSuffix);
{
"mul_out1"
,
"add_out1"
,
"out1"
},
{}));
// ASSERT_EQ(grad_mul.Input("A"), "a");
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"out1"
,
"w2"
,
"b2"
},
// ASSERT_EQ(grad_mul.Input("B"), "b");
{
"mul_out2"
,
"tmp_out2"
,
"out2"
},
{}));
// ASSERT_EQ(grad_mul.Input("Out"), "out");
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"out2"
,
"w3"
,
"b3"
},
//}
{
"mul_out3"
,
"tmp_out3"
,
"out3"
},
{}));
//
net
.
CompleteAddOp
();
// TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto
backward
=
f
::
Backward
(
net
,
{
"mul_out2"
,
"tmp_out2"
,
"out2"
});
// ops::NetOp net;
ASSERT_TRUE
(
backward
->
IsNetOp
());
// net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
// {"mul_out1", "add_out1", "out1"}, {}));
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
// {"mul_out2", "tmp_out2", "out2"}, {}));
EXPECT_EQ
(
grad_fc
.
inputs_
.
size
(),
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
3UL
/* external input number */
// {"mul_out3", "tmp_out3", "out3"}, {}));
+
1UL
/* external output number*/
// net.CompleteAddOp();
+
1UL
/* number of gradient of external output*/
// auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
+
2U
/* internal variable number*/
);
// ASSERT_TRUE(backward->IsNetOp());
EXPECT_EQ
(
grad_fc
.
outputs_
.
size
(),
2UL
/* input number of mul*/
// auto bwd_net = static_cast<ops::NetOp *>(backward.get());
+
2UL
/* input number of rowwise_add */
// ASSERT_EQ(bwd_net->ops_.size(), 3UL);
+
1UL
/* input number of sigmod */
);
// auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
.
size
(),
0UL
);
// EXPECT_EQ(grad_fc.inputs_.size(),
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
.
size
(),
0UL
);
// 3UL /* external input number */
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
.
size
(),
0UL
);
// + 1UL /* external output number*/
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
.
size
(),
0UL
);
// + 1UL /* number of gradient of external output*/
}
// + 2U /* internal variable number*/);
// EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
// + 2UL /* input number of rowwise_add
// */
// + 1UL /* input number of sigmod */);
// EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL);
// EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
// EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
// EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
//}
paddle/framework/ddim.cc
浏览文件 @
d97a2b42
...
@@ -284,5 +284,11 @@ DDim::DDim(std::initializer_list<int> init_list) {
...
@@ -284,5 +284,11 @@ DDim::DDim(std::initializer_list<int> init_list) {
*
this
=
make_ddim
(
init_list
);
*
this
=
make_ddim
(
init_list
);
}
}
std
::
string
DDim
::
DebugString
()
const
{
std
::
ostringstream
ss
;
ss
<<
*
this
;
return
ss
.
str
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/ddim.h
浏览文件 @
d97a2b42
...
@@ -73,6 +73,8 @@ struct DDim {
...
@@ -73,6 +73,8 @@ struct DDim {
DDim
operator
*
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
ssize_t
size
()
const
;
ssize_t
size
()
const
;
std
::
string
DebugString
()
const
;
};
};
/**
/**
...
...
paddle/framework/framework.proto
浏览文件 @
d97a2b42
...
@@ -40,8 +40,8 @@ message OpDesc {
...
@@ -40,8 +40,8 @@ message OpDesc {
};
};
message
Var
{
message
Var
{
required
string
name
;
// e.g. "X"
required
string
op_proto_name
=
1
;
optional
int
dup
=
2
[
default
=
0
];
// e.g., "1"
repeated
string
var_names
=
2
;
};
};
required
string
type
=
3
;
required
string
type
=
3
;
...
@@ -57,7 +57,7 @@ message OpProto {
...
@@ -57,7 +57,7 @@ message OpProto {
message
Var
{
message
Var
{
required
string
name
=
1
;
required
string
name
=
1
;
required
string
comment
=
2
;
required
string
comment
=
2
;
// OpDesc::Var::dup indices the duplica.
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
no_gradient
=
5
[
default
=
false
];
optional
bool
no_gradient
=
5
[
default
=
false
];
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
d97a2b42
...
@@ -13,12 +13,12 @@ express or implied. See the License for the specific language governing
...
@@ -13,12 +13,12 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/
op_proto
.pb.h"
#include "paddle/framework/
framework
.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
/**
class OpRegistry;
class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>;
using VarIndexMap = std::unordered_map<std::string, int>;
...
@@ -98,6 +98,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
...
@@ -98,6 +98,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op;
return grad_op;
}
}
**/
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
return
nullptr
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/grad_op_builder_test.cc
浏览文件 @
d97a2b42
...
@@ -47,8 +47,8 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
...
@@ -47,8 +47,8 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
TEST
(
GradOpBuilder
,
AddTwo
)
{
TEST
(
GradOpBuilder
,
AddTwo
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
add_op
(
std
::
shared_ptr
<
f
::
OperatorBase
>
add_op
(
f
::
OpRegistry
::
CreateOp
(
f
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
"add_two"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"y"
}}},
{{
"Out"
,
{
"out"
}}
},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_add_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_add_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
add_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
...
@@ -70,8 +70,10 @@ TEST(GradOpBuilder, MutiInOut) {
...
@@ -70,8 +70,10 @@ TEST(GradOpBuilder, MutiInOut) {
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
4
,
5
}},
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
4
,
5
}},
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
}}};
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
}}};
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"mult_io"
,
{
"in1"
,
"in2_1"
,
"in2_2"
,
"in2_3"
,
"in3"
},
"mult_io"
,
{{
"In1"
,
{
"in1"
}},
{
"out1"
,
"out2_1"
,
"out2_2"
},
attrs
));
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
,
"in2_3"
}},
{
"In3"
,
{
"in3"
}}},
{{
"Out1"
,
{
"Out2_mult"
}},
{
"Out2"
,
{
"out2_1"
,
"out2_2"
}}},
attrs
));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
...
@@ -104,8 +106,10 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
...
@@ -104,8 +106,10 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
,
5
}},
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
,
5
}},
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
2
,
3
}}};
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
2
,
3
}}};
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"io_ignored"
,
{
"in1"
,
"in2_1"
,
"in2_2"
,
"in3_1"
,
"in3_2"
},
"io_ignored"
,
{{
"In1"
,
{
"in1"
}},
{
"out1_1"
,
"out1_2"
,
"out2"
},
attrs
));
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
}},
{
"In3_mult"
,
{
"in3_1"
,
"in3_2"
}}},
{{
"Out1_mult"
,
{
"out1_1"
,
"out1_2"
}},
{
"Out2"
,
{
"out2"
}}},
attrs
));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
...
...
paddle/framework/op_registry.h
浏览文件 @
d97a2b42
...
@@ -20,8 +20,8 @@ limitations under the License. */
...
@@ -20,8 +20,8 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -44,25 +44,20 @@ class OpProtoAndCheckerMaker {
...
@@ -44,25 +44,20 @@ class OpProtoAndCheckerMaker {
protected:
protected:
struct
VariableBuilder
{
struct
VariableBuilder
{
VarProto
*
var_
;
OpProto
::
Var
*
var_
;
std
::
function
<
void
()
>
on_multiple_
;
std
::
function
<
void
()
>
on_temporary_
;
VariableBuilder
&
SetMultiple
()
{
VariableBuilder
&
SetMultiple
()
{
var_
->
set_multiple
(
true
);
var_
->
set_duplicable
(
true
);
on_multiple_
();
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
SetTemporary
()
{
VariableBuilder
&
SetTemporary
()
{
PADDLE_ENFORCE
(
bool
(
on_temporary_
),
"Cannot set temporary"
);
var_
->
set_intermediate
(
true
);
var_
->
set_temporary
(
true
);
on_temporary_
();
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
IgnoreGradient
()
{
VariableBuilder
&
IgnoreGradient
()
{
var_
->
set_
ignore
_gradient
(
true
);
var_
->
set_
no
_gradient
(
true
);
return
*
this
;
return
*
this
;
}
}
};
};
...
@@ -72,8 +67,7 @@ class OpProtoAndCheckerMaker {
...
@@ -72,8 +67,7 @@ class OpProtoAndCheckerMaker {
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_comment
()
=
comment
;
*
input
->
mutable_comment
()
=
comment
;
return
VariableBuilder
{
input
,
[
=
]
{
this
->
SetHasMultipleInput
();
},
return
VariableBuilder
{
input
};
nullptr
};
}
}
VariableBuilder
AddOutput
(
const
std
::
string
&
name
,
VariableBuilder
AddOutput
(
const
std
::
string
&
name
,
...
@@ -81,8 +75,7 @@ class OpProtoAndCheckerMaker {
...
@@ -81,8 +75,7 @@ class OpProtoAndCheckerMaker {
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_comment
()
=
comment
;
*
output
->
mutable_comment
()
=
comment
;
return
VariableBuilder
{
output
,
[
=
]
{
this
->
SetHasMultipleOutput
();
},
return
VariableBuilder
{
output
};
[
=
]
{
this
->
SetHasTemporaryOutput
();
}};
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -102,53 +95,6 @@ class OpProtoAndCheckerMaker {
...
@@ -102,53 +95,6 @@ class OpProtoAndCheckerMaker {
}
}
private:
private:
void
SetHasMultiple
(
const
std
::
string
&
in_out
,
bool
*
flag
)
{
if
(
!*
flag
)
{
AddAttr
<
std
::
vector
<
int
>>
(
in_out
+
"_format"
,
"The multiple index of "
+
in_out
+
"
\n
"
R
"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.
e.g.
input = ["
a
", "
b
", "
c
", "
d
", "
e
", "
f
"]
input_format = [0, 4, 5, 6]
means
The number of all input variables this op is six, and they are segmented into
three inputs.
The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC"
,
/*generated*/
true
);
*
flag
=
true
;
}
}
void
SetHasMultipleInput
()
{
SetHasMultiple
(
"input"
,
&
has_multiple_input_
);
}
void
SetHasMultipleOutput
()
{
SetHasMultiple
(
"output"
,
&
has_multiple_output_
);
}
void
SetHasTemporaryOutput
()
{
if
(
!
has_temporary_output_
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"temporary_index"
,
R
"DOC(The temporary index of output.
Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.
Add a mark to which output is temporary is helpful for future optimization.
)DOC"
,
/*generated*/
true
)
.
SetDefault
(
std
::
vector
<
int
>
());
has_temporary_output_
=
true
;
}
}
void
CheckNoDuplicatedInOutAttrs
()
{
void
CheckNoDuplicatedInOutAttrs
()
{
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
...
@@ -169,15 +115,12 @@ Add a mark to which output is temporary is helpful for future optimization.
...
@@ -169,15 +115,12 @@ Add a mark to which output is temporary is helpful for future optimization.
OpProto
*
proto_
;
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
OpAttrChecker
*
op_checker_
;
bool
validated_
{
false
};
bool
validated_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_temporary_output_
{
false
};
};
};
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
VarName
List
=
std
::
vector
<
std
::
string
>
;
using
VarName
Map
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>
>
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -213,8 +156,8 @@ class OpRegistry {
...
@@ -213,8 +156,8 @@ class OpRegistry {
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarName
List
&
inputs
,
const
VarName
Map
&
inputs
,
const
VarName
List
&
outputs
,
const
VarName
Map
&
outputs
,
const
AttributeMap
&
attrs
)
{
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
op_creators
().
find
(
type
);
auto
op_create_it
=
op_creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
...
@@ -230,27 +173,28 @@ class OpRegistry {
...
@@ -230,27 +173,28 @@ class OpRegistry {
GenerateTempVariableName
(
op
);
GenerateTempVariableName
(
op
);
{
auto
var_index_it
=
VarIndexMaps
().
find
(
type
);
if
(
var_index_it
!=
VarIndexMaps
().
end
())
{
op
->
in_out_idxs_
=
var_index_it
->
second
;
}
}
op
->
Init
();
op
->
Init
();
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
vector
<
std
::
string
>
inputs
;
VarNameMap
inputs
;
inputs
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
for
(
auto
&
input
:
op_desc
.
inputs
())
{
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
auto
&
var_names
=
inputs
[
input
.
op_proto_name
()];
std
::
back_inserter
(
inputs
));
auto
&
var_names_in_proto
=
input
.
var_names
();
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
}
std
::
vector
<
std
::
string
>
outputs
;
VarNameMap
outputs
;
outputs
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
for
(
auto
&
output
:
op_desc
.
outputs
())
{
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
auto
&
var_names
=
outputs
[
output
.
op_proto_name
()];
std
::
back_inserter
(
outputs
));
auto
&
var_names_in_proto
=
output
.
var_names
();
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
}
AttributeMap
attrs
;
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
...
@@ -303,11 +247,13 @@ class OpRegistry {
...
@@ -303,11 +247,13 @@ class OpRegistry {
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
for
(
auto
&
output
:
op
->
outputs_
)
{
if
(
outname
==
kTempVarName
)
{
for
(
auto
&
output_name
:
output
.
second
)
{
outname
+=
op
->
type_
;
if
(
output_name
==
kTempVarName
)
{
outname
+=
"@"
;
output_name
+=
op
->
type_
;
outname
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
output_name
+=
"@"
;
output_name
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
}
}
}
}
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
d97a2b42
...
@@ -57,8 +57,13 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
...
@@ -57,8 +57,13 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST
(
OpRegistry
,
CreateOp
)
{
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
auto
input
=
op_desc
.
add_inputs
();
op_desc
.
add_outputs
(
"bb"
);
input
->
set_op_proto_name
(
"input"
);
*
input
->
mutable_var_names
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_op_proto_name
(
"output"
);
*
output
->
mutable_var_names
()
->
Add
()
=
"bb"
;
float
scale
=
3.3
;
float
scale
=
3.3
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
...
@@ -78,8 +83,13 @@ TEST(OpRegistry, CreateOp) {
...
@@ -78,8 +83,13 @@ TEST(OpRegistry, CreateOp) {
TEST
(
OpRegistry
,
IllegalAttr
)
{
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
auto
input
=
op_desc
.
add_inputs
();
op_desc
.
add_outputs
(
"bb"
);
input
->
set_op_proto_name
(
"input"
);
*
input
->
mutable_var_names
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_op_proto_name
(
"output"
);
*
output
->
mutable_var_names
()
->
Add
()
=
"bb"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -103,8 +113,13 @@ TEST(OpRegistry, IllegalAttr) {
...
@@ -103,8 +113,13 @@ TEST(OpRegistry, IllegalAttr) {
TEST
(
OpRegistry
,
DefaultValue
)
{
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
auto
input
=
op_desc
.
add_inputs
();
op_desc
.
add_outputs
(
"bb"
);
input
->
set_op_proto_name
(
"input"
);
*
input
->
mutable_var_names
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_op_proto_name
(
"output"
);
*
output
->
mutable_var_names
()
->
Add
()
=
"bb"
;
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
...
@@ -127,8 +142,13 @@ static void SetInputFormat(paddle::framework::OpDesc* desc) {
...
@@ -127,8 +142,13 @@ static void SetInputFormat(paddle::framework::OpDesc* desc) {
TEST
(
OpRegistry
,
CustomChecker
)
{
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
add_inputs
(
"ii"
);
auto
input
=
op_desc
.
add_inputs
();
op_desc
.
add_outputs
(
"oo"
);
input
->
set_op_proto_name
(
"input"
);
*
input
->
mutable_var_names
()
->
Add
()
=
"ii"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_op_proto_name
(
"output"
);
*
output
->
mutable_var_names
()
->
Add
()
=
"oo"
;
SetInputFormat
(
&
op_desc
);
SetInputFormat
(
&
op_desc
);
// attr 'test_attr' is not set
// attr 'test_attr' is not set
...
...
paddle/framework/operator.cc
浏览文件 @
d97a2b42
...
@@ -34,83 +34,72 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
...
@@ -34,83 +34,72 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif
#endif
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
auto
it
=
inputs_
.
find
(
name
);
"Input Output Indices could not be nullptr"
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
return
inputs_
.
at
((
size_t
)
it
->
second
);
"Op %s input %s should contain only one variable"
,
type_
,
}
else
{
name
);
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
return
it
->
second
[
0
];
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
((
size_t
)
idx
);
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"IO Idx could not be nullptr"
);
const
std
::
string
&
name
)
const
{
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
return
inputs_
.
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
input_format
.
at
(
static_cast
<
size_t
>
(
offset
)
+
1
)
<=
static_cast
<
int
>
(
inputs_
.
size
()),
"Input Out Of Range"
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
+
1
)};
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
it
=
outputs_
.
find
(
name
);
auto
it
=
in_out_idxs_
->
find
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
return
outputs_
.
at
((
size_t
)
it
->
second
);
"Op %s input %s should contain only one variable"
,
type_
,
}
else
{
name
);
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
return
it
->
second
[
0
];
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
((
size_t
)
idx
);
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
return
outputs_
.
at
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
output_format
.
at
(
static_cast
<
size_t
>
(
offset
)
+
1
)
<=
static_cast
<
int
>
(
outputs_
.
size
()),
"Output Out of Range"
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
}
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
ss
<<
"Op("
<<
type_
<<
"), inputs:{"
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
for
(
auto
&
input
:
inputs_
)
{
ss
<<
inputs_
[
i
];
ss
<<
input
.
first
<<
"["
;
if
(
i
!=
inputs_
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
input
.
second
.
size
();
++
i
)
{
ss
<<
", "
;
ss
<<
input
.
second
[
i
];
if
(
i
!=
input
.
second
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
}
ss
<<
"]"
;
}
}
ss
<<
"), outputs:("
;
ss
<<
"}, outputs:{"
;
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
for
(
auto
&
output
:
outputs_
)
{
ss
<<
outputs_
[
i
];
ss
<<
output
.
first
<<
"["
;
if
(
i
!=
outputs_
.
size
()
-
1
)
{
for
(
size_t
i
=
0
;
i
<
output
.
second
.
size
();
++
i
)
{
ss
<<
", "
;
ss
<<
output
.
second
[
i
];
if
(
i
!=
output
.
second
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
}
ss
<<
"]"
;
}
}
ss
<<
"
)
."
;
ss
<<
"
}
."
;
return
ss
.
str
();
return
ss
.
str
();
}
}
void
OperatorBase
::
Rename
(
const
std
::
string
&
old_name
,
void
OperatorBase
::
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
const
std
::
string
&
new_name
)
{
std
::
replace
(
inputs_
.
begin
(),
inputs_
.
end
(),
old_name
,
new_name
);
for
(
auto
&
input
:
inputs_
)
{
std
::
replace
(
outputs_
.
begin
(),
outputs_
.
end
(),
old_name
,
new_name
);
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
}
for
(
auto
&
output
:
outputs_
)
{
std
::
replace
(
output
.
second
.
begin
(),
output
.
second
.
end
(),
old_name
,
new_name
);
}
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/operator.h
浏览文件 @
d97a2b42
...
@@ -21,8 +21,7 @@ limitations under the License. */
...
@@ -21,8 +21,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
...
@@ -95,13 +94,12 @@ class OperatorBase {
...
@@ -95,13 +94,12 @@ class OperatorBase {
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
//! Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
//! Get an output which has multiple variables.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
//! TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
public:
public:
std
::
string
type_
;
std
::
string
type_
;
...
@@ -109,13 +107,12 @@ class OperatorBase {
...
@@ -109,13 +107,12 @@ class OperatorBase {
// I (Inputs)
// I (Inputs)
// O (Outputs)
// O (Outputs)
// OG (Output Gradients)
// OG (Output Gradients)
std
::
vector
<
std
::
string
>
inputs_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
// IG (Inputs Gradients)
std
::
vector
<
std
::
string
>
outputs_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>
>
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
int
>>
in_out_idxs_
;
};
};
class
OperatorContext
{
class
OperatorContext
{
...
@@ -123,16 +120,12 @@ class OperatorContext {
...
@@ -123,16 +120,12 @@ class OperatorContext {
OperatorContext
(
const
OperatorBase
*
op
,
const
Scope
&
scope
)
OperatorContext
(
const
OperatorBase
*
op
,
const
Scope
&
scope
)
:
op_
(
*
op
),
scope_
(
scope
)
{}
:
op_
(
*
op
),
scope_
(
scope
)
{}
size_t
InputSize
()
const
{
return
op_
.
inputs_
.
size
();
}
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
inputs_
.
at
(
name
).
size
();
size_t
OutputSize
()
const
{
return
op_
.
outputs_
.
size
();
}
const
Variable
*
InputVar
(
const
size_t
index
)
const
{
return
scope_
.
FindVar
(
op_
.
inputs_
.
at
(
index
));
}
}
Variable
*
OutputVar
(
const
size_t
index
)
const
{
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
return
scope_
.
FindVar
(
op_
.
outputs_
.
at
(
index
)
);
return
op_
.
outputs_
.
at
(
name
).
size
(
);
}
}
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
...
@@ -164,24 +157,6 @@ class OperatorContext {
...
@@ -164,24 +157,6 @@ class OperatorContext {
return
res
;
return
res
;
}
}
template
<
typename
T
>
const
T
*
Input
(
const
size_t
index
)
const
{
auto
var
=
InputVar
(
index
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Input(%d) should not be nullptr"
,
index
);
return
&
var
->
Get
<
T
>
();
}
template
<
typename
T
>
T
*
Output
(
const
size_t
index
)
const
{
auto
var
=
OutputVar
(
index
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope"
,
index
,
op_
.
outputs_
[
index
]);
return
var
->
GetMutable
<
T
>
();
}
template
<
typename
T
>
template
<
typename
T
>
const
T
*
Input
(
const
std
::
string
&
name
)
const
{
const
T
*
Input
(
const
std
::
string
&
name
)
const
{
auto
var
=
InputVar
(
name
);
auto
var
=
InputVar
(
name
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
d97a2b42
...
@@ -27,12 +27,12 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -27,12 +27,12 @@ class OpWithoutKernelTest : public OperatorBase {
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
op_run_num
++
;
++
op_run_num
;
ASSERT_EQ
(
(
int
)
inputs_
.
size
(
),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()
),
1
);
ASSERT_EQ
(
(
int
)
outputs_
.
size
(
),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()
),
1
);
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
.
at
(
"input"
)
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
[
0
]),
nullptr
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
.
at
(
"output"
)
[
0
]),
nullptr
);
}
}
public:
public:
...
@@ -60,8 +60,13 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
...
@@ -60,8 +60,13 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
TEST
(
OperatorBase
,
all
)
{
TEST
(
OperatorBase
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
*
ipt
->
mutable_var_names
()
->
Add
()
=
"IN1"
;
ipt
->
set_op_proto_name
(
"input"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_var_names
()
->
Add
()
=
"OUT1"
;
output
->
set_op_proto_name
(
"output"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -113,24 +118,6 @@ class CPUKernelTest : public OpKernel {
...
@@ -113,24 +118,6 @@ class CPUKernelTest : public OpKernel {
}
}
};
};
// multiple inputs test
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
Input
(
"y"
),
"OUT1"
);
}
public:
float
x
=
0
;
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
:
public
OpProtoAndCheckerMaker
{
public:
public:
...
@@ -196,8 +183,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
...
@@ -196,8 +183,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
*
ipt
->
mutable_var_names
()
->
Add
()
=
"IN1"
;
ipt
->
set_op_proto_name
(
"input"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_var_names
()
->
Add
()
=
"OUT1"
;
output
->
set_op_proto_name
(
"output"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -223,12 +216,19 @@ TEST(OpKernel, multi_inputs) {
...
@@ -223,12 +216,19 @@ TEST(OpKernel, multi_inputs) {
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x0"
;
auto
x
=
op_desc
.
mutable_inputs
()
->
Add
();
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x1"
;
x
->
set_op_proto_name
(
"xs"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x2"
;
*
x
->
mutable_var_names
()
->
Add
()
=
"x0"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"k0"
;
*
x
->
mutable_var_names
()
->
Add
()
=
"x1"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y0"
;
*
x
->
mutable_var_names
()
->
Add
()
=
"x2"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y1"
;
auto
k
=
op_desc
.
mutable_inputs
()
->
Add
();
k
->
set_op_proto_name
(
"k"
);
*
k
->
mutable_var_names
()
->
Add
()
=
"k0"
;
auto
y
=
op_desc
.
mutable_outputs
()
->
Add
();
y
->
set_op_proto_name
(
"ys"
);
*
y
->
mutable_var_names
()
->
Add
()
=
"y0"
;
*
y
->
mutable_var_names
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
...
paddle/framework/pybind.cc
浏览文件 @
d97a2b42
...
@@ -53,9 +53,10 @@ void ExposeOperator(ClassType &m) {
...
@@ -53,9 +53,10 @@ void ExposeOperator(ClassType &m) {
return
op
.
type_
;
return
op
.
type_
;
})
})
.
def
(
"outputs"
,
.
def
(
"outputs"
,
[](
const
typename
ClassType
::
type
&
op
)
->
std
::
vector
<
std
::
string
>
{
[](
const
typename
ClassType
::
type
&
op
)
return
op
.
outputs_
;
->
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
{
})
return
op
.
outputs_
;
})
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
);
}
}
...
...
paddle/operators/add_op.cc
浏览文件 @
d97a2b42
...
@@ -20,15 +20,10 @@ namespace operators {
...
@@ -20,15 +20,10 @@ namespace operators {
class
AddOp
:
public
OperatorWithKernel
{
class
AddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
2
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
(),
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
);
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
(),
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
&&
ctx
.
InputVar
(
1
)
!=
nullptr
,
"Two input of Add Op's dimension must be same."
);
"Inputs of AddOp must all be set"
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Outputs of AddOp must all be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of Add Op's dimension must be same."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/add_op.h
浏览文件 @
d97a2b42
...
@@ -22,9 +22,9 @@ template <typename Place, typename T>
...
@@ -22,9 +22,9 @@ template <typename Place, typename T>
class
AddKernel
:
public
OpKernel
{
class
AddKernel
:
public
OpKernel
{
public:
public:
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
ExecutionContext
&
context
)
const
override
{
auto
input0
=
context
.
Input
<
Tensor
>
(
0
);
auto
*
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
input1
=
context
.
Input
<
Tensor
>
(
1
);
auto
*
input1
=
context
.
Input
<
Tensor
>
(
"Y"
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
d97a2b42
...
@@ -20,19 +20,13 @@ namespace operators {
...
@@ -20,19 +20,13 @@ namespace operators {
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
class
OnehotCrossEntropyOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
"Input size of OnehotCrossEntropyOp must be two"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"label"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of OnehotCrossEntropyOp must be one"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
().
size
(),
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
&&
ctx
.
InputVar
(
1
)
!=
nullptr
,
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
1
,
"label's dimension must be 1."
);
"Inputs of OnehotCrossEntropyOp must all be set"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
()[
0
],
label
->
dims
()[
0
]);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
X
->
dims
()[
0
]});
"Outputs of OnehotCrossEntropyOp must all be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
().
size
()
==
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE
(
ctx
.
Output
<
Tensor
>
(
0
)
->
dims
().
size
()
==
1
,
"label's dimension must be 1."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()[
0
]});
}
}
};
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
d97a2b42
...
@@ -43,7 +43,7 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
...
@@ -43,7 +43,7 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
1
)
->
data
<
int
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/operators/fc_op.cc
浏览文件 @
d97a2b42
...
@@ -22,19 +22,19 @@ class FullyConnectedOp : public NetOp {
...
@@ -22,19 +22,19 @@ class FullyConnectedOp : public NetOp {
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
{
Input
(
"X"
),
Input
(
"W"
)
,
{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}
,
},
},
{
Output
(
"before_act"
)
},
{}));
{
{
"Out"
,
{
Output
(
"before_act"
)}}
},
{}));
auto
b
=
Input
(
"b"
);
auto
b
=
Input
(
"b"
);
if
(
b
!=
framework
::
kEmptyVarName
)
{
if
(
b
!=
framework
::
kEmptyVarName
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
AddOp
(
OpRegistry
::
CreateOp
(
{
Output
(
"before_act"
),
Input
(
"b"
)
},
"rowwise_add"
,
{{
"X"
,
{
Output
(
"before_act"
)}},
{
"b"
,
{
Input
(
"b"
)}}
},
{
Output
(
"before_act"
)
},
{}));
{{
"Out"
,
{
Output
(
"before_act"
)}}
},
{}));
}
}
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
auto
activation
=
GetAttr
<
std
::
string
>
(
"activation"
);
AddOp
(
OpRegistry
::
CreateOp
(
activation
,
{
Output
(
"before_act"
)
},
AddOp
(
OpRegistry
::
CreateOp
(
activation
,
{
{
"X"
,
{
Output
(
"before_act"
)}}
},
{
Output
(
"Y"
)
},
{}));
{
{
"Out"
,
{
Output
(
"Out"
)}}
},
{}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
};
};
...
@@ -47,7 +47,7 @@ class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
...
@@ -47,7 +47,7 @@ class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
AddInput
(
"W"
,
"the weight of fc operator"
);
AddInput
(
"W"
,
"the weight of fc operator"
);
AddInput
(
"b"
,
"the bias of fc operator"
);
AddInput
(
"b"
,
"the bias of fc operator"
);
AddOutput
(
"
Y
"
,
"the output of fc operator"
);
AddOutput
(
"
Out
"
,
"the output of fc operator"
);
AddOutput
(
"before_act"
,
"the before activation output of fc operator"
)
AddOutput
(
"before_act"
,
"the before activation output of fc operator"
)
.
SetTemporary
();
.
SetTemporary
();
AddAttr
<
std
::
string
>
(
"activation"
,
"The activation key for fc layer"
)
AddAttr
<
std
::
string
>
(
"activation"
,
"The activation key for fc layer"
)
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
d97a2b42
...
@@ -20,16 +20,8 @@ namespace operators {
...
@@ -20,16 +20,8 @@ namespace operators {
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
ctx
.
Output
<
framework
::
Tensor
>
(
"Dst"
)
->
Resize
(
"Input size of FillZerosLikeOp must be one."
);
ctx
.
Input
<
framework
::
Tensor
>
(
"Src"
)
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
"Input of FillZerosLikeOp must be set."
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"Output of FillZerosLikeOp must be set."
);
ctx
.
Output
<
framework
::
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
framework
::
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/mean_op.cc
浏览文件 @
d97a2b42
...
@@ -20,11 +20,9 @@ namespace operators {
...
@@ -20,11 +20,9 @@ namespace operators {
class
MeanOp
:
public
OperatorWithKernel
{
class
MeanOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Input size of AddOp must be one"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"X"
)
!=
nullptr
,
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of AddOp must be one"
);
"Input of MeanOp must be initialized."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
&&
ctx
.
OutputVar
(
0
)
!=
nullptr
,
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
({
1
});
"Input/Output of MeanOp must be initialized."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
framework
::
make_ddim
({
1
}));
}
}
};
};
...
...
paddle/operators/mul_op.cc
浏览文件 @
d97a2b42
...
@@ -20,9 +20,8 @@ namespace operators {
...
@@ -20,9 +20,8 @@ namespace operators {
class
MulOp
:
public
OperatorWithKernel
{
class
MulOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE_EQ
(
dim0
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
dim0
.
size
(),
2
,
"input X(%s) should be a tensor with 2 dims, a matrix"
,
"input X(%s) should be a tensor with 2 dims, a matrix"
,
ctx
.
op_
.
Input
(
"X"
));
ctx
.
op_
.
Input
(
"X"
));
...
@@ -32,8 +31,7 @@ class MulOp : public OperatorWithKernel {
...
@@ -32,8 +31,7 @@ class MulOp : public OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dim0
[
1
],
dim1
[
0
],
dim0
[
1
],
dim1
[
0
],
"First matrix's width must be equal with second matrix's height."
);
"First matrix's width must be equal with second matrix's height."
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
,
"The mul op takes only one output"
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
}
}
};
};
...
...
paddle/operators/net_op.cc
浏览文件 @
d97a2b42
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
*/
*/
#include "paddle/operators/net_op.h"
#include "paddle/operators/net_op.h"
#include <set>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -23,36 +24,39 @@ namespace operators {
...
@@ -23,36 +24,39 @@ namespace operators {
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
unordered_
set
<
std
::
string
>
input_set
;
std
::
set
<
std
::
string
>
input_set
;
std
::
unordered_
set
<
std
::
string
>
output_set
;
std
::
set
<
std
::
string
>
output_set
;
std
::
unordered_
set
<
std
::
string
>
temp_output
;
std
::
set
<
std
::
string
>
temp_output
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
if
(
!
Contains
(
output_set
,
ipt
))
{
// Not other op's output
for
(
auto
&
var_name
:
ipt
.
second
)
{
input_set
.
insert
(
ipt
);
if
(
!
Contains
(
output_set
,
var_name
))
{
// Not other op's output
}
else
{
input_set
.
insert
(
var_name
);
temp_output
.
insert
(
ipt
);
}
else
{
temp_output
.
insert
(
var_name
);
}
}
}
}
}
for
(
auto
&
opt
:
op
->
outputs_
)
{
for
(
auto
&
opt
:
op
->
outputs_
)
{
output_set
.
insert
(
opt
);
for
(
auto
&
var_name
:
opt
.
second
)
{
output_set
.
insert
(
var_name
);
}
}
}
}
}
auto
&
inputs
=
inputs_
[
"all"
];
inputs
.
reserve
(
input_set
.
size
());
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs
));
auto
&
outputs
=
outputs_
[
"all"
];
outputs
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs
));
inputs_
.
reserve
(
input_set
.
size
());
//! TODO figure out how to generate temporary_index in Network.
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs_
));
std
::
sort
(
inputs_
.
begin
(),
inputs_
.
end
());
outputs_
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs_
));
std
::
sort
(
outputs_
.
begin
(),
outputs_
.
end
());
std
::
vector
<
int
>
tmp_index
;
std
::
vector
<
int
>
tmp_index
;
tmp_index
.
reserve
(
temp_output
.
size
());
tmp_index
.
reserve
(
temp_output
.
size
());
int
output_len
=
static_cast
<
int
>
(
outputs
_
.
size
());
int
output_len
=
static_cast
<
int
>
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
{
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
{
if
(
Contains
(
temp_output
,
outputs
_
[
i
]))
{
if
(
Contains
(
temp_output
,
outputs
[
i
]))
{
tmp_index
.
push_back
(
i
);
tmp_index
.
push_back
(
i
);
}
}
}
}
...
...
paddle/operators/net_op.h
浏览文件 @
d97a2b42
...
@@ -14,8 +14,7 @@ limitations under the License. */
...
@@ -14,8 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
...
...
paddle/operators/net_op_test.cc
浏览文件 @
d97a2b42
...
@@ -47,23 +47,24 @@ TEST(OpKernel, all) {
...
@@ -47,23 +47,24 @@ TEST(OpKernel, all) {
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
inputs_
=
{
{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}
};
op1
->
outputs_
=
{
"y"
};
op1
->
outputs_
=
{
{
"Out"
,
{
"y"
}}
};
net
->
AddOp
(
op1
);
net
->
AddOp
(
op1
);
auto
op2
=
std
::
make_shared
<
TestOp
>
();
auto
op2
=
std
::
make_shared
<
TestOp
>
();
op2
->
inputs_
=
{
"y"
,
"w2"
,
"b2"
};
op2
->
inputs_
=
{
{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}
};
op2
->
outputs_
=
{
"z"
};
op2
->
outputs_
=
{
{
"Out"
,
{
"z"
}}
};
net
->
AddOp
(
op2
);
net
->
AddOp
(
op2
);
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net
->
inputs_
);
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
);
net
->
inputs_
.
at
(
"__all__"
));
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
"__all__"
));
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
ASSERT_EQ
(
"y"
,
net
->
outputs_
.
at
(
"__all__"
)
[
tmp_idx
[
0
]]);
Scope
scope
;
Scope
scope
;
platform
::
CPUDeviceContext
dev_ctx
;
platform
::
CPUDeviceContext
dev_ctx
;
...
@@ -78,8 +79,8 @@ TEST(OpKernel, all) {
...
@@ -78,8 +79,8 @@ TEST(OpKernel, all) {
TEST
(
NetOp
,
insert_op
)
{
TEST
(
NetOp
,
insert_op
)
{
NetOp
net
;
NetOp
net
;
auto
op1
=
std
::
make_shared
<
EmptyOp
>
();
auto
op1
=
std
::
make_shared
<
EmptyOp
>
();
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
inputs_
=
{
{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}
};
op1
->
outputs_
=
{
"y"
};
op1
->
outputs_
=
{
{
"Out"
,
{
"y"
}}
};
net
.
AddOp
(
op1
);
net
.
AddOp
(
op1
);
net
.
InsertOp
(
0
,
op1
);
net
.
InsertOp
(
0
,
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
d97a2b42
...
@@ -89,12 +89,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
...
@@ -89,12 +89,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// create step net's temp inputs
// create step net's temp inputs
for
(
auto
&
input
:
net_op
->
inputs_
)
{
for
(
auto
&
input
:
net_op
->
inputs_
)
{
// the weight are located in parent scope
// the weight are located in parent scope
if
(
!
step_scope
.
FindVar
(
input
))
for
(
auto
&
var_name
:
input
.
second
)
{
step_scope
.
NewVar
(
input
)
->
GetMutable
<
Tensor
>
();
if
(
!
step_scope
.
FindVar
(
var_name
))
{
step_scope
.
NewVar
(
var_name
)
->
GetMutable
<
Tensor
>
();
}
}
}
}
// create stepnet's outputs
// create stepnet's outputs
for
(
const
auto
&
output
:
net_op
->
outputs_
)
{
for
(
const
auto
&
output
:
net_op
->
outputs_
)
{
step_scope
.
NewVar
(
output
);
for
(
auto
&
var_name
:
output
.
second
)
{
step_scope
.
NewVar
(
var_name
);
}
}
}
step_scopes
->
emplace_back
(
&
step_scope
);
step_scopes
->
emplace_back
(
&
step_scope
);
}
}
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
d97a2b42
...
@@ -22,373 +22,382 @@
...
@@ -22,373 +22,382 @@
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
TEST
(
rnn
,
bad
)
{
ASSERT_TRUE
(
false
);
}
namespace
operators
{
// namespace paddle {
using
framework
::
make_ddim
;
// namespace operators {
using
framework
::
DDim
;
//
// using framework::make_ddim;
class
RecurrentOpTest
:
public
::
testing
::
Test
{
// using framework::DDim;
protected:
//
virtual
void
SetUp
()
override
{
// class RecurrentOpTest : public ::testing::Test {
CreateGlobalVariables
();
// protected:
CreateStepNet
();
// virtual void SetUp() override {
CreateRNNOp
();
// CreateGlobalVariables();
}
// CreateStepNet();
// CreateRNNOp();
virtual
void
TearDown
()
override
{}
// }
//
void
CreateGlobalVariables
()
{
// virtual void TearDown() override {}
// create input, and init content
//
LOG
(
INFO
)
<<
"create global variable x"
;
// void CreateGlobalVariables() {
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"x"
,
"x0"
,
"x1"
,
"h"
})
{
// // create input, and init content
Variable
*
x
=
scope_
.
NewVar
(
inlink
);
// LOG(INFO) << "create global variable x";
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
// for (auto inlink : std::vector<std::string>{"x", "x0", "x1", "h"}) {
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
// Variable* x = scope_.NewVar(inlink);
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
// DDim dims = make_ddim(std::vector<int>{
}
// 10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
// create output alias just for test
// x->GetMutable<Tensor>()->mutable_data<float>(dims,
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
// platform::CPUPlace());
Variable
*
x
=
scope_
.
NewVar
(
inlink
);
// }
DDim
dims
=
// // create output alias just for test
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
});
// for (auto inlink : std::vector<std::string>{"h@alias"}) {
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
// Variable* x = scope_.NewVar(inlink);
}
// DDim dims =
// make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/});
LOG
(
INFO
)
<<
"create global variable w"
;
// x->GetMutable<Tensor>()->mutable_data<float>(dims,
Variable
*
w
=
scope_
.
NewVar
(
"rnn/w"
);
// platform::CPUPlace());
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
// }
make_ddim
(
std
::
vector
<
int
>
{
30
,
30
}),
platform
::
CPUPlace
());
//
// LOG(INFO) << "create global variable w";
for
(
auto
boot
:
std
::
vector
<
std
::
string
>
{
"h_boot"
})
{
// Variable* w = scope_.NewVar("rnn/w");
LOG
(
INFO
)
<<
"create global variable "
<<
boot
;
// w->GetMutable<Tensor>()->mutable_data<float>(
Variable
*
h_boot
=
scope_
.
NewVar
(
boot
);
// make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
//
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
}),
// for (auto boot : std::vector<std::string>{"h_boot"}) {
platform
::
CPUPlace
());
// LOG(INFO) << "create global variable " << boot;
}
// Variable* h_boot = scope_.NewVar(boot);
// h_boot->GetMutable<Tensor>()->mutable_data<float>(
LOG
(
INFO
)
<<
"create variable step_scopes"
;
// make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/}),
scope_
.
NewVar
(
"step_scopes"
);
// platform::CPUPlace());
// }
LOG
(
INFO
)
<<
"create variable h"
;
//
scope_
.
NewVar
(
"h"
);
// LOG(INFO) << "create variable step_scopes";
}
// scope_.NewVar("step_scopes");
//
void
CreateRNNOp
()
{
// LOG(INFO) << "create variable h";
framework
::
OpDesc
op_desc
;
// scope_.NewVar("h");
// }
op_desc
.
set_type
(
"recurrent_op"
);
//
// inlinks 0
// void CreateRNNOp() {
op_desc
.
add_inputs
(
"x"
);
// framework::OpDesc op_desc;
op_desc
.
add_inputs
(
"x0"
);
//
op_desc
.
add_inputs
(
"x1"
);
// op_desc.set_type("recurrent_op");
// boot_memories 3
// // inlinks 0
op_desc
.
add_inputs
(
"h_boot"
);
// op_desc.add_inputs("x");
// step net 5
// op_desc.add_inputs("x0");
op_desc
.
add_inputs
(
"step_net"
);
// op_desc.add_inputs("x1");
// outlinks 6
// // boot_memories 3
op_desc
.
add_outputs
(
"h"
);
// op_desc.add_inputs("h_boot");
// step scopes 7
// // step net 5
op_desc
.
add_outputs
(
"step_scopes"
);
// op_desc.add_inputs("step_net");
// // outlinks 6
auto
_input_format
=
std
::
vector
<
int
>
{
// op_desc.add_outputs("h");
0
,
// in_link
// // step scopes 7
3
,
// memories
// op_desc.add_outputs("step_scopes");
4
// step_net
//
};
// auto _input_format = std::vector<int>{
auto
input_format
=
op_desc
.
add_attrs
();
// 0, // in_link
input_format
->
set_name
(
"input_format"
);
// 3, // memories
input_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
// 4 // step_net
for
(
auto
i
:
_input_format
)
{
// };
input_format
->
add_ints
(
i
);
// auto input_format = op_desc.add_attrs();
}
// input_format->set_name("input_format");
// input_format->set_type(paddle::framework::AttrType::INTS);
auto
output_format
=
op_desc
.
add_attrs
();
// for (auto i : _input_format) {
output_format
->
set_name
(
"output_format"
);
// input_format->add_ints(i);
output_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
// }
for
(
auto
i
:
std
::
vector
<
int
>
{
0
,
1
,
2
})
{
//
output_format
->
add_ints
(
i
);
// auto output_format = op_desc.add_attrs();
}
// output_format->set_name("output_format");
// output_format->set_type(paddle::framework::AttrType::INTS);
auto
inlink_alias
=
op_desc
.
add_attrs
();
// for (auto i : std::vector<int>{0, 1, 2}) {
inlink_alias
->
set_name
(
"inlink_alias"
);
// output_format->add_ints(i);
inlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// }
//
auto
outlink_alias
=
op_desc
.
add_attrs
();
// auto inlink_alias = op_desc.add_attrs();
outlink_alias
->
set_name
(
"outlink_alias"
);
// inlink_alias->set_name("inlink_alias");
outlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// inlink_alias->set_type(paddle::framework::AttrType::STRINGS);
//
auto
pre_memories
=
op_desc
.
add_attrs
();
// auto outlink_alias = op_desc.add_attrs();
pre_memories
->
set_name
(
"pre_memories"
);
// outlink_alias->set_name("outlink_alias");
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// outlink_alias->set_type(paddle::framework::AttrType::STRINGS);
//
auto
memories
=
op_desc
.
add_attrs
();
// auto pre_memories = op_desc.add_attrs();
memories
->
set_name
(
"memories"
);
// pre_memories->set_name("pre_memories");
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// pre_memories->set_type(paddle::framework::AttrType::STRINGS);
//
// create inlink_alias
// auto memories = op_desc.add_attrs();
for
(
const
auto
&
item
:
// memories->set_name("memories");
std
::
vector
<
std
::
string
>
{
"x@alias"
,
"x0@alias"
,
"x1@alias"
})
{
// memories->set_type(paddle::framework::AttrType::STRINGS);
inlink_alias
->
add_strings
(
item
);
//
}
// // create inlink_alias
// pre memories
// for (const auto& item :
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/h@pre"
})
{
// std::vector<std::string>{"x@alias", "x0@alias", "x1@alias"}) {
pre_memories
->
add_strings
(
item
);
// inlink_alias->add_strings(item);
}
// }
// memories
// // pre memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/h"
})
{
// for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
memories
->
add_strings
(
item
);
// pre_memories->add_strings(item);
}
// }
// output alias
// // memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
// for (const auto& item : std::vector<std::string>{"rnn/h"}) {
outlink_alias
->
add_strings
(
item
);
// memories->add_strings(item);
}
// }
// // output alias
rnn_op_
=
OpRegistry
::
CreateOp
(
op_desc
);
// for (const auto& item : std::vector<std::string>{"h@alias"}) {
// outlink_alias->add_strings(item);
LOG
(
INFO
)
<<
"rnn_op finish init"
;
// }
}
//
// rnn_op_ = OpRegistry::CreateOp(op_desc);
void
CreateStepNet
()
{
//
LOG
(
INFO
)
<<
"create variable step_net"
;
// LOG(INFO) << "rnn_op finish init";
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
// }
auto
net
=
var
->
GetMutable
<
NetOp
>
();
//
net
->
AddOp
(
// void CreateStepNet() {
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h@pre"
,
"rnn/w"
},
{
"rnn/s"
},
{}));
// LOG(INFO) << "create variable step_net";
// Variable* var = scope_.NewVar("step_net");
net
->
AddOp
(
// auto net = var->GetMutable<NetOp>();
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x@alias"
,
"rnn/s"
},
{
"rnn/h"
},
{}));
// net->AddOp(
net
->
CompleteAddOp
();
// OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
}
//
// net->AddOp(
// father scope
// OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
Scope
scope_
;
// net->CompleteAddOp();
std
::
shared_ptr
<
OperatorBase
>
rnn_op_
;
// }
};
//
// // father scope
TEST_F
(
RecurrentOpTest
,
Run
)
{
// Scope scope_;
platform
::
CPUDeviceContext
ctx
;
// std::shared_ptr<OperatorBase> rnn_op_;
rnn_op_
->
InferShape
(
scope_
);
//};
rnn_op_
->
Run
(
scope_
,
ctx
);
//
}
// TEST_F(RecurrentOpTest, Run) {
// platform::CPUDeviceContext ctx;
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
// rnn_op_->InferShape(scope_);
protected:
// rnn_op_->Run(scope_, ctx);
virtual
void
SetUp
()
override
{
//}
CreateGlobalVariables
();
//
CreateStepScopes
();
// class RecurrentGradientAlgorithmTest : public ::testing::Test {
CreateStepNet
();
// protected:
CreateRNNGradientAlgorithm
();
// virtual void SetUp() override {
// CreateGlobalVariables();
// segment inputs
// CreateStepScopes();
SegmentInputs
();
// CreateStepNet();
// link forward memories
// CreateRNNGradientAlgorithm();
LinkeMemories
();
//
}
// // segment inputs
// SegmentInputs();
virtual
void
TearDown
()
override
{}
// // link forward memories
// LinkeMemories();
void
CreateGlobalVariables
()
{
// }
// inputs: x
//
LOG
(
INFO
)
<<
"create global variable x"
;
// virtual void TearDown() override {}
Variable
*
x
=
scope_
.
NewVar
(
"x"
);
//
DDim
dims
=
// void CreateGlobalVariables() {
make_ddim
({
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
// // inputs: x
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
// LOG(INFO) << "create global variable x";
// inputs: h_boot
// Variable* x = scope_.NewVar("x");
LOG
(
INFO
)
<<
"create global variable h_boot"
;
// DDim dims =
Variable
*
h_boot
=
scope_
.
NewVar
(
"h_boot"
);
// make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
// x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
make_ddim
({
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
// // inputs: h_boot
// inputs: w
// LOG(INFO) << "create global variable h_boot";
LOG
(
INFO
)
<<
"create global variable w"
;
// Variable* h_boot = scope_.NewVar("h_boot");
Variable
*
w
=
scope_
.
NewVar
(
"rnn/w"
);
// h_boot->GetMutable<Tensor>()->mutable_data<float>(
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
30
,
30
}),
// make_ddim({20 /*batch size*/, 30 /*input dim*/}),
platform
::
CPUPlace
());
// platform::CPUPlace());
// inputs: h_grad
// // inputs: w
LOG
(
INFO
)
<<
"create variable h_grad"
;
// LOG(INFO) << "create global variable w";
Variable
*
dh
=
scope_
.
NewVar
(
"h_grad"
);
// Variable* w = scope_.NewVar("rnn/w");
dh
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
({
10
,
20
,
30
}),
// w->GetMutable<Tensor>()->mutable_data<float>(make_ddim({30, 30}),
platform
::
CPUPlace
());
// platform::CPUPlace());
// inputs: step_scopes
// // inputs: h_grad
LOG
(
INFO
)
<<
"create variable step_scopes"
;
// LOG(INFO) << "create variable h_grad";
scope_
.
NewVar
(
"step_scopes"
);
// Variable* dh = scope_.NewVar("h_grad");
// inputs: step_net
// dh->GetMutable<Tensor>()->mutable_data<float>(make_ddim({10, 20, 30}),
LOG
(
INFO
)
<<
"create variable step_net"
;
// platform::CPUPlace());
scope_
.
NewVar
(
"step_net"
);
// // inputs: step_scopes
// outputs: w_grad
// LOG(INFO) << "create variable step_scopes";
LOG
(
INFO
)
<<
"create global variable w_grad"
;
// scope_.NewVar("step_scopes");
scope_
.
NewVar
(
"rnn/w_grad"
);
// // inputs: step_net
// outputs: x_grad
// LOG(INFO) << "create variable step_net";
LOG
(
INFO
)
<<
"create global variable x_grad"
;
// scope_.NewVar("step_net");
scope_
.
NewVar
(
"x_grad"
);
// // outputs: w_grad
// outputs: h_boot_grad
// LOG(INFO) << "create global variable w_grad";
LOG
(
INFO
)
<<
"create global variable h_boot_grad"
;
// scope_.NewVar("rnn/w_grad");
scope_
.
NewVar
(
"h_boot_grad"
);
// // outputs: x_grad
}
// LOG(INFO) << "create global variable x_grad";
// scope_.NewVar("x_grad");
void
CreateStepScopes
()
{
// // outputs: h_boot_grad
auto
step_scopes
=
// LOG(INFO) << "create global variable h_boot_grad";
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
// scope_.NewVar("h_boot_grad");
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
// }
auto
&
scope
=
scope_
.
NewScope
();
//
auto
pre_t
=
scope
.
NewVar
(
"rnn/pre_h"
)
->
GetMutable
<
Tensor
>
();
// void CreateStepScopes() {
pre_t
->
mutable_data
<
float
>
({
20
,
30
},
platform
::
CPUPlace
());
// auto step_scopes =
auto
tensor
=
scope
.
NewVar
(
"rnn/h"
)
->
GetMutable
<
Tensor
>
();
// scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
tensor
->
mutable_data
<
float
>
({
20
,
30
},
platform
::
CPUPlace
());
// for (int i = 0; i < 10; ++i) {
// auto& scope = scope_.NewScope();
// for unit test of ConcatOutputs
// auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable<Tensor>();
auto
xg
=
scope
.
NewVar
(
"rnn/x_grad"
)
->
GetMutable
<
Tensor
>
();
// pre_t->mutable_data<float>({20, 30}, platform::CPUPlace());
xg
->
mutable_data
<
float
>
({
20
,
30
},
platform
::
CPUPlace
());
// auto tensor = scope.NewVar("rnn/h")->GetMutable<Tensor>();
// tensor->mutable_data<float>({20, 30}, platform::CPUPlace());
step_scopes
->
emplace_back
(
&
scope
);
//
}
// // for unit test of ConcatOutputs
// auto xg = scope.NewVar("rnn/x_grad")->GetMutable<Tensor>();
// last time step
// xg->mutable_data<float>({20, 30}, platform::CPUPlace());
auto
g
=
(
*
step_scopes
)[
9
]
->
NewVar
(
"rnn/h_pre_grad"
)
->
GetMutable
<
Tensor
>
();
//
g
->
mutable_data
<
float
>
({
20
,
30
},
platform
::
CPUPlace
());
// step_scopes->emplace_back(&scope);
}
// }
//
void
CreateRNNGradientAlgorithm
()
{
// // last time step
std
::
unique_ptr
<
rnn
::
Argument
>
arg
(
new
rnn
::
Argument
());
// auto g =
arg
->
step_net
=
"step_net"
;
// (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable<Tensor>();
arg
->
step_scopes
=
"step_scopes"
;
// g->mutable_data<float>({20, 30}, platform::CPUPlace());
rnn
::
Link
inlink
;
// }
inlink
.
external
=
"h_grad"
;
//
inlink
.
internal
=
"rnn/h_grad"
;
// void CreateRNNGradientAlgorithm() {
arg
->
inlinks
=
std
::
vector
<
rnn
::
Link
>
{
inlink
};
// std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
// arg->step_net = "step_net";
rnn
::
Link
outlink
;
// arg->step_scopes = "step_scopes";
outlink
.
external
=
"x_grad"
;
// rnn::Link inlink;
outlink
.
internal
=
"rnn/x_grad"
;
// inlink.external = "h_grad";
arg
->
outlinks
=
std
::
vector
<
rnn
::
Link
>
{
outlink
};
// inlink.internal = "rnn/h_grad";
// arg->inlinks = std::vector<rnn::Link>{inlink};
rnn
::
MemoryAttr
mem_attr
;
//
mem_attr
.
pre_var
=
"rnn/h_pre_grad"
;
// rnn::Link outlink;
mem_attr
.
var
=
"rnn/h_grad"
;
// outlink.external = "x_grad";
mem_attr
.
boot_var
=
"h_boot_grad"
;
// outlink.internal = "rnn/x_grad";
arg
->
memories
=
std
::
vector
<
rnn
::
MemoryAttr
>
{
mem_attr
};
// arg->outlinks = std::vector<rnn::Link>{outlink};
//
rnn_grad_algo_
.
Init
(
std
::
move
(
arg
));
// rnn::MemoryAttr mem_attr;
}
// mem_attr.pre_var = "rnn/h_pre_grad";
// mem_attr.var = "rnn/h_grad";
void
CreateStepNet
()
{
// mem_attr.boot_var = "h_boot_grad";
LOG
(
INFO
)
<<
"create variable step_net"
;
// arg->memories = std::vector<rnn::MemoryAttr>{mem_attr};
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
//
auto
net
=
var
->
GetMutable
<
NetOp
>
();
// rnn_grad_algo_.Init(std::move(arg));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
// }
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
//
// void CreateStepNet() {
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/h_grad"
},
// LOG(INFO) << "create variable step_net";
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
// Variable* var = scope_.NewVar("step_net");
net
->
CompleteAddOp
();
// auto net = var->GetMutable<NetOp>();
}
// net->AddOp(OpRegistry::CreateOp("mul", {"rnn/h_pre", "rnn/w",
// "rnn/s_grad"},
void
SegmentInputs
()
{
// {"rnn/h_pre_grad", "rnn/w_grad"}, {}));
LOG
(
INFO
)
<<
"segment inputs"
;
//
std
::
vector
<
std
::
string
>
inlinks
=
{
"x"
};
// net->AddOp(OpRegistry::CreateOp("add_two", {"rnn/h_grad"},
std
::
vector
<
std
::
string
>
inlinks_alias
=
{
"rnn/x"
};
// {"rnn/x_grad", "rnn/s_grad"}, {}));
// net->CompleteAddOp();
rnn
::
Link
inlink
;
// }
inlink
.
external
=
"x"
;
//
inlink
.
internal
=
"rnn/x"
;
// void SegmentInputs() {
auto
step_scopes
=
// LOG(INFO) << "segment inputs";
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
// std::vector<std::string> inlinks = {"x"};
rnn
::
SegmentInputs
(
*
step_scopes
,
std
::
vector
<
rnn
::
Link
>
{
inlink
},
10
,
// std::vector<std::string> inlinks_alias = {"rnn/x"};
true
/*infer_shape_mode*/
);
//
}
// rnn::Link inlink;
// inlink.external = "x";
void
LinkeMemories
()
{
// inlink.internal = "rnn/x";
LOG
(
INFO
)
<<
"link memories"
;
// auto step_scopes =
rnn
::
MemoryAttr
mem_attr
;
// scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
mem_attr
.
pre_var
=
"rnn/h_pre"
;
// rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10,
mem_attr
.
var
=
"rnn/h"
;
// true /*infer_shape_mode*/);
mem_attr
.
boot_var
=
"boot_h"
;
// }
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
//
memories
.
push_back
(
mem_attr
);
// void LinkeMemories() {
auto
step_scopes
=
// LOG(INFO) << "link memories";
scope_
.
FindVar
(
"step_scopes"
)
->
GetMutable
<
std
::
vector
<
Scope
*>>
();
// rnn::MemoryAttr mem_attr;
for
(
int
i
=
1
;
i
<
10
;
++
i
)
{
// mem_attr.pre_var = "rnn/h_pre";
rnn
::
LinkMemories
(
*
step_scopes
,
memories
,
i
,
-
1
,
// mem_attr.var = "rnn/h";
true
/*infer_shape_mode*/
);
// mem_attr.boot_var = "boot_h";
}
// std::vector<rnn::MemoryAttr> memories;
}
// memories.push_back(mem_attr);
// auto step_scopes =
Scope
scope_
;
// scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
RecurrentGradientAlgorithm
rnn_grad_algo_
;
// for (int i = 1; i < 10; ++i) {
};
// rnn::LinkMemories(*step_scopes, memories, i, -1,
// true /*infer_shape_mode*/);
// TEST_F(RecurrentGradientAlgorithmTest, Run) {
// }
// platform::CPUDeviceContext ctx;
// }
// rnn_grad_algo_.Run(scope_, ctx);
//
// }
// Scope scope_;
// RecurrentGradientAlgorithm rnn_grad_algo_;
}
// namespace operators
//};
}
// namespace paddle
//
//// TEST_F(RecurrentGradientAlgorithmTest, Run) {
TEST
(
RecurrentOp
,
LinkMemories
)
{
//// platform::CPUDeviceContext ctx;
using
namespace
paddle
::
framework
;
//// rnn_grad_algo_.Run(scope_, ctx);
using
namespace
paddle
::
platform
;
//// }
using
namespace
paddle
::
operators
;
//
//} // namespace operators
// create and init step scopes
//} // namespace paddle
size_t
len
=
10
;
//
std
::
vector
<
Scope
*>
step_scopes
;
// TEST(RecurrentOp, LinkMemories) {
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
// using namespace paddle::framework;
auto
scope
=
new
Scope
();
// using namespace paddle::platform;
scope
->
NewVar
(
"pre_h"
);
// using namespace paddle::operators;
auto
tensor
=
scope
->
NewVar
(
"h"
)
->
GetMutable
<
Tensor
>
();
//
float
*
data
=
tensor
->
mutable_data
<
float
>
({
15
,
20
},
CPUPlace
());
// // create and init step scopes
for
(
size_t
j
=
0
;
j
<
15
*
20
;
++
j
)
{
// size_t len = 10;
data
[
j
]
=
rand
()
*
(
1.
/
(
double
)
RAND_MAX
);
// std::vector<Scope*> step_scopes;
}
// for (size_t i = 0; i < len; ++i) {
step_scopes
.
push_back
(
scope
);
// auto scope = new Scope();
}
// scope->NewVar("pre_h");
// auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
// create MemoryAttr
// float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
rnn
::
MemoryAttr
mem_attr
;
// for (size_t j = 0; j < 15 * 20; ++j) {
mem_attr
.
pre_var
=
"pre_h"
;
// data[j] = rand() * (1. / (double)RAND_MAX);
mem_attr
.
var
=
"h"
;
// }
mem_attr
.
boot_var
=
"boot_h"
;
// step_scopes.push_back(scope);
std
::
vector
<
rnn
::
MemoryAttr
>
memories
;
// }
memories
.
push_back
(
mem_attr
);
//
// // create MemoryAttr
for
(
size_t
i
=
1
;
i
<
len
;
++
i
)
{
// rnn::MemoryAttr mem_attr;
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
-
1
,
false
/*infer_shape_mode*/
);
// mem_attr.pre_var = "pre_h";
}
// mem_attr.var = "h";
// check
// mem_attr.boot_var = "boot_h";
for
(
size_t
i
=
0
;
i
<
len
-
1
;
++
i
)
{
// std::vector<rnn::MemoryAttr> memories;
const
float
*
a
=
// memories.push_back(mem_attr);
step_scopes
[
i
]
->
FindVar
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
//
const
float
*
b
=
step_scopes
[
i
+
1
]
// for (size_t i = 1; i < len; ++i) {
->
FindVar
(
"pre_h"
)
// rnn::LinkMemories(step_scopes, memories, i, -1, false
->
GetMutable
<
Tensor
>
()
// /*infer_shape_mode*/);
->
data
<
float
>
();
// }
for
(
size_t
j
=
0
;
j
<
15
*
20
;
++
j
)
{
// // check
ASSERT_FLOAT_EQ
(
a
[
j
],
b
[
j
]);
// for (size_t i = 0; i < len - 1; ++i) {
}
// const float* a =
}
// step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
// const float* b = step_scopes[i + 1]
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
// ->FindVar("pre_h")
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
1
,
false
/*infer_shape_mode*/
);
// ->GetMutable<Tensor>()
}
// ->data<float>();
// check
// for (size_t j = 0; j < 15 * 20; ++j) {
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
// ASSERT_FLOAT_EQ(a[j], b[j]);
const
float
*
a
=
// }
step_scopes
[
i
]
->
FindVar
(
"pre_h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
// }
const
float
*
b
=
//
step_scopes
[
i
+
1
]
->
FindVar
(
"h"
)
->
GetMutable
<
Tensor
>
()
->
data
<
float
>
();
// for (int i = len - 2; i >= 0; --i) {
for
(
size_t
j
=
0
;
j
<
15
*
20
;
++
j
)
{
// rnn::LinkMemories(step_scopes, memories, i, 1, false
ASSERT_FLOAT_EQ
(
a
[
j
],
b
[
j
]);
// /*infer_shape_mode*/);
}
// }
}
// // check
// for (int i = len - 2; i >= 0; --i) {
for
(
auto
s
:
step_scopes
)
{
// const float* a =
delete
s
;
// step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
}
// const float* b =
}
// step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
// for (size_t j = 0; j < 15 * 20; ++j) {
USE_OP
(
add_two
);
// ASSERT_FLOAT_EQ(a[j], b[j]);
USE_OP
(
mul
);
// }
USE_OP_WITHOUT_KERNEL
(
recurrent_op
);
// }
//
// for (auto s : step_scopes) {
// delete s;
// }
//}
//
// USE_OP(add_two);
// USE_OP(mul);
// USE_OP_WITHOUT_KERNEL(recurrent_op);
paddle/operators/rowwise_add_op.cc
浏览文件 @
d97a2b42
...
@@ -19,16 +19,14 @@ namespace operators {
...
@@ -19,16 +19,14 @@ namespace operators {
class
RowWiseAddOp
:
public
OperatorWithKernel
{
class
RowWiseAddOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
"Two inputs is needed by rowwise add"
);
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
"b"
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
,
"Input 0 must be matrix"
);
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
,
"Input 0 must be matrix"
);
PADDLE_ENFORCE
(
dim1
.
size
()
==
1
,
"The second input must be vector"
);
PADDLE_ENFORCE
(
dim1
.
size
()
==
1
,
"The second input must be vector"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"The width of two input must be same"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"The width of two input must be same"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"The output size must be 1"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
(
"Out"
)
==
1
,
"The output size must be 1"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
}
}
};
};
...
...
paddle/operators/rowwise_add_op.h
浏览文件 @
d97a2b42
...
@@ -25,8 +25,8 @@ class RowWiseAddKernel : public OpKernel {
...
@@ -25,8 +25,8 @@ class RowWiseAddKernel : public OpKernel {
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
input
=
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
0
));
auto
input
=
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"X"
));
auto
bias
=
EigenVector
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
1
));
auto
bias
=
EigenVector
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"b"
));
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
const
int
bias_size
=
bias
.
dimension
(
0
);
const
int
bias_size
=
bias
.
dimension
(
0
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
d97a2b42
...
@@ -20,14 +20,10 @@ namespace operators {
...
@@ -20,14 +20,10 @@ namespace operators {
class
SGDOp
:
public
OperatorWithKernel
{
class
SGDOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Output size of SGDOp must be one"
);
ctx
.
Input
<
Tensor
>
(
"param"
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
"grad"
)
->
dims
(),
PADDLE_ENFORCE
(
ctx
.
InputVar
(
0
)
!=
nullptr
,
"inputs[0] mast be set"
);
"Two input of SGD Op's dimension must be same."
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
1
)
!=
nullptr
,
"inputs[1] mast be set"
);
ctx
.
Output
<
Tensor
>
(
"param_out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
"outputs[0] mast be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of SGD Op's dimension must be same."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
d97a2b42
...
@@ -19,9 +19,7 @@ namespace operators {
...
@@ -19,9 +19,7 @@ namespace operators {
class
SigmoidOp
:
public
OperatorWithKernel
{
class
SigmoidOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/softmax_op.cc
浏览文件 @
d97a2b42
...
@@ -20,12 +20,8 @@ namespace operators {
...
@@ -20,12 +20,8 @@ namespace operators {
class
SoftmaxOp
:
public
OperatorWithKernel
{
class
SoftmaxOp
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1UL
,
"Only one input is need for softmax"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
"The input of softmax op must be matrix"
);
"The input of softmax op must be matrix"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
"Only one output is need for softmax"
);
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
}
}
};
};
...
@@ -43,10 +39,6 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
...
@@ -43,10 +39,6 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
class
SoftmaxOpGrad
:
public
OperatorWithKernel
{
protected:
protected:
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
3UL
,
"Input of SoftmaxOpGrad should be 3, X, Y, YG"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1UL
,
"Output of SoftmaxOpGrad should be 1"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"Y"
)
!=
nullptr
,
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"Y"
)
!=
nullptr
,
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
))
!=
nullptr
,
PADDLE_ENFORCE
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
))
!=
nullptr
,
"Input(Y@GRAD) should not be null"
);
"Input(Y@GRAD) should not be null"
);
...
...
paddle/platform/enforce.h
浏览文件 @
d97a2b42
...
@@ -195,12 +195,28 @@ struct CompatibleType {
...
@@ -195,12 +195,28 @@ struct CompatibleType {
typedef
typename
std
::
conditional
<
t1_to_t2
,
T2
,
T1
>::
type
type
;
typedef
typename
std
::
conditional
<
t1_to_t2
,
T2
,
T1
>::
type
type
;
};
};
template
<
typename
T
>
inline
std
::
string
enforce_to_string
(
const
T
&
val
)
{
std
::
ostringstream
sout
;
sout
<<
val
;
return
sout
.
str
();
}
template
<
>
inline
std
::
string
enforce_to_string
(
const
std
::
string
&
val
)
{
return
val
;
}
template
<
>
inline
std
::
string
enforce_to_string
(
const
char
*
const
&
val
)
{
return
std
::
string
(
val
);
}
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, std::to_string(__VAL0), \
#__VAL0, #__VAL1, \
std::to_string(__VAL1), \
paddle::platform::enforce_to_string(__VAL0), \
paddle::platform::enforce_to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
paddle::string::Sprintf("" __VA_ARGS__));
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录