Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
376c2f01
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
376c2f01
编写于
5月 11, 2021
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add default attr; test=develop
上级
0f1e7e3d
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
53 addition
and
12 deletion
+53
-12
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+13
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+5
-1
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+1
-0
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+1
-0
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+2
-0
paddle/fluid/imperative/execution_context.h
paddle/fluid/imperative/execution_context.h
+14
-5
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+5
-2
paddle/fluid/imperative/tests/test_layer.cc
paddle/fluid/imperative/tests/test_layer.cc
+1
-1
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+5
-0
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+2
-0
paddle/fluid/operators/sync_batch_norm_op.cu.h
paddle/fluid/operators/sync_batch_norm_op.cu.h
+3
-1
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+1
-1
未找到文件。
paddle/fluid/framework/attribute.h
浏览文件 @
376c2f01
...
@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
...
@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class
AttrReader
{
class
AttrReader
{
public:
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
AttrReader
(
const
AttributeMap
&
attrs
,
const
AttributeMap
&
default_attrs
=
{}
)
:
attrs_
(
attrs
),
default_attrs_
(
default_attrs
)
{}
template
<
typename
T
>
template
<
typename
T
>
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
...
@@ -224,6 +225,7 @@ class AttrReader {
...
@@ -224,6 +225,7 @@ class AttrReader {
private:
private:
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
default_attrs_
;
};
};
// check whether a value(attribute) fit a certain limit
// check whether a value(attribute) fit a certain limit
...
@@ -406,6 +408,14 @@ class OpAttrChecker {
...
@@ -406,6 +408,14 @@ class OpAttrChecker {
return
default_values_map
;
return
default_values_map
;
}
}
void
InitDefaultMap
()
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
&
default_values_map_
,
true
);
}
}
const
AttributeMap
&
default_attr_map
()
const
{
return
default_values_map_
;
}
void
RecordExplicitCheckerNum
()
{
void
RecordExplicitCheckerNum
()
{
explicit_checker_num_
=
attr_checkers_
.
size
();
explicit_checker_num_
=
attr_checkers_
.
size
();
}
}
...
@@ -413,6 +423,8 @@ class OpAttrChecker {
...
@@ -413,6 +423,8 @@ class OpAttrChecker {
private:
private:
std
::
vector
<
AttrChecker
>
attr_checkers_
;
std
::
vector
<
AttrChecker
>
attr_checkers_
;
AttributeMap
default_values_map_
;
// in order to improve the efficiency of dynamic graph mode,
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// for explicit attribute, we mean the attribute added in the customized
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
376c2f01
...
@@ -194,11 +194,14 @@ void HogwildWorker::TrainFilesWithProfiler() {
...
@@ -194,11 +194,14 @@ void HogwildWorker::TrainFilesWithProfiler() {
void
HogwildWorker
::
TrainFiles
()
{
void
HogwildWorker
::
TrainFiles
()
{
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
1
);
std
::
cerr
<<
"1!!!!!"
<<
std
::
endl
;
// how to accumulate fetched values here
// how to accumulate fetched values here
device_reader_
->
Start
();
device_reader_
->
Start
();
int
cur_batch
;
int
cur_batch
;
int
i
=
0
;
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
i
++
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
bool
need_skip
=
false
;
bool
need_skip
=
false
;
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
...
@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() {
...
@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
PrintFetchVars
();
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
}
}
std
::
cerr
<<
"total bacth "
<<
i
<<
std
::
endl
;
#if defined PADDLE_WITH_PSCORE
#if defined PADDLE_WITH_PSCORE
if
(
thread_barrier_
)
{
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
376c2f01
...
@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) {
...
@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) {
void
MultiTrainer
::
Run
()
{
void
MultiTrainer
::
Run
()
{
VLOG
(
3
)
<<
"Going to run"
;
VLOG
(
3
)
<<
"Going to run"
;
LOG
(
ERROR
)
<<
"multi run "
<<
thread_num_
<<
"
\t
"
<<
debug_
;
for
(
int
thidx
=
0
;
thidx
<
thread_num_
;
++
thidx
)
{
for
(
int
thidx
=
0
;
thidx
<
thread_num_
;
++
thidx
)
{
if
(
!
debug_
)
{
if
(
!
debug_
)
{
threads_
.
push_back
(
threads_
.
push_back
(
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
376c2f01
...
@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
...
@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_
=
attr_checker
;
op_checker_
=
attr_checker
;
Make
();
Make
();
op_checker_
->
RecordExplicitCheckerNum
();
op_checker_
->
RecordExplicitCheckerNum
();
op_checker_
->
InitDefaultMap
();
AddAttr
<
int
>
(
OpRoleAttrName
(),
"The role of this operator"
)
AddAttr
<
int
>
(
OpRoleAttrName
(),
"The role of this operator"
)
.
InEnum
(
.
InEnum
(
...
...
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
376c2f01
...
@@ -28,4 +28,6 @@ endif(NOT WIN32)
...
@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library
(
gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function
)
cc_library
(
gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function
)
cc_binary
(
tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
add_subdirectory
(
tests
)
add_subdirectory
(
tests
)
paddle/fluid/imperative/execution_context.h
浏览文件 @
376c2f01
...
@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
...
@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
framework
::
RuntimeContext
&
ctx
,
const
framework
::
RuntimeContext
&
ctx
,
const
NameVarMap
<
VarType
>&
var_base_map_in
,
const
NameVarMap
<
VarType
>&
var_base_map_in
,
const
NameVarMap
<
VarType
>&
var_base_map_out
,
const
NameVarMap
<
VarType
>&
var_base_map_out
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
,
const
framework
::
AttributeMap
&
default_attrs
)
:
ExecutionContext
(
op
,
scope
,
device_context
,
ctx
),
:
ExecutionContext
(
op
,
scope
,
device_context
,
ctx
),
var_base_map_in_
(
var_base_map_in
),
var_base_map_in_
(
var_base_map_in
),
var_base_map_out_
(
var_base_map_out
),
var_base_map_out_
(
var_base_map_out
),
attrs_
(
attrs
)
{}
attrs_
(
attrs
),
default_attrs_
(
default_attrs
){}
std
::
string
InputName
(
const
std
::
string
&
name
)
const
override
{
std
::
string
InputName
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
var_base_map_in_
.
find
(
name
);
auto
it
=
var_base_map_in_
.
find
(
name
);
...
@@ -92,16 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext {
...
@@ -92,16 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
return
attrs_
.
count
(
name
)
!=
0
;
return
attrs_
.
count
(
name
)
!=
0
||
default_attrs_
.
count
(
name
)
;
}
}
const
framework
::
AttributeMap
&
Attrs
()
const
override
{
return
attrs_
;
}
const
framework
::
AttributeMap
&
Attrs
()
const
override
{
return
attrs_
;
}
const
framework
::
Attribute
&
GetAttr
(
const
std
::
string
&
name
)
const
override
{
const
framework
::
Attribute
&
GetAttr
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
attrs_
.
find
(
name
);
auto
it
=
attrs_
.
find
(
name
);
bool
find
=
(
it
!=
attrs_
.
end
()
);
if
(
it
==
attrs_
.
end
()
)
{
it
=
default_attrs_
.
find
(
name
);
find
=
(
it
!=
default_attrs_
.
end
()
);
}
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
it
,
attrs_
.
end
()
,
find
,
false
,
platform
::
errors
::
NotFound
(
"can not find [%s] in attrs"
,
name
));
platform
::
errors
::
NotFound
(
"can not find [%s] in attrs"
,
name
));
return
it
->
second
;
return
it
->
second
;
...
@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
...
@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
NameVarMap
<
VarType
>&
var_base_map_in_
;
const
NameVarMap
<
VarType
>&
var_base_map_in_
;
const
NameVarMap
<
VarType
>&
var_base_map_out_
;
const
NameVarMap
<
VarType
>&
var_base_map_out_
;
const
framework
::
AttributeMap
&
attrs_
;
const
framework
::
AttributeMap
&
attrs_
;
const
framework
::
AttributeMap
&
default_attrs_
;
};
};
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
376c2f01
...
@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
...
@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
}
}
#endif
#endif
//auto *attr_checker = op_->Info().Checker();
// 1. get expected kernel key
// 1. get expected kernel key
auto
expected_kernel_key
=
auto
expected_kernel_key
=
op
.
GetExpectedKernelType
(
DygraphExecutionContext
<
VarType
>
(
op
.
GetExpectedKernelType
(
DygraphExecutionContext
<
VarType
>
(
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
))
;
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
,
{}
))
;
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
// 2. check if op[type] has kernel registered.
// 2. check if op[type] has kernel registered.
...
@@ -172,8 +173,10 @@ static void PreparedOpRunImpl(
...
@@ -172,8 +173,10 @@ static void PreparedOpRunImpl(
static_cast
<
const
framework
::
OperatorWithKernel
&>
(
op
).
InferShape
(
static_cast
<
const
framework
::
OperatorWithKernel
&>
(
op
).
InferShape
(
&
infer_shape_ctx
);
&
infer_shape_ctx
);
auto
*
attr_checker
=
op
.
Info
().
Checker
();
func
(
DygraphExecutionContext
<
VarType
>
(
op
,
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
func
(
DygraphExecutionContext
<
VarType
>
(
op
,
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
));
attrs
,
attr_checker
->
default_attr_map
()
));
/**
/**
* [ Why need handle complex gradient to real gradient? ]
* [ Why need handle complex gradient to real gradient? ]
...
...
paddle/fluid/imperative/tests/test_layer.cc
浏览文件 @
376c2f01
...
@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) {
...
@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework
::
Scope
scope
;
framework
::
Scope
scope
;
DygraphExecutionContext
<
imperative
::
VarBase
>
dy_exe_context
(
DygraphExecutionContext
<
imperative
::
VarBase
>
dy_exe_context
(
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
concat_att_map
);
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
concat_att_map
,
{}
);
ASSERT_EQ
(
dy_exe_context
.
InputSize
(
"X"
),
1u
);
ASSERT_EQ
(
dy_exe_context
.
InputSize
(
"X"
),
1u
);
ASSERT_EQ
(
dy_exe_context
.
InputName
(
"X"
),
"vin"
);
ASSERT_EQ
(
dy_exe_context
.
InputName
(
"X"
),
"vin"
);
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
376c2f01
...
@@ -149,11 +149,16 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
...
@@ -149,11 +149,16 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
}
}
}
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
type
,
{},
{},
{},
false
);
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
type
,
{},
{},
{},
false
);
const
auto
&
op_info
=
op
->
Info
();
const
auto
&
op_info
=
op
->
Info
();
auto
*
attr_checker
=
op_info
.
Checker
();
auto
*
attr_checker
=
op_info
.
Checker
();
if
(
attr_checker
)
{
if
(
attr_checker
)
{
attr_checker
->
Check
(
&
attrs
,
true
);
attr_checker
->
Check
(
&
attrs
,
true
);
}
}
NameVarBaseMap
new_ins
=
ins
;
NameVarBaseMap
new_ins
=
ins
;
if
(
enable_autocast_
)
{
if
(
enable_autocast_
)
{
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
376c2f01
...
@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto
input_data_type
=
auto
input_data_type
=
OperatorWithKernel
::
IndicateOrPromoteVarDataTypes
(
ctx
,
"X"
,
"Y"
);
OperatorWithKernel
::
IndicateOrPromoteVarDataTypes
(
ctx
,
"X"
,
"Y"
);
/*
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
...
@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
...
@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN);
framework::LibraryType::kMKLDNN);
}
}
#endif
#endif
*/
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
...
...
paddle/fluid/operators/sync_batch_norm_op.cu.h
浏览文件 @
376c2f01
...
@@ -186,12 +186,14 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
...
@@ -186,12 +186,14 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
framework
::
DataLayout
::
kNHWC
><<<
grid
,
threads
,
0
,
stream
>>>
(
framework
::
DataLayout
::
kNHWC
><<<
grid
,
threads
,
0
,
stream
>>>
(
x_d
,
N
,
H
*
W
*
D
,
C
,
stats
);
x_d
,
N
,
H
*
W
*
D
,
C
,
stats
);
}
}
/*
Tensor c_g_st;
Tensor c_g_st;
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace());
{2 * C + 1}, platform::CPUPlace());
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
*/
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
*
comm
=
dev_ctx
.
nccl_comm
();
auto
*
comm
=
dev_ctx
.
nccl_comm
();
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
376c2f01
...
@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput(
...
@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput(
}
// namespace paddle
}
// namespace paddle
// This include must be the last line
// This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h"
#include "paddle/fluid/pybind/op_function_impl
_new
.h"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录