Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
02cde244
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02cde244
编写于
7月 24, 2017
作者:
F
fengjiayi
提交者:
GitHub
7月 24, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2949 from dzhwinter/backward
Backward
上级
238f7c82
81df39fe
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
393 addition
and
58 deletion
+393
-58
doc/design/simple_op_design.md
doc/design/simple_op_design.md
+1
-0
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+6
-3
paddle/framework/grad_op_creator.cc
paddle/framework/grad_op_creator.cc
+115
-0
paddle/framework/grad_op_creator.h
paddle/framework/grad_op_creator.h
+48
-0
paddle/framework/grad_op_creator_test.cc
paddle/framework/grad_op_creator_test.cc
+26
-0
paddle/framework/net.cc
paddle/framework/net.cc
+11
-1
paddle/framework/net.h
paddle/framework/net.h
+2
-0
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+39
-8
paddle/framework/net_test.cc
paddle/framework/net_test.cc
+0
-24
paddle/framework/op_proto.proto
paddle/framework/op_proto.proto
+6
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+77
-21
paddle/framework/operator.h
paddle/framework/operator.h
+5
-0
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+13
-0
paddle/operators/add_op_test.cc
paddle/operators/add_op_test.cc
+6
-1
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+13
-0
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+13
-0
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+12
-0
未找到文件。
doc/design/simple_op_design.md
浏览文件 @
02cde244
...
...
@@ -49,6 +49,7 @@ message AttrProto {
message
VarProto
{
required
string
name
=
1
;
required
string
comment
=
2
;
required
bool
is_tensor
=
3
;
};
message
OpProto
{
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
02cde244
...
...
@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_library
(
grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
...
...
@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
add_op mul_op sigmoid_op softmax_op fc_op
)
paddle/framework/grad_op_creator.cc
0 → 100644
浏览文件 @
02cde244
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
OperatorBase
*
GradOpCreator
::
Create
()
{
BuildOpInOutArgList
();
OperatorBase
*
grad_op
=
OpRegistry
::
grad_creators
().
at
(
op_
->
type_
)();
CompleteGradOp
(
grad_op
);
return
grad_op
;
}
OpInOutArg
*
GradOpCreator
::
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
)
{
int
idx
=
var_map
.
at
(
var
.
name
());
int
begin_idx
=
format
.
empty
()
?
idx
:
format
.
at
(
idx
);
int
end_idx
=
format
.
empty
()
?
idx
+
1
:
format
.
at
(
idx
+
1
);
return
new
OpInOutArg
(
var
.
name
(),
type
,
!
var
.
ignore_gradient
(),
begin_idx
,
end_idx
);
}
void
GradOpCreator
::
BuildOpInOutArgList
()
{
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
->
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
->
type_
));
const
std
::
vector
<
int
>&
in_format
=
op_
->
attrs_
.
count
(
"input_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
)
:
std
::
vector
<
int
>
();
const
std
::
vector
<
int
>&
out_format
=
op_
->
attrs_
.
count
(
"output_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
)
:
std
::
vector
<
int
>
();
for
(
const
auto
&
var
:
op_proto
.
inputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
in_format
,
IN
)));
}
for
(
const
auto
&
var
:
op_proto
.
outputs
())
{
arg_list_
.
emplace_back
(
std
::
shared_ptr
<
OpInOutArg
>
(
BuildArg
(
var
,
var_map
,
out_format
,
OUT
)));
}
}
void
GradOpCreator
::
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
{
std
::
string
var_name
=
arg
->
proto_name_
;
if
(
is_grad
)
{
var_name
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
(
*
varmap
)[
var_name
]
=
idx
++
;
size_t
pre_sz
=
in_out
.
size
();
auto
base_it
=
arg
->
type_
==
IN
?
op_
->
inputs_
.
begin
()
:
op_
->
outputs_
.
begin
();
std
::
copy
(
base_it
+
arg
->
begin_idx_
,
base_it
+
arg
->
end_idx_
,
std
::
back_inserter
(
in_out
));
if
(
is_grad
)
{
for
(
size_t
i
=
pre_sz
;
i
<
in_out
.
size
();
++
i
)
{
in_out
[
i
]
+=
OperatorBase
::
GRAD_VAR_SUFFIX
();
}
}
format
.
push_back
(
in_out
.
size
());
}
void
GradOpCreator
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
type_
=
op_
->
type_
+
"@GRAD"
;
// not necessary
grad_op
->
attrs_
=
op_
->
attrs_
;
grad_op
->
attrs_
.
erase
(
"input_format"
);
grad_op
->
attrs_
.
erase
(
"output_format"
);
VarIndexMap
*
grad_varmap
=
new
VarIndexMap
();
int
in_idx
=
0
;
int
out_idx
=
0
;
std
::
vector
<
int
>
in_format
({
0
});
std
::
vector
<
int
>
out_format
({
0
});
for
(
const
auto
&
arg
:
arg_list_
)
{
// op_'s inputs_ and outputs_
if
(
arg
->
needed_in_grad_
)
{
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
false
);
}
if
(
arg
->
type_
==
IN
)
{
// gradients of op_'s inputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
outputs_
,
out_format
,
grad_varmap
,
out_idx
,
true
);
}
else
{
// gradients of op_'s outputs_
AddArgIntoGradOp
(
arg
.
get
(),
grad_op
->
inputs_
,
in_format
,
grad_varmap
,
in_idx
,
true
);
}
}
grad_op
->
attrs_
[
"input_format"
]
=
in_format
;
grad_op
->
attrs_
[
"output_format"
]
=
out_format
;
grad_op
->
in_out_idxs_
.
reset
(
grad_varmap
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_creator.h
0 → 100644
浏览文件 @
02cde244
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
framework
{
class
OpRegistry
;
enum
InOutType
{
IN
,
OUT
};
struct
OpInOutArg
{
OpInOutArg
(
const
std
::
string
&
proto_name
,
const
InOutType
&
type
,
bool
needed_in_grad
,
size_t
begin_idx
,
size_t
end_idx
)
:
proto_name_
(
proto_name
),
type_
(
type
),
needed_in_grad_
(
needed_in_grad
),
begin_idx_
(
begin_idx
),
end_idx_
(
end_idx
)
{}
std
::
string
proto_name_
;
InOutType
type_
;
bool
needed_in_grad_
;
size_t
begin_idx_
;
size_t
end_idx_
;
};
class
GradOpCreator
{
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
GradOpCreator
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
OperatorBase
*
Create
();
private:
OpInOutArg
*
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
);
void
BuildOpInOutArgList
();
void
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
;
void
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
;
const
OperatorBase
*
op_
;
std
::
vector
<
std
::
shared_ptr
<
OpInOutArg
>>
arg_list_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_creator_test.cc
0 → 100644
浏览文件 @
02cde244
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP
(
add_two
);
namespace
paddle
{
namespace
framework
{
TEST
(
GradOpCreator
,
AddTwo
)
{
std
::
shared_ptr
<
OperatorBase
>
add_op
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
outputs_
.
size
()),
2
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Y"
),
"y"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out"
),
"out"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out@GRAD"
),
"out@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"X@GRAD"
),
"x@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"Y@GRAD"
),
"y@GRAD"
);
}
}
// namespace framework
}
// namespace paddle
\ No newline at end of file
paddle/framework/net.cc
浏览文件 @
02cde244
...
...
@@ -15,14 +15,24 @@
*/
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
auto
grad_ops
=
std
::
make_shared
<
PlainNet
>
();
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
temp_output
;
...
...
paddle/framework/net.h
浏览文件 @
02cde244
...
...
@@ -100,5 +100,7 @@ class PlainNet : public Net {
}
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
02cde244
...
...
@@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
namespace
pd
=
paddle
::
framework
;
USE_OP
(
add_two
);
USE_OP
(
mul
);
USE_OP
(
sigmoid
);
USE_OP
(
softmax
);
namespace
paddle
{
namespace
framework
{
static
int
infer_shape_cnt
=
0
;
static
int
run_cnt
=
0
;
class
TestOp
:
public
pd
::
OperatorBase
{
class
TestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
pd
::
Scope
>&
scope
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
std
::
shared_ptr
<
pd
::
Scope
>&
scope
,
void
Run
(
const
std
::
shared_ptr
<
framework
::
Scope
>&
scope
,
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
++
run_cnt
;
}
...
...
@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
TEST
(
OpKernel
,
all
)
{
auto
net
=
std
::
make_shared
<
paddle
::
framework
::
PlainNet
>
();
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
...
...
@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
auto
scope
=
std
::
make_shared
<
pd
::
Scope
>
();
p
addle
::
p
latform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
std
::
runtime_error
);
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
""
},
{}));
auto
grad_ops
=
AddBackwardOp
(
net
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
}
// namespace framework
}
// namespace paddle
paddle/framework/net_test.cc
已删除
100644 → 0
浏览文件 @
238f7c82
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
class
FakeFC
:
public
Operator
{}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_proto.proto
浏览文件 @
02cde244
...
...
@@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1]
// }
optional
bool
temporary
=
4
[
default
=
false
];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional
bool
ignore_gradient
=
6
;
}
// Op protocol message for 3rd-party language binding.
...
...
@@ -105,4 +110,5 @@ message OpProto {
// The type of that Op.
required
string
type
=
5
;
}
paddle/framework/op_registry.h
浏览文件 @
02cde244
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
...
...
@@ -6,9 +20,9 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
multiple
=
false
)
{
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
input
=
proto_
->
mutable_inputs
()
->
Add
();
*
input
->
mutable_name
()
=
name
;
*
input
->
mutable_comment
()
=
comment
;
input
->
set_ignore_gradient
(
ignore_gradient
);
input
->
set_multiple
(
multiple
);
if
(
multiple
)
{
SetHasMultipleInput
();
}
}
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
AddInput
(
name
,
comment
,
true
);
void
AddInputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
ignore_gradient
=
false
)
{
AddInput
(
name
,
comment
,
true
,
ignore_gradient
);
}
void
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
,
bool
multiple
=
false
)
{
bool
temporary
=
false
,
bool
multiple
=
false
,
bool
ignore_gradient
=
false
)
{
auto
output
=
proto_
->
mutable_outputs
()
->
Add
();
*
output
->
mutable_name
()
=
name
;
*
output
->
mutable_comment
()
=
comment
;
output
->
set_ignore_gradient
(
ignore_gradient
);
output
->
set_multiple
(
multiple
);
if
(
multiple
)
{
SetHasMultipleOutput
();
...
...
@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
}
void
AddOutputs
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
bool
temporary
=
false
)
{
AddOutput
(
name
,
comment
,
temporary
,
true
);
bool
temporary
=
false
,
bool
ignore_gradient
=
false
)
{
AddOutput
(
name
,
comment
,
temporary
,
true
,
ignore_gradient
);
}
template
<
typename
T
>
...
...
@@ -205,8 +223,8 @@ class OpRegistry {
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
...
...
@@ -227,18 +245,24 @@ class OpRegistry {
}
}
template
<
typename
OpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
)
{
grad_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameList
&
inputs
,
const
VarNameList
&
outputs
,
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
creators
().
end
(),
"Operator %s cannot be found"
,
type
);
"Operator %s cannot be found
.
"
,
type
);
auto
op
=
op_create_it
->
second
();
op
->
type_
=
type
;
op
->
inputs_
=
inputs
;
op
->
outputs_
=
outputs
;
op
->
attrs_
=
attrs
;
op_checkers
().
at
(
type
).
Check
(
op
->
attrs_
);
...
...
@@ -274,18 +298,41 @@ class OpRegistry {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
GradOpCreator
creator
(
op
.
get
());
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
creator
.
Create
());
grad_op
->
Init
();
return
grad_op
;
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
};
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
grad_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
grad_creators_
;
return
grad_creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
...
...
@@ -296,16 +343,6 @@ class OpRegistry {
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
};
};
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
...
@@ -316,6 +353,14 @@ class OpRegisterHelper {
}
};
template
<
typename
OpType
>
class
GradOpRegisterHelper
{
public:
GradOpRegisterHelper
(
const
char
*
op_type
)
{
OpRegistry
::
RegisterGradOp
<
OpType
>
(
op_type
);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
...
...
@@ -335,6 +380,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register OperatorKernel.
*/
...
...
paddle/framework/operator.h
浏览文件 @
02cde244
...
...
@@ -62,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static
std
::
string
GRAD_VAR_SUFFIX
()
{
return
"@GRAD"
;
}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
...
...
paddle/operators/add_op.cc
浏览文件 @
02cde244
...
...
@@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC"
);
}
};
class
AddOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"AddOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_GRADIENT_OP
(
add_two
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_OP_CPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/add_op_test.cc
浏览文件 @
02cde244
...
...
@@ -16,8 +16,13 @@ limitations under the License. */
#define private public
#include <paddle/framework/op_registry.h>
USE_OP
(
add_two
);
// USE_OP(add_two_grad);
TEST
(
AddOp
,
GetOpProto
)
{
auto
&
protos
=
paddle
::
framework
::
OpRegistry
::
protos
();
auto
it
=
protos
.
find
(
"add_two"
);
ASSERT_NE
(
it
,
protos
.
end
());
}
\ No newline at end of file
auto
&
grad_creators
=
paddle
::
framework
::
OpRegistry
::
grad_creators
();
auto
it1
=
grad_creators
.
find
(
"add_two"
);
ASSERT_NE
(
it1
,
grad_creators
.
end
());
}
paddle/operators/mul_op.cc
浏览文件 @
02cde244
...
...
@@ -52,9 +52,22 @@ The equation is: Out = X * Y
}
};
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operators
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
paddle
::
operators
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cc
浏览文件 @
02cde244
...
...
@@ -39,12 +39,25 @@ public:
}
};
class
SigmoidOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SigmoidGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOp
,
paddle
::
operators
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
paddle
::
operators
::
SigmoidKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_op.cc
浏览文件 @
02cde244
...
...
@@ -42,11 +42,23 @@ public:
}
};
class
SoftmaxOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"SoftmaxOpGrad"
;
return
""
;
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_GRADIENT_OP
(
softmax
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录