Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3af47711
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3af47711
编写于
3月 20, 2020
作者:
Y
Yiqun Liu
提交者:
GitHub
3月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the detection and code-generation of sqrt and square in fusion_group (#23095)
上级
d066d6f9
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
128 addition
and
118 deletion
+128
-118
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+15
-10
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+4
-4
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
+12
-6
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
...d/framework/ir/fusion_group/elementwise_group_detector.cc
+29
-40
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
...id/framework/ir/fusion_group/elementwise_group_detector.h
+1
-1
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
+3
-1
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+22
-9
python/paddle/fluid/tests/unittests/ir/pass_test.py
python/paddle/fluid/tests/unittests/ir/pass_test.py
+5
-2
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
...dle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
+37
-45
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
3af47711
...
...
@@ -25,15 +25,21 @@ namespace ir {
namespace
fusion_group
{
std
::
string
ExtractDataType
(
const
std
::
vector
<
Node
*>&
nodes
)
{
std
::
string
dtype_str
=
"float"
;
auto
data_type
=
nodes
.
back
()
->
Var
()
->
GetDataType
();
if
(
data_type
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
data_type
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
std
::
string
dtype_str
=
""
;
for
(
const
auto
*
n
:
nodes
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
// The data type of all inputs/outputs must be the same, which are
// checked when detecting the subgraph.
auto
dtype
=
n
->
Var
()
->
GetDataType
();
if
(
dtype
==
proto
::
VarType
::
FP32
)
{
dtype_str
=
"float"
;
}
else
if
(
dtype
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
}
else
if
(
dtype
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"float16"
;
}
break
;
}
}
return
dtype_str
;
...
...
@@ -80,7 +86,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for
(
auto
&
name
:
input_names
)
{
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if
((
HasInput
(
node
,
name
)
&&
op
->
Input
(
name
).
size
()
>=
1U
))
{
for
(
size_t
i
=
0
;
i
<
op
->
Input
(
name
).
size
();
i
++
)
{
PADDLE_ENFORCE_NE
(
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
3af47711
...
...
@@ -38,13 +38,13 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
int
start_pos
=
rhs
.
find
(
"["
,
0
);
int
end_pos
=
rhs
.
find
(
"]"
,
0
);
std
::
string
sum_rhs
=
rhs
.
substr
(
0
,
start_pos
);
std
::
string
sum_rhs
_component
=
std
::
string
repeated
_component
=
rhs
.
substr
(
start_pos
+
1
,
(
end_pos
-
start_pos
-
1
));
int
replace_pos
=
sum_rhs
_component
.
find
(
"?"
,
0
);
int
replace_pos
=
repeated
_component
.
find
(
"?"
,
0
);
for
(
size_t
i
=
1
;
i
<
input_size
;
i
++
)
{
std
::
string
append_str
=
sum_rhs_component
.
replace
(
replace_pos
,
1
,
std
::
to_string
(
i
));
std
::
string
append_str
=
repeated_component
;
append_str
.
replace
(
replace_pos
,
1
,
std
::
to_string
(
i
));
sum_rhs
=
sum_rhs
+
append_str
;
}
return
sum_rhs
;
...
...
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
浏览文件 @
3af47711
...
...
@@ -20,20 +20,26 @@ namespace ir {
namespace
fusion_group
{
static
constexpr
char
predefined_cuda_functions_fp32
[]
=
R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
)"
;
static
constexpr
char
predefined_cuda_functions_fp64
[]
=
R"(
__device__ inline double real_exp(double x) { return ::exp(x); }
__device__ inline double real_log(double x) { return ::log(x); }
__device__ inline double Max(double x, double y) { return fmax(x, y); }
__device__ inline double Exp(double x) { return exp(x); }
__device__ inline double Log(double x) { return log(x); }
__device__ inline double Sqrt(double x) { return sqrt(x); }
)"
;
static
constexpr
char
predefined_cuda_functions_fp16
[]
=
R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.cc
浏览文件 @
3af47711
...
...
@@ -60,52 +60,41 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return
l
.
size
()
!=
0U
&&
r
.
size
()
!=
0U
&&
l
==
r
;
}
bool
GroupDetector
::
IsFusionGroupOp
(
const
Node
*
n
)
{
if
(
!
(
n
&&
n
->
IsOp
()
&&
n
->
Op
()))
return
false
;
bool
is_first
=
true
;
proto
::
VarType
::
Type
i_data_type
=
proto
::
VarType
::
FP32
;
proto
::
VarType
::
Type
o_data_type
=
proto
::
VarType
::
FP32
;
for
(
auto
*
i_node
:
n
->
inputs
)
{
if
(
!
i_node
->
Var
())
return
false
;
if
(
i_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
i_data_type
=
i_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
i_data_type
!=
i_node
->
Var
()
->
GetDataType
())
return
false
;
}
}
bool
GroupDetector
::
CheckPrecondition
(
const
Node
*
n
)
{
auto
check_data_type
=
[
&
](
const
std
::
vector
<
Node
*>&
nodes
)
->
bool
{
bool
is_first
=
true
;
proto
::
VarType
::
Type
data_type_0
;
for
(
auto
*
n
:
nodes
)
{
if
(
n
&&
n
->
IsVar
()
&&
n
->
Var
())
{
if
(
n
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
is_first
=
true
;
for
(
auto
*
o_node
:
n
->
outputs
)
{
if
(
!
o_node
->
Var
())
return
false
;
if
(
o_node
->
Var
()
->
GetType
()
!=
proto
::
VarType
::
LOD_TENSOR
)
{
return
false
;
}
if
(
is_first
)
{
o_data_type
=
o_node
->
Var
()
->
GetDataType
();
is_first
=
false
;
}
else
{
if
(
o_data_type
!=
o_node
->
Var
()
->
GetDataType
())
return
false
;
proto
::
VarType
::
Type
data_type_i
=
n
->
Var
()
->
GetDataType
();
if
(
data_type_i
==
proto
::
VarType
::
FP32
||
data_type_i
==
proto
::
VarType
::
FP64
||
data_type_i
==
proto
::
VarType
::
FP16
)
{
if
(
is_first
)
{
data_type_0
=
data_type_i
;
is_first
=
false
;
}
else
if
(
data_type_0
!=
data_type_i
)
{
return
false
;
}
}
else
{
return
false
;
}
}
}
}
if
(
!
(
i_data_type
==
proto
::
VarType
::
FP32
||
i_data_type
==
proto
::
VarType
::
FP64
||
i_data_type
==
proto
::
VarType
::
FP16
)
||
!
(
o_data_type
==
proto
::
VarType
::
FP32
||
o_data_type
==
proto
::
VarType
::
FP64
||
o_data_type
==
proto
::
VarType
::
FP16
))
return
false
;
return
true
;
};
return
true
;
return
n
&&
n
->
IsOp
()
&&
n
->
Op
()
&&
check_data_type
(
n
->
inputs
)
&&
check_data_type
(
n
->
outputs
);
}
bool
ElementwiseGroupDetector
::
IsElementwiseOp
(
const
Node
*
n
)
{
if
(
IsSpecifiedOp
(
GetElementwiseOpTypes
(),
n
))
{
// Check whether all inputs have the same shape.
std
::
vector
<
int64_t
>
shape_0
;
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
auto
*
in_i
=
n
->
inputs
[
i
];
...
...
@@ -130,7 +119,7 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std
::
vector
<
std
::
vector
<
Node
*>>
ElementwiseGroupDetector
::
operator
()(
Graph
*
graph
)
{
auto
teller
=
[
&
](
const
Node
*
n
)
->
bool
{
return
IsFusionGroupOp
(
n
)
&&
IsElementwiseOp
(
n
);
return
CheckPrecondition
(
n
)
&&
IsElementwiseOp
(
n
);
};
return
SubgraphDetector
(
graph
,
teller
)();
...
...
paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h
浏览文件 @
3af47711
...
...
@@ -25,7 +25,7 @@ namespace fusion_group {
class
GroupDetector
{
protected:
bool
IsFusionGroupOp
(
const
Node
*
n
);
bool
CheckPrecondition
(
const
Node
*
n
);
};
class
ElementwiseGroupDetector
:
GroupDetector
{
...
...
paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc
浏览文件 @
3af47711
...
...
@@ -33,6 +33,8 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
fusion_group
::
OperationMap
::
Init
();
int
num_elementwise_groups
=
DetectFusionGroup
(
graph
,
0
);
AddStatis
(
num_elementwise_groups
);
LOG
(
INFO
)
<<
"Detect "
<<
num_elementwise_groups
<<
" elementwise fusion groups."
;
}
}
...
...
@@ -54,7 +56,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
VLOG
(
3
)
<<
"subgraph: {
\n
"
<<
DebugString
(
subgraph
.
SortedNodes
())
<<
"}
\n
"
;
if
(
subgraph
.
IsValid
(
min_subgraph_size
))
{
subgraph
.
SetFuncName
(
"
fused_elementwise_
"
+
std
::
to_string
(
index
++
));
subgraph
.
SetFuncName
(
"
FusedElementwise
"
+
std
::
to_string
(
index
++
));
if
(
GenerateCode
(
&
subgraph
))
{
InsertFusionGroupOp
(
graph
,
&
subgraph
);
num_subgraphs
++
;
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
3af47711
...
...
@@ -95,20 +95,29 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
insert_handler
(
"sigmoid"
,
"1.0 / (1.0 +
real_e
xp(- ${0}))"
,
insert_handler
(
"sigmoid"
,
"1.0 / (1.0 +
E
xp(- ${0}))"
,
{
"${2} * ${1} * (1.0 - ${1})"
});
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
insert_handler
(
"tanh"
,
"2.0 / (1.0 +
real_e
xp(-2.0 * ${0})) - 1.0"
,
insert_handler
(
"tanh"
,
"2.0 / (1.0 +
E
xp(-2.0 * ${0})) - 1.0"
,
{
"${2} * (1.0 - ${1} * ${1})"
});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler
(
"cast"
,
"${0}"
,
{
"${0}"
});
insert_handler
(
"cast"
,
"${0}"
,
{});
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
insert_handler
(
"sqrt"
,
"Sqrt(${0})"
,
{
"${2} * 0.5 / ${1}"
});
// square:
// out = x^2
// dx = dout * 2.0 * x
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} * 2.0 * ${0}"
});
}
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
...
...
@@ -168,9 +177,13 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
Insert
(
type
,
num_oprands
,
op_type
,
expr
,
grad_exprs
,
{
"X"
},
{
"Out"
});
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
// sum:
// out = x_0 + x_1 + ... + x_N-1
//
// For sum with N inputs, the expression inside "[]" will be expanded
// N - 1 times. The ${?} represents the number of inputs starting with is 1.
// For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3}
insert_handler
(
"sum"
,
"${0}[ + ${?}]"
,
{});
}
...
...
python/paddle/fluid/tests/unittests/ir/pass_test.py
浏览文件 @
3af47711
...
...
@@ -38,7 +38,6 @@ class PassTest(unittest.TestCase):
self
.
pass_attrs
=
{}
self
.
fused_op_type
=
None
self
.
num_fused_ops
=
-
1
self
.
backward
=
True
np
.
random
.
seed
(
123
)
random
.
seed
(
124
)
...
...
@@ -49,7 +48,11 @@ class PassTest(unittest.TestCase):
places
.
append
(
fluid
.
CUDAPlace
(
0
))
return
places
def
append_gradinets
(
self
,
outs
):
def
grad
(
self
,
var
):
grad_name
=
var
.
name
+
"@GRAD"
return
self
.
main_program
.
global_block
().
var
(
grad_name
)
def
append_gradients
(
self
,
outs
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
loss
=
fluid
.
layers
.
mean
(
outs
)
fluid
.
backward
.
append_backward
(
loss
)
...
...
python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py
浏览文件 @
3af47711
...
...
@@ -35,15 +35,12 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes
tmp_2
=
layers
.
relu
(
tmp_0
+
tmp_1
)
self
.
num_fused_ops
=
1
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_1
.
name
+
"@GRAD"
]
self
.
append_gradients
(
tmp_2
)
if
self
.
backward
:
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_1
)]
def
setUp
(
self
):
self
.
backward
=
True
self
.
build_program
(
"float32"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
...
...
@@ -91,13 +88,10 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self
.
feed_vars
[
2
])
*
layers
.
tanh
(
self
.
feed_vars
[
3
])
tmp_2
=
layers
.
tanh
(
tmp_1
)
+
layers
.
sigmoid
(
self
.
feed_vars
[
4
])
if
self
.
backward
:
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
else
:
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_2
)
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_0
.
name
+
"@GRAD"
]
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassTest2
(
FusionGroupPassTest
):
...
...
@@ -115,20 +109,11 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2
=
layers
.
relu
(
layers
.
sigmoid
(
self
.
feed_vars
[
3
]))
tmp_3
=
layers
.
mul
(
tmp_1
,
tmp_2
)
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_3
.
name
]
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if
self
.
backward
:
self
.
append_gradinets
(
tmp_3
)
self
.
num_fused_ops
=
3
# TODO(wangchaochaohu): support the case when some vars are set
# stop_gradient = True.
def
setUp
(
self
):
self
.
backward
=
False
self
.
build_program
(
"float32"
)
self
.
feeds
=
self
.
_feed_random_data
(
self
.
feed_vars
)
self
.
pass_names
=
"fusion_group_pass"
self
.
fused_op_type
=
"fusion_group"
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_3
]
class
FusionGroupPassTestFP64
(
FusionGroupPassTest
):
...
...
@@ -147,32 +132,41 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid
.
data
(
name
=
"data2"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
# subgraph with 2 op nodes
tmp_0
=
self
.
feed_vars
[
0
]
*
self
.
feed_vars
[
1
]
tmp_1
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
tmp_3
=
layers
.
cast
(
tmp_1
,
dtype
=
"float16"
)
tmp_2
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_2
+
tmp_3
)
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"float16"
)
tmp_2
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
2
])
# subgraph with 4 op nodes
tmp_3
=
layers
.
cast
(
tmp_2
,
dtype
=
"float16"
)
tmp_4
=
layers
.
relu
(
tmp_1
+
tmp_3
)
tmp_5
=
layers
.
cast
(
tmp_4
,
dtype
=
dtype
)
self
.
num_fused_ops
=
1
self
.
fetch_list
=
[
tmp_5
.
name
]
self
.
append_gradients
(
tmp_5
)
if
self
.
backward
:
self
.
num_fused_ops
=
4
self
.
append_gradinets
(
tmp_5
)
self
.
num_fused_ops
=
3
self
.
fetch_list
=
[
tmp_5
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassSumTest
(
FusionGroupPassTest
):
def
build_program
(
self
,
dtype
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
5
)
self
.
feed_vars
=
self
.
_prepare_feed_vars
([
32
,
128
],
dtype
,
3
)
self
.
feed_vars
.
append
(
fluid
.
data
(
name
=
"data3"
,
shape
=
[
128
,
128
],
dtype
=
dtype
))
tmp_0
=
layers
.
elementwise_add
(
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
])
tmp_1
=
layers
.
sum
([
tmp_0
,
self
.
feed_vars
[
2
],
self
.
feed_vars
[
3
]])
tmp_2
=
layers
.
sum
([
tmp_1
,
self
.
feed_vars
[
4
]])
# subgraph with 2 op nodes
tmp_0
=
layers
.
sum
(
[
self
.
feed_vars
[
0
],
self
.
feed_vars
[
1
],
self
.
feed_vars
[
2
]])
tmp_1
=
layers
.
sqrt
(
tmp_0
)
tmp_2
=
layers
.
mul
(
tmp_0
,
self
.
feed_vars
[
3
])
# subgraph with 2 op nodes
tmp_3
=
layers
.
square
(
layers
.
sum
([
tmp_1
,
tmp_2
]))
self
.
fetch_list
=
[
tmp_0
,
tmp_1
,
tmp_2
]
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_3
)
self
.
num_fused_ops
=
3
self
.
fetch_list
=
[
tmp_3
,
self
.
grad
(
tmp_0
)]
class
FusionGroupPassCastTest
(
FusionGroupPassTest
):
...
...
@@ -184,12 +178,10 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1
=
layers
.
cast
(
tmp_0
,
dtype
=
"double"
)
tmp_2
=
layers
.
cast
(
tmp_1
,
dtype
=
"float32"
)
self
.
fetch_list
=
[
tmp_2
.
name
,
tmp_1
.
name
+
"@GRAD"
]
self
.
num_fused_ops
=
1
self
.
append_gradients
(
tmp_2
)
if
self
.
backward
:
self
.
num_fused_ops
=
2
self
.
append_gradinets
(
tmp_2
)
self
.
num_fused_ops
=
2
self
.
fetch_list
=
[
tmp_2
,
self
.
grad
(
tmp_0
)]
def
setUp
(
self
):
self
.
build_program
(
"float64"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录