Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
6824c09d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6824c09d
编写于
8月 02, 2017
作者:
Q
QI JUN
提交者:
GitHub
8月 02, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3050 from QiJune/op_gpu_test
enable operator gpu unittest
上级
bfaea910
043e983b
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
162 addition
and
76 deletion
+162
-76
cmake/flags.cmake
cmake/flags.cmake
+5
-0
paddle/framework/detail/tensor-inl.h
paddle/framework/detail/tensor-inl.h
+4
-3
paddle/operators/add_op.cu
paddle/operators/add_op.cu
+1
-0
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+1
-0
paddle/operators/mul_op.cu
paddle/operators/mul_op.cu
+1
-0
paddle/operators/rowwise_add_op.cu
paddle/operators/rowwise_add_op.cu
+1
-0
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+1
-0
paddle/operators/sigmoid_op.cu
paddle/operators/sigmoid_op.cu
+1
-0
paddle/operators/softmax_op.cu
paddle/operators/softmax_op.cu
+1
-0
paddle/platform/enforce.h
paddle/platform/enforce.h
+6
-6
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+50
-10
paddle/pybind/tensor_bind.h
paddle/pybind/tensor_bind.h
+36
-12
python/paddle/v2/framework/tests/CMakeLists.txt
python/paddle/v2/framework/tests/CMakeLists.txt
+0
-1
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+32
-27
python/paddle/v2/framework/tests/test_add_two_op.py
python/paddle/v2/framework/tests/test_add_two_op.py
+2
-2
python/paddle/v2/framework/tests/test_fc_op.py
python/paddle/v2/framework/tests/test_fc_op.py
+6
-4
python/paddle/v2/framework/tests/test_mul_op.py
python/paddle/v2/framework/tests/test_mul_op.py
+2
-2
python/paddle/v2/framework/tests/test_rowwise_add_op.py
python/paddle/v2/framework/tests/test_rowwise_add_op.py
+2
-2
python/paddle/v2/framework/tests/test_sgd_op.py
python/paddle/v2/framework/tests/test_sgd_op.py
+2
-2
python/paddle/v2/framework/tests/test_tensor.py
python/paddle/v2/framework/tests/test_tensor.py
+8
-5
未找到文件。
cmake/flags.cmake
浏览文件 @
6824c09d
...
@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
...
@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if
(
${
CMAKE_CXX_COMPILER_VERSION
}
VERSION_LESS 4.8
)
if
(
${
CMAKE_CXX_COMPILER_VERSION
}
VERSION_LESS 4.8
)
message
(
FATAL_ERROR
"Unsupported GCC version. GCC >= 4.8 required."
)
message
(
FATAL_ERROR
"Unsupported GCC version. GCC >= 4.8 required."
)
endif
()
endif
()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9
)
set
(
CMAKE_BUILD_TYPE
"Debug"
CACHE STRING
""
FORCE
)
endif
()
elseif
(
CMAKE_CXX_COMPILER_ID STREQUAL
"AppleClang"
OR CMAKE_CXX_COMPILER_ID STREQUAL
"Clang"
)
elseif
(
CMAKE_CXX_COMPILER_ID STREQUAL
"AppleClang"
OR CMAKE_CXX_COMPILER_ID STREQUAL
"Clang"
)
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
...
...
paddle/framework/detail/tensor-inl.h
浏览文件 @
6824c09d
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memcpy.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
...
@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if
(
platform
::
is_cpu_place
(
place
))
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
));
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW
(
"'GPUPlace' is not supported in CPU only device."
);
}
}
#ifndef PADDLE_ONLY_CPU
#else
else
if
(
platform
::
is_gpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
holder_
.
reset
(
new
PlaceholderImpl
<
T
,
platform
::
GPUPlace
>
(
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
size
));
boost
::
get
<
platform
::
GPUPlace
>
(
place
),
size
));
}
}
...
...
paddle/operators/add_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
#include "paddle/operators/add_op.h"
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
...
...
paddle/operators/mul_op.cu
浏览文件 @
6824c09d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"
#include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL
(
mul
,
ops
::
MulKernel
<
ops
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
mul
,
ops
::
MulKernel
<
ops
::
GPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/rowwise_add_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
REGISTER_OP_GPU_KERNEL
(
rowwise_add
,
...
...
paddle/operators/sgd_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
#include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
ops
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
ops
::
GPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/sigmoid_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
sigmoid
,
ops
::
SigmoidKernel
<
ops
::
GPUPlace
,
float
>
);
paddle/operators/softmax_op.cu
浏览文件 @
6824c09d
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
#include "paddle/operators/softmax_op.h"
...
...
paddle/platform/enforce.h
浏览文件 @
6824c09d
...
@@ -148,7 +148,7 @@ inline void throw_on_error(T e) {
...
@@ -148,7 +148,7 @@ inline void throw_on_error(T e) {
do { \
do { \
throw ::paddle::platform::EnforceNotMet( \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
std::runtime_error(
paddle::
string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
__FILE__, __LINE__); \
} while (0)
} while (0)
...
...
paddle/pybind/pybind.cc
浏览文件 @
6824c09d
...
@@ -20,6 +20,8 @@ limitations under the License. */
...
@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
...
@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() {
...
@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() {
return
generator
.
fetch_add
(
1
);
return
generator
.
fetch_add
(
1
);
}
}
bool
IsCompileGPU
()
{
#ifdef PADDLE_ONLY_CPU
return
false
;
#else
return
true
;
#endif
}
PYBIND11_PLUGIN
(
core
)
{
PYBIND11_PLUGIN
(
core
)
{
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
py
::
module
m
(
"core"
,
"C++ core of PaddlePaddle"
);
...
@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) {
...
@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) {
self
.
Resize
(
pd
::
make_ddim
(
dim
));
self
.
Resize
(
pd
::
make_ddim
(
dim
));
})
})
.
def
(
"alloc_float"
,
.
def
(
"alloc_float"
,
[](
pd
::
Tensor
&
self
)
{
[](
pd
::
Tensor
&
self
,
paddle
::
platform
::
GPUPlace
&
place
)
{
self
.
mutable_data
<
float
>
(
paddle
::
platform
::
CPUPlace
());
self
.
mutable_data
<
float
>
(
place
);
})
.
def
(
"alloc_float"
,
[](
pd
::
Tensor
&
self
,
paddle
::
platform
::
CPUPlace
&
place
)
{
self
.
mutable_data
<
float
>
(
place
);
})
.
def
(
"alloc_int"
,
[](
pd
::
Tensor
&
self
,
paddle
::
platform
::
CPUPlace
&
place
)
{
self
.
mutable_data
<
int
>
(
place
);
})
})
.
def
(
"alloc_int"
,
.
def
(
"alloc_int"
,
[](
pd
::
Tensor
&
self
)
{
[](
pd
::
Tensor
&
self
,
paddle
::
platform
::
GPUPlace
&
place
)
{
self
.
mutable_data
<
int
>
(
p
addle
::
platform
::
CPUPlace
()
);
self
.
mutable_data
<
int
>
(
p
lace
);
})
})
.
def
(
"set"
,
paddle
::
pybind
::
PyTensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCPUTensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyTensorSetFromArray
<
int
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCPUTensorSetFromArray
<
int
>
)
#ifndef PADDLE_ONLY_CPU
.
def
(
"set"
,
paddle
::
pybind
::
PyCUDATensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCUDATensorSetFromArray
<
int
>
)
#endif
.
def
(
"shape"
,
.
def
(
"shape"
,
[](
pd
::
Tensor
&
self
)
{
return
pd
::
vectorize
(
self
.
dims
());
});
[](
pd
::
Tensor
&
self
)
{
return
pd
::
vectorize
(
self
.
dims
());
});
...
@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle.
"The module will return special predefined variable name in Paddle"
)
"The module will return special predefined variable name in Paddle"
)
.
def
(
"empty"
,
pd
::
OperatorBase
::
EMPTY_VAR_NAME
)
.
def
(
"empty"
,
pd
::
OperatorBase
::
EMPTY_VAR_NAME
)
.
def
(
"temp"
,
pd
::
OperatorBase
::
TMP_VAR_NAME
);
.
def
(
"temp"
,
pd
::
OperatorBase
::
TMP_VAR_NAME
);
// clang-format off
py
::
class_
<
paddle
::
platform
::
DeviceContext
>
(
m
,
"DeviceContext"
)
py
::
class_
<
paddle
::
platform
::
DeviceContext
>
(
m
,
"DeviceContext"
)
.
def_static
(
"cpu_context"
,
[]()
->
paddle
::
platform
::
DeviceContext
*
{
.
def_static
(
"create"
,
[](
paddle
::
platform
::
CPUPlace
&
place
)
->
paddle
::
platform
::
DeviceContext
*
{
return
new
paddle
::
platform
::
CPUDeviceContext
();
return
new
paddle
::
platform
::
CPUDeviceContext
();
})
.
def_static
(
"create"
,
[](
paddle
::
platform
::
GPUPlace
&
place
)
->
paddle
::
platform
::
DeviceContext
*
{
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW
(
"GPUPlace is not supported in CPU device."
);
#else
return
new
paddle
::
platform
::
CUDADeviceContext
(
place
);
#endif
});
});
// clang-format on
py
::
class_
<
paddle
::
platform
::
GPUPlace
>
(
m
,
"GPUPlace"
).
def
(
py
::
init
<
int
>
());
py
::
class_
<
paddle
::
platform
::
CPUPlace
>
(
m
,
"CPUPlace"
).
def
(
py
::
init
<>
());
py
::
class_
<
pd
::
OperatorBase
,
std
::
shared_ptr
<
pd
::
OperatorBase
>>
operator_base
(
py
::
class_
<
pd
::
OperatorBase
,
std
::
shared_ptr
<
pd
::
OperatorBase
>>
operator_base
(
m
,
"Operator"
);
m
,
"Operator"
);
...
@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
m
.
def
(
"is_compile_gpu"
,
IsCompileGPU
);
return
m
.
ptr
();
return
m
.
ptr
();
}
}
paddle/pybind/tensor_bind.h
浏览文件 @
6824c09d
...
@@ -13,9 +13,11 @@
...
@@ -13,9 +13,11 @@
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <paddle/framework/tensor.h>
#include <string>
#include <pybind11/numpy.h>
#include "paddle/framework/tensor.h"
#include <pybind11/pybind11.h>
#include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
...
@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
py
::
buffer_info
operator
()(
framework
::
Tensor
&
tensor
)
{
py
::
buffer_info
operator
()(
framework
::
Tensor
&
tensor
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
tensor
.
holder_
->
place
()),
"Only CPU tensor can cast to numpy array"
);
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
holder_
->
type
())
{
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
holder_
->
type
())
{
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
std
::
vector
<
size_t
>
dims_outside
;
std
::
vector
<
size_t
>
dims_outside
;
...
@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
...
@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides
[
i
-
1
]
=
sizeof
(
CUR_TYPE
)
*
prod
;
strides
[
i
-
1
]
=
sizeof
(
CUR_TYPE
)
*
prod
;
prod
*=
dims_outside
[
i
-
1
];
prod
*=
dims_outside
[
i
-
1
];
}
}
framework
::
Tensor
dst_tensor
;
if
(
paddle
::
platform
::
is_gpu_place
(
tensor
.
holder_
->
place
()))
{
dst_tensor
.
CopyFrom
<
CUR_TYPE
>
(
tensor
,
platform
::
CPUPlace
());
}
else
if
(
paddle
::
platform
::
is_cpu_place
(
tensor
.
holder_
->
place
()))
{
dst_tensor
=
tensor
;
}
return
py
::
buffer_info
(
return
py
::
buffer_info
(
tensor
.
mutable_data
<
CUR_TYPE
>
(
tensor
.
holder_
->
place
()),
dst_tensor
.
mutable_data
<
CUR_TYPE
>
(
dst_
tensor
.
holder_
->
place
()),
sizeof
(
CUR_TYPE
),
sizeof
(
CUR_TYPE
),
py
::
format_descriptor
<
CUR_TYPE
>::
format
(),
py
::
format_descriptor
<
CUR_TYPE
>::
format
(),
(
size_t
)
framework
::
arity
(
tensor
.
dims
()),
(
size_t
)
framework
::
arity
(
dst_
tensor
.
dims
()),
dims_outside
,
dims_outside
,
strides
);
strides
);
}
else
{
}
else
{
...
@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
...
@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
}
}
template
<
typename
T
>
template
<
typename
T
>
void
PyTensorSetFromArray
(
void
Py
CPU
TensorSetFromArray
(
framework
::
Tensor
&
self
,
framework
::
Tensor
&
self
,
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
array
)
{
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
array
,
paddle
::
platform
::
CPUPlace
&
place
)
{
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
...
@@ -87,9 +92,28 @@ void PyTensorSetFromArray(
...
@@ -87,9 +92,28 @@ void PyTensorSetFromArray(
}
}
self
.
Resize
(
framework
::
make_ddim
(
dims
));
self
.
Resize
(
framework
::
make_ddim
(
dims
));
auto
*
dst
=
self
.
mutable_data
<
T
>
(
p
addle
::
platform
::
CPUPlace
()
);
auto
*
dst
=
self
.
mutable_data
<
T
>
(
p
lace
);
std
::
memcpy
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
());
std
::
memcpy
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
());
}
}
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
void
PyCUDATensorSetFromArray
(
framework
::
Tensor
&
self
,
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
array
,
paddle
::
platform
::
GPUPlace
&
place
)
{
std
::
vector
<
int
>
dims
;
dims
.
reserve
(
array
.
ndim
());
for
(
size_t
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
((
int
)
array
.
shape
()[
i
]);
}
self
.
Resize
(
framework
::
make_ddim
(
dims
));
auto
*
dst
=
self
.
mutable_data
<
T
>
(
place
);
paddle
::
platform
::
GpuMemcpySync
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
(),
cudaMemcpyHostToDevice
);
}
#endif
}
// namespace pybind
}
// namespace pybind
}
// namespace paddle
}
// namespace paddle
python/paddle/v2/framework/tests/CMakeLists.txt
浏览文件 @
6824c09d
...
@@ -8,7 +8,6 @@ add_python_test(test_framework
...
@@ -8,7 +8,6 @@ add_python_test(test_framework
test_fc_op.py
test_fc_op.py
test_add_two_op.py
test_add_two_op.py
test_sgd_op.py
test_sgd_op.py
test_cross_entropy_op.py
test_mul_op.py
test_mul_op.py
test_mean_op.py
test_mean_op.py
test_sigmoid_op.py
test_sigmoid_op.py
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
6824c09d
...
@@ -26,14 +26,19 @@ class OpTestMeta(type):
...
@@ -26,14 +26,19 @@ class OpTestMeta(type):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
kwargs
=
dict
()
kwargs
=
dict
()
places
=
[]
places
.
append
(
core
.
CPUPlace
())
if
core
.
is_compile_gpu
():
places
.
append
(
core
.
GPUPlace
(
0
))
for
place
in
places
:
for
in_name
in
func
.
all_input_args
:
for
in_name
in
func
.
all_input_args
:
if
hasattr
(
self
,
in_name
):
if
hasattr
(
self
,
in_name
):
kwargs
[
in_name
]
=
in_name
kwargs
[
in_name
]
=
in_name
var
=
scope
.
new_var
(
in_name
).
get_tensor
()
var
=
scope
.
new_var
(
in_name
).
get_tensor
()
arr
=
getattr
(
self
,
in_name
)
arr
=
getattr
(
self
,
in_name
)
var
.
set_dims
(
arr
.
shape
)
var
.
set_dims
(
arr
.
shape
)
var
.
set
(
arr
)
var
.
set
(
arr
,
place
)
else
:
else
:
kwargs
[
in_name
]
=
"@EMPTY@"
kwargs
[
in_name
]
=
"@EMPTY@"
...
@@ -50,7 +55,7 @@ class OpTestMeta(type):
...
@@ -50,7 +55,7 @@ class OpTestMeta(type):
op
.
infer_shape
(
scope
)
op
.
infer_shape
(
scope
)
ctx
=
core
.
DeviceContext
.
cpu_context
(
)
ctx
=
core
.
DeviceContext
.
create
(
place
)
op
.
run
(
scope
,
ctx
)
op
.
run
(
scope
,
ctx
)
for
out_name
in
func
.
all_output_args
:
for
out_name
in
func
.
all_output_args
:
...
...
python/paddle/v2/framework/tests/test_add_two_op.py
浏览文件 @
6824c09d
...
@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase):
...
@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
type
=
"add_two"
self
.
type
=
"add_two"
self
.
X
=
numpy
.
random
.
random
((
342
,
34
5
)).
astype
(
"float32"
)
self
.
X
=
numpy
.
random
.
random
((
102
,
10
5
)).
astype
(
"float32"
)
self
.
Y
=
numpy
.
random
.
random
((
342
,
34
5
)).
astype
(
"float32"
)
self
.
Y
=
numpy
.
random
.
random
((
102
,
10
5
)).
astype
(
"float32"
)
self
.
Out
=
self
.
X
+
self
.
Y
self
.
Out
=
self
.
X
+
self
.
Y
...
...
python/paddle/v2/framework/tests/test_fc_op.py
浏览文件 @
6824c09d
...
@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
...
@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
class
TestFc
(
unittest
.
TestCase
):
class
TestFc
(
unittest
.
TestCase
):
def
test_fc
(
self
):
def
test_fc
(
self
):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
place
=
core
.
CPUPlace
()
x
=
scope
.
new_var
(
"X"
)
x
=
scope
.
new_var
(
"X"
)
x_tensor
=
x
.
get_tensor
()
x_tensor
=
x
.
get_tensor
()
x_tensor
.
set_dims
([
1000
,
784
])
x_tensor
.
set_dims
([
1000
,
784
])
x_tensor
.
alloc_float
()
x_tensor
.
alloc_float
(
place
)
w
=
scope
.
new_var
(
"W"
)
w
=
scope
.
new_var
(
"W"
)
w_tensor
=
w
.
get_tensor
()
w_tensor
=
w
.
get_tensor
()
w_tensor
.
set_dims
([
784
,
100
])
w_tensor
.
set_dims
([
784
,
100
])
w_tensor
.
alloc_float
()
w_tensor
.
alloc_float
(
place
)
w_tensor
.
set
(
numpy
.
random
.
random
((
784
,
100
)).
astype
(
"float32"
))
w_tensor
.
set
(
numpy
.
random
.
random
((
784
,
100
)).
astype
(
"float32"
)
,
place
)
# Set a real numpy array here.
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
# x_tensor.set(numpy.array([]))
...
@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
...
@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
op
.
infer_shape
(
scope
)
op
.
infer_shape
(
scope
)
self
.
assertEqual
([
1000
,
100
],
tensor
.
shape
())
self
.
assertEqual
([
1000
,
100
],
tensor
.
shape
())
ctx
=
core
.
DeviceContext
.
c
pu_context
(
)
ctx
=
core
.
DeviceContext
.
c
reate
(
place
)
op
.
run
(
scope
,
ctx
)
op
.
run
(
scope
,
ctx
)
...
...
python/paddle/v2/framework/tests/test_mul_op.py
浏览文件 @
6824c09d
...
@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
...
@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
type
=
"mul"
self
.
type
=
"mul"
self
.
X
=
np
.
random
.
random
((
32
,
7
84
)).
astype
(
"float32"
)
self
.
X
=
np
.
random
.
random
((
32
,
84
)).
astype
(
"float32"
)
self
.
Y
=
np
.
random
.
random
((
7
84
,
100
)).
astype
(
"float32"
)
self
.
Y
=
np
.
random
.
random
((
84
,
100
)).
astype
(
"float32"
)
self
.
Out
=
np
.
dot
(
self
.
X
,
self
.
Y
)
self
.
Out
=
np
.
dot
(
self
.
X
,
self
.
Y
)
...
...
python/paddle/v2/framework/tests/test_rowwise_add_op.py
浏览文件 @
6824c09d
...
@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
...
@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
type
=
"rowwise_add"
self
.
type
=
"rowwise_add"
self
.
X
=
np
.
random
.
random
((
32
,
7
84
)).
astype
(
"float32"
)
self
.
X
=
np
.
random
.
random
((
32
,
84
)).
astype
(
"float32"
)
self
.
b
=
np
.
random
.
random
(
7
84
).
astype
(
"float32"
)
self
.
b
=
np
.
random
.
random
(
84
).
astype
(
"float32"
)
self
.
Out
=
np
.
add
(
self
.
X
,
self
.
b
)
self
.
Out
=
np
.
add
(
self
.
X
,
self
.
b
)
...
...
python/paddle/v2/framework/tests/test_sgd_op.py
浏览文件 @
6824c09d
...
@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
...
@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
type
=
"sgd"
self
.
type
=
"sgd"
self
.
param
=
numpy
.
random
.
random
((
342
,
34
5
)).
astype
(
"float32"
)
self
.
param
=
numpy
.
random
.
random
((
102
,
10
5
)).
astype
(
"float32"
)
self
.
grad
=
numpy
.
random
.
random
((
342
,
34
5
)).
astype
(
"float32"
)
self
.
grad
=
numpy
.
random
.
random
((
102
,
10
5
)).
astype
(
"float32"
)
self
.
learning_rate
=
0.1
self
.
learning_rate
=
0.1
self
.
param_out
=
self
.
param
-
self
.
learning_rate
*
self
.
grad
self
.
param_out
=
self
.
param
-
self
.
learning_rate
*
self
.
grad
...
...
python/paddle/v2/framework/tests/test_tensor.py
浏览文件 @
6824c09d
...
@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
...
@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
def
test_int_tensor
(
self
):
def
test_int_tensor
(
self
):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
var
=
scope
.
new_var
(
"test_tensor"
)
place
=
core
.
CPUPlace
()
tensor
=
var
.
get_tensor
()
tensor
=
var
.
get_tensor
()
tensor
.
set_dims
([
1000
,
784
])
tensor
.
set_dims
([
1000
,
784
])
tensor
.
alloc_int
()
tensor
.
alloc_int
(
place
)
tensor_array
=
numpy
.
array
(
tensor
)
tensor_array
=
numpy
.
array
(
tensor
)
self
.
assertEqual
((
1000
,
784
),
tensor_array
.
shape
)
self
.
assertEqual
((
1000
,
784
),
tensor_array
.
shape
)
tensor_array
[
3
,
9
]
=
1
tensor_array
[
3
,
9
]
=
1
tensor_array
[
19
,
11
]
=
2
tensor_array
[
19
,
11
]
=
2
tensor
.
set
(
tensor_array
)
tensor
.
set
(
tensor_array
,
place
)
tensor_array_2
=
numpy
.
array
(
tensor
)
tensor_array_2
=
numpy
.
array
(
tensor
)
self
.
assertEqual
(
1.0
,
tensor_array_2
[
3
,
9
])
self
.
assertEqual
(
1.0
,
tensor_array_2
[
3
,
9
])
...
@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
...
@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
def
test_float_tensor
(
self
):
def
test_float_tensor
(
self
):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
var
=
scope
.
new_var
(
"test_tensor"
)
place
=
core
.
CPUPlace
()
tensor
=
var
.
get_tensor
()
tensor
=
var
.
get_tensor
()
tensor
.
set_dims
([
1000
,
784
])
tensor
.
set_dims
([
1000
,
784
])
tensor
.
alloc_float
()
tensor
.
alloc_float
(
place
)
tensor_array
=
numpy
.
array
(
tensor
)
tensor_array
=
numpy
.
array
(
tensor
)
self
.
assertEqual
((
1000
,
784
),
tensor_array
.
shape
)
self
.
assertEqual
((
1000
,
784
),
tensor_array
.
shape
)
tensor_array
[
3
,
9
]
=
1.0
tensor_array
[
3
,
9
]
=
1.0
tensor_array
[
19
,
11
]
=
2.0
tensor_array
[
19
,
11
]
=
2.0
tensor
.
set
(
tensor_array
)
tensor
.
set
(
tensor_array
,
place
)
tensor_array_2
=
numpy
.
array
(
tensor
)
tensor_array_2
=
numpy
.
array
(
tensor
)
self
.
assertAlmostEqual
(
1.0
,
tensor_array_2
[
3
,
9
])
self
.
assertAlmostEqual
(
1.0
,
tensor_array_2
[
3
,
9
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录