Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4c8254e3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4c8254e3
编写于
3月 27, 2019
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revert some loop op revision
test=develop
上级
16f09947
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
380 addition
and
40 deletion
+380
-40
paddle/fluid/operators/controlflow/CMakeLists.txt
paddle/fluid/operators/controlflow/CMakeLists.txt
+1
-1
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+14
-7
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+291
-0
paddle/fluid/operators/controlflow/while_op_helper.h
paddle/fluid/operators/controlflow/while_op_helper.h
+43
-0
paddle/fluid/operators/interpolate_op.cc
paddle/fluid/operators/interpolate_op.cc
+4
-0
paddle/fluid/operators/lstm_op.cc
paddle/fluid/operators/lstm_op.cc
+1
-0
paddle/fluid/operators/margin_rank_loss_op.cc
paddle/fluid/operators/margin_rank_loss_op.cc
+1
-0
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+3
-0
paddle/fluid/operators/multiplex_op.cc
paddle/fluid/operators/multiplex_op.cc
+1
-0
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+20
-32
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+1
-0
未找到文件。
paddle/fluid/operators/controlflow/CMakeLists.txt
浏览文件 @
4c8254e3
include
(
operators
)
register_operators
(
DEPS naive_executor
)
cc_library
(
loop_op_helper SRCS loop
_op_helper.cc DEPS operator
)
cc_library
(
while_op_helper SRCS while
_op_helper.cc DEPS operator
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
4c8254e3
...
...
@@ -18,21 +18,28 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/
loop
_op_helper.h"
#include "paddle/fluid/operators/controlflow/
while
_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace
paddle
{
namespace
operators
{
static
constexpr
char
kCondition
[]
=
"Condition"
;
static
constexpr
char
kStepScopes
[]
=
"StepScopes"
;
static
constexpr
char
kX
[]
=
"X"
;
static
constexpr
char
kXGRAD
[]
=
"X@GRAD"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
using
StepScopeVar
=
std
::
vector
<
framework
::
Scope
*>
;
using
LoDTensor
=
framework
::
LoDTensor
;
namespace
{
// NOLINT
static
std
::
string
GetSkipEagerDeletionVarsDebugString
(
const
std
::
vector
<
std
::
string
>
&
vars
)
{
std
::
string
str
=
"Skip "
+
std
::
to_string
(
vars
.
size
())
+
" var(s) in eager deletion mode: "
;
for
(
auto
&
var
:
vars
)
{
str
.
append
(
var
);
str
.
push_back
(
' '
);
}
return
str
;
}
}
// NOLINT
class
WhileOp
:
public
framework
::
OperatorBase
{
public:
WhileOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
0 → 100644
浏览文件 @
4c8254e3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
operators
{
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class
OpVariant
{
struct
InputsVisitor
:
public
boost
::
static_visitor
<
const
framework
::
VariableNameMap
*>
{
template
<
typename
OpType
>
const
framework
::
VariableNameMap
*
operator
()(
const
OpType
*
op
)
const
{
return
&
(
op
->
Inputs
());
}
};
struct
OutputsVisitor
:
public
boost
::
static_visitor
<
const
framework
::
VariableNameMap
*>
{
template
<
typename
OpType
>
const
framework
::
VariableNameMap
*
operator
()(
const
OpType
*
op
)
const
{
return
&
(
op
->
Outputs
());
}
};
struct
AttributeMapVisitor
:
public
boost
::
static_visitor
<
const
framework
::
AttributeMap
*>
{
const
framework
::
AttributeMap
*
operator
()(
const
framework
::
OpDesc
*
op
)
const
{
return
&
(
op
->
GetAttrMap
());
}
const
framework
::
AttributeMap
*
operator
()(
const
framework
::
OperatorBase
*
op
)
const
{
return
&
(
op
->
Attrs
());
}
};
struct
RawPointerVisitor
:
public
boost
::
static_visitor
<
const
void
*>
{
template
<
typename
OpType
>
const
void
*
operator
()(
const
OpType
*
op
)
const
{
return
op
;
}
};
public:
OpVariant
(
const
framework
::
OperatorBase
*
op
)
:
op_
(
op
)
{}
// NOLINT
OpVariant
(
const
framework
::
OpDesc
*
op
)
:
op_
(
op
)
{}
// NOLINT
const
framework
::
VariableNameMap
&
Inputs
()
const
{
return
*
boost
::
apply_visitor
(
InputsVisitor
(),
op_
);
}
const
framework
::
VariableNameMap
&
Outputs
()
const
{
return
*
boost
::
apply_visitor
(
OutputsVisitor
(),
op_
);
}
const
framework
::
AttributeMap
&
Attrs
()
const
{
return
*
boost
::
apply_visitor
(
AttributeMapVisitor
(),
op_
);
}
template
<
typename
AttrType
>
const
AttrType
&
Attr
(
const
std
::
string
&
name
)
const
{
auto
&
attrs
=
Attrs
();
auto
it
=
attrs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs
.
end
(),
"Cannot find attribute %s"
,
name
);
return
boost
::
get
<
AttrType
>
(
it
->
second
);
}
bool
operator
==
(
const
OpVariant
&
other
)
const
{
return
RawPointer
()
==
other
.
RawPointer
();
}
const
void
*
RawPointer
()
const
{
return
boost
::
apply_visitor
(
RawPointerVisitor
(),
op_
);
}
int
which
()
const
{
return
static_cast
<
int
>
(
op_
.
which
());
}
struct
Hasher
{
size_t
operator
()(
const
OpVariant
&
op
)
const
{
return
reinterpret_cast
<
size_t
>
(
op
.
RawPointer
());
}
};
private:
const
boost
::
variant
<
const
framework
::
OperatorBase
*
,
const
framework
::
OpDesc
*>
op_
;
};
static
std
::
string
GetDebugString
(
const
std
::
vector
<
std
::
string
>
&
names
)
{
if
(
names
.
empty
())
return
""
;
std
::
string
ret
=
names
[
0
];
for
(
size_t
i
=
1
;
i
<
names
.
size
();
++
i
)
{
ret
+=
(
" "
+
names
[
i
]);
}
return
ret
;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static
void
SetSkipVars
(
const
OpVariant
&
op
,
std
::
vector
<
std
::
string
>
attr
)
{
auto
&
attrs
=
const_cast
<
framework
::
AttributeMap
&>
(
op
.
Attrs
());
VLOG
(
2
)
<<
"Prepare to skip "
<<
attr
.
size
()
<<
" var(s): "
<<
GetDebugString
(
attr
);
attrs
[
kSkipEagerDeletionVars
]
=
std
::
move
(
attr
);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static
bool
IsMatchedWhileOpAndWhileGradOp
(
const
OpVariant
&
fwd_op
,
const
OpVariant
&
grad_op
)
{
return
fwd_op
.
Inputs
().
at
(
kX
)
==
grad_op
.
Inputs
().
at
(
kX
)
&&
fwd_op
.
Outputs
().
at
(
kOutputs
)
==
grad_op
.
Inputs
().
at
(
kOutputs
);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static
bool
IsSkippableVar
(
const
std
::
string
&
name
,
framework
::
BlockDesc
*
grad_block
)
{
return
name
!=
framework
::
kEmptyVarName
&&
!
grad_block
->
HasVar
(
name
);
}
static
void
ModifyWhileOpAndWhileGradOpAttr
(
const
OpVariant
&
fwd_op
,
const
OpVariant
&
bwd_op
)
{
auto
*
grad_block
=
bwd_op
.
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
// Find all skippable variables in forward while_op
std
::
unordered_set
<
std
::
string
>
forward_skip_vars
;
for
(
auto
*
op_desc
:
grad_block
->
AllOps
())
{
for
(
auto
&
in_arg_name
:
op_desc
->
InputArgumentNames
())
{
if
(
IsSkippableVar
(
in_arg_name
,
grad_block
))
{
forward_skip_vars
.
insert
(
in_arg_name
);
}
}
for
(
auto
&
out_arg_name
:
op_desc
->
OutputArgumentNames
())
{
if
(
IsSkippableVar
(
out_arg_name
,
grad_block
))
{
forward_skip_vars
.
insert
(
out_arg_name
);
}
}
}
SetSkipVars
(
fwd_op
,
std
::
vector
<
std
::
string
>
(
forward_skip_vars
.
begin
(),
forward_skip_vars
.
end
()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto
&
fwd_input
=
fwd_op
.
Inputs
().
at
(
kX
);
auto
&
in_grads
=
bwd_op
.
Outputs
().
at
(
framework
::
GradVarName
(
kX
));
PADDLE_ENFORCE_EQ
(
fwd_input
.
size
(),
in_grads
.
size
(),
"Backward input gradient number does not match forward input number."
);
std
::
unordered_set
<
std
::
string
>
backward_skip_vars
;
for
(
size_t
i
=
0
;
i
<
in_grads
.
size
();
++
i
)
{
if
(
in_grads
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
backward_skip_vars
.
insert
(
in_grads
[
i
]);
backward_skip_vars
.
insert
(
framework
::
GradVarName
(
fwd_input
[
i
]));
}
SetSkipVars
(
bwd_op
,
std
::
vector
<
std
::
string
>
(
backward_skip_vars
.
begin
(),
backward_skip_vars
.
end
()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static
void
FindAllWhileAndWhileGradOp
(
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
if
(
while_ops
->
empty
())
return
;
const
auto
*
program
=
while_ops
->
front
().
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
)
->
Program
();
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
auto
&
block
=
program
->
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
auto
*
op
=
block
.
Op
(
j
);
if
(
op
->
Type
()
==
"while"
)
{
while_ops
->
emplace_back
(
op
);
}
else
if
(
op
->
Type
()
==
"while_grad"
)
{
while_grad_ops
->
emplace_back
(
op
);
}
}
}
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
"There are extra while_grad ops in the graph or program"
);
}
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
FindAllWhileAndWhileGradOp
(
while_ops
,
while_grad_ops
);
VLOG
(
2
)
<<
"Found while op num: "
<<
while_ops
->
size
()
<<
", while grad op num: "
<<
while_grad_ops
->
size
();
if
(
while_grad_ops
->
empty
())
{
return
;
}
std
::
unordered_set
<
OpVariant
,
OpVariant
::
Hasher
>
while_op_set
(
while_ops
->
begin
(),
while_ops
->
end
());
for
(
auto
&
bwd_op
:
*
while_grad_ops
)
{
const
OpVariant
*
matched_fwd_op
=
nullptr
;
for
(
auto
&
fwd_op
:
while_op_set
)
{
if
(
IsMatchedWhileOpAndWhileGradOp
(
fwd_op
,
bwd_op
))
{
PADDLE_ENFORCE
(
matched_fwd_op
==
nullptr
,
"Found multiple matched while ops"
);
matched_fwd_op
=
&
fwd_op
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
matched_fwd_op
,
"Cannot find matched forward while op."
);
ModifyWhileOpAndWhileGradOpAttr
(
*
matched_fwd_op
,
bwd_op
);
while_op_set
.
erase
(
*
matched_fwd_op
);
}
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if
(
block_id
!=
0
)
return
;
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
for
(
auto
&
op
:
all_ops
)
{
if
(
op
->
Type
()
==
"while"
)
{
fwd_ops
.
emplace_back
(
op
.
get
());
}
else
if
(
op
->
Type
()
==
"while_grad"
)
{
bwd_ops
.
emplace_back
(
op
.
get
());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
)
{
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
fwd_ops
.
reserve
(
while_ops
.
size
());
for
(
auto
*
op
:
while_ops
)
{
fwd_ops
.
emplace_back
(
op
);
}
bwd_ops
.
reserve
(
while_grad_ops
.
size
());
for
(
auto
*
op
:
while_grad_ops
)
{
bwd_ops
.
emplace_back
(
op
);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/controlflow/while_op_helper.h
0 → 100644
浏览文件 @
4c8254e3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
operators
{
static
constexpr
char
kStepBlock
[]
=
"sub_block"
;
static
constexpr
char
kCondition
[]
=
"Condition"
;
static
constexpr
char
kStepScopes
[]
=
"StepScopes"
;
static
constexpr
char
kX
[]
=
"X"
;
static
constexpr
char
kXGRAD
[]
=
"X@GRAD"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
static
constexpr
char
kSkipEagerDeletionVars
[]
=
"skip_eager_deletion_vars"
;
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
);
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/interpolate_op.cc
浏览文件 @
4c8254e3
...
...
@@ -10,6 +10,7 @@
limitations under the License. */
#include "paddle/fluid/operators/interpolate_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -209,6 +210,9 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
op
->
SetType
(
ForwardOp
().
Type
()
+
"_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
if
(
ForwardOp
().
Inputs
().
count
(
"OutSize"
)
>
0
)
{
op
->
SetInput
(
"OutSize"
,
Input
(
"OutSize"
));
}
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetAttrMap
(
Attrs
());
...
...
paddle/fluid/operators/lstm_op.cc
浏览文件 @
4c8254e3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lstm_op.h"
#include <memory>
#include <string>
namespace
paddle
{
...
...
paddle/fluid/operators/margin_rank_loss_op.cc
浏览文件 @
4c8254e3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/margin_rank_loss_op.h"
#include <memory>
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/mean_op.cc
浏览文件 @
4c8254e3
...
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/multiplex_op.cc
浏览文件 @
4c8254e3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/multiplex_op.h"
#include <memory>
#include <vector>
namespace
paddle
{
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
4c8254e3
...
...
@@ -15,24 +15,24 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/controlflow/loop_op_helper.h"
namespace
paddle
{
namespace
operators
{
using
recurrent
::
kInputs
;
using
recurrent
::
kInitialStates
;
using
recurrent
::
kParameters
;
using
recurrent
::
kOutputs
;
using
recurrent
::
kStepScopes
;
using
recurrent
::
kExStates
;
using
recurrent
::
kStates
;
using
recurrent
::
kReverse
;
using
recurrent
::
kIsTrain
;
using
recurrent
::
kInputGrads
;
using
recurrent
::
kOutputGrads
;
using
recurrent
::
kParamGrads
;
using
recurrent
::
kInitStateGrads
;
constexpr
char
kInputs
[]
=
"inputs"
;
constexpr
char
kInitialStates
[]
=
"initial_states"
;
constexpr
char
kParameters
[]
=
"parameters"
;
constexpr
char
kOutputs
[]
=
"outputs"
;
constexpr
char
kStepScopes
[]
=
"step_scopes"
;
constexpr
char
kExStates
[]
=
"ex_states"
;
constexpr
char
kStates
[]
=
"states"
;
constexpr
char
kStepBlock
[]
=
"sub_block"
;
constexpr
char
kReverse
[]
=
"reverse"
;
constexpr
char
kIsTrain
[]
=
"is_train"
;
#define GRAD_SUFFIX "@GRAD"
constexpr
char
kInputGrads
[]
=
"inputs"
GRAD_SUFFIX
;
constexpr
char
kOutputGrads
[]
=
"outputs"
GRAD_SUFFIX
;
constexpr
char
kParamGrads
[]
=
"parameters"
GRAD_SUFFIX
;
constexpr
char
kInitStateGrads
[]
=
"initial_states"
GRAD_SUFFIX
;
using
StepScopeVar
=
std
::
vector
<
framework
::
Scope
*>
;
...
...
@@ -249,9 +249,6 @@ class RecurrentOp : public RecurrentBase {
framework
::
Executor
executor
(
place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
&
keep_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
keep_vars
);
auto
*
program
=
block
->
Program
();
for
(
size_t
i
=
0
;
i
<
seq_len
;
++
i
)
{
...
...
@@ -286,7 +283,8 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute!
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
false
/*create_local_scope*/
,
true
/*create_vars*/
,
keep_vars
);
std
::
vector
<
std
::
string
>
()
/*skip_ref_cnt_vars*/
,
true
/*force_disable_gc*/
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
...
...
@@ -343,9 +341,6 @@ class RecurrentGradOp : public RecurrentBase {
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
&
keep_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
keep_vars
);
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
@@ -406,7 +401,8 @@ class RecurrentGradOp : public RecurrentBase {
// Run step block with cur_scope
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
false
/*create_local_scope*/
,
true
/*create_vars*/
,
keep_vars
);
std
::
vector
<
std
::
string
>
()
/*skip_ref_cnt_vars*/
,
true
/*force_disable_gc*/
);
VLOG
(
5
)
<<
"executor.Run finished "
;
...
...
@@ -583,10 +579,6 @@ if reverse is True
o o o o
)DOC"
).
SetDefault
(
false
);
AddAttr
<
bool
>
(
kIsTrain
,
""
).
SetDefault
(
true
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
,
"Skip vars that would "
"be used in backward ops"
)
.
SetDefault
(
std
::
vector
<
std
::
string
>
());
AddComment
(
R"DOC(
Static Length Recurrent Operator.
...
...
@@ -622,11 +614,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
this
->
OutputGrad
(
output_param
));
}
}
auto
attrs
=
this
->
Attrs
();
attrs
.
insert
({
kSkipEagerDeletionVars
,
std
::
vector
<
std
::
string
>
()});
grad
->
SetAttrMap
(
attrs
);
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetBlockAttr
(
kStepBlock
,
grad_block_
[
0
]);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad
);
...
...
paddle/fluid/operators/scatter_op.cc
浏览文件 @
4c8254e3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/scatter_op.h"
#include <memory>
#include "paddle/fluid/framework/ddim.h"
namespace
paddle
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录