Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6a5f6046
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a5f6046
编写于
12月 28, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support stop_gradients var in imperative backward
test=develop
上级
9e3155e0
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
40 addition
and
17 deletion
+40
-17
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+9
-0
paddle/fluid/framework/operator_test.cc
paddle/fluid/framework/operator_test.cc
+9
-0
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+19
-14
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
未找到文件。
paddle/fluid/framework/operator.h
浏览文件 @
6a5f6046
...
...
@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) {
return
result
;
}
inline
std
::
string
OriginVarName
(
const
std
::
string
&
grad_var_name
)
{
std
::
size_t
pos
=
grad_var_name
.
find_last_of
(
kGradVarSuffix
);
if
(
pos
==
std
::
string
::
npos
)
{
return
grad_var_name
;
}
else
{
return
grad_var_name
.
substr
(
0
,
pos
);
}
}
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
);
const
Tensor
*
GetLoDTensorOrSelectedRowsValueFromVar
(
const
Variable
&
var
);
Tensor
*
GetMutableLoDTensorOrSelectedRowsValueFromVar
(
Variable
*
var
);
...
...
paddle/fluid/framework/operator_test.cc
浏览文件 @
6a5f6046
...
...
@@ -288,3 +288,12 @@ TEST(OpKernel, multi_inputs) {
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_place
);
}
TEST
(
Functions
,
all
)
{
std
::
string
var_name
(
"X"
);
std
::
string
grad_var_name
=
paddle
::
framework
::
GradVarName
(
var_name
);
ASSERT_EQ
(
grad_var_name
.
c_str
(),
"X@GRAD"
);
std
::
string
original_var_name
=
paddle
::
framework
::
OriginVarName
(
grad_var_name
);
ASSERT_EQ
(
original_var_name
.
c_str
(),
"X"
);
}
paddle/fluid/imperative/layer.cc
浏览文件 @
6a5f6046
...
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
...
...
@@ -31,8 +32,9 @@ using framework::Variable;
void
AddTo
(
Variable
*
src
,
Variable
*
dst
)
{
framework
::
LoDTensor
*
dst_tensor
=
dst
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
LoDTensor
*
src_tensor
=
src
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE
(
dst_tensor
->
numel
()
==
src_tensor
->
numel
(),
"%lld vs %lld"
,
dst_tensor
->
numel
(),
src_tensor
->
numel
());
PADDLE_ENFORCE
(
dst_tensor
->
numel
()
==
src_tensor
->
numel
(),
"dst_numel %lld vs. src_numel %lld"
,
dst_tensor
->
numel
(),
src_tensor
->
numel
());
float
*
dst_data
=
dst_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
const
float
*
src_data
=
src_tensor
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
src_tensor
->
numel
();
++
i
)
{
...
...
@@ -114,7 +116,7 @@ framework::LoDTensor& VarBase::Grad() {
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
OpBase
::
ApplyGrad
()
{
if
(
!
grad_op_desc_
)
{
VLOG
(
3
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
LOG
(
WARNING
)
<<
"op with no grad: "
<<
op_desc_
->
Type
();
return
{};
}
VLOG
(
3
)
<<
"op grad "
<<
grad_op_desc_
->
Type
();
...
...
@@ -124,20 +126,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for
(
auto
it
:
grad_output_vars_
)
{
auto
&
outputs
=
grad_outputs
[
it
.
first
];
for
(
size_t
i
=
0
;
i
<
it
.
second
.
size
();
++
i
)
{
tmp_vars
.
emplace_back
(
new
framework
::
Variable
());
outputs
.
push_back
(
tmp_vars
.
back
().
get
());
outputs
.
back
()
->
GetMutable
<
framework
::
LoDTensor
>
();
// Allocate a new variable
Variable
*
tmp_var
=
new
framework
::
Variable
();
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tmp_vars
.
emplace_back
(
tmp_var
);
outputs
.
push_back
(
tmp_var
);
}
grad_invar_desc
.
SetShape
(
framework
::
vectorize
(
var
->
Get
<
framework
::
LoDTensor
>
().
dims
()));
VLOG
(
3
)
<<
"set op grad var desc's shape size "
<<
framework
::
vectorize
(
var
->
Get
<
framework
::
LoDTensor
>
().
dims
()).
size
();
}
framework
::
RuntimeContext
ctx
(
grad_input_vars_
,
grad_outputs
);
// No need to do
static
infer shape here.
// No need to do
compile time
infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc_
->
InferVarType
(
block_
);
...
...
@@ -156,11 +156,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for
(
auto
it
:
grad_output_vars_
)
{
auto
&
outputs
=
grad_outputs
[
it
.
first
];
auto
&
origin_outputs
=
it
.
second
;
auto
&
forward_inputs
=
input_vars_
[
framework
::
OriginVarName
(
it
.
first
)];
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
!
forward_inputs
[
i
]
->
stop_gradient_
)
{
framework
::
Variable
*
orig_grad
=
origin_outputs
[
i
];
AddTo
(
outputs
[
i
],
orig_grad
);
}
}
}
return
input_vars_
;
}
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
6a5f6046
...
...
@@ -57,7 +57,7 @@ class Tracer {
void
Trace
(
OpBase
*
op
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>&
inputs
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>&
outputs
,
framework
::
BlockDesc
*
block
)
{
framework
::
BlockDesc
*
block
,
const
bool
stop_gradient
)
{
std
::
map
<
std
::
string
,
VarBase
*>
vars
;
framework
::
OpDesc
*
op_desc
=
op
->
op_desc_
;
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
6a5f6046
...
...
@@ -152,7 +152,7 @@ PYBIND11_MODULE(core, m) {
[](
const
imperative
::
VarBase
&
self
)
{
return
self
.
stop_gradient_
;
},
[](
imperative
::
VarBase
&
self
,
bool
stop_gradient
)
{
self
.
stop_gradient_
=
stop_gradient
;
})
})
;
py
::
class_
<
imperative
::
OpBase
,
PyOpBase
>
(
m
,
"OpBase"
,
R"DOC()DOC"
)
.
def
(
py
::
init
<>
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录