Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
93eaa8e3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
93eaa8e3
编写于
12月 29, 2017
作者:
Y
Yu Yang
提交者:
GitHub
12月 29, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7068 from reyoung/feature/is_nan
Feature/is nan
上级
9b847380
f97205ee
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
266 addition
and
3 deletion
+266
-3
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+10
-2
paddle/framework/tensor_util.cc
paddle/framework/tensor_util.cc
+119
-0
paddle/framework/tensor_util.cu
paddle/framework/tensor_util.cu
+1
-0
paddle/framework/tensor_util.h
paddle/framework/tensor_util.h
+8
-0
paddle/framework/tensor_util_test.cc
paddle/framework/tensor_util_test.cc
+24
-0
paddle/framework/tensor_util_test.cu
paddle/framework/tensor_util_test.cu
+57
-0
paddle/platform/device_context.h
paddle/platform/device_context.h
+20
-0
paddle/platform/place.h
paddle/platform/place.h
+27
-1
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
93eaa8e3
...
@@ -5,10 +5,18 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
...
@@ -5,10 +5,18 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context framework_proto
)
if
(
WITH_GPU
)
nv_library
(
tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto
)
else
()
cc_library
(
tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto
)
endif
()
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
tensor_util_test SRCS tensor_util_test.cc DEPS tensor
)
if
(
WITH_GPU
)
nv_test
(
tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor
)
else
()
cc_test
(
tensor_util_test SRCS tensor_util_test.cc DEPS tensor
)
endif
()
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
...
...
paddle/framework/tensor_util.cc
0 → 100644
浏览文件 @
93eaa8e3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/tensor_util.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
Predicate
,
typename
DevCtx
>
struct
AnyDTypeVisitor
{
Predicate
predicate_
;
const
Tensor
&
tensor_
;
const
DevCtx
&
ctx_
;
Tensor
*
out_
;
AnyDTypeVisitor
(
Predicate
predicate
,
const
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
Tensor
*
out
)
:
predicate_
(
predicate
),
tensor_
(
tensor
),
ctx_
(
ctx
),
out_
(
out
)
{}
template
<
typename
T
>
void
operator
()()
const
{
auto
t
=
EigenVector
<
T
>::
Flatten
(
tensor_
);
auto
o
=
EigenScalar
<
bool
>::
From
(
*
out_
);
// return any of predicate_(t) is true.
o
.
device
(
*
ctx_
.
eigen_device
())
=
predicate_
(
t
).
any
();
}
};
template
<
typename
Predicate
,
typename
DevCtx
>
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
VisitDataType
(
ToDataType
(
tensor
.
type
()),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
predicate
,
tensor
,
ctx
,
out
));
}
template
<
typename
Predicate
>
struct
AnyVisitor
:
public
boost
::
static_visitor
<
bool
>
{
const
framework
::
Tensor
&
tensor_
;
Predicate
predicate_
;
AnyVisitor
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
:
tensor_
(
tensor
),
predicate_
(
std
::
move
(
predicate
))
{}
template
<
typename
Place
>
bool
operator
()(
const
Place
&
place
)
const
{
framework
::
Tensor
out
;
out
.
Resize
({
1
});
out
.
mutable_data
<
bool
>
(
place
);
auto
*
ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place
);
AnyImpl
(
predicate_
,
tensor_
,
*
ctx
,
&
out
);
return
this
->
GetResult
(
out
,
place
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CUDAPlace
&
gpu
)
const
{
platform
::
CPUPlace
cpu
;
framework
::
Tensor
tmp
;
tmp
.
Resize
({
1
});
tmp
.
mutable_data
<
bool
>
(
cpu
);
auto
gpuctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
gpu
);
gpuctx
->
Wait
();
CopyFrom
(
out
,
cpu
,
*
gpuctx
,
&
tmp
);
gpuctx
->
Wait
();
return
GetResult
(
tmp
,
cpu
);
}
bool
GetResult
(
const
framework
::
Tensor
&
out
,
const
platform
::
CPUPlace
&
cpu
)
const
{
return
*
out
.
data
<
bool
>
();
}
};
template
<
typename
Predicate
>
inline
bool
Any
(
const
framework
::
Tensor
&
tensor
,
Predicate
predicate
)
{
AnyVisitor
<
Predicate
>
visitor
(
tensor
,
predicate
);
auto
place
=
tensor
.
place
();
return
platform
::
VisitPlace
(
place
,
visitor
);
}
struct
HasNANPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isnan
())
{
// Cast eigen_vector to vector of bool. true if is inf.
return
eigen_vec
.
isnan
();
}
};
bool
HasNAN
(
const
framework
::
Tensor
&
tensor
)
{
HasNANPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
struct
HasInfPredicate
{
template
<
typename
T
>
auto
operator
()(
const
T
&
eigen_vec
)
const
->
decltype
(
std
::
declval
<
T
>
().
isinf
())
{
// Cast eigen_vector to vector of bool. true if is inf.
return
eigen_vec
.
isinf
();
}
};
bool
HasInf
(
const
framework
::
Tensor
&
tensor
)
{
HasInfPredicate
predicate
;
return
Any
(
tensor
,
predicate
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/tensor_util.cu
0 → 120000
浏览文件 @
93eaa8e3
.
/
tensor_util
.
cc
\ No newline at end of file
paddle/framework/tensor_util.h
浏览文件 @
93eaa8e3
...
@@ -14,8 +14,10 @@ limitations under the License. */
...
@@ -14,8 +14,10 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -207,6 +209,12 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
...
@@ -207,6 +209,12 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr
,
size
);
src_ptr
,
size
);
}
}
// Returns true if a tensor contains NAN, i.e., Not A Number.
extern
bool
HasNAN
(
const
framework
::
Tensor
&
tensor
);
// Returns true if a tensor contains Inf, i.e., Infinity.
extern
bool
HasInf
(
const
framework
::
Tensor
&
tensor
);
inline
void
SerializeToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
inline
void
SerializeToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
const
platform
::
DeviceContext
&
dev_ctx
)
{
// TODO(typhoonzero): serialize to ostream
// TODO(typhoonzero): serialize to ostream
...
...
paddle/framework/tensor_util_test.cc
浏览文件 @
93eaa8e3
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <string>
namespace
paddle
{
namespace
paddle
{
...
@@ -230,6 +231,29 @@ TEST(CopyToVector, Tensor) {
...
@@ -230,6 +231,29 @@ TEST(CopyToVector, Tensor) {
#endif
#endif
}
}
TEST
(
IsNAN
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
float
*
buf
=
src
.
mutable_data
<
float
>
({
3
},
CPUPlace
());
buf
[
0
]
=
0.0
;
buf
[
1
]
=
NAN
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasNAN
(
src
));
}
TEST
(
IsInf
,
CPU
)
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
src
;
double
*
buf
=
src
.
mutable_data
<
double
>
({
3
},
CPUPlace
());
buf
[
0
]
=
1.0
;
buf
[
1
]
=
INFINITY
;
buf
[
2
]
=
0.0
;
ASSERT_TRUE
(
HasInf
(
src
));
}
TEST
(
Tensor
,
SerializeAndDeserialize
)
{
TEST
(
Tensor
,
SerializeAndDeserialize
)
{
framework
::
Tensor
src_tensor
;
framework
::
Tensor
src_tensor
;
int
array
[
6
]
=
{
1
,
2
,
3
,
4
,
5
,
6
};
int
array
[
6
]
=
{
1
,
2
,
3
,
4
,
5
,
6
};
...
...
paddle/framework/tensor_util_test.cu
0 → 100644
浏览文件 @
93eaa8e3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
namespace
paddle
{
namespace
framework
{
static
__global__
void
FillNAN
(
float
*
buf
)
{
buf
[
0
]
=
0.0
;
buf
[
1
]
=
0.1
;
buf
[
2
]
=
NAN
;
}
static
__global__
void
FillInf
(
float
*
buf
)
{
buf
[
0
]
=
0.0
;
buf
[
1
]
=
INFINITY
;
buf
[
2
]
=
0.5
;
}
TEST
(
HasNAN
,
GPU
)
{
Tensor
tensor
;
platform
::
CUDAPlace
gpu
(
0
);
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
cuda_ctx
=
pool
.
GetByPlace
(
gpu
);
float
*
buf
=
tensor
.
mutable_data
<
float
>
({
3
},
gpu
);
FillNAN
<<<
1
,
1
,
0
,
cuda_ctx
->
stream
()
>>>
(
buf
);
cuda_ctx
->
Wait
();
ASSERT_TRUE
(
HasNAN
(
tensor
));
}
TEST
(
HasInf
,
GPU
)
{
Tensor
tensor
;
platform
::
CUDAPlace
gpu
(
0
);
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
cuda_ctx
=
pool
.
GetByPlace
(
gpu
);
float
*
buf
=
tensor
.
mutable_data
<
float
>
({
3
},
gpu
);
FillInf
<<<
1
,
1
,
0
,
cuda_ctx
->
stream
()
>>>
(
buf
);
cuda_ctx
->
Wait
();
ASSERT_TRUE
(
HasInf
(
tensor
));
}
}
// namespace framework
}
// namespace paddle
paddle/platform/device_context.h
浏览文件 @
93eaa8e3
...
@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext {
...
@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
};
};
template
<
typename
Place
>
struct
DefaultDeviceContextType
;
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CPUPlace
>
{
using
TYPE
=
CPUDeviceContext
;
};
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
class
EigenCudaStreamDevice
;
class
EigenCudaStreamDevice
;
...
@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
};
};
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CUDAPlace
>
{
using
TYPE
=
CUDADeviceContext
;
};
class
CUDNNDeviceContext
:
public
CUDADeviceContext
{
class
CUDNNDeviceContext
:
public
CUDADeviceContext
{
public:
public:
explicit
CUDNNDeviceContext
(
CUDAPlace
place
);
explicit
CUDNNDeviceContext
(
CUDAPlace
place
);
...
@@ -125,6 +138,13 @@ class DeviceContextPool {
...
@@ -125,6 +138,13 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */
/*! \brief Return handle of single device context. */
const
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
const
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
template
<
typename
Place
>
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*
GetByPlace
(
const
Place
&
place
)
{
return
reinterpret_cast
<
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*>
(
Get
(
place
));
}
private:
private:
static
DeviceContextPool
*
pool
;
static
DeviceContextPool
*
pool
;
constexpr
static
int
LEFT_SHIFT
=
8
;
constexpr
static
int
LEFT_SHIFT
=
8
;
...
...
paddle/platform/place.h
浏览文件 @
93eaa8e3
...
@@ -15,7 +15,7 @@ limitations under the License. */
...
@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
#include "paddle/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &);
...
@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
template
<
typename
Visitor
>
struct
PlaceVisitorWrapper
:
public
boost
::
static_visitor
<
typename
Visitor
::
result_type
>
{
const
Visitor
&
visitor_
;
explicit
PlaceVisitorWrapper
(
const
Visitor
&
visitor
)
:
visitor_
(
visitor
)
{}
typename
Visitor
::
result_type
operator
()(
const
CPUPlace
&
cpu
)
const
{
return
visitor_
(
cpu
);
}
typename
Visitor
::
result_type
operator
()(
const
CUDAPlace
&
cuda
)
const
{
#ifdef PADDLE_WITH_CUDA
return
visitor_
(
cuda
);
#else
PADDLE_THROW
(
"Paddle is not compiled with CUDA. Cannot visit cuda device"
);
return
typename
Visitor
::
result_type
();
#endif
}
};
template
<
typename
Visitor
>
typename
Visitor
::
result_type
VisitPlace
(
const
Place
&
place
,
const
Visitor
&
visitor
)
{
return
boost
::
apply_visitor
(
PlaceVisitorWrapper
<
Visitor
>
(
visitor
),
place
);
}
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录