Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3af47711
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3af47711
编写于
3月 20, 2020
作者:
Y
Yiqun Liu
提交者:
GitHub
3月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the detection and code-generation of sqrt and square in fusion_group (#23095)
上级
d066d6f9
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
128 addition
and
118 deletion
+128
-118
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+15
-10
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+4
-4
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
+12
-6
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+29
-40
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
...id/framework/ir/fusion_group/elementwise_group_detector.h
+1
-1
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+3
-1
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+22
-9
python/paddle/fluid/tests/unittests/ir/pass_test.py
python/paddle/fluid/tests/unittests/ir/pass_test.py
+5
-2
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+37
-45
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
3af47711
...
...
@@ -25,15 +25,21 @@ namespace ir {
namespace
fusion_group
{
std
::
string
ExtractDataType
(
const
std
::
vector
<
Node
*>&
nodes
)
{
std
::
string
dtype_str
=
"float"
;
auto
data_type
=
nodes
.
back
()
->
Var
()
->
GetDataType
();
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
std
::
string
dtype_str
=
""
;
for
(
const
auto
*
n
:
nodes
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
// The data type of all inputs/outputs must be the same, which are
// checked when detecting the subgraph.
auto
dtype
=
n
->
Var
()
->
GetDataType
();
if
(
dtype
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
dtype
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
dtype
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
}
break
;
}
}
return
dtype_str
;
...
...
@@ -80,7 +86,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for
(
auto
&
name
:
input_names
)
{
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if
((
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
))
{
for
(
size_t
i
=
0
;
i
<
op
->
Input
(
name
).
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
3af47711
...
...
@@ -38,13 +38,13 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
int
start_pos
=
rhs
.
find
(
"["
,
0
);
int
end_pos
=
rhs
.
find
(
"]"
,
0
);
std
::
string
sum_rhs
=
rhs
.
substr
(
0
,
start_pos
);
std
::
string
sum_rhs
_component
=
std
::
string
repeated
_component
=
rhs
.
substr
(
start_pos
+
1
,
(
end_pos
-
start_pos
-
1
));
int
replace_pos
=
sum_rhs
_component
.
find
(
"?"
,
0
);
int
replace_pos
=
repeated
_component
.
find
(
"?"
,
0
);
for
(
size_t
i
=
1
;
i
<
input_size
;
i
++
)
{
std
::
string
append_str
=
sum_rhs_component
.
replace
(
replace_pos
,
1
,
std
::
to_string
(
i
));
std
::
string
append_str
=
repeated_component
;
append_str
.
replace
(
replace_pos
,
1
,
std
::
to_string
(
i
));
sum_rhs
=
sum_rhs
+
append_str
;
}
return
sum_rhs
;
...
...
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
浏览文件 @
3af47711
...
...
@@ -20,20 +20,26 @@ namespace ir {
namespace
fusion_group
{
static
constexpr
char
predefined_cuda_functions_fp32
[]
=
R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
)"
;
static
constexpr
char
predefined_cuda_functions_fp64
[]
=
R"(
__device__ inline double real_exp(double x) { return ::exp(x); }
__device__ inline double real_log(double x) { return ::log(x); }
__device__ inline double Max(double x, double y) { return fmax(x, y); }
__device__ inline double Exp(double x) { return exp(x); }
__device__ inline double Log(double x) { return log(x); }
__device__ inline double Sqrt(double x) { return sqrt(x); }
)"
;
static
constexpr
char
predefined_cuda_functions_fp16
[]
=
R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
3af47711
...
...
@@ -60,52 +60,41 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
}
bool
GroupDetector
::
IsFusionGroupOp
(
const
Node
*
n
)
{
if
(
!
(
n
&&
n
->
IsOp
()
&&
n
->
Op
()))
return
false
;
bool
is_first
=
true
;
proto
::
VarType
::
Type
i_data_type
=
proto
::
VarType
::
FP32
;
proto
::
VarType
::
Type
o_data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
i_node
:
n
->
inputs
)
{
if
(
!
i_node
->
Var
())
return
false
;
if
(
i_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
i_data_type
=
i_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
i_data_type
!=
i_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
bool
GroupDetector
::
CheckPrecondition
(
const
Node
*
n
)
{
auto
check_data_type
=
[
&
](
const
std
::
vector
<
Node
*>&
nodes
)
->
bool
{
bool
is_first
=
true
;
proto
::
VarType
::
Type
data_type_0
;
for
(
auto
*
n
:
nodes
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
is_first
=
true
;
for
(
auto
*
o_node
:
n
->
outputs
)
{
if
(
!
o_node
->
Var
())
return
false
;
if
(
o_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
o_data_type
=
o_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
o_data_type
!=
o_node
->
Var
()
->
GetDataType
())
return
false
;
proto
::
VarType
::
Type
data_type_i
=
n
->
Var
()
->
GetDataType
();
if
(
data_type_i
==
proto
::
VarType
::
FP32
||
data_type_i
==
proto
::
VarType
::
FP64
||
data_type_i
==
proto
::
VarType
::
FP16
)
{
if
(
is_first
)
{
data_type_0
=
data_type_i
;
is_first
=
false
;
}
else
if
(
data_type_0
!=
data_type_i
)
{
return
false
;
}
}
else
{
return
false
;
}
}
}
}
if
(
!
(
i_data_type
==
proto
::
VarType
::
FP32
||
i_data_type
==
proto
::
VarType
::
FP64
||
i_data_type
==
proto
::
VarType
::
FP16
)
||
!
(
o_data_type
==
proto
::
VarType
::
FP32
||
o_data_type
==
proto
::
VarType
::
FP64
||
o_data_type
==
proto
::
VarType
::
FP16
))
return
false
;
return
true
;
};
return
true
;
return
n
&&
n
->
IsOp
()
&&
n
->
Op
()
&&
check_data_type
(
n
->
inputs
)
&&
check_data_type
(
n
->
outputs
);
}
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
// Check whether all inputs have the same shape.
std
::
vector
<
int64_t
>
shape_0
;
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
auto
*
in_i
=
n
->
inputs
[
i
];
...
...
@@ -130,7 +119,7 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
Graph
*
graph
)
{
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsFusionGroupOp
(
n
)
&&
IsElementwiseOp
(
n
);
return
CheckPrecondition
(
n
)
&&
IsElementwiseOp
(
n
);
};
return
SubgraphDetector
(
graph
,
teller
)();
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
浏览文件 @
3af47711
...
...
@@ -25,7 +25,7 @@ namespace fusion_group {
class
GroupDetector
{
protected:
bool
IsFusionGroupOp
(
const
Node
*
n
);
bool
CheckPrecondition
(
const
Node
*
n
);
};
class
ElementwiseGroupDetector
:
GroupDetector
{
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
3af47711
...
...
@@ -33,6 +33,8 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
fusion_group
::
OperationMap
::
Init
();
int
num_elementwise_groups
=
DetectFusionGroup
(
graph
,
0
);
AddStatis
(
num_elementwise_groups
);
LOG
(
INFO
)
<<
"Detect "
<<
num_elementwise_groups
<<
" elementwise fusion groups."
;
}
}
...
...
@@ -54,7 +56,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
VLOG
(
3
)
<<
"subgraph: {
\n
"
<<
DebugString
(
subgraph
.
SortedNodes
())
<<
"}
\n
"
;
if
(
subgraph
.
IsValid
(
min_subgraph_size
))
{
subgraph
.
SetFuncName
(
"
fused_elementwise_
"
+
std
::
to_string
(
index
++
));
subgraph
.
SetFuncName
(
"
FusedElementwise
"
+
std
::
to_string
(
index
++
));
if
(
GenerateCode
(
&
subgraph
))
{
InsertFusionGroupOp
(
graph
,
&
subgraph
);
num_subgraphs
++
;
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
3af47711
...
...
@@ -95,20 +95,29 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
insert_handler
(
"sigmoid"
,
"1.0 / (1.0 +
real_e
xp(- ${0}))"
,
insert_handler
(
"sigmoid"
,
"1.0 / (1.0 +
E
xp(- ${0}))"
,
{
"${2} * ${1} * (1.0 - ${1})"
});
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
insert_handler
(
"tanh"
,
"2.0 / (1.0 +
real_e
xp(-2.0 * ${0})) - 1.0"
,
insert_handler
(
"tanh"
,
"2.0 / (1.0 +
E
xp(-2.0 * ${0})) - 1.0"
,
{
"${2} * (1.0 - ${1} * ${1})"
});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{
"${0}"
});
insert_handler
(
"cast"
,
"${0}"
,
{});
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
insert_handler
(
"sqrt"
,
"Sqrt(${0})"
,
{
"${2} * 0.5 / ${1}"
});
// square:
// out = x^2
// dx = dout * 2.0 * x
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} * 2.0 * ${0}"
});
}
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
...
...
@@ -168,9 +177,13 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
// sum:
// out = x_0 + x_1 + ... + x_N-1
//
// For sum with N inputs, the expression inside "[]" will be expanded
// N - 1 times. The ${?} represents the number of inputs starting with is 1.
// For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3}
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
}
...
...
python/paddle/fluid/tests/unittests/ir/pass_test.py
浏览文件 @
3af47711
...
...
@@ -38,7 +38,6 @@ class PassTest(unittest.TestCase):
self
.
pass_attrs
=
{}
self
.
fused_op_type
=
None
self
.
num_fused_ops
=
-
1
self
.
backward
=
True
np
.
random
.
seed
(
123
)
random
.
seed
(
124
)
...
...
@@ -49,7 +48,11 @@ class PassTest(unittest.TestCase):
places
.
append
(
fluid
.
CUDAPlace
(
0
))
return
places
def
append_gradinets
(
self
,
outs
):
def
grad
(
self
,
var
):
grad_name
=
var
.
name
+
"@GRAD"
return
self
.
main_program
.
global_block
().
var
(
grad_name
)
def
append_gradients
(
self
,
outs
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
loss
=
fluid
.
layers
.
mean
(
outs
)
fluid
.
backward
.
append_backward
(
loss
)
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
3af47711
...
...
@@ -35,15 +35,12 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes
tmp_2
=
layers
.
relu
(
tmp_0
+
tmp_1
)
self
.
num_fused_ops
=
1
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_1
.
name
+
"@GRAD"
]
self
.
append_gradients
(
tmp_2
)
if
self
.
backward
:
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_1
)]
def
setUp
(
self
):
self
.
backward
=
True
self
.
build_program
(
"float32"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
...
...
@@ -91,13 +88,10 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self
.
feed_vars
[
2
])
*
layers
.
tanh
(
self
.
feed_vars
[
3
])
tmp_2
=
layers
.
tanh
(
tmp_1
)
+
layers
.
sigmoid
(
self
.
feed_vars
[
4
])
if
self
.
backward
:
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
else
:
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_2
)
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_0
.
name
+
"@GRAD"
]
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassTest2
(
FusionGroupPassTest
):
...
...
@@ -115,20 +109,11 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2
=
layers
.
relu
(
layers
.
sigmoid
(
self
.
feed_vars
[
3
]))
tmp_3
=
layers
.
mul
(
tmp_1
,
tmp_2
)
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_3
.
name
]
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if
self
.
backward
:
self
.
append_gradinets
(
tmp_3
)
self
.
num_fused_ops
=
3
# TODO(wangchaochaohu): support the case when some vars are set
# stop_gradient = True.
def
setUp
(
self
):
self
.
backward
=
False
self
.
build_program
(
"float32"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
self
.
fused_op_type
=
"fusion_group"
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_3
]
class
FusionGroupPassTestFP64
(
FusionGroupPassTest
):
...
...
@@ -147,32 +132,41 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid
.
data
(
name
=
"data2"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
# subgraph with 2 op nodes
tmp_0
=
self
.
feed_vars
[
0
]
*
self
.
feed_vars
[
1
]
tmp_1
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
tmp_3
=
layers
.
cast
(
tmp_1
,
dtype
=
"float16"
)
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_2
+
tmp_3
)
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_2
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
# subgraph with 4 op nodes
tmp_3
=
layers
.
cast
(
tmp_2
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_1
+
tmp_3
)
tmp_5
=
layers
.
cast
(
tmp_4
,
dtype
=
dtype
)
self
.
num_fused_ops
=
1
self
.
fetch_list
=
[
tmp_5
.
name
]
self
.
append_gradients
(
tmp_5
)
if
self
.
backward
:
self
.
num_fused_ops
=
4
self
.
append_gradinets
(
tmp_5
)
self
.
num_fused_ops
=
3
self
.
fetch_list
=
[
tmp_5
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
5
)
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
3
)
self
.
feed_vars
.
append
(
fluid
.
data
(
name
=
"data3"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
tmp_0
=
layers
.
elementwise_add
(
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
])
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
# subgraph with 2 op nodes
tmp_0
=
layers
.
sum
(
[
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
],
self
.
feed_vars
[
2
]])
tmp_1
=
layers
.
sqrt
(
tmp_0
)
tmp_2
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
3
])
# subgraph with 2 op nodes
tmp_3
=
layers
.
square
(
layers
.
sum
([
tmp_1
,
tmp_2
]))
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_3
)
self
.
num_fused_ops
=
3
self
.
fetch_list
=
[
tmp_3
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassCastTest
(
FusionGroupPassTest
):
...
...
@@ -184,12 +178,10 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"double"
)
tmp_2
=
layers
.
cast
(
tmp_1
,
dtype
=
"float32"
)
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_1
.
name
+
"@GRAD"
]
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_2
)
if
self
.
backward
:
self
.
num_fused_ops
=
2
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
def
setUp
(
self
):
self
.
build_program
(
"float64"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录