Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
376c2f01
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
376c2f01
编写于
5月 11, 2021
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add default attr; test=develop
上级
0f1e7e3d
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
53 addition
and
12 deletion
+53
-12
paddle/fluid/framework/attribute.h
paddle/fluid/framework/attribute.h
+13
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+5
-1
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+1
-0
paddle/fluid/framework/op_proto_maker.cc
paddle/fluid/framework/op_proto_maker.cc
+1
-0
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+2
-0
paddle/fluid/imperative/execution_context.h
paddle/fluid/imperative/execution_context.h
+14
-5
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+5
-2
paddle/fluid/imperative/tests/test_layer.cc
paddle/fluid/imperative/tests/test_layer.cc
+1
-1
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+5
-0
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+2
-0
paddle/fluid/operators/sync_batch_norm_op.cu.h
paddle/fluid/operators/sync_batch_norm_op.cu.h
+3
-1
paddle/fluid/pybind/op_function.h
paddle/fluid/pybind/op_function.h
+1
-1
未找到文件。
paddle/fluid/framework/attribute.h
浏览文件 @
376c2f01
...
...
@@ -208,7 +208,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class
AttrReader
{
public:
explicit
AttrReader
(
const
AttributeMap
&
attrs
)
:
attrs_
(
attrs
)
{}
explicit
AttrReader
(
const
AttributeMap
&
attrs
,
const
AttributeMap
&
default_attrs
=
{}
)
:
attrs_
(
attrs
),
default_attrs_
(
default_attrs
)
{}
template
<
typename
T
>
inline
const
T
&
Get
(
const
std
::
string
&
name
)
const
{
...
...
@@ -224,6 +225,7 @@ class AttrReader {
private:
const
AttributeMap
&
attrs_
;
const
AttributeMap
&
default_attrs_
;
};
// check whether a value(attribute) fit a certain limit
...
...
@@ -406,6 +408,14 @@ class OpAttrChecker {
return
default_values_map
;
}
void
InitDefaultMap
()
{
for
(
const
auto
&
checker
:
attr_checkers_
)
{
checker
(
&
default_values_map_
,
true
);
}
}
const
AttributeMap
&
default_attr_map
()
const
{
return
default_values_map_
;
}
void
RecordExplicitCheckerNum
()
{
explicit_checker_num_
=
attr_checkers_
.
size
();
}
...
...
@@ -413,6 +423,8 @@ class OpAttrChecker {
private:
std
::
vector
<
AttrChecker
>
attr_checkers_
;
AttributeMap
default_values_map_
;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
376c2f01
...
...
@@ -195,10 +195,13 @@ void HogwildWorker::TrainFilesWithProfiler() {
void
HogwildWorker
::
TrainFiles
()
{
platform
::
SetNumThreads
(
1
);
std
::
cerr
<<
"1!!!!!"
<<
std
::
endl
;
// how to accumulate fetched values here
device_reader_
->
Start
();
int
cur_batch
;
int
i
=
0
;
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
i
++
;
for
(
auto
&
op
:
ops_
)
{
bool
need_skip
=
false
;
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
...
...
@@ -215,6 +218,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
thread_scope_
->
DropKids
();
}
std
::
cerr
<<
"total bacth "
<<
i
<<
std
::
endl
;
#if defined PADDLE_WITH_PSCORE
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
376c2f01
...
...
@@ -124,6 +124,7 @@ Scope* MultiTrainer::GetWorkerScope(int thread_id) {
void
MultiTrainer
::
Run
()
{
VLOG
(
3
)
<<
"Going to run"
;
LOG
(
ERROR
)
<<
"multi run "
<<
thread_num_
<<
"
\t
"
<<
debug_
;
for
(
int
thidx
=
0
;
thidx
<
thread_num_
;
++
thidx
)
{
if
(
!
debug_
)
{
threads_
.
push_back
(
...
...
paddle/fluid/framework/op_proto_maker.cc
浏览文件 @
376c2f01
...
...
@@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_
=
attr_checker
;
Make
();
op_checker_
->
RecordExplicitCheckerNum
();
op_checker_
->
InitDefaultMap
();
AddAttr
<
int
>
(
OpRoleAttrName
(),
"The role of this operator"
)
.
InEnum
(
...
...
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
376c2f01
...
...
@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library
(
gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function
)
cc_binary
(
tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
add_subdirectory
(
tests
)
paddle/fluid/imperative/execution_context.h
浏览文件 @
376c2f01
...
...
@@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
framework
::
RuntimeContext
&
ctx
,
const
NameVarMap
<
VarType
>&
var_base_map_in
,
const
NameVarMap
<
VarType
>&
var_base_map_out
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
,
const
framework
::
AttributeMap
&
default_attrs
)
:
ExecutionContext
(
op
,
scope
,
device_context
,
ctx
),
var_base_map_in_
(
var_base_map_in
),
var_base_map_out_
(
var_base_map_out
),
attrs_
(
attrs
)
{}
attrs_
(
attrs
),
default_attrs_
(
default_attrs
){}
std
::
string
InputName
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
var_base_map_in_
.
find
(
name
);
...
...
@@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
override
{
return
attrs_
.
count
(
name
)
!=
0
;
return
attrs_
.
count
(
name
)
!=
0
||
default_attrs_
.
count
(
name
)
;
}
const
framework
::
AttributeMap
&
Attrs
()
const
override
{
return
attrs_
;
}
...
...
@@ -100,8 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
framework
::
Attribute
&
GetAttr
(
const
std
::
string
&
name
)
const
override
{
auto
it
=
attrs_
.
find
(
name
);
bool
find
=
(
it
!=
attrs_
.
end
()
);
if
(
it
==
attrs_
.
end
()
)
{
it
=
default_attrs_
.
find
(
name
);
find
=
(
it
!=
default_attrs_
.
end
()
);
}
PADDLE_ENFORCE_NE
(
it
,
attrs_
.
end
()
,
find
,
false
,
platform
::
errors
::
NotFound
(
"can not find [%s] in attrs"
,
name
));
return
it
->
second
;
...
...
@@ -192,6 +200,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const
NameVarMap
<
VarType
>&
var_base_map_in_
;
const
NameVarMap
<
VarType
>&
var_base_map_out_
;
const
framework
::
AttributeMap
&
attrs_
;
const
framework
::
AttributeMap
&
default_attrs_
;
};
}
// namespace imperative
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
376c2f01
...
...
@@ -104,10 +104,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
}
#endif
//auto *attr_checker = op_->Info().Checker();
// 1. get expected kernel key
auto
expected_kernel_key
=
op
.
GetExpectedKernelType
(
DygraphExecutionContext
<
VarType
>
(
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
))
;
op
,
framework
::
Scope
(),
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
,
{}
))
;
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
// 2. check if op[type] has kernel registered.
...
...
@@ -172,8 +173,10 @@ static void PreparedOpRunImpl(
static_cast
<
const
framework
::
OperatorWithKernel
&>
(
op
).
InferShape
(
&
infer_shape_ctx
);
auto
*
attr_checker
=
op
.
Info
().
Checker
();
func
(
DygraphExecutionContext
<
VarType
>
(
op
,
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
attrs
));
attrs
,
attr_checker
->
default_attr_map
()
));
/**
* [ Why need handle complex gradient to real gradient? ]
...
...
paddle/fluid/imperative/tests/test_layer.cc
浏览文件 @
376c2f01
...
...
@@ -358,7 +358,7 @@ TEST(test_layer, test_dygraph_execution_context) {
framework
::
Scope
scope
;
DygraphExecutionContext
<
imperative
::
VarBase
>
dy_exe_context
(
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
concat_att_map
);
*
(
op
.
get
()),
scope
,
*
dev_ctx
,
ctx
,
ins
,
outs
,
concat_att_map
,
{}
);
ASSERT_EQ
(
dy_exe_context
.
InputSize
(
"X"
),
1u
);
ASSERT_EQ
(
dy_exe_context
.
InputName
(
"X"
),
"vin"
);
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
376c2f01
...
...
@@ -149,12 +149,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
}
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
type
,
{},
{},
{},
false
);
const
auto
&
op_info
=
op
->
Info
();
auto
*
attr_checker
=
op_info
.
Checker
();
if
(
attr_checker
)
{
attr_checker
->
Check
(
&
attrs
,
true
);
}
NameVarBaseMap
new_ins
=
ins
;
if
(
enable_autocast_
)
{
VLOG
(
5
)
<<
"Auto mixed precision run operator: "
<<
type
;
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
376c2f01
...
...
@@ -109,6 +109,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto
input_data_type
=
OperatorWithKernel
::
IndicateOrPromoteVarDataTypes
(
ctx
,
"X"
,
"Y"
);
/*
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
...
...
@@ -116,6 +117,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN);
}
#endif
*/
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
...
...
paddle/fluid/operators/sync_batch_norm_op.cu.h
浏览文件 @
376c2f01
...
...
@@ -187,11 +187,13 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
x_d
,
N
,
H
*
W
*
D
,
C
,
stats
);
}
/*
Tensor c_g_st;
auto *c_g_st_d = c_g_st.mutable_data<BatchNormParamType<T>>(
{2 * C + 1}, platform::CPUPlace());
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
*/
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
*
comm
=
dev_ctx
.
nccl_comm
();
...
...
paddle/fluid/pybind/op_function.h
浏览文件 @
376c2f01
...
...
@@ -177,4 +177,4 @@ static inline void HandleViewBetweenInputAndOutput(
}
// namespace paddle
// This include must be the last line
#include "paddle/fluid/pybind/op_function_impl.h"
#include "paddle/fluid/pybind/op_function_impl
_new
.h"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录