Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
72d36970
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72d36970
编写于
6月 16, 2021
作者:
Z
zhangbo9674
提交者:
GitHub
6月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Feature] add paddle.trunc (#33371)
* new api trunc, test=develop
上级
32e3353f
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
396 addition
and
0 deletion
+396
-0
paddle/fluid/framework/unused_var_check.cc
paddle/fluid/framework/unused_var_check.cc
+1
-0
paddle/fluid/operators/trunc_op.cc
paddle/fluid/operators/trunc_op.cc
+89
-0
paddle/fluid/operators/trunc_op.cu
paddle/fluid/operators/trunc_op.cu
+115
-0
paddle/fluid/operators/trunc_op.h
paddle/fluid/operators/trunc_op.h
+55
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_trunc_op.py
python/paddle/fluid/tests/unittests/test_trunc_op.py
+88
-0
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+2
-0
python/paddle/tensor/math.py
python/paddle/tensor/math.py
+44
-0
未找到文件。
paddle/fluid/framework/unused_var_check.cc
浏览文件 @
72d36970
...
...
@@ -75,6 +75,7 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
"data_norm_grad"
,
// 0
"update_loss_scaling"
,
// 0
"fused_embedding_eltwise_layernorm"
,
// 0
"trunc_grad"
,
// 1
});
return
*
allow_set
;
}
...
...
paddle/fluid/operators/trunc_op.cc
0 → 100644
浏览文件 @
72d36970
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/trunc_op.h"
namespace
paddle
{
namespace
operators
{
class
TruncOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"trunc"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"trunc"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Out"
,
input_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
class
TruncOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of trunc op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of trunc op."
);
AddComment
(
R"DOC(
Trunc Operator.
Returns a new tensor with the truncated integer values of input.
$$out = trunc(x)$$
)DOC"
);
}
};
class
TruncGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"TruncGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"TruncGrad"
);
auto
dout_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dout_dims
);
}
};
template
<
typename
T
>
class
TruncGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
void
Apply
(
GradOpPtr
<
T
>
retv
)
const
override
{
retv
->
SetType
(
"trunc_grad"
);
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetAttrMap
(
this
->
Attrs
());
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
trunc
,
ops
::
TruncOp
,
ops
::
TruncOpMaker
,
ops
::
TruncGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
TruncGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
trunc_grad
,
ops
::
TruncGradOp
);
REGISTER_OP_CPU_KERNEL
(
trunc
,
ops
::
TruncKernel
<
float
>
,
ops
::
TruncKernel
<
double
>
,
ops
::
TruncKernel
<
int
>
,
ops
::
TruncKernel
<
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
trunc_grad
,
ops
::
TruncGradKernel
<
float
>
,
ops
::
TruncGradKernel
<
double
>
,
ops
::
TruncGradKernel
<
int
>
,
ops
::
TruncGradKernel
<
int64_t
>
);
paddle/fluid/operators/trunc_op.cu
0 → 100644
浏览文件 @
72d36970
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/trunc_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
class
TruncFunctor
{
public:
__device__
TruncFunctor
(
const
T
x
)
:
x_
(
x
)
{}
__device__
T
operator
()()
{
return
trunc
(
x_
);
}
public:
const
T
x_
;
};
template
<
>
class
TruncFunctor
<
int
>
{
public:
__device__
TruncFunctor
(
const
int
x
)
:
x_
(
x
)
{}
__device__
int
operator
()()
{
return
x_
;
}
public:
const
int
x_
;
};
template
<
>
class
TruncFunctor
<
int64_t
>
{
public:
__device__
TruncFunctor
(
const
int64_t
x
)
:
x_
(
x
)
{}
__device__
int64_t
operator
()()
{
return
x_
;
}
public:
const
int64_t
x_
;
};
template
<
typename
T
>
__global__
void
Trunc
(
const
T
*
x
,
T
*
out
,
int64_t
N
)
{
CUDA_KERNEL_LOOP
(
index
,
N
)
{
TruncFunctor
<
T
>
functor
(
x
[
index
]);
out
[
index
]
=
functor
();
}
}
template
<
typename
T
>
__global__
void
TruncGrad
(
T
*
dx
,
int64_t
N
)
{
CUDA_KERNEL_LOOP
(
index
,
N
)
{
dx
[
index
]
=
static_cast
<
T
>
(
0.0
);
}
}
template
<
typename
T
>
class
TruncCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
const
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
x
->
numel
();
int
theads
=
PADDLE_CUDA_NUM_THREADS
;
int
blocks
=
(
numel
+
theads
-
1
)
/
theads
;
Trunc
<<<
blocks
,
theads
>>>
(
x_data
,
out_data
,
numel
);
}
};
template
<
typename
T
>
class
TruncCUDAGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
dout
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
const
auto
*
dout_data
=
dout
->
data
<
T
>
();
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
dout
->
numel
();
int
theads
=
PADDLE_CUDA_NUM_THREADS
;
int
blocks
=
(
numel
+
theads
-
1
)
/
theads
;
TruncGrad
<<<
blocks
,
theads
>>>
(
dx_data
,
numel
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
trunc
,
ops
::
TruncCUDAKernel
<
float
>
,
ops
::
TruncCUDAKernel
<
double
>
,
ops
::
TruncCUDAKernel
<
int
>
,
ops
::
TruncCUDAKernel
<
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
trunc_grad
,
ops
::
TruncCUDAGradKernel
<
float
>
,
ops
::
TruncCUDAGradKernel
<
double
>
,
ops
::
TruncCUDAGradKernel
<
int
>
,
ops
::
TruncCUDAGradKernel
<
int64_t
>
);
paddle/fluid/operators/trunc_op.h
0 → 100644
浏览文件 @
72d36970
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
TruncKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
size_t
numel
=
x
->
numel
();
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
numel
;
i
++
)
{
out_data
[
i
]
=
trunc
(
x_data
[
i
]);
}
}
};
template
<
typename
T
>
class
TruncGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
dx
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
numel
=
dx
->
numel
();
memset
(
dx_data
,
0.0
,
numel
*
sizeof
(
T
));
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/__init__.py
浏览文件 @
72d36970
...
...
@@ -205,6 +205,7 @@ from .tensor.math import isnan # noqa: F401
from
.tensor.math
import
prod
# noqa: F401
from
.tensor.math
import
broadcast_shape
# noqa: F401
from
.tensor.math
import
conj
# noqa: F401
from
.tensor.math
import
trunc
# noqa: F401
from
.tensor.math
import
digamma
# noqa: F401
from
.tensor.math
import
neg
# noqa: F401
from
.tensor.math
import
lgamma
# noqa: F401
...
...
@@ -490,6 +491,7 @@ __all__ = [ # noqa
'log10'
,
'concat'
,
'check_shape'
,
'trunc'
'digamma'
,
'standard_normal'
]
python/paddle/fluid/tests/unittests/test_trunc_op.py
0 → 100644
浏览文件 @
72d36970
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
paddle
.
enable_static
()
class
TestTruncOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"trunc"
self
.
dtype
=
np
.
float64
np
.
random
.
seed
(
2021
)
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
20
,
20
)).
astype
(
self
.
dtype
)}
self
.
outputs
=
{
'Out'
:
(
np
.
trunc
(
self
.
inputs
[
'X'
]))}
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float64
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
numeric_grad_delta
=
1e-5
)
class
TestFloatTruncOp
(
TestTruncOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float32
class
TestIntTruncOp
(
TestTruncOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
int32
class
TestTruncAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
[
20
,
20
]
self
.
x
=
np
.
random
.
random
((
20
,
20
)).
astype
(
np
.
float32
)
self
.
place
=
paddle
.
CPUPlace
()
def
test_api_static
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
fluid
.
data
(
'X'
,
self
.
shape
)
out
=
paddle
.
trunc
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x
},
fetch_list
=
[
out
])
out_ref
=
np
.
trunc
(
self
.
x
)
for
out
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out
,
out_ref
,
rtol
=
1e-08
),
True
)
def
test_api_dygraph
(
self
):
paddle
.
disable_static
(
self
.
place
)
x_tensor
=
paddle
.
to_tensor
(
self
.
x
)
out
=
paddle
.
trunc
(
x_tensor
)
out_ref
=
np
.
trunc
(
self
.
x
)
self
.
assertEqual
(
np
.
allclose
(
out
.
numpy
(),
out_ref
,
rtol
=
1e-08
),
True
)
paddle
.
enable_static
()
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
fluid
.
data
(
'X'
,
[
20
,
20
],
'bool'
)
self
.
assertRaises
(
TypeError
,
paddle
.
trunc
,
x
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/__init__.py
浏览文件 @
72d36970
...
...
@@ -162,6 +162,7 @@ from .math import all # noqa: F401
from
.math
import
any
# noqa: F401
from
.math
import
broadcast_shape
# noqa: F401
from
.math
import
conj
# noqa: F401
from
.math
import
trunc
# noqa: F401
from
.math
import
digamma
# noqa: F401
from
.math
import
neg
# noqa: F401
from
.math
import
lgamma
# noqa: F401
...
...
@@ -349,5 +350,6 @@ tensor_method_func = [ #noqa
'shape'
,
'real'
,
'imag'
,
'trunc'
'digamma'
]
python/paddle/tensor/math.py
浏览文件 @
72d36970
...
...
@@ -857,6 +857,50 @@ def add_n(inputs, name=None):
return
out
def
trunc
(
input
,
name
=
None
):
'''
This API is used to returns a new tensor with the truncated integer values of input.
Args:
input (Tensor): The input tensor, it's data type should be int32, int64, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output Tensor of trunc.
Examples:
.. code-block:: python
import paddle
input = paddle.rand([2,2],'float32')
print(input)
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.02331470, 0.42374918],
# [0.79647720, 0.74970269]])
output = paddle.trunc(input)
print(output)
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0., 0.],
# [0., 0.]]))
'''
if
in_dygraph_mode
():
return
core
.
ops
.
trunc
(
input
)
else
:
inputs
=
{
"X"
:
input
}
attrs
=
{}
helper
=
LayerHelper
(
"trunc"
,
**
locals
())
check_variable_and_dtype
(
input
,
'X'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'trunc'
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
helper
.
append_op
(
type
=
"trunc"
,
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
{
"Out"
:
out
})
return
out
def
mm
(
input
,
mat2
,
name
=
None
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录