Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0a71d580
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a71d580
编写于
9月 06, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fix_dist_base
上级
a0b68653
88685255
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
818 addition
and
242 deletion
+818
-242
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+0
-3
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+1
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+42
-17
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+0
-1
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+84
-80
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+0
-1
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+41
-0
paddle/fluid/operators/flatten_op.cc
paddle/fluid/operators/flatten_op.cc
+115
-0
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+100
-0
paddle/fluid/operators/squeeze_op.cc
paddle/fluid/operators/squeeze_op.cc
+119
-7
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+103
-3
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+7
-0
paddle/fluid/operators/unsqueeze_op.cc
paddle/fluid/operators/unsqueeze_op.cc
+117
-6
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+21
-11
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+13
-5
python/paddle/fluid/tests/unittests/test_flatten_op.py
python/paddle/fluid/tests/unittests/test_flatten_op.py
+6
-3
python/paddle/fluid/tests/unittests/test_reshape_op.py
python/paddle/fluid/tests/unittests/test_reshape_op.py
+30
-94
python/paddle/fluid/tests/unittests/test_squeeze_op.py
python/paddle/fluid/tests/unittests/test_squeeze_op.py
+6
-3
python/paddle/fluid/tests/unittests/test_transpose_op.py
python/paddle/fluid/tests/unittests/test_transpose_op.py
+7
-4
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
+6
-3
未找到文件。
paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
浏览文件 @
0a71d580
...
@@ -13,13 +13,10 @@
...
@@ -13,13 +13,10 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
0a71d580
...
@@ -85,7 +85,7 @@ void GraphPatternDetector::operator()(Graph* graph,
...
@@ -85,7 +85,7 @@ void GraphPatternDetector::operator()(Graph* graph,
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
LOG
(
INFO
)
<<
"detect "
<<
subgraphs
.
size
()
<<
" subgraph matches the pattern"
;
int
id
=
0
;
int
id
=
0
;
for
(
auto
&
g
:
subgraphs
)
{
for
(
auto
&
g
:
subgraphs
)
{
LOG
(
INFO
)
<<
"optimizing #"
<<
id
++
<<
" subgraph"
;
VLOG
(
3
)
<<
"optimizing #"
<<
id
++
<<
" subgraph"
;
handler
(
g
,
graph
);
handler
(
g
,
graph
);
}
}
}
}
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
浏览文件 @
0a71d580
...
@@ -50,20 +50,37 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
...
@@ -50,20 +50,37 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
Dot
dot
;
Dot
dot
;
std
::
vector
<
Dot
::
Attr
>
op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
const
std
::
vector
<
Dot
::
Attr
>
op_attrs
({
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"style"
,
"rounded,filled,bold"
),
//
Dot
::
Attr
(
"fillcolor"
,
"red"
)});
Dot
::
Attr
(
"shape"
,
"box"
),
//
std
::
vector
<
Dot
::
Attr
>
var_attrs
({
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
Dot
::
Attr
(
"color"
,
"#303A3A"
),
//
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fontcolor"
,
"#ffffff"
),
//
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
Dot
::
Attr
(
"width"
,
"1.3"
),
//
Dot
::
Attr
(
"height"
,
"0.84"
),
//
std
::
vector
<
Dot
::
Attr
>
marked_op_attrs
({
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"fontname"
,
"Arial"
),
//
Dot
::
Attr
(
"shape"
,
"box"
),
});
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
const
std
::
vector
<
Dot
::
Attr
>
arg_attrs
({
std
::
vector
<
Dot
::
Attr
>
marked_var_attrs
(
Dot
::
Attr
(
"shape"
,
"box"
),
//
{
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
Dot
::
Attr
(
"style"
,
"rounded,filled,bold"
),
//
// Dot::Attr("shape", "diamond"),
Dot
::
Attr
(
"fontname"
,
"Arial"
),
//
Dot
::
Attr
(
"fillcolor"
,
"lightgray"
)});
Dot
::
Attr
(
"fillcolor"
,
"#999999"
),
//
Dot
::
Attr
(
"color"
,
"#dddddd"
),
//
});
const
std
::
vector
<
Dot
::
Attr
>
param_attrs
({
Dot
::
Attr
(
"shape"
,
"box"
),
//
Dot
::
Attr
(
"style"
,
"rounded,filled,bold"
),
//
Dot
::
Attr
(
"fontname"
,
"Arial"
),
//
Dot
::
Attr
(
"color"
,
"#148b97"
),
//
Dot
::
Attr
(
"fontcolor"
,
"#ffffff"
),
//
});
const
std
::
vector
<
Dot
::
Attr
>
marked_op_attrs
(
{
Dot
::
Attr
(
"style"
,
"rounded,filled,bold"
),
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
const
std
::
vector
<
Dot
::
Attr
>
marked_var_attrs
(
{
Dot
::
Attr
(
"style"
,
"filled,rounded"
),
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
auto
marked_nodes
=
ConsumeMarkedNodes
(
graph
.
get
());
// Create nodes
// Create nodes
...
@@ -74,9 +91,17 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
...
@@ -74,9 +91,17 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
marked_nodes
.
count
(
n
)
?
marked_op_attrs
:
op_attrs
;
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
}
else
if
(
n
->
IsVar
())
{
}
else
if
(
n
->
IsVar
())
{
decltype
(
op_attrs
)
attr
=
decltype
(
op_attrs
)
*
attr
;
marked_nodes
.
count
(
n
)
?
marked_var_attrs
:
var_attrs
;
if
(
marked_nodes
.
count
(
n
))
{
dot
.
AddNode
(
node_id
,
attr
,
node_id
);
attr
=
&
marked_var_attrs
;
}
else
if
(
const_cast
<
Node
*>
(
n
)
->
Var
()
&&
const_cast
<
Node
*>
(
n
)
->
Var
()
->
Persistable
())
{
attr
=
&
param_attrs
;
}
else
{
attr
=
&
arg_attrs
;
}
dot
.
AddNode
(
node_id
,
*
attr
,
node_id
);
}
}
node2dot
[
n
]
=
node_id
;
node2dot
[
n
]
=
node_id
;
}
}
...
...
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
0a71d580
...
@@ -106,7 +106,6 @@ void Analyzer::Run(Argument* argument) {
...
@@ -106,7 +106,6 @@ void Analyzer::Run(Argument* argument) {
}
}
}
}
passes
.
push_back
(
"graph_viz_pass"
);
passes
.
push_back
(
"graph_viz_pass"
);
// Ugly support fluid-to-ir-pass
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
(
passes
));
argument
->
Set
(
kFluidToIrPassesAttr
,
new
std
::
vector
<
std
::
string
>
(
passes
));
for
(
auto
&
x
:
data_
)
{
for
(
auto
&
x
:
data_
)
{
...
...
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
0a71d580
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <thread> // NOLINT
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
...
@@ -24,12 +25,12 @@
...
@@ -24,12 +25,12 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_model
,
""
,
"model path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
DEFINE_string
(
infer_ditu_rnn_data
,
""
,
"data path for ditu RNN"
);
DEFINE_int32
(
batch_size
,
10
,
"batch size."
);
DEFINE_int32
(
batch_size
,
10
,
"batch size."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_int32
(
repeat
,
1
,
"Running the inference program repeat times."
);
DEFINE_int32
(
num_threads
,
1
,
"Running the inference program in multi-threads."
);
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -220,39 +221,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
...
@@ -220,39 +221,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
}
}
}
}
std
::
string
DescribeTensor
(
const
PaddleTensor
&
tensor
)
{
std
::
stringstream
os
;
os
<<
"Tensor ["
<<
tensor
.
name
<<
"]
\n
"
;
os
<<
" - type: "
;
switch
(
tensor
.
dtype
)
{
case
PaddleDType
::
FLOAT32
:
os
<<
"float32"
;
break
;
case
PaddleDType
::
INT64
:
os
<<
"int64"
;
break
;
default:
os
<<
"unset"
;
}
os
<<
'\n'
;
os
<<
" - shape: "
<<
to_string
(
tensor
.
shape
)
<<
'\n'
;
os
<<
" - lod: "
;
for
(
auto
&
l
:
tensor
.
lod
)
{
os
<<
to_string
(
l
)
<<
"; "
;
}
os
<<
"
\n
"
;
os
<<
" - data: "
;
int
dim
=
std
::
accumulate
(
tensor
.
shape
.
begin
(),
tensor
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
os
<<
static_cast
<
float
*>
(
tensor
.
data
.
data
())[
i
]
<<
" "
;
}
os
<<
'\n'
;
return
os
.
str
();
}
}
// namespace
}
// namespace
const
float
ditu_rnn_target_data
[]
=
{
const
float
ditu_rnn_target_data
[]
=
{
...
@@ -266,11 +234,29 @@ const float ditu_rnn_target_data[] = {
...
@@ -266,11 +234,29 @@ const float ditu_rnn_target_data[] = {
10.7286
,
12.0595
,
10.6672
,
0
,
0
,
0
,
0
,
0
,
10.7286
,
12.0595
,
10.6672
,
0
,
0
,
0
,
0
,
0
,
93.5771
,
3.84641
,
0
,
0
,
0
,
0
,
0
,
0
,
93.5771
,
3.84641
,
0
,
0
,
0
,
0
,
0
,
0
,
169.426
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
169.426
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
void
CompareResult
(
const
std
::
vector
<
PaddleTensor
>
&
outputs
,
const
std
::
vector
<
PaddleTensor
>
&
base_outputs
)
{
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
base_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
&
out
=
outputs
[
i
];
auto
&
base_out
=
base_outputs
[
i
];
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
size_t
size1
=
std
::
accumulate
(
base_out
.
shape
.
begin
(),
base_out
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
PADDLE_ENFORCE_EQ
(
size
,
size1
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
float
*
base_data
=
static_cast
<
float
*>
(
base_out
.
data
.
data
());
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_NEAR
(
data
[
i
],
base_data
[
i
],
1e-3
);
}
}
}
// Test with a really complicate model.
// Test with a really complicate model.
void
TestDituRNNPrediction
(
const
std
::
string
&
model_path
,
void
TestDituRNNPrediction
(
bool
use_analysis
,
bool
activate_ir
,
const
std
::
string
&
data_path
,
int
batch_size
,
int
num_threads
)
{
bool
use_analysis
,
bool
activate_ir
,
int
num_times
=
1
)
{
AnalysisConfig
config
;
AnalysisConfig
config
;
config
.
prog_file
=
FLAGS_infer_ditu_rnn_model
+
"/__model__"
;
config
.
prog_file
=
FLAGS_infer_ditu_rnn_model
+
"/__model__"
;
config
.
param_file
=
FLAGS_infer_ditu_rnn_model
+
"/param"
;
config
.
param_file
=
FLAGS_infer_ditu_rnn_model
+
"/param"
;
...
@@ -281,6 +267,8 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -281,6 +267,8 @@ void TestDituRNNPrediction(const std::string &model_path,
PADDLE_ENFORCE
(
config
.
ir_mode
==
PADDLE_ENFORCE
(
config
.
ir_mode
==
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
AnalysisConfig
::
IrPassMode
::
kExclude
);
// default
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
config
.
ir_passes
.
clear
();
// Do not exclude any pass.
int
batch_size
=
FLAGS_batch_size
;
int
num_times
=
FLAGS_repeat
;
auto
base_predictor
=
auto
base_predictor
=
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
CreatePaddlePredictor
<
NativeConfig
,
PaddleEngineKind
::
kNative
>
(
config
);
...
@@ -288,40 +276,55 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -288,40 +276,55 @@ void TestDituRNNPrediction(const std::string &model_path,
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
config
);
config
);
std
::
vector
<
PaddleTensor
>
input_slots
;
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
data_path
,
batch_size
);
DataRecord
data
(
FLAGS_infer_ditu_rnn_data
,
batch_size
);
// Prepare inputs.
// Prepare inputs.
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
std
::
vector
<
PaddleTensor
>
outputs
,
base_outputs
;
base_predictor
->
Run
(
input_slots
,
&
base_outputs
);
base_predictor
->
Run
(
input_slots
,
&
base_outputs
);
Timer
timer
;
timer
.
tic
();
for
(
int
i
=
0
;
i
<
num_times
;
i
++
)
{
predictor
->
Run
(
input_slots
,
&
outputs
);
}
LOG
(
INFO
)
<<
"===========profile result==========="
;
LOG
(
INFO
)
<<
"===========profile result==========="
;
LOG
(
INFO
)
<<
"batch_size: "
<<
batch_size
<<
", repeat: "
<<
num_times
if
(
num_threads
==
1
)
{
<<
", latency: "
<<
timer
.
toc
()
/
num_times
<<
"ms"
;
// Prepare inputs.
LOG
(
INFO
)
<<
"====================================="
;
Timer
timer
;
timer
.
tic
();
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
for
(
int
i
=
0
;
i
<
num_times
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
base_outputs
.
size
());
predictor
->
Run
(
input_slots
,
&
outputs
);
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
}
auto
&
out
=
outputs
[
i
];
PrintTime
(
batch_size
,
num_times
,
1
,
0
,
timer
.
toc
()
/
num_times
);
auto
&
base_out
=
base_outputs
[
i
];
CompareResult
(
outputs
,
base_outputs
);
size_t
size
=
std
::
accumulate
(
out
.
shape
.
begin
(),
out
.
shape
.
end
(),
1
,
}
else
{
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
std
::
vector
<
std
::
thread
>
threads
;
size_t
size1
=
std
::
accumulate
(
base_out
.
shape
.
begin
(),
base_out
.
shape
.
end
(),
std
::
vector
<
std
::
unique_ptr
<
PaddlePredictor
>>
predictors
;
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
// TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled
PADDLE_ENFORCE_EQ
(
size
,
size1
);
// because AttentionLSTM's hard code nodeid will be damanged.
PADDLE_ENFORCE_GT
(
size
,
0
);
for
(
int
tid
=
0
;
tid
<
num_threads
;
++
tid
)
{
float
*
data
=
static_cast
<
float
*>
(
out
.
data
.
data
());
predictors
.
emplace_back
(
float
*
base_data
=
static_cast
<
float
*>
(
base_out
.
data
.
data
());
CreatePaddlePredictor
<
AnalysisConfig
,
PaddleEngineKind
::
kAnalysis
>
(
for
(
size_t
j
=
0
;
j
<
size
;
j
++
)
{
config
));
EXPECT_NEAR
(
data
[
j
],
base_data
[
j
],
1e-3
);
}
for
(
int
tid
=
0
;
tid
<
num_threads
;
++
tid
)
{
threads
.
emplace_back
([
&
,
tid
]()
{
// Each thread should have local input_slots and outputs.
std
::
vector
<
PaddleTensor
>
input_slots
;
DataRecord
data
(
FLAGS_infer_ditu_rnn_data
,
batch_size
);
PrepareInputs
(
&
input_slots
,
&
data
,
batch_size
);
std
::
vector
<
PaddleTensor
>
outputs
;
Timer
timer
;
timer
.
tic
();
for
(
int
i
=
0
;
i
<
num_times
;
i
++
)
{
predictors
[
tid
]
->
Run
(
input_slots
,
&
outputs
);
}
PrintTime
(
batch_size
,
num_times
,
num_threads
,
tid
,
timer
.
toc
()
/
num_times
);
CompareResult
(
outputs
,
base_outputs
);
});
}
for
(
int
i
=
0
;
i
<
num_threads
;
++
i
)
{
threads
[
i
].
join
();
}
}
}
}
LOG
(
INFO
)
<<
"====================================="
;
if
(
use_analysis
&&
activate_ir
)
{
if
(
use_analysis
&&
activate_ir
)
{
AnalysisPredictor
*
analysis_predictor
=
AnalysisPredictor
*
analysis_predictor
=
...
@@ -350,25 +353,26 @@ void TestDituRNNPrediction(const std::string &model_path,
...
@@ -350,25 +353,26 @@ void TestDituRNNPrediction(const std::string &model_path,
}
}
}
}
// Directly infer with the original model.
// Inference with analysis and IR, easy for profiling independently.
TEST
(
Analyzer
,
DituRNN_without_analysis
)
{
TEST
(
Analyzer
,
DituRNN
)
{
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
TestDituRNNPrediction
(
true
,
true
,
FLAGS_num_threads
);
FLAGS_batch_size
,
false
,
false
,
FLAGS_repeat
);
}
}
// Inference with the original model with the analysis turned on, the analysis
// Other unit-tests of DituRNN, test different options of use_analysis,
// module will transform the program to a data flow graph.
// activate_ir and multi-threads.
TEST
(
Analyzer
,
DituRNN_with_analysis
)
{
TEST
(
Analyzer
,
DituRNN_tests
)
{
LOG
(
INFO
)
<<
"ditu rnn with analysis"
;
int
num_threads
[
2
]
=
{
1
,
4
};
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
for
(
auto
i
:
num_threads
)
{
FLAGS_batch_size
,
true
,
false
,
FLAGS_repeat
);
// Directly infer with the original model.
}
TestDituRNNPrediction
(
false
,
false
,
i
);
// Inference with the original model with the analysis turned on, the
// Inference with analysis and IR. The IR module will fuse some large kernels.
// analysis
TEST
(
Analyzer
,
DituRNN_with_analysis_with_IR
)
{
// module will transform the program to a data flow graph.
LOG
(
INFO
)
<<
"ditu rnn with analysis and IR fuse"
;
TestDituRNNPrediction
(
true
,
false
,
i
);
TestDituRNNPrediction
(
FLAGS_infer_ditu_rnn_model
,
FLAGS_infer_ditu_rnn_data
,
// Inference with analysis and IR. The IR module will fuse some large
FLAGS_batch_size
,
true
,
true
,
FLAGS_repeat
);
// kernels.
TestDituRNNPrediction
(
true
,
true
,
i
);
}
}
}
}
// namespace analysis
}
// namespace analysis
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
0a71d580
...
@@ -35,7 +35,6 @@ bool AnalysisPredictor::Init(
...
@@ -35,7 +35,6 @@ bool AnalysisPredictor::Init(
}
else
{
}
else
{
place_
=
paddle
::
platform
::
CPUPlace
();
place_
=
paddle
::
platform
::
CPUPlace
();
}
}
PADDLE_ENFORCE
(
!
parent_scope
);
if
(
parent_scope
)
{
if
(
parent_scope
)
{
scope_
=
parent_scope
;
scope_
=
parent_scope
;
sub_scope_
=
&
(
parent_scope
->
NewScope
());
sub_scope_
=
&
(
parent_scope
->
NewScope
());
...
...
paddle/fluid/inference/api/helper.h
浏览文件 @
0a71d580
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <glog/logging.h>
#include <sys/time.h>
#include <sys/time.h>
#include <algorithm>
#include <algorithm>
#include <numeric>
#include <numeric>
...
@@ -88,5 +89,45 @@ static void TensorAssignData(PaddleTensor *tensor,
...
@@ -88,5 +89,45 @@ static void TensorAssignData(PaddleTensor *tensor,
}
}
}
}
std
::
string
DescribeTensor
(
const
PaddleTensor
&
tensor
)
{
std
::
stringstream
os
;
os
<<
"Tensor ["
<<
tensor
.
name
<<
"]
\n
"
;
os
<<
" - type: "
;
switch
(
tensor
.
dtype
)
{
case
PaddleDType
::
FLOAT32
:
os
<<
"float32"
;
break
;
case
PaddleDType
::
INT64
:
os
<<
"int64"
;
break
;
default:
os
<<
"unset"
;
}
os
<<
'\n'
;
os
<<
" - shape: "
<<
to_string
(
tensor
.
shape
)
<<
'\n'
;
os
<<
" - lod: "
;
for
(
auto
&
l
:
tensor
.
lod
)
{
os
<<
to_string
(
l
)
<<
"; "
;
}
os
<<
"
\n
"
;
os
<<
" - data: "
;
int
dim
=
std
::
accumulate
(
tensor
.
shape
.
begin
(),
tensor
.
shape
.
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
os
<<
static_cast
<
float
*>
(
tensor
.
data
.
data
())[
i
]
<<
" "
;
}
os
<<
'\n'
;
return
os
.
str
();
}
void
PrintTime
(
int
batch_size
,
int
repeat
,
int
num_threads
,
int
tid
,
double
latency
)
{
LOG
(
INFO
)
<<
"batch_size: "
<<
batch_size
<<
", repeat: "
<<
repeat
<<
", threads: "
<<
num_threads
<<
", thread id: "
<<
tid
<<
", latency: "
<<
latency
<<
"ms"
;
}
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/flatten_op.cc
浏览文件 @
0a71d580
...
@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
...
@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
}
}
};
};
// FIXME(zcd): flatten2 adds an intermediate output(XShape) based on flatten,
// the XShape is used to carry the shape and lod of X which will be used in
// flatten_grad, in this way, the framework can reuse the memory of X
// immediately the flatten2_op is finished.
// Considering compatibility issues, we could not fix flatten2_op
class
Flatten2OpInferShape
:
public
FlattenOpInferShape
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
FlattenOpInferShape
::
operator
()(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output (XShape) of Flatten op should not be null."
);
const
auto
&
in_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int64_t
>
xshape_dims
(
in_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
xshape_dims
[
i
+
1
]
=
in_dims
[
i
];
}
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
xshape_dims
));
ctx
->
ShareLoD
(
"X"
,
"XShape"
);
}
};
class
Flatten2Op
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
axis
=
Attr
<
int
>
(
"axis"
);
auto
in_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
const
auto
&
out_dims
=
FlattenOpInferShape
::
GetOutputShape
(
axis
,
in_dims
);
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
out_dims
;
attrs
[
"inplace"
]
=
false
;
// Invoke Reshape Op
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
Output
(
"Out"
)}},
{
"XShape"
,
{
Output
(
"XShape"
)}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
class
Flatten2OpMaker
:
public
FlattenOpMaker
{
public:
void
Make
()
override
{
FlattenOpMaker
::
Make
();
AddOutput
(
"XShape"
,
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp."
)
.
AsIntermediate
();
}
};
class
Flatten2GradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"flatten2_grad"
);
grad_op
->
SetInput
(
"XShape"
,
Output
(
"XShape"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
class
Flatten2GradInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"XShape"
),
"Input(XShape) shouldn't be null."
);
PADDLE_ENFORCE
(
context
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
auto
xshape_dims
=
context
->
GetInputDim
(
"XShape"
);
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
context
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
};
class
Flatten2GradOp
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
dout_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
xshape_name
=
Input
(
"XShape"
);
auto
xshape_dims
=
scope
.
FindVar
(
xshape_name
)
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
x_dims
);
attrs
[
"inplace"
]
=
false
;
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
dout_name
}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
dx_name
}},
{
"XShape"
,
{
xshape_name
}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
...
@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops
::
FlattenOpInferShape
,
ops
::
FlattenOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
);
REGISTER_OPERATOR
(
flatten2
,
ops
::
Flatten2Op
,
ops
::
Flatten2OpMaker
,
ops
::
Flatten2OpInferShape
,
ops
::
Flatten2GradOpMaker
);
REGISTER_OPERATOR
(
flatten2_grad
,
ops
::
Flatten2GradOp
,
ops
::
Flatten2GradInferShape
);
paddle/fluid/operators/reshape_op.cc
浏览文件 @
0a71d580
...
@@ -246,6 +246,88 @@ class ReshapeGradKernel {
...
@@ -246,6 +246,88 @@ class ReshapeGradKernel {
}
}
};
};
// FIXME(zcd): reshape2 adds an intermediate output(XShape) based on reshape,
// the XShape is used to carry the shape and lod of X which will be used in
// reshape_grad, in this way, the framework can reuse the memory of X
// immediately the reshape_op is finished.
// Considering compatibility issues, we could not fix reshape_op
class
Reshape2Op
:
public
ReshapeOp
{
public:
Reshape2Op
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
ReshapeOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ReshapeOp
::
InferShape
(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output(XShape) of ReshapeOp should not be null."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int64_t
>
xshape_dims
(
x_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
xshape_dims
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"XShape"
);
}
};
class
Reshape2OpMaker
:
public
ReshapeOpMaker
{
public:
void
Make
()
override
{
ReshapeOpMaker
::
Make
();
AddOutput
(
"XShape"
,
"XShape is just used to store the shape and lod of X, which will "
"be used in FlattenGradOp."
)
.
AsIntermediate
();
}
};
class
Reshape2GradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"reshape2_grad"
);
grad_op
->
SetInput
(
"XShape"
,
Output
(
"XShape"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
class
Reshape2GradOp
:
public
framework
::
OperatorWithKernel
{
public:
Reshape2GradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"XShape"
),
"Input(XShape) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
auto
xshape_dims
=
ctx
->
GetInputDim
(
"XShape"
);
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
...
@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
...
@@ -261,6 +343,17 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
ops
::
ReshapeGradKernel
);
REGISTER_OPERATOR
(
reshape2
,
ops
::
Reshape2Op
,
ops
::
Reshape2OpMaker
,
ops
::
Reshape2GradMaker
);
REGISTER_OPERATOR
(
reshape2_grad
,
ops
::
Reshape2GradOp
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape2
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape2_grad
,
float
,
ops
::
ReshapeGradKernel
,
double
,
ops
::
ReshapeGradKernel
,
int
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
...
@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
...
@@ -269,4 +362,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double
,
ops
::
ReshapeGradKernel
,
int
,
double
,
ops
::
ReshapeGradKernel
,
int
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
ops
::
ReshapeGradKernel
);
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape2
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape2_grad
,
float
,
ops
::
ReshapeGradKernel
,
double
,
ops
::
ReshapeGradKernel
,
int
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
#endif
#endif
paddle/fluid/operators/squeeze_op.cc
浏览文件 @
0a71d580
...
@@ -126,15 +126,15 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -126,15 +126,15 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
({});
.
SetDefault
({});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Squeeze Operator.
Squeeze Operator.
Remove single-dimensional entries from the shape of a tensor.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Examples:
Case 1:
Case 1:
Given
Given
X.shape = (1, 3, 1, 5)
X.shape = (1, 3, 1, 5)
and
and
axes = [0]
axes = [0]
...
@@ -144,7 +144,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -144,7 +144,7 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
Case 2:
Case 2:
Given
Given
X.shape = (1, 3, 1, 5)
X.shape = (1, 3, 1, 5)
and
and
axes = []
axes = []
we get:
we get:
Out.shape = (3, 5)
Out.shape = (3, 5)
...
@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase {
...
@@ -181,6 +181,113 @@ class SqueezeGradOp : public framework::OperatorBase {
}
}
};
};
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class
Squeeze2OpMaker
:
public
SqueezeOpMaker
{
public:
void
Make
()
override
{
SqueezeOpMaker
::
Make
();
AddOutput
(
"XShape"
,
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp."
)
.
AsIntermediate
();
}
};
class
Squeeze2OpInferShape
:
public
SqueezeOpInferShape
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
SqueezeOpInferShape
::
operator
()(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output(XShape) of Squeeze operator should not be null."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int64_t
>
xshape_dims
(
x_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
xshape_dims
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"XShape"
);
}
};
class
Squeeze2Op
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
axes
=
Attr
<
std
::
vector
<
int
>>
(
"axes"
);
auto
x_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
out_dims
=
Squeeze2OpInferShape
::
GetOutputShape
(
axes
,
x_dims
);
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
out_dims
);
// Invoke Reshape Op
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
Output
(
"Out"
)}},
{
"XShape"
,
{
Output
(
"XShape"
)}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
class
Squeeze2GradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"squeeze2_grad"
);
grad_op
->
SetInput
(
"XShape"
,
Output
(
"XShape"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
class
Squeeze2GradInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"XShape"
),
"Input(XShape) shouldn't be null."
);
PADDLE_ENFORCE
(
context
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
auto
xshape_dims
=
context
->
GetInputDim
(
"XShape"
);
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
context
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
};
class
Squeeze2GradOp
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
dout_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
xshape_name
=
Input
(
"XShape"
);
auto
xshape_dims
=
scope
.
FindVar
(
xshape_name
)
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
x_dims
);
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
dout_name
}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
dx_name
}},
{
"XShape"
,
{
xshape_name
}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
...
@@ -192,3 +299,8 @@ REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
ops
::
SqueezeOpInferShape
,
ops
::
SqueezeOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
squeeze_grad
,
ops
::
SqueezeGradOp
,
ops
::
SqueezeGradInferShape
);
REGISTER_OPERATOR
(
squeeze_grad
,
ops
::
SqueezeGradOp
,
ops
::
SqueezeGradInferShape
);
REGISTER_OPERATOR
(
squeeze2
,
ops
::
Squeeze2Op
,
ops
::
Squeeze2OpMaker
,
ops
::
Squeeze2OpInferShape
,
ops
::
Squeeze2GradOpMaker
);
REGISTER_OPERATOR
(
squeeze2_grad
,
ops
::
Squeeze2GradOp
,
ops
::
Squeeze2GradInferShape
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
0a71d580
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include <string>
#include <vector>
#include <vector>
namespace
paddle
{
namespace
paddle
{
...
@@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel {
...
@@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) should not be null"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
...
@@ -90,7 +91,7 @@ The behavior of this operator is similar to how `numpy.transpose` works.
...
@@ -90,7 +91,7 @@ The behavior of this operator is similar to how `numpy.transpose` works.
2 &5
2 &5
\end{pmatrix}$$
\end{pmatrix}$$
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
- Given a input tensor with shape $(N, C, H, W)$ and the `axes` is
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
)DOC"
);
)DOC"
);
...
@@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null"
);
...
@@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
...
@@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
}
}
};
};
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class
Transpose2Op
:
public
TransposeOp
{
public:
Transpose2Op
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
TransposeOp
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
TransposeOp
::
InferShape
(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output(XShape) should not be null"
);
const
auto
&
in_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int64_t
>
x_shape_dim
(
in_dims
.
size
()
+
1
);
x_shape_dim
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
x_shape_dim
[
i
+
1
]
=
in_dims
[
i
];
}
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
x_shape_dim
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"XShape"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
Transpose2OpMaker
:
public
TransposeOpMaker
{
public:
void
Make
()
override
{
TransposeOpMaker
::
Make
();
AddOutput
(
"XShape"
,
"(Tensor)The output tensor."
).
AsIntermediate
();
}
};
class
Transpose2GradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"transpose2_grad"
);
grad_op
->
SetInput
(
"XShape"
,
Output
(
"XShape"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
class
Transpose2OpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"XShape"
),
"Input(XShape) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
auto
xshape_dim
=
ctx
->
GetInputDim
(
"XShape"
);
auto
x_shape_dim
=
framework
::
slice_ddim
(
xshape_dim
,
1
,
xshape_dim
.
size
());
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_shape_dim
);
ctx
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -120,8 +208,20 @@ namespace ops = paddle::operators;
...
@@ -120,8 +208,20 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
transpose
,
ops
::
TransposeOp
,
ops
::
TransposeOpMaker
,
REGISTER_OPERATOR
(
transpose
,
ops
::
TransposeOp
,
ops
::
TransposeOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
0a71d580
...
@@ -21,3 +21,10 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -21,3 +21,10 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/unsqueeze_op.cc
浏览文件 @
0a71d580
...
@@ -127,13 +127,13 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -127,13 +127,13 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
});
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Unsqueeze Operator.
Unsqueeze Operator.
Insert single-dimensional entries to the shape of a tensor.
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Insert single-dimensional entries to the shape of a tensor.
Given a tensor such that tensor with shape [3, 4, 5],
Takes one required argument axes, a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example:
Given a tensor such that tensor with shape [3, 4, 5],
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
)DOC"
);
)DOC"
);
}
}
...
@@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase {
...
@@ -168,6 +168,112 @@ class UnsqueezeGradOp : public framework::OperatorBase {
}
}
};
};
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
class
Unsqueeze2OpInferShape
:
public
UnsqueezeOpInferShape
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
UnsqueezeOpInferShape
::
operator
()(
ctx
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XShape"
),
"Output(XShape) of Unsqueeze operator should not be null."
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int64_t
>
xshape_dims
(
x_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
xshape_dims
[
i
+
1
]
=
x_dims
[
i
];
}
ctx
->
SetOutputDim
(
"XShape"
,
framework
::
make_ddim
(
xshape_dims
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"XShape"
);
}
};
class
Unsqueeze2OpMaker
:
public
UnsqueezeOpMaker
{
public:
void
Make
()
override
{
UnsqueezeOpMaker
::
Make
();
AddOutput
(
"XShape"
,
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp."
)
.
AsIntermediate
();
}
};
class
Unsqueeze2Op
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
axes
=
Attr
<
std
::
vector
<
int
>>
(
"axes"
);
auto
x_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
out_dims
=
Unsqueeze2OpInferShape
::
GetOutputShape
(
axes
,
x_dims
);
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
out_dims
);
// Invoke Reshape op.
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
Output
(
"Out"
)}},
{
"XShape"
,
{
Output
(
"XShape"
)}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
class
Unsqueeze2GradOpMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"unsqueeze2_grad"
);
grad_op
->
SetInput
(
"XShape"
,
Output
(
"XShape"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
};
class
Unsqueeze2GradInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"XShape"
),
"Input(XShape) shouldn't be null."
);
PADDLE_ENFORCE
(
context
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
auto
xshape_dims
=
context
->
GetInputDim
(
"XShape"
);
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
context
->
ShareLoD
(
"XShape"
,
framework
::
GradVarName
(
"X"
));
}
};
class
Unsqueeze2GradOp
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
dout_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
xshape_name
=
Input
(
"XShape"
);
auto
xshape_dims
=
scope
.
FindVar
(
xshape_name
)
->
Get
<
framework
::
LoDTensor
>
().
dims
();
auto
x_dims
=
framework
::
slice_ddim
(
xshape_dims
,
1
,
xshape_dims
.
size
());
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
x_dims
);
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape2"
,
{{
"X"
,
{
dout_name
}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
dx_name
}},
{
"XShape"
,
{
xshape_name
}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
...
@@ -180,3 +286,8 @@ REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
unsqueeze_grad
,
ops
::
UnsqueezeGradOp
,
REGISTER_OPERATOR
(
unsqueeze_grad
,
ops
::
UnsqueezeGradOp
,
ops
::
UnsqueezeGradInferShape
);
ops
::
UnsqueezeGradInferShape
);
REGISTER_OPERATOR
(
unsqueeze2
,
ops
::
Unsqueeze2Op
,
ops
::
Unsqueeze2OpMaker
,
ops
::
Unsqueeze2OpInferShape
,
ops
::
Unsqueeze2GradOpMaker
);
REGISTER_OPERATOR
(
unsqueeze2_grad
,
ops
::
Unsqueeze2GradOp
,
ops
::
Unsqueeze2GradInferShape
);
python/paddle/fluid/layers/nn.py
浏览文件 @
0a71d580
...
@@ -4025,10 +4025,12 @@ def transpose(x, perm, name=None):
...
@@ -4025,10 +4025,12 @@ def transpose(x, perm, name=None):
helper
=
LayerHelper
(
'transpose'
,
**
locals
())
helper
=
LayerHelper
(
'transpose'
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
x
.
dtype
)
out
=
helper
.
create_tmp_variable
(
x
.
dtype
)
x_shape
=
helper
.
create_tmp_variable
(
x
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'transpose'
,
type
=
'transpose
2
'
,
inputs
=
{
'X'
:
[
x
]},
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]},
outputs
=
{
'Out'
:
[
out
],
'XShape'
:
[
x_shape
]},
attrs
=
{
'axis'
:
perm
})
attrs
=
{
'axis'
:
perm
})
return
out
return
out
...
@@ -4520,13 +4522,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
...
@@ -4520,13 +4522,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"Each dimension size given in shape must not be negtive "
"Each dimension size given in shape must not be negtive "
"except one unknown dimension."
)
"except one unknown dimension."
)
helper
=
LayerHelper
(
"reshape"
,
**
locals
())
helper
=
LayerHelper
(
"reshape
2
"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
x_shape
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"reshape"
,
type
=
"reshape
2
"
,
inputs
=
inputs
,
inputs
=
inputs
,
attrs
=
{
"shape"
:
shape
},
attrs
=
{
"shape"
:
shape
},
outputs
=
{
"Out"
:
out
})
outputs
=
{
"Out"
:
out
,
"XShape"
:
x_shape
})
return
helper
.
append_activation
(
out
)
return
helper
.
append_activation
(
out
)
...
@@ -4570,11 +4574,13 @@ def squeeze(input, axes, name=None):
...
@@ -4570,11 +4574,13 @@ def squeeze(input, axes, name=None):
"""
"""
helper
=
LayerHelper
(
"squeeze"
,
**
locals
())
helper
=
LayerHelper
(
"squeeze"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
x_shape
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"squeeze"
,
type
=
"squeeze
2
"
,
inputs
=
{
"X"
:
input
},
inputs
=
{
"X"
:
input
},
attrs
=
{
"axes"
:
axes
},
attrs
=
{
"axes"
:
axes
},
outputs
=
{
"Out"
:
out
})
outputs
=
{
"Out"
:
out
,
"XShape"
:
x_shape
})
return
out
return
out
...
@@ -4605,11 +4611,13 @@ def unsqueeze(input, axes, name=None):
...
@@ -4605,11 +4611,13 @@ def unsqueeze(input, axes, name=None):
"""
"""
helper
=
LayerHelper
(
"unsqueeze"
,
**
locals
())
helper
=
LayerHelper
(
"unsqueeze"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
x_shape
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"unsqueeze"
,
type
=
"unsqueeze
2
"
,
inputs
=
{
"X"
:
input
},
inputs
=
{
"X"
:
input
},
attrs
=
{
"axes"
:
axes
},
attrs
=
{
"axes"
:
axes
},
outputs
=
{
"Out"
:
out
})
outputs
=
{
"Out"
:
out
,
"XShape"
:
x_shape
})
return
out
return
out
...
@@ -5811,10 +5819,12 @@ def flatten(x, axis=1, name=None):
...
@@ -5811,10 +5819,12 @@ def flatten(x, axis=1, name=None):
raise
ValueError
(
"The axis should be a int, and in range [0, rank(x)]"
)
raise
ValueError
(
"The axis should be a int, and in range [0, rank(x)]"
)
out
=
helper
.
create_tmp_variable
(
x
.
dtype
)
out
=
helper
.
create_tmp_variable
(
x
.
dtype
)
x_shape
=
helper
.
create_tmp_variable
(
x
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'flatten'
,
type
=
'flatten
2
'
,
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
},
outputs
=
{
'Out'
:
out
},
outputs
=
{
'Out'
:
out
,
'XShape'
:
x_shape
},
attrs
=
{
"axis"
:
axis
})
attrs
=
{
"axis"
:
axis
})
return
out
return
out
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
0a71d580
...
@@ -249,7 +249,7 @@ class OpTest(unittest.TestCase):
...
@@ -249,7 +249,7 @@ class OpTest(unittest.TestCase):
outs
,
_
=
self
.
_calc_output
(
place
)
outs
,
_
=
self
.
_calc_output
(
place
)
return
outs
return
outs
def
_calc_output
(
self
,
place
,
parallel
=
False
):
def
_calc_output
(
self
,
place
,
parallel
=
False
,
no_check_set
=
None
):
program
=
Program
()
program
=
Program
()
block
=
program
.
global_block
()
block
=
program
.
global_block
()
...
@@ -273,6 +273,8 @@ class OpTest(unittest.TestCase):
...
@@ -273,6 +273,8 @@ class OpTest(unittest.TestCase):
# if not, fill the fetch_list by the user configured outputs in test.
# if not, fill the fetch_list by the user configured outputs in test.
if
len
(
fetch_list
)
==
0
:
if
len
(
fetch_list
)
==
0
:
for
var_name
,
var
in
six
.
iteritems
(
outputs
):
for
var_name
,
var
in
six
.
iteritems
(
outputs
):
if
no_check_set
is
not
None
and
var_name
in
no_check_set
:
continue
if
isinstance
(
var
,
list
):
if
isinstance
(
var
,
list
):
for
v
in
var
:
for
v
in
var
:
fetch_list
.
append
(
v
)
fetch_list
.
append
(
v
)
...
@@ -291,11 +293,17 @@ class OpTest(unittest.TestCase):
...
@@ -291,11 +293,17 @@ class OpTest(unittest.TestCase):
return_numpy
=
False
)
return_numpy
=
False
)
return
outs
,
fetch_list
return
outs
,
fetch_list
def
check_output_with_place
(
self
,
place
,
atol
,
equal_nan
=
False
):
def
check_output_with_place
(
self
,
outs
,
fetch_list
=
self
.
_calc_output
(
place
)
place
,
atol
,
no_check_set
=
None
,
equal_nan
=
False
):
outs
,
fetch_list
=
self
.
_calc_output
(
place
,
no_check_set
=
no_check_set
)
for
out_name
,
out_dup
in
Operator
.
get_op_outputs
(
self
.
op_type
):
for
out_name
,
out_dup
in
Operator
.
get_op_outputs
(
self
.
op_type
):
if
out_name
not
in
self
.
outputs
:
if
out_name
not
in
self
.
outputs
:
continue
continue
if
no_check_set
is
not
None
and
out_name
in
no_check_set
:
continue
def
find_actual
(
target_name
,
fetch_list
):
def
find_actual
(
target_name
,
fetch_list
):
found
=
[
found
=
[
...
@@ -360,10 +368,10 @@ class OpTest(unittest.TestCase):
...
@@ -360,10 +368,10 @@ class OpTest(unittest.TestCase):
places
.
append
(
core
.
CUDAPlace
(
0
))
places
.
append
(
core
.
CUDAPlace
(
0
))
return
places
return
places
def
check_output
(
self
,
atol
=
1e-5
,
equal_nan
=
False
):
def
check_output
(
self
,
atol
=
1e-5
,
no_check_set
=
None
,
equal_nan
=
False
):
places
=
self
.
_get_places
()
places
=
self
.
_get_places
()
for
place
in
places
:
for
place
in
places
:
self
.
check_output_with_place
(
place
,
atol
,
equal_nan
)
self
.
check_output_with_place
(
place
,
atol
,
no_check_set
,
equal_nan
)
def
check_output_customized
(
self
,
checker
):
def
check_output_customized
(
self
,
checker
):
places
=
self
.
_get_places
()
places
=
self
.
_get_places
()
...
...
python/paddle/fluid/tests/unittests/test_flatten_op.py
浏览文件 @
0a71d580
...
@@ -22,14 +22,17 @@ from op_test import OpTest
...
@@ -22,14 +22,17 @@ from op_test import OpTest
class
TestFlattenOp
(
OpTest
):
class
TestFlattenOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"flatten"
self
.
op_type
=
"flatten
2
"
self
.
init_test_case
()
self
.
init_test_case
()
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
in_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
in_shape
).
astype
(
"float32"
)}
self
.
init_attrs
()
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
),
"XShape"
:
np
.
random
.
random
(
self
.
in_shape
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
"XShape"
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
...
...
python/paddle/fluid/tests/unittests/test_reshape_op.py
浏览文件 @
0a71d580
...
@@ -22,106 +22,39 @@ from op_test import OpTest
...
@@ -22,106 +22,39 @@ from op_test import OpTest
class
TestReshapeOp
(
OpTest
):
class
TestReshapeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
ori_shape
=
(
2
,
25
)
self
.
init_data
()
new_shape
=
(
5
,
10
)
self
.
op_type
=
"reshape2"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)}
self
.
op_type
=
"reshape"
self
.
attrs
=
{
"shape"
:
self
.
new_shape
}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
outputs
=
{
self
.
attrs
=
{
"shape"
:
new_shape
}
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
infered_shape
),
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
'XShape'
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpDimInfer1
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
5
,
10
)
new_shape
=
(
5
,
-
1
,
5
)
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
attrs
[
"shape"
])}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpDimInfer2
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
2
,
2
,
6
)
new_shape
=
(
2
,
0
,
3
,
-
1
)
infered_shape
=
(
2
,
2
,
3
,
-
1
)
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
infered_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpInplace
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
2
,
25
)
new_shape
=
(
5
,
10
)
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpDimInferInplace1
(
OpTest
):
def
setUp
(
self
):
ori_shape
=
(
5
,
10
)
new_shape
=
(
5
,
-
1
,
5
)
self
.
op_type
=
"reshape"
def
init_data
(
self
):
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
ori_shape
=
(
2
,
25
)
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
new_shape
=
(
5
,
10
)
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
self
.
infered_shape
=
(
5
,
10
)
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
'XShape'
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
class
TestReshapeOpDimInferInplace2
(
OpTest
):
class
TestReshapeOpDimInfer1
(
TestReshapeOp
):
def
setUp
(
self
):
def
init_data
(
self
):
ori_shape
=
(
2
,
2
,
6
)
self
.
ori_shape
=
(
5
,
10
)
new_shape
=
(
2
,
0
,
3
,
-
1
)
self
.
new_shape
=
(
5
,
-
1
,
5
)
infered_shape
=
(
2
,
2
,
3
,
-
1
)
self
.
infered_shape
=
(
5
,
-
1
,
5
)
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
infered_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
class
TestReshapeOpDimInfer2
(
TestReshapeOp
):
self
.
check_grad
([
"X"
],
"Out"
)
def
init_data
(
self
):
self
.
ori_shape
=
(
2
,
2
,
6
)
self
.
new_shape
=
(
2
,
0
,
3
,
-
1
)
self
.
infered_shape
=
(
2
,
2
,
3
,
-
1
)
class
TestReshapeOpWithInputShape
(
OpTest
):
class
TestReshapeOpWithInputShape
(
OpTest
):
...
@@ -130,20 +63,23 @@ class TestReshapeOpWithInputShape(OpTest):
...
@@ -130,20 +63,23 @@ class TestReshapeOpWithInputShape(OpTest):
new_shape
=
(
0
,
-
1
,
5
)
new_shape
=
(
0
,
-
1
,
5
)
actual_shape
=
(
2
,
3
,
5
)
actual_shape
=
(
2
,
3
,
5
)
self
.
op_type
=
"reshape"
self
.
op_type
=
"reshape
2
"
self
.
inputs
=
{
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
),
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
),
"Shape"
:
np
.
array
(
"Shape"
:
np
.
array
(
actual_shape
,
dtype
=
"int32"
)
actual_shape
,
dtype
=
"int32"
)
}
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
actual_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
actual_shape
),
'XShape'
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
'XShape'
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
,
sum_outputs
=
[
"Out"
]
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/test_squeeze_op.py
浏览文件 @
0a71d580
...
@@ -23,14 +23,17 @@ from op_test import OpTest
...
@@ -23,14 +23,17 @@ from op_test import OpTest
# Correct: General.
# Correct: General.
class
TestSqueezeOp
(
OpTest
):
class
TestSqueezeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"squeeze"
self
.
op_type
=
"squeeze
2
"
self
.
init_test_case
()
self
.
init_test_case
()
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)}
self
.
init_attrs
()
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
),
"XShape"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
'XShape'
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
...
...
python/paddle/fluid/tests/unittests/test_transpose_op.py
浏览文件 @
0a71d580
...
@@ -22,16 +22,19 @@ from op_test import OpTest
...
@@ -22,16 +22,19 @@ from op_test import OpTest
class
TestTransposeOp
(
OpTest
):
class
TestTransposeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
initTestCase
()
self
.
op_type
=
"transpose"
self
.
op_type
=
"transpose
2
"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
'axis'
:
list
(
self
.
axis
)}
self
.
attrs
=
{
'axis'
:
list
(
self
.
axis
)}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
transpose
(
self
.
axis
)}
self
.
outputs
=
{
'XShape'
:
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
),
'Out'
:
self
.
inputs
[
'X'
].
transpose
(
self
.
axis
)
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
'XShape'
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
,
sum_outputs
=
[
'Out'
]
)
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
shape
=
(
3
,
4
)
self
.
shape
=
(
3
,
4
)
...
...
python/paddle/fluid/tests/unittests/test_unsqueeze_op.py
浏览文件 @
0a71d580
...
@@ -24,13 +24,16 @@ from op_test import OpTest
...
@@ -24,13 +24,16 @@ from op_test import OpTest
class
TestUnsqueezeOp
(
OpTest
):
class
TestUnsqueezeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
init_test_case
()
self
.
init_test_case
()
self
.
op_type
=
"unsqueeze"
self
.
op_type
=
"unsqueeze
2
"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)}
self
.
init_attrs
()
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
),
"XShape"
:
np
.
random
.
random
(
self
.
ori_shape
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
no_check_set
=
[
"XShape"
]
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
self
.
check_grad
([
"X"
],
"Out"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录