Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
125129cf
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
125129cf
编写于
1月 16, 2018
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: has_header_field
Former-commit-id:
cf20a9f8
上级
088423ba
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
41 addition
and
18 deletion
+41
-18
oneflow/core/operator/basic_data_loader_op.cpp
oneflow/core/operator/basic_data_loader_op.cpp
+1
-0
oneflow/core/operator/recurrent_op.cpp
oneflow/core/operator/recurrent_op.cpp
+1
-1
oneflow/core/register/blob.cpp
oneflow/core/register/blob.cpp
+6
-6
oneflow/core/register/blob.h
oneflow/core/register/blob.h
+3
-2
oneflow/core/register/blob_desc.cpp
oneflow/core/register/blob_desc.cpp
+20
-4
oneflow/core/register/blob_desc.h
oneflow/core/register/blob_desc.h
+6
-2
oneflow/core/register/blob_desc.proto
oneflow/core/register/blob_desc.proto
+4
-3
未找到文件。
oneflow/core/operator/basic_data_loader_op.cpp
浏览文件 @
125129cf
...
@@ -30,6 +30,7 @@ void BasicDataLoaderOp::InferBlobDescs(
...
@@ -30,6 +30,7 @@ void BasicDataLoaderOp::InferBlobDescs(
}
}
out
->
mut_shape
()
=
Shape
(
dim_vec
);
out
->
mut_shape
()
=
Shape
(
dim_vec
);
out
->
set_data_type
(
conf
.
data_type
());
out
->
set_data_type
(
conf
.
data_type
());
out
->
set_has_header_field
(
true
);
out
->
set_has_data_id_field
(
JobDesc
::
Singleton
()
->
SizeOfOneDataId
()
>
0
);
out
->
set_has_data_id_field
(
JobDesc
::
Singleton
()
->
SizeOfOneDataId
()
>
0
);
out
->
set_has_col_num_field
(
conf
.
max_sequence_size
()
>
1
);
out
->
set_has_col_num_field
(
conf
.
max_sequence_size
()
>
1
);
out
->
set_max_col_num
(
conf
.
max_sequence_size
());
out
->
set_max_col_num
(
conf
.
max_sequence_size
());
...
...
oneflow/core/operator/recurrent_op.cpp
浏览文件 @
125129cf
...
@@ -13,7 +13,7 @@ void InferBasicRnnCellBlobDesc(
...
@@ -13,7 +13,7 @@ void InferBasicRnnCellBlobDesc(
int64_t
piece_size
=
in_blob_desc
->
shape
().
At
(
0
);
int64_t
piece_size
=
in_blob_desc
->
shape
().
At
(
0
);
BlobDesc
data_tmp_blob_desc
=
BlobDesc
data_tmp_blob_desc
=
BlobDesc
(
Shape
({
embedding_size
,
hidden_size
}),
BlobDesc
(
Shape
({
embedding_size
,
hidden_size
}),
JobDesc
::
Singleton
()
->
DefaultDataType
(),
false
,
false
,
JobDesc
::
Singleton
()
->
DefaultDataType
(),
false
,
false
,
false
,
in_blob_desc
->
max_col_num
());
in_blob_desc
->
max_col_num
());
*
GetBlobDesc4BnInOp
(
"in_ip_op_out"
)
=
data_tmp_blob_desc
;
*
GetBlobDesc4BnInOp
(
"in_ip_op_out"
)
=
data_tmp_blob_desc
;
*
GetBlobDesc4BnInOp
(
"hidden_ip_op_out"
)
=
data_tmp_blob_desc
;
*
GetBlobDesc4BnInOp
(
"hidden_ip_op_out"
)
=
data_tmp_blob_desc
;
...
...
oneflow/core/register/blob.cpp
浏览文件 @
125129cf
...
@@ -7,8 +7,12 @@ namespace oneflow {
...
@@ -7,8 +7,12 @@ namespace oneflow {
Blob
::
Blob
(
const
BlobDesc
*
blob_desc
,
char
*
mem_ptr
,
Blob
::
Blob
(
const
BlobDesc
*
blob_desc
,
char
*
mem_ptr
,
const
void
*
comm_net_token
)
{
const
void
*
comm_net_token
)
{
mem_ptr_
=
mem_ptr
;
if
(
blob_desc
->
has_header_field
())
{
blob_header_
=
reinterpret_cast
<
BlobHeader
*>
(
mem_ptr
);
blob_header_
=
reinterpret_cast
<
BlobHeader
*>
(
mem_ptr
);
}
else
{
blob_header_
=
nullptr
;
}
if
(
blob_desc
->
has_data_id_field
())
{
if
(
blob_desc
->
has_data_id_field
())
{
data_id_ptr_
=
mem_ptr
+
blob_desc
->
ByteSizeOfBlobHeaderField
();
data_id_ptr_
=
mem_ptr
+
blob_desc
->
ByteSizeOfBlobHeaderField
();
}
else
{
}
else
{
...
@@ -43,10 +47,6 @@ void Blob::set_col_num(int32_t no, int32_t val) {
...
@@ -43,10 +47,6 @@ void Blob::set_col_num(int32_t no, int32_t val) {
*
(
col_num_ptr_
+
no
)
=
val
;
*
(
col_num_ptr_
+
no
)
=
val
;
}
}
const
void
*
Blob
::
memory_ptr
()
const
{
return
reinterpret_cast
<
void
*>
(
blob_header_
);
}
size_t
Blob
::
ByteSizeOfBlobHeaderField
()
const
{
size_t
Blob
::
ByteSizeOfBlobHeaderField
()
const
{
return
blob_desc_
->
ByteSizeOfBlobHeaderField
();
return
blob_desc_
->
ByteSizeOfBlobHeaderField
();
}
}
...
...
oneflow/core/register/blob.h
浏览文件 @
125129cf
...
@@ -24,8 +24,8 @@ class Blob final {
...
@@ -24,8 +24,8 @@ class Blob final {
int32_t
col_num
(
int32_t
no
)
const
;
int32_t
col_num
(
int32_t
no
)
const
;
void
set_col_num
(
int32_t
no
,
int32_t
val
);
void
set_col_num
(
int32_t
no
,
int32_t
val
);
const
void
*
memory_ptr
()
const
;
const
void
*
memory_ptr
()
const
{
return
mem_ptr_
;
}
void
*
mut_memory_ptr
()
{
return
const_cast
<
void
*>
(
memory_ptr
())
;
}
void
*
mut_memory_ptr
()
{
return
mem_ptr_
;
}
template
<
typename
T
=
void
>
template
<
typename
T
=
void
>
const
T
*
dptr
()
const
{
const
T
*
dptr
()
const
{
...
@@ -82,6 +82,7 @@ class Blob final {
...
@@ -82,6 +82,7 @@ class Blob final {
<<
blob_desc_
->
data_type
()
<<
" "
<<
GetDataType
<
T
>::
val
;
<<
blob_desc_
->
data_type
()
<<
" "
<<
GetDataType
<
T
>::
val
;
}
}
void
*
mem_ptr_
;
BlobHeader
*
blob_header_
;
BlobHeader
*
blob_header_
;
char
*
data_id_ptr_
;
char
*
data_id_ptr_
;
int32_t
*
col_num_ptr_
;
int32_t
*
col_num_ptr_
;
...
...
oneflow/core/register/blob_desc.cpp
浏览文件 @
125129cf
...
@@ -5,12 +5,14 @@ namespace oneflow {
...
@@ -5,12 +5,14 @@ namespace oneflow {
BlobDesc
::
BlobDesc
()
BlobDesc
::
BlobDesc
()
:
BlobDesc
(
Shape
(),
JobDesc
::
Singleton
()
->
DefaultDataType
(),
false
,
false
,
:
BlobDesc
(
Shape
(),
JobDesc
::
Singleton
()
->
DefaultDataType
(),
false
,
false
,
1
)
{}
false
,
1
)
{}
BlobDesc
::
BlobDesc
(
Shape
shape
,
DataType
data_type
,
bool
has_data_id_field
,
BlobDesc
::
BlobDesc
(
Shape
shape
,
DataType
data_type
,
bool
has_header_field
,
bool
has_col_num_field
,
int32_t
max_col_num
)
bool
has_data_id_field
,
bool
has_col_num_field
,
int32_t
max_col_num
)
:
shape_
(
shape
),
:
shape_
(
shape
),
data_type_
(
data_type
),
data_type_
(
data_type
),
has_header_field_
(
has_header_field
),
has_data_id_field_
(
has_data_id_field
),
has_data_id_field_
(
has_data_id_field
),
has_col_num_field_
(
has_col_num_field
),
has_col_num_field_
(
has_col_num_field
),
max_col_num_
(
max_col_num
)
{}
max_col_num_
(
max_col_num
)
{}
...
@@ -18,6 +20,7 @@ BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
...
@@ -18,6 +20,7 @@ BlobDesc::BlobDesc(Shape shape, DataType data_type, bool has_data_id_field,
BlobDesc
::
BlobDesc
(
const
BlobDescProto
&
proto
)
{
BlobDesc
::
BlobDesc
(
const
BlobDescProto
&
proto
)
{
shape_
=
Shape
(
proto
.
shape
());
shape_
=
Shape
(
proto
.
shape
());
data_type_
=
proto
.
data_type
();
data_type_
=
proto
.
data_type
();
has_header_field_
=
proto
.
has_header_field
();
has_data_id_field_
=
proto
.
has_data_id_field
();
has_data_id_field_
=
proto
.
has_data_id_field
();
has_col_num_field_
=
proto
.
has_col_num_field
();
has_col_num_field_
=
proto
.
has_col_num_field
();
max_col_num_
=
proto
.
max_col_num
();
max_col_num_
=
proto
.
max_col_num
();
...
@@ -26,11 +29,20 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) {
...
@@ -26,11 +29,20 @@ BlobDesc::BlobDesc(const BlobDescProto& proto) {
void
BlobDesc
::
ToProto
(
BlobDescProto
*
proto
)
const
{
void
BlobDesc
::
ToProto
(
BlobDescProto
*
proto
)
const
{
shape_
.
ToProto
(
proto
->
mutable_shape
());
shape_
.
ToProto
(
proto
->
mutable_shape
());
proto
->
set_data_type
(
data_type_
);
proto
->
set_data_type
(
data_type_
);
proto
->
set_has_header_field
(
has_header_field_
);
proto
->
set_has_data_id_field
(
has_data_id_field_
);
proto
->
set_has_data_id_field
(
has_data_id_field_
);
proto
->
set_has_col_num_field
(
has_col_num_field_
);
proto
->
set_has_col_num_field
(
has_col_num_field_
);
proto
->
set_max_col_num
(
max_col_num_
);
proto
->
set_max_col_num
(
max_col_num_
);
}
}
size_t
BlobDesc
::
ByteSizeOfBlobHeaderField
()
const
{
if
(
has_header_field_
)
{
return
sizeof
(
BlobHeader
);
}
else
{
return
0
;
}
}
size_t
BlobDesc
::
ByteSizeOfDataIdField
()
const
{
size_t
BlobDesc
::
ByteSizeOfDataIdField
()
const
{
if
(
has_data_id_field_
)
{
if
(
has_data_id_field_
)
{
return
shape_
.
At
(
0
)
*
JobDesc
::
Singleton
()
->
SizeOfOneDataId
();
return
shape_
.
At
(
0
)
*
JobDesc
::
Singleton
()
->
SizeOfOneDataId
();
...
@@ -58,6 +70,7 @@ size_t BlobDesc::TotalByteSize() const {
...
@@ -58,6 +70,7 @@ size_t BlobDesc::TotalByteSize() const {
bool
BlobDesc
::
operator
==
(
const
BlobDesc
&
rhs
)
const
{
bool
BlobDesc
::
operator
==
(
const
BlobDesc
&
rhs
)
const
{
return
shape_
==
rhs
.
shape_
&&
data_type_
==
rhs
.
data_type_
return
shape_
==
rhs
.
shape_
&&
data_type_
==
rhs
.
data_type_
&&
has_header_field_
==
rhs
.
has_header_field_
&&
has_data_id_field_
==
rhs
.
has_data_id_field_
&&
has_data_id_field_
==
rhs
.
has_data_id_field_
&&
has_col_num_field_
==
rhs
.
has_col_num_field_
&&
has_col_num_field_
==
rhs
.
has_col_num_field_
&&
max_col_num_
==
rhs
.
max_col_num_
;
&&
max_col_num_
==
rhs
.
max_col_num_
;
...
@@ -67,6 +80,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
...
@@ -67,6 +80,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
int64_t
total_byte_size
=
0
;
int64_t
total_byte_size
=
0
;
int64_t
total_data_content_byte_size
=
0
;
int64_t
total_data_content_byte_size
=
0
;
HashSet
<
int
>
data_type_set
;
HashSet
<
int
>
data_type_set
;
bool
has_header_field
=
false
;
bool
has_data_id_field
=
false
;
bool
has_data_id_field
=
false
;
bool
has_col_num_field
=
false
;
bool
has_col_num_field
=
false
;
int32_t
max_col_num
=
-
1
;
int32_t
max_col_num
=
-
1
;
...
@@ -78,6 +92,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
...
@@ -78,6 +92,7 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
data_type_set
.
insert
(
static_cast
<
int
>
(
blob_desc
->
data_type
()));
data_type_set
.
insert
(
static_cast
<
int
>
(
blob_desc
->
data_type
()));
has_data_id_field
=
has_data_id_field
||
blob_desc
->
has_data_id_field
();
has_data_id_field
=
has_data_id_field
||
blob_desc
->
has_data_id_field
();
has_col_num_field
=
has_col_num_field
||
blob_desc
->
has_col_num_field
();
has_col_num_field
=
has_col_num_field
||
blob_desc
->
has_col_num_field
();
has_header_field
=
has_header_field
||
blob_desc
->
has_header_field
();
if
(
max_col_num
==
-
1
)
{
if
(
max_col_num
==
-
1
)
{
max_col_num
=
blob_desc
->
max_col_num
();
max_col_num
=
blob_desc
->
max_col_num
();
}
else
{
}
else
{
...
@@ -90,7 +105,8 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
...
@@ -90,7 +105,8 @@ BlobDesc ComputePackedBlobDesc(std::function<const BlobDesc*()> NextBlobDesc) {
if
(
blob_desc_cnt
<=
1
)
{
return
ret
;
}
if
(
blob_desc_cnt
<=
1
)
{
return
ret
;
}
CHECK_EQ
(
has_col_num_field
,
false
);
CHECK_EQ
(
has_col_num_field
,
false
);
CHECK_EQ
(
max_col_num
,
1
);
CHECK_EQ
(
max_col_num
,
1
);
if
(
has_data_id_field
==
false
&&
data_type_set
.
size
()
==
1
)
{
if
(
has_header_field
==
false
&&
has_data_id_field
==
false
&&
data_type_set
.
size
()
==
1
)
{
DataType
sole_data_type
=
static_cast
<
DataType
>
(
*
(
data_type_set
.
begin
()));
DataType
sole_data_type
=
static_cast
<
DataType
>
(
*
(
data_type_set
.
begin
()));
int64_t
size_of_one_elem
=
GetSizeOfDataType
(
sole_data_type
);
int64_t
size_of_one_elem
=
GetSizeOfDataType
(
sole_data_type
);
CHECK_EQ
(
total_data_content_byte_size
%
size_of_one_elem
,
0
);
CHECK_EQ
(
total_data_content_byte_size
%
size_of_one_elem
,
0
);
...
...
oneflow/core/register/blob_desc.h
浏览文件 @
125129cf
...
@@ -19,7 +19,7 @@ class BlobDesc final {
...
@@ -19,7 +19,7 @@ class BlobDesc final {
~
BlobDesc
()
=
default
;
~
BlobDesc
()
=
default
;
BlobDesc
();
BlobDesc
();
BlobDesc
(
Shape
shape
,
DataType
data_type
,
bool
has_data_id_field
,
BlobDesc
(
Shape
,
DataType
,
bool
has_header_field
,
bool
has_data_id_field
,
bool
has_col_num_field
,
int32_t
max_col_num
);
bool
has_col_num_field
,
int32_t
max_col_num
);
BlobDesc
(
Shape
shape
)
:
BlobDesc
()
{
shape_
=
shape
;
}
BlobDesc
(
Shape
shape
)
:
BlobDesc
()
{
shape_
=
shape
;
}
BlobDesc
(
const
BlobDescProto
&
proto
);
BlobDesc
(
const
BlobDescProto
&
proto
);
...
@@ -30,6 +30,9 @@ class BlobDesc final {
...
@@ -30,6 +30,9 @@ class BlobDesc final {
DataType
data_type
()
const
{
return
data_type_
;
}
DataType
data_type
()
const
{
return
data_type_
;
}
void
set_data_type
(
DataType
val
)
{
data_type_
=
val
;
}
void
set_data_type
(
DataType
val
)
{
data_type_
=
val
;
}
bool
has_header_field
()
const
{
return
has_header_field_
;
}
void
set_has_header_field
(
bool
val
)
{
has_header_field_
=
val
;
}
bool
has_data_id_field
()
const
{
return
has_data_id_field_
;
}
bool
has_data_id_field
()
const
{
return
has_data_id_field_
;
}
void
set_has_data_id_field
(
bool
val
)
{
has_data_id_field_
=
val
;
}
void
set_has_data_id_field
(
bool
val
)
{
has_data_id_field_
=
val
;
}
...
@@ -40,7 +43,7 @@ class BlobDesc final {
...
@@ -40,7 +43,7 @@ class BlobDesc final {
void
set_max_col_num
(
int32_t
val
)
{
max_col_num_
=
val
;
}
void
set_max_col_num
(
int32_t
val
)
{
max_col_num_
=
val
;
}
void
ToProto
(
BlobDescProto
*
proto
)
const
;
void
ToProto
(
BlobDescProto
*
proto
)
const
;
size_t
ByteSizeOfBlobHeaderField
()
const
{
return
sizeof
(
BlobHeader
);
}
size_t
ByteSizeOfBlobHeaderField
()
const
;
size_t
ByteSizeOfDataIdField
()
const
;
size_t
ByteSizeOfDataIdField
()
const
;
size_t
ByteSizeOfColNumField
()
const
;
size_t
ByteSizeOfColNumField
()
const
;
size_t
ByteSizeOfDataContentField
()
const
;
size_t
ByteSizeOfDataContentField
()
const
;
...
@@ -50,6 +53,7 @@ class BlobDesc final {
...
@@ -50,6 +53,7 @@ class BlobDesc final {
private:
private:
Shape
shape_
;
Shape
shape_
;
DataType
data_type_
;
DataType
data_type_
;
bool
has_header_field_
;
bool
has_data_id_field_
;
bool
has_data_id_field_
;
bool
has_col_num_field_
;
bool
has_col_num_field_
;
int64_t
max_col_num_
;
int64_t
max_col_num_
;
...
...
oneflow/core/register/blob_desc.proto
浏览文件 @
125129cf
...
@@ -7,7 +7,8 @@ import "oneflow/core/common/data_type.proto";
...
@@ -7,7 +7,8 @@ import "oneflow/core/common/data_type.proto";
message
BlobDescProto
{
message
BlobDescProto
{
required
ShapeProto
shape
=
1
;
required
ShapeProto
shape
=
1
;
required
DataType
data_type
=
2
;
required
DataType
data_type
=
2
;
required
bool
has_data_id_field
=
3
;
required
bool
has_header_field
=
3
;
required
bool
has_col_num_field
=
4
;
required
bool
has_data_id_field
=
4
;
required
int32
max_col_num
=
5
;
required
bool
has_col_num_field
=
5
;
required
int32
max_col_num
=
6
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录