Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b75bd29c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b75bd29c
编写于
12月 12, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove debug info
上级
7a43e517
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
224 addition
and
263 deletion
+224
-263
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+8
-37
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+1
-1
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+83
-49
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+60
-100
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+33
-36
paddle/fluid/operators/optimizers/adam_op.cc
paddle/fluid/operators/optimizers/adam_op.cc
+39
-40
未找到文件。
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
b75bd29c
...
...
@@ -26,46 +26,17 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
scope_
(
scope
),
place_
(
place
)
{}
struct
RecordTime
{
RecordTime
(
const
std
::
string
&
name
,
const
std
::
string
&
type
)
:
name_
(
name
),
type_
(
type
),
start_
(
std
::
chrono
::
system_clock
::
now
())
{}
~
RecordTime
()
{
if
(
type_
==
"elementsize_add"
)
{
end_
=
std
::
chrono
::
system_clock
::
now
();
std
::
chrono
::
duration
<
double
>
diff
=
end_
-
start_
;
VLOG
(
1
)
<<
name_
<<
" "
<<
type_
<<
" time record: "
<<
diff
.
count
();
}
}
std
::
string
name_
;
std
::
string
type_
;
std
::
chrono
::
system_clock
::
time_point
start_
;
std
::
chrono
::
system_clock
::
time_point
end_
;
};
void
ComputationOpHandle
::
RunImpl
()
{
{
RecordTime
rt
(
"ComputationOpHandle::RunImpl"
,
"Wait"
);
WaitInputVarGenerated
(
place_
);
}
Scope
*
scope
=
nullptr
;
{
RecordTime
rt
(
"ComputationOpHandle::RunImpl"
,
"PrepareScope"
);
scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
}
{
RecordTime
rt
(
"ComputationOpHandle::RunImpl"
,
"ReallyRun "
+
op_
->
Type
());
WaitInputVarGenerated
(
place_
);
auto
run_func
=
[
this
,
scope
]()
{
op_
->
Run
(
*
scope
,
place_
);
};
auto
run_func
=
[
this
]()
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
};
if
(
is_lock_and_record_event_free_
)
{
run_func
();
}
else
{
this
->
RunAndRecordEvent
(
run_func
);
}
if
(
is_lock_and_record_event_free_
)
{
run_func
();
}
else
{
this
->
RunAndRecordEvent
(
run_func
);
}
}
...
...
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
b75bd29c
...
...
@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() {
void
OpHandleBase
::
Run
(
bool
use_cuda
)
{
#ifdef PADDLE_WITH_CUDA
if
(
events_
.
empty
()
&&
use_cuda
&&
!
dev_ctxes_
.
empty
()
)
{
if
(
events_
.
empty
()
&&
use_cuda
)
{
for
(
auto
&
p
:
dev_ctxes_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
;
PADDLE_ENFORCE
(
cudaSetDevice
(
dev_id
));
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
b75bd29c
...
...
@@ -20,6 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
DEFINE_bool
(
enforce_when_check_program
,
true
,
"Checking whether the program is correct or not. We will log "
"errors rather than throwing exceptions if this flag turned off"
);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
...
...
@@ -28,55 +32,85 @@ namespace {
void
CheckProgram
(
const
ProgramDesc
&
program
)
{
#define _INT(role) static_cast<int>(role)
// std::map<int, bool> visit;
// for (OpDesc *op : program.Block(0).AllOps()) {
// // For backward compatibility, some program doesn't have role added.
// if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
// int role_id =
// boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
// visit[role_id] = true;
// switch (role_id) {
// case _INT(OpRole::kForward):
// if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
// LOG(ERROR)
// << "Cannot add backward operator before forward operator %s."
// << op->Type();
// }
// break;
// case _INT(OpRole::kBackward):
// case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add backward operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
// _INT(OpRole::kLoss)) == visit.end(),
// "Cannot add backward|loss operator before "
// "forward|loss operator %s.",
// op->Type());
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add forward|loss operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kOptimize):
// case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
// "Optimize operators %s must follow backward operator.",
// op->Type());
// break;
// case _INT(OpRole::kLRSched):
// case _INT(OpRole::kDist):
// case _INT(OpRole::kRPC):
// case _INT(OpRole::kNotSpecified):
// break;
// default:
// LOG(FATAL) << "Unknown operator role. Don't add new role because "
// "you don't know what you are doing.";
// }
// }
std
::
map
<
int
,
bool
>
visit
;
for
(
OpDesc
*
op
:
program
.
Block
(
0
).
AllOps
())
{
// For backward compatibility, some program doesn't have role added.
if
(
!
op
->
HasAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
continue
;
int
role_id
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()));
visit
[
role_id
]
=
true
;
switch
(
role_id
)
{
case
_INT
(
OpRole
::
kForward
):
if
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add backward operator before forward operator %s."
<<
op
->
Type
();
}
break
;
case
_INT
(
OpRole
::
kBackward
):
case
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
):
if
(
!
FLAGS_enforce_when_check_program
)
{
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add backward operator %s after optimize operator."
,
op
->
Type
());
}
else
{
if
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add backward operator %s after optimize operator."
,
<<
op
->
Type
();
}
}
break
;
case
_INT
(
OpRole
::
kForward
)
|
_INT
(
OpRole
::
kLoss
):
if
(
!
FLAGS_enforce_when_check_program
)
{
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
))
==
visit
.
end
(),
"Cannot add backward|loss operator before "
"forward|loss operator %s."
,
op
->
Type
());
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
==
visit
.
end
(),
"Cannot add forward|loss operator %s after optimize operator."
,
op
->
Type
());
}
else
{
if
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
)
|
_INT
(
OpRole
::
kLoss
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add backward|loss operator before "
<<
"forward|loss operator %s."
<<
op
->
Type
();
}
if
(
visit
.
find
(
_INT
(
OpRole
::
kOptimize
))
!=
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Cannot add forward|loss operator %s after optimize "
"operator."
,
<<
op
->
Type
();
}
}
break
;
case
_INT
(
OpRole
::
kOptimize
):
case
_INT
(
OpRole
::
kOptimize
)
|
_INT
(
OpRole
::
kLRSched
):
if
(
!
FLAGS_enforce_when_check_program
)
{
PADDLE_ENFORCE
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
!=
visit
.
end
(),
"Optimize operators %s must follow backward operator."
,
op
->
Type
());
}
else
{
if
(
visit
.
find
(
_INT
(
OpRole
::
kBackward
))
==
visit
.
end
())
{
LOG
(
ERROR
)
<<
"Optimize operators %s must follow backward operator."
,
<<
op
->
Type
();
}
}
break
;
case
_INT
(
OpRole
::
kLRSched
):
case
_INT
(
OpRole
::
kDist
):
case
_INT
(
OpRole
::
kRPC
):
case
_INT
(
OpRole
::
kNotSpecified
):
break
;
default:
LOG
(
FATAL
)
<<
"Unknown operator role. Don't add new role because "
"you don't know what you are doing."
;
}
}
#undef _INT
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
b75bd29c
...
...
@@ -701,125 +701,85 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this
->
InferShape
(
&
infer_shape_ctx
);
}
struct
RecordTime
{
RecordTime
(
const
std
::
string
&
name
,
const
std
::
string
&
type
)
:
name_
(
name
),
type_
(
type
),
start_
(
std
::
chrono
::
system_clock
::
now
())
{}
void
inline
stop
()
{
end_
=
std
::
chrono
::
system_clock
::
now
();
std
::
chrono
::
duration
<
double
>
diff
=
end_
-
start_
;
VLOG
(
1
)
<<
name_
<<
" "
<<
type_
<<
" time record: "
<<
diff
.
count
();
}
~
RecordTime
()
{
if
(
type_
==
"elementwise_add"
)
{
stop
();
}
// stop();
}
std
::
string
name_
;
std
::
string
type_
;
std
::
chrono
::
system_clock
::
time_point
start_
;
std
::
chrono
::
system_clock
::
time_point
end_
;
};
void
OperatorWithKernel
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
RecordTime
rt
(
"OperatorWithKernel::All"
,
type_
);
{
RecordTime
rt
(
"OperatorWithKernel::InferShape"
,
type_
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
}
{
RecordTime
*
rt_1
=
new
RecordTime
(
"OperatorWithKernel::Compute1"
,
type_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
place
);
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
place
);
// check if op[type] has kernel registered.
auto
&
all_op_kernels
=
AllOpKernels
();
auto
kernels_iter
=
all_op_kernels
.
find
(
type_
);
if
(
kernels_iter
==
all_op_kernels
.
end
())
{
PADDLE_THROW
(
"There are no kernels which are registered in the %s operator."
,
type_
);
}
// check if op[type] has kernel registered.
auto
&
all_op_kernels
=
AllOpKernels
();
auto
kernels_iter
=
all_op_kernels
.
find
(
type_
);
if
(
kernels_iter
==
all_op_kernels
.
end
())
{
PADDLE_THROW
(
"There are no kernels which are registered in the %s operator."
,
type_
);
}
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
));
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
));
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if
(
kernel_iter
==
kernels
.
end
()
&&
expected_kernel_key
.
library_type_
==
LibraryType
::
kMKLDNN
)
{
VLOG
(
3
)
<<
"missing MKLDNN kernel: fallbacking to PLAIN one"
;
expected_kernel_key
.
library_type_
=
LibraryType
::
kPlain
;
expected_kernel_key
.
data_layout_
=
DataLayout
::
kAnyLayout
;
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
}
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if
(
kernel_iter
==
kernels
.
end
()
&&
expected_kernel_key
.
library_type_
==
LibraryType
::
kMKLDNN
)
{
VLOG
(
3
)
<<
"missing MKLDNN kernel: fallbacking to PLAIN one"
;
expected_kernel_key
.
library_type_
=
LibraryType
::
kPlain
;
expected_kernel_key
.
data_layout_
=
DataLayout
::
kAnyLayout
;
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
}
#endif
if
(
kernel_iter
==
kernels
.
end
())
{
PADDLE_THROW
(
"op %s does not have kernel for %s"
,
type_
,
KernelTypeToString
(
expected_kernel_key
));
}
if
(
kernel_iter
==
kernels
.
end
())
{
PADDLE_THROW
(
"op %s does not have kernel for %s"
,
type_
,
KernelTypeToString
(
expected_kernel_key
));
}
// do data transformScope &transfer_scope;
std
::
vector
<
std
::
string
>
transfered_inplace_vars
;
Scope
*
transfer_scope
=
nullptr
;
// auto* transfer_scope =
// TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// do data transformScope &transfer_scope;
std
::
vector
<
std
::
string
>
transfered_inplace_vars
;
auto
*
transfer_scope
=
TryTransferData
(
scope
,
expected_kernel_key
,
&
transfered_inplace_vars
);
// exec scope is the scope that kernel actually executed on.
const
Scope
&
exec_scope
=
scope
;
// const Scope& exec_scope =
// (transfer_scope == nullptr ? scope : *transfer_scope);
// exec scope is the scope that kernel actually executed on.
const
Scope
&
exec_scope
=
(
transfer_scope
==
nullptr
?
scope
:
*
transfer_scope
);
if
(
!
(
expected_kernel_key
.
place_
==
dev_ctx
->
GetPlace
()))
{
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
}
delete
rt_1
;
if
(
!
(
expected_kernel_key
.
place_
==
dev_ctx
->
GetPlace
()))
{
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
}
RecordTime
*
rt_2
=
new
RecordTime
(
"OperatorWithKernel::Compute2"
,
type_
);
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
delete
rt_2
;
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
RecordTime
*
rt_3
=
new
RecordTime
(
"OperatorWithKernel::Compute3"
,
type_
);
if
(
!
transfered_inplace_vars
.
empty
())
{
// there is inplace variable has been transfered.
TransferInplaceVarsBack
(
scope
,
transfered_inplace_vars
,
*
transfer_scope
);
}
if
(
!
transfered_inplace_vars
.
empty
())
{
// there is inplace variable has been transfered.
TransferInplaceVarsBack
(
scope
,
transfered_inplace_vars
,
*
transfer_scope
);
}
/*For profiling/benchmark only*/
if
(
FLAGS_benchmark
)
{
dev_ctx
->
Wait
();
}
/*For profiling/benchmark only*/
if
(
FLAGS_benchmark
)
{
dev_ctx
->
Wait
();
}
if
(
FLAGS_check_nan_inf
)
{
for
(
auto
&
vname
:
OutputVars
(
true
))
{
auto
*
var
=
exec_scope
.
FindVar
(
vname
);
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
LoDTensor
>
());
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
SelectedRows
>
().
value
());
}
if
(
FLAGS_check_nan_inf
)
{
for
(
auto
&
vname
:
OutputVars
(
true
))
{
auto
*
var
=
exec_scope
.
FindVar
(
vname
);
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
LoDTensor
>
());
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
SelectedRows
>
().
value
());
}
}
delete
rt_3
;
}
}
void
OperatorWithKernel
::
TransferInplaceVarsBack
(
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
b75bd29c
...
...
@@ -33,37 +33,34 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
if
(
!
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Y"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the "
"received is %s [%s]"
,
ctx
->
GetInputsVarType
(
"Y"
).
front
(),
ctx
->
Inputs
(
"Y"
).
front
());
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
y_dim
.
size
(),
"Rank of first input must >= rank of second input."
);
}
else
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
PADDLE_ENFORCE
((
ctx
->
GetInputDim
(
"Y"
).
size
()
==
1u
)
&&
(
ctx
->
GetInputDim
(
"Y"
)[
0
]
==
1
),
"For elementwise_op, if X is Sparse, "
"Y must be scalar."
);
}
else
{
PADDLE_THROW
(
"X's type[%s] is not supported by elementwise_op."
,
ctx
->
GetInputsVarType
(
"X"
).
front
());
}
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Y"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input var's type should be LoDTensor, but the received is %s [%s]"
,
ctx
->
GetInputsVarType
(
"Y"
).
front
(),
ctx
->
Inputs
(
"Y"
).
front
());
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
y_dim
.
size
(),
"Rank of first input must >= rank of second input."
);
}
else
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
PADDLE_ENFORCE
((
ctx
->
GetInputDim
(
"Y"
).
size
()
==
1u
)
&&
(
ctx
->
GetInputDim
(
"Y"
)[
0
]
==
1
),
"For elementwise_op, if X is Sparse, "
"Y must be scalar."
);
}
else
{
PADDLE_THROW
(
"X's type[%s] is not supported by elementwise_op."
,
ctx
->
GetInputsVarType
(
"X"
).
front
());
}
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
...
...
@@ -128,7 +125,7 @@ The equation is:
$$%s$$
- $X$: a tensor of any dimension.
- $X$: a tensor of any dimension.
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator:
...
...
@@ -138,10 +135,10 @@ There are two cases for this operator:
For case 2:
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
For example:
...
...
@@ -155,7 +152,7 @@ For example:
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information.
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the input $X$.
)DOC"
,
...
...
paddle/fluid/operators/optimizers/adam_op.cc
浏览文件 @
b75bd29c
...
...
@@ -23,57 +23,56 @@ class AdamOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
//
PADDLE_ENFORCE(ctx->HasInput("Param"),
//
"Input(Param) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("Grad"),
//
"Input(Grad) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
//
"Input(Moment1) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
//
"Input(Moment2) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
//
"Input(LearningRate) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
//
"Input(Beta1Pow) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
//
"Input(Beta2Pow) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
//
"Output(ParamOut) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
//
"Output(Moment1Out) of AdamOp should not be null.");
//
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
//
"Output(Moment2Out) of AdamOp should not be null.");
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
"Input(Grad) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment1"
),
"Input(Moment1) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment2"
),
"Input(Moment2) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Beta1Pow"
),
"Input(Beta1Pow) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Beta2Pow"
),
"Input(Beta2Pow) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(ParamOut) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Moment1Out"
),
"Output(Moment1Out) of AdamOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Moment2Out"
),
"Output(Moment2Out) of AdamOp should not be null."
);
auto
lr_dims
=
ctx
->
GetInputDim
(
"LearningRate"
);
//
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
//
"Learning rate should have 1 dimension");
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dims
),
1
,
"Learning rate should have 1 dimension"
);
auto
beta1_pow_dims
=
ctx
->
GetInputDim
(
"Beta1Pow"
);
//
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
//
"Beta1 power accumulator should have 1 dimension");
PADDLE_ENFORCE_EQ
(
framework
::
product
(
beta1_pow_dims
),
1
,
"Beta1 power accumulator should have 1 dimension"
);
auto
beta2_pow_dims
=
ctx
->
GetInputDim
(
"Beta2Pow"
);
//
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
//
"Beta2 power accumulator should have 1 dimension");
PADDLE_ENFORCE_EQ
(
framework
::
product
(
beta2_pow_dims
),
1
,
"Beta2 power accumulator should have 1 dimension"
);
auto
param_dims
=
ctx
->
GetInputDim
(
"Param"
);
//
if (ctx->GetInputsVarType("Grad")[0] ==
//
framework::proto::VarType::LOD_TENSOR) {
//
PADDLE_ENFORCE_EQ(
//
param_dims, ctx->GetInputDim("Grad"),
//
"Param and Grad input of AdamOp should have same dimension");
//
}
//
PADDLE_ENFORCE_EQ(
//
param_dims, ctx->GetInputDim("Moment1"),
//
"Param and Moment1 input of AdamOp should have same dimension");
//
PADDLE_ENFORCE_EQ(
//
param_dims, ctx->GetInputDim("Moment2"),
//
"Param and Moment2 input of AdamOp should have same dimension");
if
(
ctx
->
GetInputsVarType
(
"Grad"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
param_dims
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and Grad input of AdamOp should have same dimension"
);
}
PADDLE_ENFORCE_EQ
(
param_dims
,
ctx
->
GetInputDim
(
"Moment1"
),
"Param and Moment1 input of AdamOp should have same dimension"
);
PADDLE_ENFORCE_EQ
(
param_dims
,
ctx
->
GetInputDim
(
"Moment2"
),
"Param and Moment2 input of AdamOp should have same dimension"
);
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dims
);
ctx
->
SetOutputDim
(
"Moment1Out"
,
param_dims
);
ctx
->
SetOutputDim
(
"Moment2Out"
,
param_dims
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录