Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
980227f9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
980227f9
编写于
4月 01, 2021
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support uint8_t for fill_constant_op (#31911)
上级
07741593
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
13 addition
and
17 deletion
+13
-17
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+1
-0
paddle/fluid/operators/fill_constant_op.cu.cc
paddle/fluid/operators/fill_constant_op.cu.cc
+1
-0
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+1
-0
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+1
-0
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+7
-6
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
+1
-7
python/paddle/fluid/tests/unittests/test_full_op.py
python/paddle/fluid/tests/unittests/test_full_op.py
+1
-4
未找到文件。
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
980227f9
...
@@ -149,6 +149,7 @@ REGISTER_OPERATOR(
...
@@ -149,6 +149,7 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL
(
fill_constant
,
ops
::
FillConstantKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
fill_constant
,
ops
::
FillConstantKernel
<
float
>
,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
uint8_t
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
bool
>
,
ops
::
FillConstantKernel
<
bool
>
,
...
...
paddle/fluid/operators/fill_constant_op.cu.cc
浏览文件 @
980227f9
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
fill_constant
,
ops
::
FillConstantKernel
<
float
>
,
REGISTER_OP_CUDA_KERNEL
(
fill_constant
,
ops
::
FillConstantKernel
<
float
>
,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
uint8_t
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
bool
>
,
ops
::
FillConstantKernel
<
bool
>
,
...
...
paddle/fluid/operators/math/math_function.cc
浏览文件 @
980227f9
...
@@ -51,6 +51,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
...
@@ -51,6 +51,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
uint8_t
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
int64_t
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
int64_t
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
bool
>;
template
struct
SetConstant
<
platform
::
XPUDeviceContext
,
bool
>;
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
980227f9
...
@@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128;
...
@@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
uint8_t
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
bool
>;
template
struct
SetConstant
<
platform
::
CUDADeviceContext
,
bool
>;
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
980227f9
...
@@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
...
@@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can
dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, int32, int64.
be float16, float32, float64,
uint8,
int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize
value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
...
@@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
...
@@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs
=
{
'force_cpu'
:
force_cpu
}
attrs
=
{
'force_cpu'
:
force_cpu
}
dtype
=
convert_dtype
(
dtype
)
dtype
=
convert_dtype
(
dtype
)
if
not
isinstance
(
value
,
Variable
):
if
not
isinstance
(
value
,
Variable
):
if
dtype
in
[
'int64'
,
'int32'
]:
if
dtype
in
[
'
uint8'
,
'
int64'
,
'int32'
]:
attrs
[
'str_value'
]
=
str
(
int
(
value
))
attrs
[
'str_value'
]
=
str
(
int
(
value
))
attrs
[
'value'
]
=
int
(
value
)
attrs
[
'value'
]
=
int
(
value
)
else
:
else
:
...
@@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
...
@@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out
=
_varbase_creator
(
dtype
=
dtype
)
out
=
_varbase_creator
(
dtype
=
dtype
)
if
isinstance
(
value
,
Variable
):
if
isinstance
(
value
,
Variable
):
if
dtype
in
[
'int64'
,
'int32'
]:
if
dtype
in
[
'
uint8'
,
'
int64'
,
'int32'
]:
attrs
[
'str_value'
]
=
str
(
int
(
value
.
numpy
().
item
(
0
)))
attrs
[
'str_value'
]
=
str
(
int
(
value
.
numpy
().
item
(
0
)))
else
:
else
:
attrs
[
'str_value'
]
=
str
(
float
(
value
.
numpy
().
item
(
0
)))
attrs
[
'str_value'
]
=
str
(
float
(
value
.
numpy
().
item
(
0
)))
...
@@ -706,9 +706,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
...
@@ -706,9 +706,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs
[
'ValueTensor'
]
=
value
inputs
[
'ValueTensor'
]
=
value
check_shape
(
shape
)
check_shape
(
shape
)
check_dtype
(
dtype
,
'dtype'
,
check_dtype
(
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
dtype
,
'dtype'
,
'fill_constant'
)
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'uint8'
,
'int32'
,
'int64'
],
'fill_constant'
)
check_type
(
shape
,
'shape'
,
(
Variable
,
list
,
tuple
),
'fill_constant'
)
check_type
(
shape
,
'shape'
,
(
Variable
,
list
,
tuple
),
'fill_constant'
)
if
out
is
not
None
:
if
out
is
not
None
:
...
...
python/paddle/fluid/tests/unittests/test_fill_constant_op.py
浏览文件 @
980227f9
...
@@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase):
...
@@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase):
out
=
x1
)
out
=
x1
)
# The argument dtype of fill_constant_op must be one of bool, float16,
# The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, int32 or int64
#float32, float64,
uint8,
int32 or int64
x2
=
fluid
.
layers
.
data
(
name
=
'x2'
,
shape
=
[
1
],
dtype
=
"int32"
)
x2
=
fluid
.
layers
.
data
(
name
=
'x2'
,
shape
=
[
1
],
dtype
=
"int32"
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
fill_constant
,
shape
=
[
1
],
value
=
5
,
dtype
=
'uint8'
)
self
.
assertRaises
(
self
.
assertRaises
(
TypeError
,
TypeError
,
fluid
.
layers
.
fill_constant
,
fluid
.
layers
.
fill_constant
,
...
...
python/paddle/fluid/tests/unittests/test_full_op.py
浏览文件 @
980227f9
...
@@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase):
...
@@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase):
TypeError
,
paddle
.
full
,
shape
=
[
1
],
fill_value
=
5
,
dtype
=
'uint4'
)
TypeError
,
paddle
.
full
,
shape
=
[
1
],
fill_value
=
5
,
dtype
=
'uint4'
)
# The argument dtype of full must be one of bool, float16,
# The argument dtype of full must be one of bool, float16,
#float32, float64, int32 or int64
#float32, float64, uint8, int32 or int64
self
.
assertRaises
(
TypeError
,
paddle
.
full
,
shape
=
[
1
],
fill_value
=
5
,
dtype
=
'uint8'
)
# The argument shape's type of full_op must be list, tuple or Variable.
# The argument shape's type of full_op must be list, tuple or Variable.
def
test_shape_type
():
def
test_shape_type
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录