Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
08793179
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08793179
编写于
1月 21, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
1月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[fleet executor] add a tensor wrapper to support python numpy input (#39076)
上级
3dd7f353
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
401 addition
and
6 deletion
+401
-6
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
+2
-2
paddle/fluid/distributed/fleet_executor/dist_model.cc
paddle/fluid/distributed/fleet_executor/dist_model.cc
+2
-2
paddle/fluid/distributed/fleet_executor/dist_model.h
paddle/fluid/distributed/fleet_executor/dist_model.h
+3
-2
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.cc
...d/distributed/fleet_executor/dist_model_tensor_wrapper.cc
+100
-0
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h
...id/distributed/fleet_executor/dist_model_tensor_wrapper.h
+77
-0
paddle/fluid/pybind/bind_fleet_executor.cc
paddle/fluid/pybind/bind_fleet_executor.cc
+153
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_dist_model_tensor.py
...on/paddle/fluid/tests/unittests/test_dist_model_tensor.py
+63
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/CMakeLists.txt
浏览文件 @
08793179
...
...
@@ -12,8 +12,8 @@ endif()
cc_library
(
task_loop_thread_pool SRCS task_loop_thread_pool.cc task_loop_thread.cc task_loop.cc DEPS enforce glog
)
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus
.cc
cc_library
(
fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc dist_model.cc
interceptor.cc
compute_interceptor.cc amplifier_interceptor.cc message_service.cc message_bus.cc dist_model_tensor_wrapper
.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto task_loop_thread_pool collective_helper
op_registry executor_gc_helper gflags glog
${
BRPC_DEPS
}
)
...
...
paddle/fluid/distributed/fleet_executor/dist_model.cc
浏览文件 @
08793179
...
...
@@ -355,8 +355,8 @@ bool DistModel::PrepareFeedAndFetch() {
return
true
;
}
void
DistModel
::
Run
(
const
std
::
vector
<
paddle
::
framework
::
Tensor
>
&
input_data
,
std
::
vector
<
paddle
::
framework
::
Tensor
>
*
output_data
)
{
void
DistModel
::
Run
(
const
std
::
vector
<
DistModel
Tensor
>
&
input_data
,
std
::
vector
<
DistModel
Tensor
>
*
output_data
)
{
/* TODO(fleet exe dev): implement this funct */
}
...
...
paddle/fluid/distributed/fleet_executor/dist_model.h
浏览文件 @
08793179
...
...
@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -56,8 +57,8 @@ class DistModel {
public:
explicit
DistModel
(
const
DistModelConfig
&
config
)
:
config_
(
config
)
{}
bool
Init
();
void
Run
(
const
std
::
vector
<
paddle
::
framework
::
Tensor
>&
input_data
,
std
::
vector
<
paddle
::
framework
::
Tensor
>*
output_data
);
void
Run
(
const
std
::
vector
<
DistModel
Tensor
>&
input_data
,
std
::
vector
<
DistModel
Tensor
>*
output_data
);
~
DistModel
()
=
default
;
private:
...
...
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.cc
0 → 100644
浏览文件 @
08793179
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
void
DistModelDataBuf
::
Reset
(
void
*
data
,
size_t
length
)
{
Free
();
memory_owned_
=
false
;
data_
=
data
;
length_
=
length
;
}
void
DistModelDataBuf
::
Free
()
{
if
(
memory_owned_
&&
data_
)
{
PADDLE_ENFORCE_GT
(
length_
,
0UL
,
platform
::
errors
::
PreconditionNotMet
(
"Error occurred when deconstruct DistModelDataBuf: "
"it contains no data!"
));
// NOTE: if own the memory, it must be char* type
delete
[]
static_cast
<
char
*>
(
data_
);
data_
=
nullptr
;
length_
=
0
;
}
}
void
DistModelDataBuf
::
Resize
(
size_t
length
)
{
if
(
length_
>=
length
)
{
return
;
}
if
(
memory_owned_
)
{
Free
();
data_
=
new
char
[
length
];
length_
=
length
;
memory_owned_
=
true
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"The memory is allocated externally, can not Resized"
));
}
}
DistModelDataBuf
&
DistModelDataBuf
::
operator
=
(
const
DistModelDataBuf
&
other
)
{
if
(
!
other
.
memory_owned_
)
{
data_
=
other
.
data_
;
length_
=
other
.
length_
;
memory_owned_
=
other
.
memory_owned_
;
}
else
{
Resize
(
other
.
length_
);
if
(
other
.
length
()
&&
other
.
data
())
{
std
::
memcpy
(
data_
,
other
.
data
(),
other
.
length
());
}
else
if
(
other
.
length
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Invalid argument, null pointer data with length %u is passed"
,
other
.
length
()));
}
length_
=
other
.
length_
;
memory_owned_
=
true
;
}
return
*
this
;
}
DistModelDataBuf
&
DistModelDataBuf
::
operator
=
(
DistModelDataBuf
&&
other
)
{
data_
=
other
.
data_
;
memory_owned_
=
other
.
memory_owned_
;
length_
=
other
.
length_
;
other
.
data_
=
nullptr
;
other
.
length_
=
0
;
other
.
memory_owned_
=
false
;
return
*
this
;
}
DistModelDataBuf
::
DistModelDataBuf
(
DistModelDataBuf
&&
other
)
:
data_
(
other
.
data_
),
length_
(
other
.
length_
),
memory_owned_
(
other
.
memory_owned_
)
{
other
.
memory_owned_
=
false
;
other
.
data_
=
nullptr
;
other
.
length_
=
0
;
}
DistModelDataBuf
::
DistModelDataBuf
(
const
DistModelDataBuf
&
other
)
{
*
this
=
other
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h
0 → 100644
浏览文件 @
08793179
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
enum
DistModelDataType
{
FLOAT16
,
FLOAT32
,
INT64
,
INT32
,
INT8
};
template
<
typename
T
>
constexpr
DistModelDataType
DistModelGetDtype
();
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
int32_t
>
()
{
return
DistModelDataType
::
INT32
;
}
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
int64_t
>
()
{
return
DistModelDataType
::
INT64
;
}
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
float
>
()
{
return
DistModelDataType
::
FLOAT32
;
}
class
DistModelDataBuf
{
public:
explicit
DistModelDataBuf
(
size_t
length
)
:
data_
(
new
char
[
length
]),
length_
(
length
),
memory_owned_
(
true
)
{}
DistModelDataBuf
(
void
*
data
,
size_t
length
)
:
data_
(
data
),
length_
(
length
),
memory_owned_
(
false
)
{}
void
Reset
(
void
*
data
,
size_t
length
);
size_t
length
()
const
{
return
length_
;
}
void
*
data
()
const
{
return
data_
;
}
~
DistModelDataBuf
()
{
Free
();
}
DistModelDataBuf
()
=
default
;
void
Resize
(
size_t
length
);
DistModelDataBuf
&
operator
=
(
const
DistModelDataBuf
&
other
);
DistModelDataBuf
&
operator
=
(
DistModelDataBuf
&&
other
);
DistModelDataBuf
(
DistModelDataBuf
&&
other
);
DistModelDataBuf
(
const
DistModelDataBuf
&
other
);
private:
void
Free
();
void
*
data_
{
nullptr
};
size_t
length_
{
0
};
bool
memory_owned_
{
false
};
};
struct
DistModelTensor
{
std
::
string
name
;
std
::
vector
<
int
>
shape
;
DistModelDataBuf
data
;
DistModelDataType
dtype
;
std
::
vector
<
std
::
vector
<
size_t
>>
lod
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/pybind/bind_fleet_executor.cc
浏览文件 @
08793179
...
...
@@ -13,8 +13,12 @@
// limitations under the License.
#include "paddle/fluid/pybind/bind_fleet_executor.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -31,9 +35,90 @@ using paddle::distributed::FleetExecutor;
using
paddle
::
distributed
::
TaskNode
;
using
paddle
::
distributed
::
DistModelConfig
;
using
paddle
::
distributed
::
DistModel
;
using
paddle
::
distributed
::
DistModelDataBuf
;
using
paddle
::
distributed
::
DistModelTensor
;
using
paddle
::
distributed
::
DistModelDataType
;
using
paddle
::
framework
::
OpDesc
;
using
paddle
::
framework
::
ProgramDesc
;
template
<
typename
T
>
DistModelDataBuf
DistModelDataBufCreate
(
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
data
)
{
// accept numpy array directly
DistModelDataBuf
buf
(
data
.
size
()
*
sizeof
(
T
));
std
::
copy_n
(
static_cast
<
const
T
*>
(
data
.
data
()),
data
.
size
(),
static_cast
<
T
*>
(
buf
.
data
()));
return
buf
;
}
template
<
typename
T
>
void
DistModelDataBufReset
(
DistModelDataBuf
&
buf
,
// NOLINT
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
data
)
{
// NOLINT
// reset the data with numpy array directly
buf
.
Resize
(
data
.
size
()
*
sizeof
(
T
));
std
::
copy_n
(
static_cast
<
const
T
*>
(
data
.
data
()),
data
.
size
(),
static_cast
<
T
*>
(
buf
.
data
()));
}
template
<
typename
T
>
DistModelTensor
DistModelTensorCreate
(
py
::
array_t
<
T
,
py
::
array
::
c_style
|
py
::
array
::
forcecast
>
data
,
const
std
::
string
name
,
const
std
::
vector
<
std
::
vector
<
size_t
>>&
lod
,
bool
copy
)
{
DistModelTensor
tensor
;
if
(
copy
)
{
DistModelDataBuf
buf
(
data
.
size
()
*
sizeof
(
T
));
std
::
copy_n
(
static_cast
<
const
T
*>
(
data
.
data
()),
data
.
size
(),
static_cast
<
T
*>
(
buf
.
data
()));
tensor
.
data
=
std
::
move
(
buf
);
}
else
{
tensor
.
data
=
DistModelDataBuf
(
data
.
mutable_data
(),
data
.
size
()
*
sizeof
(
T
));
}
tensor
.
dtype
=
paddle
::
distributed
::
DistModelGetDtype
<
T
>
();
tensor
.
name
=
name
;
tensor
.
lod
=
lod
;
tensor
.
shape
.
resize
(
data
.
ndim
());
std
::
copy_n
(
data
.
shape
(),
data
.
ndim
(),
tensor
.
shape
.
begin
());
return
tensor
;
}
py
::
dtype
DistModelTypeToNumpyDType
(
DistModelDataType
dtype
)
{
py
::
dtype
dt
;
switch
(
dtype
)
{
case
DistModelDataType
::
INT32
:
dt
=
py
::
dtype
::
of
<
int32_t
>
();
break
;
case
DistModelDataType
::
INT64
:
dt
=
py
::
dtype
::
of
<
int64_t
>
();
break
;
case
DistModelDataType
::
FLOAT32
:
dt
=
py
::
dtype
::
of
<
float
>
();
break
;
case
DistModelDataType
::
INT8
:
dt
=
py
::
dtype
::
of
<
int8_t
>
();
break
;
case
DistModelDataType
::
FLOAT16
:
dt
=
py
::
dtype
::
of
<
paddle
::
platform
::
float16
>
();
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type. Now only supports INT32, INT64, INT8, "
"FLOAT16 and FLOAT32."
));
}
return
dt
;
}
py
::
array
DistModelTensorGetData
(
DistModelTensor
&
tensor
)
{
// NOLINT
py
::
dtype
dt
=
DistModelTypeToNumpyDType
(
tensor
.
dtype
);
return
py
::
array
(
std
::
move
(
dt
),
{
tensor
.
shape
},
tensor
.
data
.
data
());
}
void
BindFleetExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
FleetExecutor
>
(
*
m
,
"FleetExecutor"
)
.
def
(
py
::
init
<
const
std
::
string
&>
())
...
...
@@ -78,6 +163,74 @@ void BindFleetExecutor(py::module* m) {
.
def
(
py
::
init
<
const
DistModelConfig
&>
())
.
def
(
"init"
,
&
DistModel
::
Init
)
.
def
(
"run"
,
&
DistModel
::
Run
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
class_
<
DistModelDataBuf
>
(
*
m
,
"DistModelDataBuf"
)
.
def
(
py
::
init
<
size_t
>
())
.
def
(
py
::
init
([](
std
::
vector
<
float
>&
data
)
{
auto
buf
=
DistModelDataBuf
(
data
.
size
()
*
sizeof
(
float
));
std
::
memcpy
(
buf
.
data
(),
static_cast
<
void
*>
(
data
.
data
()),
buf
.
length
());
return
buf
;
}))
.
def
(
py
::
init
(
&
DistModelDataBufCreate
<
int32_t
>
))
.
def
(
py
::
init
(
&
DistModelDataBufCreate
<
int64_t
>
))
.
def
(
py
::
init
(
&
DistModelDataBufCreate
<
float
>
))
.
def
(
"reset"
,
[](
DistModelDataBuf
&
self
,
std
::
vector
<
float
>&
data
)
{
self
.
Resize
(
data
.
size
()
*
sizeof
(
float
));
std
::
memcpy
(
self
.
data
(),
data
.
data
(),
self
.
length
());
})
.
def
(
"reset"
,
&
DistModelDataBufReset
<
int32_t
>
)
.
def
(
"reset"
,
&
DistModelDataBufReset
<
int64_t
>
)
.
def
(
"reset"
,
&
DistModelDataBufReset
<
float
>
)
.
def
(
"length"
,
&
DistModelDataBuf
::
length
)
.
def
(
"tolist"
,
[](
DistModelDataBuf
&
self
,
const
std
::
string
&
dtype
)
->
py
::
list
{
py
::
list
l
;
if
(
dtype
==
"int32"
)
{
auto
*
data
=
static_cast
<
int32_t
*>
(
self
.
data
());
auto
size
=
self
.
length
()
/
sizeof
(
int32_t
);
l
=
py
::
cast
(
std
::
vector
<
int32_t
>
(
data
,
data
+
size
));
}
else
if
(
dtype
==
"int64"
)
{
auto
*
data
=
static_cast
<
int64_t
*>
(
self
.
data
());
auto
size
=
self
.
length
()
/
sizeof
(
int64_t
);
l
=
py
::
cast
(
std
::
vector
<
int64_t
>
(
data
,
data
+
size
));
}
else
if
(
dtype
==
"float32"
)
{
auto
*
data
=
static_cast
<
float
*>
(
self
.
data
());
auto
size
=
self
.
length
()
/
sizeof
(
float
);
l
=
py
::
cast
(
std
::
vector
<
float
>
(
data
,
data
+
size
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type. Now only supports INT32, INT64 and "
"FLOAT32."
));
}
return
l
;
});
py
::
class_
<
DistModelTensor
>
(
*
m
,
"DistModelTensor"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
(
&
DistModelTensorCreate
<
int32_t
>
),
py
::
arg
(
"data"
),
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"lod"
)
=
std
::
vector
<
std
::
vector
<
size_t
>>
(),
py
::
arg
(
"copy"
)
=
true
)
.
def
(
py
::
init
(
&
DistModelTensorCreate
<
int64_t
>
),
py
::
arg
(
"data"
),
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"lod"
)
=
std
::
vector
<
std
::
vector
<
size_t
>>
(),
py
::
arg
(
"copy"
)
=
true
)
.
def
(
py
::
init
(
&
DistModelTensorCreate
<
float
>
),
py
::
arg
(
"data"
),
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"lod"
)
=
std
::
vector
<
std
::
vector
<
size_t
>>
(),
py
::
arg
(
"copy"
)
=
true
)
.
def_readwrite
(
"name"
,
&
DistModelTensor
::
name
)
.
def_readwrite
(
"shape"
,
&
DistModelTensor
::
shape
)
.
def_readwrite
(
"data"
,
&
DistModelTensor
::
data
)
.
def_readwrite
(
"dtype"
,
&
DistModelTensor
::
dtype
)
.
def_readwrite
(
"lod"
,
&
DistModelTensor
::
lod
)
.
def
(
"as_ndarray"
,
&
DistModelTensorGetData
);
py
::
enum_
<
DistModelDataType
>
(
*
m
,
"DistModelDataType"
)
.
value
(
"FLOAT32"
,
DistModelDataType
::
FLOAT32
)
.
value
(
"INT64"
,
DistModelDataType
::
INT64
)
.
value
(
"INT32"
,
DistModelDataType
::
INT32
);
}
}
// namespace pybind
}
// namespace paddle
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
08793179
...
...
@@ -152,6 +152,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler
)
LIST
(
REMOVE_ITEM TEST_OPS test_auto_parallel_mapper
)
LIST
(
REMOVE_ITEM TEST_OPS test_fleet_executor_task_node
)
LIST
(
REMOVE_ITEM TEST_OPS test_dist_model_tensor
)
endif
()
# Temporally disable test_deprecated_decorator
...
...
python/paddle/fluid/tests/unittests/test_dist_model_tensor.py
0 → 100644
浏览文件 @
08793179
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
numpy
as
np
from
paddle.fluid.core
import
DistModelTensor
from
paddle.fluid.core
import
DistModelDataType
paddle
.
enable_static
()
class
TestDistModelTensor
(
unittest
.
TestCase
):
def
test_dist_model_tensor
(
self
):
tensor_32
=
np
.
random
.
randint
(
10
,
20
,
size
=
[
20
,
2
]).
astype
(
'int32'
)
dist_tensor32
=
DistModelTensor
(
tensor_32
,
'32_tensor'
)
self
.
assertEqual
(
dist_tensor32
.
dtype
,
DistModelDataType
.
INT32
)
self
.
assertEqual
(
dist_tensor32
.
data
.
tolist
(
'int32'
),
tensor_32
.
ravel
().
tolist
())
# the length is how many byte the data contains
self
.
assertEqual
(
dist_tensor32
.
data
.
length
(),
40
*
4
)
self
.
assertEqual
(
dist_tensor32
.
name
,
'32_tensor'
)
dist_tensor32
.
data
.
reset
(
tensor_32
)
self
.
assertEqual
(
dist_tensor32
.
as_ndarray
().
ravel
().
tolist
(),
tensor_32
.
ravel
().
tolist
())
tensor_64
=
np
.
random
.
randint
(
10
,
20
,
size
=
[
20
,
2
]).
astype
(
'int64'
)
dist_tensor64
=
DistModelTensor
(
tensor_64
,
'64_tensor'
)
self
.
assertEqual
(
dist_tensor64
.
dtype
,
DistModelDataType
.
INT64
)
self
.
assertEqual
(
dist_tensor64
.
data
.
tolist
(
'int64'
),
tensor_64
.
ravel
().
tolist
())
self
.
assertEqual
(
dist_tensor64
.
data
.
length
(),
40
*
8
)
self
.
assertEqual
(
dist_tensor64
.
name
,
'64_tensor'
)
dist_tensor64
.
data
.
reset
(
tensor_64
)
self
.
assertEqual
(
dist_tensor64
.
as_ndarray
().
ravel
().
tolist
(),
tensor_64
.
ravel
().
tolist
())
tensor_float
=
np
.
random
.
randn
(
20
,
2
).
astype
(
'float32'
)
dist_tensor_float
=
DistModelTensor
(
tensor_float
,
'float_tensor'
)
self
.
assertEqual
(
dist_tensor_float
.
dtype
,
DistModelDataType
.
FLOAT32
)
self
.
assertEqual
(
dist_tensor_float
.
data
.
tolist
(
'float32'
),
tensor_float
.
ravel
().
tolist
())
self
.
assertEqual
(
dist_tensor_float
.
data
.
length
(),
40
*
4
)
self
.
assertEqual
(
dist_tensor_float
.
name
,
'float_tensor'
)
dist_tensor_float
.
data
.
reset
(
tensor_float
)
self
.
assertEqual
(
dist_tensor_float
.
as_ndarray
().
ravel
().
tolist
(),
tensor_float
.
ravel
().
tolist
())
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录