Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5e8e7fb6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5e8e7fb6
编写于
8月 30, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change data type
上级
f5329d65
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
29 addition
and
15 deletion
+29
-15
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/data_type.h
paddle/fluid/framework/data_type.h
+14
-8
paddle/fluid/framework/data_type_transform.cc
paddle/fluid/framework/data_type_transform.cc
+1
-1
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+5
-3
paddle/fluid/framework/tensor_util.h
paddle/fluid/framework/tensor_util.h
+2
-2
paddle/fluid/operators/cast_op.h
paddle/fluid/operators/cast_op.h
+6
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
5e8e7fb6
...
...
@@ -9,7 +9,7 @@ function(windows_symbolic TARGET)
if
(
NOT EXISTS
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
src
}
.cc OR NOT EXISTS
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
src
}
.cu
)
message
(
FATAL
"
${
src
}
.cc and
${
src
}
.cu must exsits, and
${
src
}
.cu must be symbolic file."
)
endif
()
add_custom_command
(
OUTPUT .
${
src
}
.cu
PRE_BUILD
add_custom_command
(
OUTPUT .
${
src
}
.cu
COMMAND
${
CMAKE_COMMAND
}
-E remove
${
CMAKE_CURRENT_SOURCE_DIR
}
/.
${
src
}
.cu
COMMAND
${
CMAKE_COMMAND
}
-E copy
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
src
}
.cc"
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/.
${
src
}
.cu"
COMMENT
"create hidden file of
${
src
}
.cu"
)
...
...
paddle/fluid/framework/data_type.h
浏览文件 @
5e8e7fb6
...
...
@@ -64,33 +64,39 @@ template <typename Visitor>
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
switch
(
type
)
{
case
proto
::
VarType
::
FP16
:
typename
visitor
.
operator
()
<
platform
::
float16
>
();
visitor
.
template
apply
<
platform
::
float16
>();
break
;
case
proto
::
VarType
::
FP32
:
visitor
.
operator
()
<
float
>
();
visitor
.
template
apply
<
float
>();
break
;
case
proto
::
VarType
::
FP64
:
visitor
.
operator
()
<
double
>
();
visitor
.
template
apply
<
double
>();
break
;
case
proto
::
VarType
::
INT32
:
visitor
.
operator
()
<
int
>
();
visitor
.
template
apply
<
int
>();
break
;
case
proto
::
VarType
::
INT64
:
visitor
.
operator
()
<
int64_t
>
();
visitor
.
template
apply
<
int64_t
>();
break
;
case
proto
::
VarType
::
BOOL
:
visitor
.
operator
()
<
bool
>
();
visitor
.
template
apply
<
bool
>();
break
;
case
proto
::
VarType
::
UINT8
:
visitor
.
operator
()
<
uint8_t
>
();
visitor
.
template
apply
<
uint8_t
>();
break
;
case
proto
::
VarType
::
INT16
:
visitor
.
operator
()
<
int16_t
>
();
visitor
.
template
apply
<
int16_t
>();
break
;
default:
PADDLE_THROW
(
"Not supported %d"
,
type
);
}
}
template
<
typename
InT
>
void
*
AnyCast
(
const
InT
*
t
)
{
return
static_cast
<
void
*>
(
const_cast
<
InT
*>
(
t
));
}
#endif // _WIN32
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
...
...
paddle/fluid/framework/data_type_transform.cc
浏览文件 @
5e8e7fb6
...
...
@@ -37,7 +37,7 @@ struct CastDataType {
const
platform
::
DeviceContext
*
ctx_
;
template
<
typename
OutType
>
void
operator
()()
{
void
apply
()()
{
auto
*
in_begin
=
in_
.
data
<
InType
>
();
auto
*
in_end
=
in_begin
+
in_
.
numel
();
auto
*
out_begin
=
out_
->
mutable_data
<
OutType
>
(
in_
.
place
());
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
5e8e7fb6
...
...
@@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
#endif
}
/*
template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor {
Predicate predicate_;
...
...
@@ -149,7 +150,7 @@ struct AnyDTypeVisitor {
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void
operator
()()
const
{
void
apply
()() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
...
...
@@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
: tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place>
bool
operator
()(
const
Place
&
place
)
const
{
bool
apply
()(const Place& place) const {
framework::Tensor out;
out.Resize({1});
out.mutable_data<bool>(place);
...
...
@@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
ContainsInfPredicate predicate;
return Any(tensor, predicate);
}
*/
void
TensorToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
...
...
@@ -302,7 +304,7 @@ struct DeserializedDataFunctor {
:
buf_
(
buf
),
tensor_
(
tensor
),
place_
(
place
)
{}
template
<
typename
T
>
void
operator
()
()
{
void
apply
()
{
*
buf_
=
tensor_
->
mutable_data
<
T
>
(
place_
);
}
...
...
paddle/fluid/framework/tensor_util.h
浏览文件 @
5e8e7fb6
...
...
@@ -57,8 +57,8 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
template
<
typename
T
>
void
TesnorToVector
(
const
Tensor
&
src
,
std
::
vector
<
T
>*
dst
);
bool
TensorContainsNAN
(
const
framework
::
Tensor
&
tensor
);
bool
TensorContainsInf
(
const
framework
::
Tensor
&
tensor
);
//
bool TensorContainsNAN(const framework::Tensor& tensor);
//
bool TensorContainsInf(const framework::Tensor& tensor);
void
TensorToStream
(
std
::
ostream
&
os
,
const
Tensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
...
...
paddle/fluid/operators/cast_op.h
浏览文件 @
5e8e7fb6
...
...
@@ -54,11 +54,17 @@ class CastOpKernel : public framework::OpKernel<InT> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
#if !defined(_MSC_VER)
framework
::
VisitDataType
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"out_dtype"
)),
CastOpFunctor
<
DeviceContext
,
InT
>
(
in
,
out
,
context
.
template
device_context
<
DeviceContext
>()));
#else
auto
type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
context
.
Attr
<
int
>
(
"out_dtype"
));
trans
#endif // msvc
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录