Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b1365d25
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b1365d25
编写于
1月 18, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
1月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify the functor of elementwise and logical ops. (#35767)
上级
dfa242e4
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
48 addition
and
56 deletion
+48
-56
paddle/fluid/operators/controlflow/logical_op.cu
paddle/fluid/operators/controlflow/logical_op.cu
+9
-5
paddle/fluid/operators/controlflow/logical_op.h
paddle/fluid/operators/controlflow/logical_op.h
+3
-0
paddle/fluid/operators/elementwise/elementwise_functor.h
paddle/fluid/operators/elementwise/elementwise_functor.h
+1
-0
paddle/fluid/operators/elementwise/elementwise_mod_op.cc
paddle/fluid/operators/elementwise/elementwise_mod_op.cc
+2
-2
paddle/fluid/operators/elementwise/elementwise_mod_op.cu
paddle/fluid/operators/elementwise/elementwise_mod_op.cu
+3
-3
paddle/fluid/operators/elementwise/elementwise_mod_op.h
paddle/fluid/operators/elementwise/elementwise_mod_op.h
+5
-35
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+5
-4
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+0
-5
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
+3
-0
paddle/fluid/operators/elementwise/elementwise_pow_op.h
paddle/fluid/operators/elementwise/elementwise_pow_op.h
+4
-1
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+0
-1
paddle/pten/kernels/funcs/elementwise_functor.h
paddle/pten/kernels/funcs/elementwise_functor.h
+13
-0
未找到文件。
paddle/fluid/operators/controlflow/logical_op.cu
浏览文件 @
b1365d25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -12,9 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
...
...
@@ -22,9 +22,10 @@ template <typename Functor>
class
BinaryLogicalOpKernel
<
platform
::
CUDADeviceContext
,
Functor
>
:
public
framework
::
OpKernel
<
typename
Functor
::
ELEMENT_TYPE
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
InT
=
typename
Functor
::
ELEMENT_TYPE
;
using
OutT
=
bool
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
functor
=
Functor
();
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
...
...
@@ -45,6 +46,9 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
...
...
paddle/fluid/operators/controlflow/logical_op.h
浏览文件 @
b1365d25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/fluid/operators/elementwise/elementwise_functor.h
浏览文件 @
b1365d25
...
...
@@ -22,6 +22,7 @@ namespace paddle {
namespace
operators
{
// Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add
template
<
typename
T
>
...
...
paddle/fluid/operators/elementwise/elementwise_mod_op.cc
浏览文件 @
b1365d25
...
...
@@ -66,8 +66,8 @@ REGISTER_OP_CPU_KERNEL(
elementwise_mod
,
ops
::
ElementwiseModKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
ElementwiseModKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
ops
::
ElementwiseMod
FP
Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseMod
FP
Kernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
ElementwiseModKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
ElementwiseModKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_VERSION
(
elementwise_mod
)
.
AddCheckpoint
(
...
...
paddle/fluid/operators/elementwise/elementwise_mod_op.cu
浏览文件 @
b1365d25
...
...
@@ -14,9 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
...
...
@@ -38,6 +35,9 @@ class ElementwiseModKernel<platform::CUDADeviceContext, T>
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
elementwise_mod
,
ops
::
ElementwiseModKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseModKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_mod_op.h
浏览文件 @
b1365d25
...
...
@@ -44,9 +44,9 @@ struct ModFunctor<T,
}
};
template
<
typename
T
>
template
<
typename
T
,
typename
Enable
=
void
>
struct
InverseModFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
T
res
=
b
%
a
;
if
((
res
!=
0
)
&&
((
res
<
0
)
!=
(
a
<
0
)))
res
+=
a
;
return
res
;
...
...
@@ -54,8 +54,9 @@ struct InverseModFunctor {
};
template
<
typename
T
>
struct
InverseModFunctorFP
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
struct
InverseModFunctor
<
T
,
typename
std
::
enable_if_t
<
std
::
is_floating_point
<
T
>::
value
>>
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
T
res
=
fmod
(
b
,
a
);
if
((
res
!=
0
)
&&
((
a
<
0
)
!=
(
res
<
0
)))
res
+=
a
;
return
res
;
...
...
@@ -78,22 +79,6 @@ void elementwise_mod(const framework::ExecutionContext &ctx,
}
}
template
<
typename
DeviceContext
,
typename
T
>
void
elementwise_mod_fp
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
if
(
x_dims
.
size
()
>=
y_dims
.
size
())
{
ElementwiseComputeEx
<
ModFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
ModFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseModFunctorFP
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseModFunctorFP
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseModKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -109,20 +94,5 @@ class ElementwiseModKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseModFPKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// dtype of x and y is float or double
elementwise_mod_fp
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
b1365d25
...
...
@@ -199,10 +199,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
int
axis
,
Functor
func
,
framework
::
Tensor
*
z
)
{
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
());
auto
pt_x
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
x
);
auto
pt_y
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
y
);
auto
pt_z
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
z
);
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#if defined(__NVCC__) || defined(__HIPCC__)
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
,
y
};
...
...
@@ -217,6 +213,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
return
;
}
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
());
auto
pt_x
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
x
);
auto
pt_y
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
y
);
auto
pt_z
=
paddle
::
experimental
::
MakePtenDenseTensor
(
*
z
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
pten
::
ElementwiseCompute
<
Functor
,
T
,
OutType
>
(
...
...
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
b1365d25
...
...
@@ -16,9 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
...
...
@@ -27,8 +24,6 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
using
ElementwiseType
=
pten
::
ElementwiseType
;
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
,
...
...
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
浏览文件 @
b1365d25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
paddle/fluid/operators/elementwise/elementwise_pow_op.h
浏览文件 @
b1365d25
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -21,7 +24,7 @@ namespace operators {
template
<
typename
T
>
struct
PowFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
...
...
paddle/fluid/operators/svd_helper.h
浏览文件 @
b1365d25
...
...
@@ -26,7 +26,6 @@
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
...
...
paddle/pten/kernels/funcs/elementwise_functor.h
浏览文件 @
b1365d25
...
...
@@ -22,6 +22,7 @@ namespace pten {
namespace
funcs
{
// Define the binary functors used in elementwise ops.
// Note: InverseXxxFunctor is needed when calling ElementwiseComputeEx on CPU.
// Add
template
<
typename
T
>
...
...
@@ -48,10 +49,22 @@ template <typename T>
struct
MultiplyFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
*
b
;
}
};
template
<
>
struct
MultiplyFunctor
<
bool
>
{
inline
HOSTDEVICE
bool
operator
()(
const
bool
a
,
const
bool
b
)
const
{
return
a
&&
b
;
}
};
template
<
typename
T
>
struct
InverseMultiplyFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
return
b
*
a
;
}
};
template
<
>
struct
InverseMultiplyFunctor
<
bool
>
{
inline
HOSTDEVICE
bool
operator
()(
const
bool
a
,
const
bool
b
)
const
{
return
b
&&
a
;
}
};
// Divide
#define DIV_ERROR_INFO \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录