Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
bef475c9
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bef475c9
编写于
11月 22, 2018
作者:
P
peizhilin
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into windows/build
上级
f10e196f
5d4d117e
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
503 addition
and
170 deletion
+503
-170
CMakeLists.txt
CMakeLists.txt
+4
-0
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+89
-53
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
+9
-7
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-0
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
+64
-0
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
+111
-0
paddle/fluid/memory/allocation/best_fit_allocator_test.cc
paddle/fluid/memory/allocation/best_fit_allocator_test.cc
+1
-0
paddle/fluid/memory/allocation/best_fit_allocator_test.cu
paddle/fluid/memory/allocation/best_fit_allocator_test.cu
+1
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+15
-7
paddle/fluid/operators/conv_fusion_op.cu.cc
paddle/fluid/operators/conv_fusion_op.cu.cc
+4
-0
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+36
-0
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+13
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+149
-101
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+4
-0
python/requirements.txt
python/requirements.txt
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
bef475c9
...
@@ -139,6 +139,10 @@ if (WIN32)
...
@@ -139,6 +139,10 @@ if (WIN32)
"Disable MKL when compiling for Windows"
FORCE
)
"Disable MKL when compiling for Windows"
FORCE
)
set
(
WITH_DISTRIBUTE OFF CACHE STRING
set
(
WITH_DISTRIBUTE OFF CACHE STRING
"Disable DISTRIBUTE when compiling for Windows"
FORCE
)
"Disable DISTRIBUTE when compiling for Windows"
FORCE
)
set
(
WITH_C_API OFF CACHE STRING
"Disable C_API when compiling for Windows"
FORCE
)
set
(
WITH_FLUID_ONLY ON CACHE STRING
"Enable FLUID_ONLY when compiling for Windows"
FORCE
)
endif
()
endif
()
set
(
THIRD_PARTY_PATH
"
${
CMAKE_BINARY_DIR
}
/third_party"
CACHE STRING
set
(
THIRD_PARTY_PATH
"
${
CMAKE_BINARY_DIR
}
/third_party"
CACHE STRING
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
bef475c9
...
@@ -35,4 +35,4 @@ function(inference_analysis_test TARGET)
...
@@ -35,4 +35,4 @@ function(inference_analysis_test TARGET)
endif
()
endif
()
endfunction
(
inference_analysis_test
)
endfunction
(
inference_analysis_test
)
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api
)
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS
reset_tensor_array
paddle_inference_api
)
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
bef475c9
...
@@ -13,25 +13,57 @@ See the License for the specific language governing permissions and
...
@@ -13,25 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
void
DealCeilMode
(
const
nvinfer1
::
Dims
&
input_shape
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
nvinfer1
::
DimsHW
*
pre_pad
,
nvinfer1
::
DimsHW
*
post_pad
,
int
input_dims
)
{
int
input_height
=
input_shape
.
d
[
input_dims
-
2
];
int
input_width
=
input_shape
.
d
[
input_dims
-
1
];
int
floor_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
ceil_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
strides
[
0
]
+
1
;
int
floor_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
ceil_w_output_size
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
strides
[
1
]
+
1
;
if
(
floor_h_output_size
!=
ceil_h_output_size
)
{
post_pad
->
h
()
=
strides
[
0
]
-
1
;
}
if
(
floor_w_output_size
!=
ceil_w_output_size
)
{
post_pad
->
w
()
=
strides
[
1
]
-
1
;
}
}
/*
/*
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
*/
*/
class
Pool2dOpConverter
:
public
OpConverter
{
class
Pool2dOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
VLOG
(
40
)
<<
"convert a fluid pool2d op to tensorrt pool2d layer without bias"
;
<<
"convert a fluid pool2d op to tensorrt pool2d layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
// Declare inputs
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
input_dims
=
input_shape
.
nbDims
;
PADDLE_ENFORCE_EQ
(
input_dims
,
3UL
);
bool
global_pooling
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"global_pooling"
));
bool
global_pooling
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"global_pooling"
));
std
::
string
pool_type
=
std
::
string
pool_type
=
...
@@ -44,23 +76,6 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -44,23 +76,6 @@ class Pool2dOpConverter : public OpConverter {
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
bool
ceil_mode
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"ceil_mode"
));
bool
ceil_mode
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"ceil_mode"
));
nvinfer1
::
Dims
input_shape
=
input1
->
getDimensions
();
int
nbDims
=
input_shape
.
nbDims
;
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
if
(
global_pooling
==
true
)
{
nv_ksize
.
d
[
0
]
=
input_shape
.
d
[
nbDims
-
2
];
nv_ksize
.
d
[
1
]
=
input_shape
.
d
[
nbDims
-
1
];
nv_strides
.
h
()
=
1
;
nv_strides
.
w
()
=
1
;
nv_paddings
.
h
()
=
0
;
nv_paddings
.
w
()
=
0
;
}
PADDLE_ENFORCE_EQ
(
input1
->
getDimensions
().
nbDims
,
3UL
);
nvinfer1
::
PoolingType
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
nvinfer1
::
PoolingType
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
if
(
pool_type
==
"max"
)
{
if
(
pool_type
==
"max"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
...
@@ -70,42 +85,63 @@ class Pool2dOpConverter : public OpConverter {
...
@@ -70,42 +85,63 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_THROW
(
"TensorRT unsupported pooling type!"
);
PADDLE_THROW
(
"TensorRT unsupported pooling type!"
);
}
}
if
(
ceil_mode
)
{
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
nvinfer1
::
DimsHW
pre_pad
(
0
,
0
);
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
nvinfer1
::
DimsHW
post_pad
(
0
,
0
);
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
int
input_height
=
input_shape
.
d
[
nbDims
-
2
];
int
input_width
=
input_shape
.
d
[
nbDims
-
1
];
nvinfer1
::
ILayer
*
layer
=
nullptr
;
int
floor_h_output_size
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
if
(
global_pooling
==
true
)
{
int
ceil_h_output_size
=
nv_ksize
.
d
[
0
]
=
input_shape
.
d
[
input_dims
-
2
];
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
]
+
strides
[
0
]
-
1
)
/
nv_ksize
.
d
[
1
]
=
input_shape
.
d
[
input_dims
-
1
];
strides
[
0
]
+
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
1
;
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
int
floor_w_output_size
=
PADDLE_ENFORCE_NOT_NULL
(
layer
,
"pool layer could not be created."
);
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
int
ceil_w_output_size
=
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
]
+
strides
[
1
]
-
1
)
/
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
strides
[
1
]
+
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
1
;
if
(
test_mode
)
{
if
(
floor_h_output_size
!=
ceil_h_output_size
)
{
engine_
->
DeclareOutput
(
output_name
);
post_pad
.
h
()
=
strides
[
0
]
-
1
;
}
}
return
;
}
if
(
floor_w_output_size
!=
ceil_w_output_size
)
{
if
(
pool_type
==
"max"
)
{
post_pad
.
w
()
=
strides
[
1
]
-
1
;
nvinfer1
::
DimsHW
pre_pad
(
paddings
[
0
],
paddings
[
1
]);
nvinfer1
::
DimsHW
post_pad
(
paddings
[
0
],
paddings
[
1
]);
if
(
ceil_mode
)
{
// If ceil mode is true, we will pad the appropriate size to the input.
DealCeilMode
(
input_shape
,
ksize
,
strides
,
paddings
,
&
pre_pad
,
&
post_pad
,
input_dims
);
auto
*
pad_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Padding
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
pre_pad
,
post_pad
);
PADDLE_ENFORCE_NOT_NULL
(
pad_layer
,
"pad layer in poolOp converter could not be created."
);
input1
=
pad_layer
->
getOutput
(
0
);
}
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
pool_layer
,
"pool layer could not be created."
);
pool_layer
->
setStride
(
nv_strides
);
pool_layer
->
setPadding
(
nv_paddings
);
layer
=
pool_layer
;
}
else
{
// Average pooling needs to exclude the padding pixels from the average
// mean.
// It is not supported well by TRT, we use a plugin here.
std
::
vector
<
int
>
input_shape_v
;
for
(
int
i
=
0
;
i
<
input_dims
;
i
++
)
{
input_shape_v
.
push_back
(
input_shape
.
d
[
i
]);
}
}
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
plugin
::
AvgPoolPlugin
*
plugin
=
new
plugin
::
AvgPoolPlugin
(
engine_
,
Padding
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
pre_pad
,
ceil_mode
,
ksize
,
strides
,
paddings
,
input_shape_v
);
post_pad
);
auto
*
avg_pool_layer
=
engine_
->
AddPlugin
(
&
input1
,
1
,
plugin
);
input1
=
layer
->
getOutput
(
0
)
;
layer
=
avg_pool_layer
;
}
}
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
"pool layer could not be created."
);
layer
->
setStride
(
nv_strides
);
layer
->
setPadding
(
nv_paddings
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
layer
->
setName
((
"pool2d (Output: "
+
output_name
+
")"
).
c_str
());
...
...
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
浏览文件 @
bef475c9
...
@@ -20,20 +20,21 @@ namespace paddle {
...
@@ -20,20 +20,21 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
tensorrt
{
namespace
tensorrt
{
void
test_pool2d
(
bool
global_pooling
,
bool
ceil_mode
)
{
void
test_pool2d
(
bool
global_pooling
,
bool
ceil_mode
,
std
::
string
pool_type
=
"max"
)
{
framework
::
Scope
scope
;
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
// The ITensor's Dims should not contain the batch size.
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
// So, the ITensor's Dims of input and output should be C * H * W.
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
13
,
14
));
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
6
,
7
));
if
(
global_pooling
)
if
(
global_pooling
)
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
1
,
1
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
1
,
1
));
else
if
(
ceil_mode
)
else
if
(
ceil_mode
)
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
6
,
7
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
3
,
4
));
else
else
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
6
,
6
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
3
,
3
));
// Prepare Op description
// Prepare Op description
framework
::
OpDesc
desc
;
framework
::
OpDesc
desc
;
...
@@ -41,10 +42,10 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
...
@@ -41,10 +42,10 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
desc
.
SetInput
(
"X"
,
{
"pool2d-X"
});
desc
.
SetInput
(
"X"
,
{
"pool2d-X"
});
desc
.
SetOutput
(
"Out"
,
{
"pool2d-Out"
});
desc
.
SetOutput
(
"Out"
,
{
"pool2d-Out"
});
std
::
vector
<
int
>
ksize
({
3
,
3
});
std
::
vector
<
int
>
ksize
({
2
,
2
});
std
::
vector
<
int
>
strides
({
2
,
2
});
std
::
vector
<
int
>
strides
({
2
,
2
});
std
::
vector
<
int
>
paddings
({
0
,
0
});
std
::
vector
<
int
>
paddings
({
0
,
0
});
std
::
string
pooling_t
=
"max"
;
std
::
string
pooling_t
=
pool_type
;
desc
.
SetAttr
(
"pooling_type"
,
pooling_t
);
desc
.
SetAttr
(
"pooling_type"
,
pooling_t
);
desc
.
SetAttr
(
"ksize"
,
ksize
);
desc
.
SetAttr
(
"ksize"
,
ksize
);
...
@@ -63,7 +64,8 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
...
@@ -63,7 +64,8 @@ void test_pool2d(bool global_pooling, bool ceil_mode) {
TEST
(
Pool2dOpConverter
,
normal
)
{
test_pool2d
(
false
,
false
);
}
TEST
(
Pool2dOpConverter
,
normal
)
{
test_pool2d
(
false
,
false
);
}
TEST
(
Pool2dOpConverter
,
test_global_pooling
)
{
test_pool2d
(
true
,
false
);
}
TEST
(
Pool2dOpConverter
,
test_global_pooling
)
{
test_pool2d
(
true
,
false
);
}
TEST
(
Pool2dOpConverter
,
test_ceil_mode
)
{
test_pool2d
(
false
,
true
);
}
TEST
(
Pool2dOpConverter
,
max_ceil_test
)
{
test_pool2d
(
false
,
true
);
}
TEST
(
Pool2dOpConverter
,
avg_ceil_test
)
{
test_pool2d
(
false
,
true
,
"avg"
);
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
bef475c9
nv_library
(
tensorrt_plugin
nv_library
(
tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu
avg_pool_op_plugin.cu
DEPS enforce tensorrt_engine
)
DEPS enforce tensorrt_engine
)
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.cu
0 → 100644
浏览文件 @
bef475c9
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
nvinfer1
::
Dims
AvgPoolPlugin
::
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputDims
,
int
nbInputs
)
{
assert
(
nbInputs
==
1
);
assert
(
index
==
0
);
assert
(
inputDims
[
0
].
nbDims
==
3
);
nvinfer1
::
Dims
const
&
input_dims
=
inputDims
[
0
];
nvinfer1
::
Dims
output_dims
=
input_dims
;
output_dims
.
d
[
1
]
=
output_shape_
[
1
];
output_dims
.
d
[
2
]
=
output_shape_
[
2
];
return
output_dims
;
}
int
AvgPoolPlugin
::
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
const
&
input_dims
=
this
->
getInputDims
(
0
);
int
input_size
=
0
;
float
const
*
idata
=
reinterpret_cast
<
float
const
*>
(
inputs
[
0
]);
float
**
odatas
=
reinterpret_cast
<
float
**>
(
outputs
);
paddle
::
operators
::
math
::
AvgPool
<
float
>
pool_process
;
paddle
::
operators
::
math
::
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>
,
float
>
pool2d_forward
;
std
::
vector
<
int
>
input_shape
=
input_shape_
;
std
::
vector
<
int
>
output_shape
=
output_shape_
;
input_shape
.
insert
(
input_shape
.
begin
(),
batchSize
);
output_shape
.
insert
(
output_shape
.
begin
(),
batchSize
);
pool2d_forward
(
idata
,
input_shape
,
output_shape
,
ksize_
,
strides_
,
paddings_
,
pool_process
,
true
,
odatas
[
0
],
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/avg_pool_op_plugin.h
0 → 100644
浏览文件 @
bef475c9
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
AvgPoolPlugin
:
public
PluginTensorRT
{
private:
bool
ceil_mode_
;
std
::
vector
<
int
>
ksize_
;
std
::
vector
<
int
>
strides_
;
std
::
vector
<
int
>
paddings_
;
std
::
vector
<
int
>
input_shape_
;
std
::
vector
<
int
>
output_shape_
;
protected:
size_t
getSerializationSize
()
override
{
return
SerializedSize
(
ceil_mode_
)
+
SerializedSize
(
ksize_
)
+
SerializedSize
(
strides_
)
+
SerializedSize
(
paddings_
)
+
SerializedSize
(
input_shape_
)
+
getBaseSerializationSize
();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void
serialize
(
void
*
buffer
)
override
{
serializeBase
(
buffer
);
SerializeValue
(
&
buffer
,
ceil_mode_
);
SerializeValue
(
&
buffer
,
ksize_
);
SerializeValue
(
&
buffer
,
strides_
);
SerializeValue
(
&
buffer
,
paddings_
);
SerializeValue
(
&
buffer
,
input_shape_
);
}
public:
AvgPoolPlugin
(
bool
ceil_mode
,
std
::
vector
<
int
>
ksize
,
std
::
vector
<
int
>
strides
,
std
::
vector
<
int
>
paddings
,
std
::
vector
<
int
>
input_shape
)
:
ceil_mode_
(
ceil_mode
),
ksize_
(
ksize
),
strides_
(
strides
),
paddings_
(
paddings
),
input_shape_
(
input_shape
)
{
int
output_h
,
output_w
;
output_shape_
=
input_shape_
;
if
(
!
ceil_mode_
)
{
output_h
=
(
input_shape
[
1
]
-
ksize_
[
0
]
+
2
*
paddings_
[
0
])
/
strides_
[
0
]
+
1
;
output_w
=
(
input_shape
[
2
]
-
ksize_
[
1
]
+
2
*
paddings_
[
1
])
/
strides_
[
1
]
+
1
;
}
else
{
output_h
=
(
input_shape
[
1
]
-
ksize_
[
0
]
+
2
*
paddings_
[
0
]
+
strides_
[
0
]
-
1
)
/
strides_
[
0
]
+
1
;
output_w
=
(
input_shape
[
2
]
-
ksize_
[
1
]
+
2
*
paddings_
[
1
]
+
strides_
[
1
]
-
1
)
/
strides_
[
1
]
+
1
;
}
output_shape_
[
1
]
=
output_h
;
output_shape_
[
2
]
=
output_w
;
}
// It was used for tensorrt deserialization.
// It should not be called by users.
AvgPoolPlugin
(
void
const
*
serialData
,
size_t
serialLength
)
{
deserializeBase
(
serialData
,
serialLength
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ceil_mode_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
ksize_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
strides_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
paddings_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
input_shape_
);
}
AvgPoolPlugin
*
clone
()
const
override
{
return
new
AvgPoolPlugin
(
ceil_mode_
,
ksize_
,
strides_
,
paddings_
,
input_shape_
);
}
const
char
*
getPluginType
()
const
override
{
return
"avg_pool"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
nvinfer1
::
Dims
getOutputDimensions
(
int
index
,
const
nvinfer1
::
Dims
*
inputs
,
int
nbInputDims
)
override
;
int
initialize
()
override
{
return
0
;
}
int
enqueue
(
int
batchSize
,
const
void
*
const
*
inputs
,
void
**
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
};
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/memory/allocation/best_fit_allocator_test.cc
浏览文件 @
bef475c9
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include "paddle/fluid/memory/allocation/best_fit_allocator.h"
#include <random>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
paddle/fluid/memory/allocation/best_fit_allocator_test.cu
浏览文件 @
bef475c9
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <random>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
bef475c9
...
@@ -32,31 +32,39 @@ if (WITH_GPU AND TENSORRT_FOUND)
...
@@ -32,31 +32,39 @@ if (WITH_GPU AND TENSORRT_FOUND)
add_subdirectory
(
tensorrt
)
add_subdirectory
(
tensorrt
)
endif
()
endif
()
register_operators
(
EXCLUDES warpctc_op conv_fusion_op
)
SET
(
OP_HEADER_DEPS xxhash
)
if
(
WITH_GPU
)
SET
(
OP_HEADER_DEPS
${
OP_HEADER_DEPS
}
cub
)
endif
()
# warpctc_cudnn need cudnn 7 above
register_operators
(
EXCLUDES warpctc_op conv_fusion_op DEPS
${
OP_HEADER_DEPS
}
)
# warpctc_op needs cudnn 7 above
if
(
WITH_GPU AND NOT WIN32
)
if
(
WITH_GPU AND NOT WIN32
)
if
(
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
if
(
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc
)
else
()
else
()
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
endif
()
endif
()
op_library
(
conv_fusion_op
)
# conv_fusion_op needs cudnn 7 above
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(conv2d_fusion);
\n
"
)
if
(
NOT
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
op_library
(
conv_fusion_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(conv2d_fusion);
\n
"
)
endif
()
else
()
else
()
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
endif
()
endif
()
set
(
COMMON_OP_DEPS
""
)
set
(
COMMON_OP_DEPS
${
OP_HEADER_DEPS
}
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
xxhash
selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor
)
if
(
NOT WIN32
)
if
(
NOT WIN32
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
endif
()
endif
()
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv
cub
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv
)
endif
()
endif
()
# FIXME(typhoonzero): operator deps may not needed.
# FIXME(typhoonzero): operator deps may not needed.
...
...
paddle/fluid/operators/conv_fusion_op.cu.cc
浏览文件 @
bef475c9
...
@@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
...
@@ -22,6 +22,7 @@ DECLARE_bool(cudnn_exhaustive_search);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
#if CUDNN_VERSION >= 7001
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
ScopedFilterDescriptor
=
platform
::
ScopedFilterDescriptor
;
using
ScopedFilterDescriptor
=
platform
::
ScopedFilterDescriptor
;
...
@@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
...
@@ -178,10 +179,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
};
};
#endif
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
#if CUDNN_VERSION >= 7001
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
conv2d_fusion
,
ops
::
CUDNNConvFusionOpKernel
<
float
>
,
REGISTER_OP_CUDA_KERNEL
(
conv2d_fusion
,
ops
::
CUDNNConvFusionOpKernel
<
float
>
,
ops
::
CUDNNConvFusionOpKernel
<
double
>
);
ops
::
CUDNNConvFusionOpKernel
<
double
>
);
#endif
paddle/fluid/operators/math/pooling.cu
浏览文件 @
bef475c9
...
@@ -153,6 +153,37 @@ __global__ void KernelMaxPool2DGrad(
...
@@ -153,6 +153,37 @@ __global__ void KernelMaxPool2DGrad(
}
}
}
}
template
<
typename
PoolProcess
,
typename
T
>
void
Pool2dDirectCUDAFunctor
<
PoolProcess
,
T
>::
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
bool
exclusive
,
T
*
output
,
cudaStream_t
stream
)
{
const
int
batch_size
=
input_shape
[
0
];
const
int
input_channels
=
input_shape
[
1
];
const
int
input_height
=
input_shape
[
2
];
const
int
input_width
=
input_shape
[
3
];
const
int
output_channels
=
output_shape
[
1
];
const
int
output_height
=
output_shape
[
2
];
const
int
output_width
=
output_shape
[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
nthreads
,
input
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_compute
,
exclusive
,
output
);
}
/*
/*
* All tensors are in NCHW format.
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* Ksize, strides, paddings are two elements. These two elements represent
...
@@ -291,6 +322,11 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
...
@@ -291,6 +322,11 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
}
}
};
};
template
class
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
template
class
Pool2dDirectCUDAFunctor
<
paddle
::
operators
::
math
::
AvgPool
<
float
>,
float
>
;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
MaxPool2dGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
bef475c9
...
@@ -82,6 +82,19 @@ class AvgPoolGrad {
...
@@ -82,6 +82,19 @@ class AvgPoolGrad {
* This is different from average pooling. So we rewrite the max_pool_grad:
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
*/
#ifdef PADDLE_WITH_CUDA
template
<
typename
PoolProcess
,
typename
T
>
class
Pool2dDirectCUDAFunctor
{
public:
void
operator
()(
const
T
*
input
,
const
std
::
vector
<
int
>&
input_shape
,
const
std
::
vector
<
int
>&
output_shape
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
bool
exclusive
,
T
*
output
,
cudaStream_t
stream
);
};
#endif
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
class
Pool2dFunctor
{
class
Pool2dFunctor
{
public:
public:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
bef475c9
...
@@ -724,11 +724,11 @@ def dynamic_gru(input,
...
@@ -724,11 +724,11 @@ def dynamic_gru(input,
create ParamAttr as param_attr. If the Initializer of the param_attr
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, dynamic_gru will create ParamAttr as
attribute of ParamAttr, dynamic_gru will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
is initialized zero. Default: None.
is_reverse(bool): Whether to compute reversed GRU, default
is_reverse(bool): Whether to compute reversed GRU, default
...
@@ -845,11 +845,11 @@ def gru_unit(input,
...
@@ -845,11 +845,11 @@ def gru_unit(input,
create ParamAttr as param_attr. If the Initializer of the param_attr
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
of GRU. Note that the bias with :math:`(1
\\
times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, gru_unit will create ParamAttr as
attribute of ParamAttr, gru_unit will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
is initialized zero. Default: None.
activation (string): The activation type for cell (actNode).
activation (string): The activation type for cell (actNode).
...
@@ -1058,9 +1058,9 @@ def dropout(x,
...
@@ -1058,9 +1058,9 @@ def dropout(x,
inference: out = input
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
dropout op can be removed from the program.
the program will be efficient
the program will be efficient
Returns:
Returns:
...
@@ -2143,7 +2143,7 @@ def pool2d(input,
...
@@ -2143,7 +2143,7 @@ def pool2d(input,
ceil_mode (bool): ${ceil_mode_comment}
ceil_mode (bool): ${ceil_mode_comment}
name (str|None): A name for this layer(optional). If set None, the
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
mode, default is true
Returns:
Returns:
...
@@ -2234,7 +2234,7 @@ def pool3d(input,
...
@@ -2234,7 +2234,7 @@ def pool3d(input,
ceil_mode (bool): ${ceil_mode_comment}
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
mode, default is true
Returns:
Returns:
...
@@ -4336,7 +4336,7 @@ def nce(input,
...
@@ -4336,7 +4336,7 @@ def nce(input,
sampler (str): The sampler used to sample class from negtive classes.
sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes].
custom_dist (Variable): A tensor with shape [num_total_classes].
It is used when sampler is set to 'custom_dist'.
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
default: None.
...
@@ -4379,7 +4379,7 @@ def nce(input,
...
@@ -4379,7 +4379,7 @@ def nce(input,
num_neg_samples=3,
num_neg_samples=3,
sampler="custom_dist",
sampler="custom_dist",
custom_dist=dist)
custom_dist=dist)
"""
"""
helper
=
LayerHelper
(
'nce'
,
**
locals
())
helper
=
LayerHelper
(
'nce'
,
**
locals
())
assert
isinstance
(
input
,
Variable
)
assert
isinstance
(
input
,
Variable
)
...
@@ -4550,9 +4550,9 @@ def transpose(x, perm, name=None):
...
@@ -4550,9 +4550,9 @@ def transpose(x, perm, name=None):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
# use append_batch_size=False to avoid prepending extra
# use append_batch_size=False to avoid prepending extra
# batch size in shape
# batch size in shape
x = fluid.layers.data(name='x', shape=[5, 10, 15],
x = fluid.layers.data(name='x', shape=[5, 10, 15],
dtype='float32', append_batch_size=False)
dtype='float32', append_batch_size=False)
x_transposed = layers.transpose(x, perm=[1, 0, 2])
x_transposed = layers.transpose(x, perm=[1, 0, 2])
"""
"""
...
@@ -4829,7 +4829,7 @@ def softmax_with_cross_entropy(logits,
...
@@ -4829,7 +4829,7 @@ def softmax_with_cross_entropy(logits,
3) If numeric_stable_mode is True, softmax is calculated first by:
3) If numeric_stable_mode is True, softmax is calculated first by:
.. math::
.. math::
max_j =
\\
max_{i=0}^{K}{
\\
text{logit}_i}
max_j =
\\
max_{i=0}^{K}{
\\
text{logit}_i}
log
\\
_max
\\
_sum_j =
\\
log
\\
sum_{i=0}^{K}
\\
exp(logit_i - max_j)
log
\\
_max
\\
_sum_j =
\\
log
\\
sum_{i=0}^{K}
\\
exp(logit_i - max_j)
...
@@ -4852,18 +4852,18 @@ def softmax_with_cross_entropy(logits,
...
@@ -4852,18 +4852,18 @@ def softmax_with_cross_entropy(logits,
numeric_stable_mode (bool): A flag to indicate whether to use a more
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
the algorithm is always numerically stable.
Note that the speed may be slower when use
Note that the speed may be slower when use
stable algorithm. Default: False
stable algorithm. Default: False
return_softmax (bool): A flag indicating whether to return the softmax
return_softmax (bool): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False
along with the cross entropy loss. Default: False
Returns:
Returns:
Variable or Tuple of two Variables: Return the cross entropy loss if
Variable or Tuple of two Variables: Return the cross entropy loss if
`return_softmax` is False, otherwise the tuple
`return_softmax` is False, otherwise the tuple
(loss, softmax), where the cross entropy loss is
(loss, softmax), where the cross entropy loss is
a 2-D tensor with shape [N x 1], and softmax is a
a 2-D tensor with shape [N x 1], and softmax is a
2-D tensor with shape [N x K].
2-D tensor with shape [N x K].
Examples:
Examples:
...
@@ -5744,20 +5744,20 @@ def image_resize(input,
...
@@ -5744,20 +5744,20 @@ def image_resize(input,
Default: None
Default: None
name(str|None): A name for this layer(optional). If set None, the layer
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST'
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST'
currently.
currently.
Default: 'BILINEAR'
Default: 'BILINEAR'
actual_shape(Variable): An optional input to specify output shape
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
dynamically. If provided, image resize
according to this given shape rather than
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
shape. That is to say actual_shape has the
highest priority. It is recommended to use
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
set, otherwise errors would be occured in graph
constructing stage.
constructing stage.
Default: None
Default: None
...
@@ -5768,7 +5768,7 @@ def image_resize(input,
...
@@ -5768,7 +5768,7 @@ def image_resize(input,
Raises:
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR'
ValueError: The 'resample' of image_resize can only be 'BILINEAR'
or 'NEAREST' currently.
or 'NEAREST' currently.
ValueError: One of out_shape and scale must not be None.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2.
ValueError: out_shape length should be 2.
...
@@ -5840,17 +5840,17 @@ def resize_bilinear(input,
...
@@ -5840,17 +5840,17 @@ def resize_bilinear(input,
name
=
None
,
name
=
None
,
actual_shape
=
None
):
actual_shape
=
None
):
"""
"""
Resize input by performing bilinear interpolation based on given
Resize input by performing bilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
output shape which specified by actual_shape, out_shape and scale
in priority order.
in priority order.
Bilinear interpolation is an extension of linear interpolation for
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
to perform linear interpolation first in one direction, and then
again in the other direction.
again in the other direction.
For details of bilinear interpolation, please refer to Wikipedia:
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
https://en.wikipedia.org/wiki/Bilinear_interpolation
Args:
Args:
...
@@ -5863,17 +5863,17 @@ def resize_bilinear(input,
...
@@ -5863,17 +5863,17 @@ def resize_bilinear(input,
a higher priority than scale. Default: None.
a higher priority than scale. Default: None.
name(str|None): The output variable name.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
dynamically. If provided, image resize
according to this given shape rather than
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
shape. That is to say actual_shape has the
highest priority. It is recommended to use
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
set, otherwise errors would be occured in graph
constructing stage.
constructing stage.
Default: None
Default: None
...
@@ -5897,11 +5897,11 @@ def resize_nearest(input,
...
@@ -5897,11 +5897,11 @@ def resize_nearest(input,
actual_shape
=
None
):
actual_shape
=
None
):
"""
"""
Resize input by performing nearest neighbor interpolation in both the
Resize input by performing nearest neighbor interpolation in both the
3rd dimention(in height direction) and the 4th dimention(in width
3rd dimention(in height direction) and the 4th dimention(in width
direction) based on given output shape which specified by actual_shape,
direction) based on given output shape which specified by actual_shape,
out_shape and scale in priority order.
out_shape and scale in priority order.
For details of nearest neighbor interpolation, please refer to Wikipedia:
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
Args:
Args:
...
@@ -5914,17 +5914,17 @@ def resize_nearest(input,
...
@@ -5914,17 +5914,17 @@ def resize_nearest(input,
a higher priority than scale. Default: None.
a higher priority than scale. Default: None.
name(str|None): The output variable name.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
dynamically. If provided, image resize
according to this given shape rather than
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
shape. That is to say actual_shape has the
highest priority. It is recommended to use
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
set, otherwise errors would be occured in graph
constructing stage.
constructing stage.
Default: None
Default: None
...
@@ -6434,15 +6434,15 @@ def affine_grid(theta, out_shape, name=None):
...
@@ -6434,15 +6434,15 @@ def affine_grid(theta, out_shape, name=None):
[x_14, x_15, x_16]]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
[x_24, x_25, x_26]]]
out_shape = [2, 3, 5, 5]
out_shape = [2, 3, 5, 5]
Step 1:
Step 1:
Generate normalized coordinates according to out_shape.
Generate normalized coordinates according to out_shape.
The values of the normalized coordinates are in the interval between -1 and 1.
The values of the normalized coordinates are in the interval between -1 and 1.
The shape of the normalized coordinates is [2, H, W] as below:
The shape of the normalized coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]
...
@@ -7690,6 +7690,15 @@ def logical_and(x, y, out=None, name=None):
...
@@ -7690,6 +7690,15 @@ def logical_and(x, y, out=None, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
left = fluid.layers.data(
name='left', shape=[1], dtype='int32')
right = fluid.layers.data(
name='right', shape=[1], dtype='int32')
result = fluid.layers.logical_and(x=left, y=right)
"""
"""
return
_logical_op
(
return
_logical_op
(
...
@@ -7709,6 +7718,15 @@ def logical_or(x, y, out=None, name=None):
...
@@ -7709,6 +7718,15 @@ def logical_or(x, y, out=None, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
left = fluid.layers.data(
name='left', shape=[1], dtype='int32')
right = fluid.layers.data(
name='right', shape=[1], dtype='int32')
result = fluid.layers.logical_or(x=left, y=right)
"""
"""
return
_logical_op
(
return
_logical_op
(
...
@@ -7728,6 +7746,15 @@ def logical_xor(x, y, out=None, name=None):
...
@@ -7728,6 +7746,15 @@ def logical_xor(x, y, out=None, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
left = fluid.layers.data(
name='left', shape=[1], dtype='int32')
right = fluid.layers.data(
name='right', shape=[1], dtype='int32')
result = fluid.layers.logical_xor(x=left, y=right)
"""
"""
return
_logical_op
(
return
_logical_op
(
...
@@ -7746,6 +7773,13 @@ def logical_not(x, out=None, name=None):
...
@@ -7746,6 +7773,13 @@ def logical_not(x, out=None, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
left = fluid.layers.data(
name='left', shape=[1], dtype='int32')
result = fluid.layers.logical_not(x=left)
"""
"""
return
_logical_op
(
return
_logical_op
(
...
@@ -7765,6 +7799,13 @@ def clip(x, min, max, name=None):
...
@@ -7765,6 +7799,13 @@ def clip(x, min, max, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
input = fluid.layers.data(
name='data', shape=[1], dtype='float32')
reward = fluid.layers.clip(x=input, min=-1.0, max=1.0)
"""
"""
helper
=
LayerHelper
(
"clip"
,
**
locals
())
helper
=
LayerHelper
(
"clip"
,
**
locals
())
...
@@ -7797,6 +7838,13 @@ def clip_by_norm(x, max_norm, name=None):
...
@@ -7797,6 +7838,13 @@ def clip_by_norm(x, max_norm, name=None):
Returns:
Returns:
out(${out_type}): ${out_comment}
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
input = fluid.layers.data(
name='data', shape=[1], dtype='float32')
reward = fluid.layers.clip_by_norm(x=input, max_norm=1.0)
"""
"""
helper
=
LayerHelper
(
"clip_by_norm"
,
**
locals
())
helper
=
LayerHelper
(
"clip_by_norm"
,
**
locals
())
...
@@ -7942,19 +7990,19 @@ def maxout(x, groups, name=None):
...
@@ -7942,19 +7990,19 @@ def maxout(x, groups, name=None):
def
space_to_depth
(
x
,
blocksize
,
name
=
None
):
def
space_to_depth
(
x
,
blocksize
,
name
=
None
):
"""
"""
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr blocksize indicates the input block size.
The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
(but keeping all data)
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- The depth of the output tensor is block_size * block_size * input channel
- The depth of the output tensor is block_size * block_size * input channel
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- channel should be divisible by square of blocksize
- channel should be divisible by square of blocksize
- height, width should be divsible by blocksize
- height, width should be divsible by blocksize
...
@@ -8001,7 +8049,7 @@ def space_to_depth(x, blocksize, name=None):
...
@@ -8001,7 +8049,7 @@ def space_to_depth(x, blocksize, name=None):
@
templatedoc
()
@
templatedoc
()
def
sequence_reverse
(
x
,
name
=
None
):
def
sequence_reverse
(
x
,
name
=
None
):
"""
"""
${comment}
${comment}
Args:
Args:
...
@@ -8068,21 +8116,21 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
...
@@ -8068,21 +8116,21 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
def
similarity_focus
(
input
,
axis
,
indexes
,
name
=
None
):
def
similarity_focus
(
input
,
axis
,
indexes
,
name
=
None
):
"""
"""
SimilarityFocus Operator
SimilarityFocus Operator
Generate a similarity focus mask with the same shape of input using the following method:
Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
is (BatchSize, A, B, C), the shape of tensor T is (BatchSize, B, C).
2. For each index, find the largest numbers in the tensor T, so that the same
2. For each index, find the largest numbers in the tensor T, so that the same
row and same column has at most one number(what it means is that if the
row and same column has at most one number(what it means is that if the
largest number has been found in the i-th row and the j-th column, then
largest number has been found in the i-th row and the j-th column, then
the numbers in the i-th row or j-th column will be skipped. And then the
the numbers in the i-th row or j-th column will be skipped. And then the
next largest number will be selected from the remaining numbers. Obviously
next largest number will be selected from the remaining numbers. Obviously
there will be min(B, C) numbers), and mark the corresponding position of the
there will be min(B, C) numbers), and mark the corresponding position of the
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
3-D similarity focus mask as 1, otherwise as 0. Do elementwise-or for
each index.
each index.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
3. Broadcast the 3-D similarity focus mask to the same shape of input X.
...
@@ -8138,16 +8186,16 @@ def similarity_focus(input, axis, indexes, name=None):
...
@@ -8138,16 +8186,16 @@ def similarity_focus(input, axis, indexes, name=None):
[1.0, 0.0]]]]
[1.0, 0.0]]]]
Args:
Args:
input(Variable): The input tensor variable(default float). It should
input(Variable): The input tensor variable(default float). It should
be a 4-D tensor with shape [BatchSize, A, B, C].
be a 4-D tensor with shape [BatchSize, A, B, C].
axis(int): Indicating the dimension to be selected. It can only be
axis(int): Indicating the dimension to be selected. It can only be
1, 2 or 3.
1, 2 or 3.
indexes(list): Indicating the indexes of the selected dimension.
indexes(list): Indicating the indexes of the selected dimension.
Returns:
Returns:
Variable: A tensor variable with the same shape and same type
Variable: A tensor variable with the same shape and same type
as the input.
as the input.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
data = fluid.layers.data(
data = fluid.layers.data(
...
@@ -8250,12 +8298,12 @@ def hash(input, hash_size, num_hash=1, name=None):
...
@@ -8250,12 +8298,12 @@ def hash(input, hash_size, num_hash=1, name=None):
@
templatedoc
()
@
templatedoc
()
def
grid_sampler
(
x
,
grid
,
name
=
None
):
def
grid_sampler
(
x
,
grid
,
name
=
None
):
"""
"""
This operation samples input X by using bilinear interpolation based on
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
interpolation value of 4 nearest corner points.
Step 1:
Step 1:
...
@@ -8265,7 +8313,7 @@ def grid_sampler(x, grid, name=None):
...
@@ -8265,7 +8313,7 @@ def grid_sampler(x, grid, name=None):
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
wn ------- y_n ------- en
...
@@ -8302,7 +8350,7 @@ def grid_sampler(x, grid, name=None):
...
@@ -8302,7 +8350,7 @@ def grid_sampler(x, grid, name=None):
name (str, default None): The name of this layer.
name (str, default None): The name of this layer.
Returns:
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
using bilnear interpolation based on input grid.
Exmples:
Exmples:
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
bef475c9
...
@@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE)
...
@@ -23,6 +23,10 @@ if(NOT WITH_DISTRIBUTE)
LIST
(
REMOVE_ITEM TEST_OPS test_dist_text_classification
)
LIST
(
REMOVE_ITEM TEST_OPS test_dist_text_classification
)
endif
(
NOT WITH_DISTRIBUTE
)
endif
(
NOT WITH_DISTRIBUTE
)
if
(
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
LIST
(
REMOVE_ITEM TEST_OPS test_conv2d_fusion_op
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_seq_concat_op
)
# FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list
(
REMOVE_ITEM TEST_OPS test_seq_concat_op
)
# FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
list
(
REMOVE_ITEM TEST_OPS test_modified_huber_loss_op
)
# FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list
(
REMOVE_ITEM TEST_OPS test_modified_huber_loss_op
)
# FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
list
(
REMOVE_ITEM TEST_OPS test_lstm_unit_op
)
# # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
list
(
REMOVE_ITEM TEST_OPS test_lstm_unit_op
)
# # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
...
...
python/requirements.txt
浏览文件 @
bef475c9
requests==2.9.2
requests==2.9.2
numpy>=1.12
,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version
numpy>=1.12
protobuf==3.1
protobuf==3.1
recordio>=0.1.0
recordio>=0.1.0
matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib
matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录