Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
468a2a17
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
468a2a17
编写于
3月 01, 2022
作者:
R
ronnywang
提交者:
GitHub
3月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[phi] migrate where kernel into phi (#39811)
上级
4fbcf6f4
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
352 addition
and
257 deletion
+352
-257
paddle/fluid/operators/where_op.cc
paddle/fluid/operators/where_op.cc
+8
-38
paddle/fluid/operators/where_op.cu
paddle/fluid/operators/where_op.cu
+0
-126
paddle/fluid/operators/where_op.h
paddle/fluid/operators/where_op.h
+0
-73
paddle/fluid/operators/where_op_npu.cc
paddle/fluid/operators/where_op_npu.cc
+1
-1
paddle/fluid/operators/where_op_xpu.cc
paddle/fluid/operators/where_op_xpu.cc
+1
-1
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+1
-2
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+25
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+4
-0
paddle/phi/kernels/cpu/atan2_grad_kernel.cc
paddle/phi/kernels/cpu/atan2_grad_kernel.cc
+2
-3
paddle/phi/kernels/cpu/atan2_kernel.cc
paddle/phi/kernels/cpu/atan2_kernel.cc
+2
-3
paddle/phi/kernels/cpu/where_grad_kernel.cc
paddle/phi/kernels/cpu/where_grad_kernel.cc
+54
-0
paddle/phi/kernels/cpu/where_kernel.cc
paddle/phi/kernels/cpu/where_kernel.cc
+40
-0
paddle/phi/kernels/gpu/atan2_grad_kernel.cu
paddle/phi/kernels/gpu/atan2_grad_kernel.cu
+2
-3
paddle/phi/kernels/gpu/atan2_kernel.cu
paddle/phi/kernels/gpu/atan2_kernel.cu
+2
-3
paddle/phi/kernels/gpu/where_grad_kernel.cu
paddle/phi/kernels/gpu/where_grad_kernel.cu
+64
-0
paddle/phi/kernels/gpu/where_kernel.cu
paddle/phi/kernels/gpu/where_kernel.cu
+48
-0
paddle/phi/kernels/impl/atan2_grad_kernel_impl.h
paddle/phi/kernels/impl/atan2_grad_kernel_impl.h
+3
-2
paddle/phi/kernels/impl/atan2_kernel_impl.h
paddle/phi/kernels/impl/atan2_kernel_impl.h
+3
-2
paddle/phi/kernels/where_grad_kernel.h
paddle/phi/kernels/where_grad_kernel.h
+33
-0
paddle/phi/kernels/where_kernel.h
paddle/phi/kernels/where_kernel.h
+31
-0
paddle/phi/ops/compat/where_grad_sig.cc
paddle/phi/ops/compat/where_grad_sig.cc
+28
-0
未找到文件。
paddle/fluid/operators/where_op.cc
浏览文件 @
468a2a17
...
...
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -21,31 +23,6 @@ class WhereOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Condition"
),
"Input"
,
"Condition"
,
"Where"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"Where"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"Where"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"Where"
);
auto
cond_dims
=
ctx
->
GetInputDim
(
"Condition"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_EQ
(
cond_dims
,
x_dims
,
platform
::
errors
::
InvalidArgument
(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]"
,
cond_dims
,
x_dims
));
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
platform
::
errors
::
InvalidArgument
(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]"
,
x_dims
,
y_dims
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -140,19 +117,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y");
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DELCARE_INFER_SHAPE_FUNCTOR
(
where
,
WhereInferShapeFunctor
,
PT_INFER_META
(
phi
::
WhereInferMeta
));
REGISTER_OPERATOR
(
where
,
ops
::
WhereOp
,
ops
::
WhereOpMaker
,
ops
::
WhereOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
WhereOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
WhereOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
WhereInferShapeFunctor
);
REGISTER_OPERATOR
(
where_grad
,
ops
::
WhereGradOp
,
ops
::
WhereGradNoNeedBufferVarsInferer
);
REGISTER_OP_CPU_KERNEL
(
where
,
ops
::
WhereKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
WhereKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
WhereKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
WhereKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
where_grad
,
ops
::
WhereGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
WhereGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
WhereGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
WhereGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/where_op.cu
已删除
100644 → 0
浏览文件 @
4fbcf6f4
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
namespace
platform
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
CondFunctor
{
HOSTDEVICE
inline
CondFunctor
()
{}
HOSTDEVICE
inline
T
operator
()(
const
bool
cond
,
const
T
x
,
const
T
y
)
const
{
return
cond
?
x
:
y
;
}
};
template
<
typename
T
>
__global__
void
WhereCUDAKernel
(
const
int
N
,
const
bool
*
cond
,
const
T
*
x
,
const
T
*
y
,
T
*
out
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
out
[
idx
]
=
cond
[
idx
]
?
x
[
idx
]
:
y
[
idx
];
}
}
template
<
typename
T
>
__global__
void
WhereGradCUDAKernel
(
const
int
N
,
const
T
*
dout
,
const
bool
*
cond
,
T
*
dx
,
T
*
dy
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
dx
!=
nullptr
)
{
dx
[
idx
]
=
cond
[
idx
]
?
dout
[
idx
]
:
0.
;
}
if
(
dy
!=
nullptr
)
{
dy
[
idx
]
=
cond
[
idx
]
?
0.
:
dout
[
idx
];
}
}
}
template
<
typename
T
>
class
WhereKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
auto
*
X
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
Y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
numel
=
condition
->
numel
();
// TODO(GaaoWei8): Input of where can be broadcast
const
bool
*
cond_data
=
condition
->
data
<
bool
>
();
const
T
*
x_data
=
X
->
data
<
T
>
();
const
T
*
y_data
=
Y
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
functor
=
CondFunctor
<
T
>
();
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
condition
,
X
,
Y
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
};
template
<
typename
T
>
class
WhereGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
const
bool
*
cond_data
=
condition
->
data
<
bool
>
();
auto
numel
=
condition
->
numel
();
auto
*
dout_t
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx_t
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy_t
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dout
=
dout_t
->
data
<
T
>
();
T
*
dx
=
(
dx_t
!=
nullptr
)
?
dx_t
->
mutable_data
<
T
>
(
context
.
GetPlace
())
:
nullptr
;
T
*
dy
=
(
dy_t
!=
nullptr
)
?
dy_t
->
mutable_data
<
T
>
(
context
.
GetPlace
())
:
nullptr
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
condition
->
numel
());
WhereGradCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
numel
,
dout
,
cond_data
,
dx
,
dy
);
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
where
,
paddle
::
operators
::
WhereKernel
<
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
WhereKernel
<
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
WhereKernel
<
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
WhereKernel
<
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
where_grad
,
paddle
::
operators
::
WhereGradKernel
<
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
WhereGradKernel
<
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
WhereGradKernel
<
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
WhereGradKernel
<
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/where_op.h
已删除
100644 → 0
浏览文件 @
4fbcf6f4
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
WhereKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
auto
*
X
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
Y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
const
bool
*
cond_data
=
condition
->
data
<
bool
>
();
const
T
*
x_data
=
X
->
data
<
T
>
();
const
T
*
y_data
=
Y
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_numel
=
X
->
numel
();
for
(
int
i
=
0
;
i
<
x_numel
;
i
++
)
{
out_data
[
i
]
=
cond_data
[
i
]
?
x_data
[
i
]
:
y_data
[
i
];
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
WhereGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
LoDTensor
>
(
"Condition"
);
const
auto
*
cond_data
=
condition
->
data
<
bool
>
();
auto
numel
=
condition
->
numel
();
auto
*
dout_t
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx_t
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy_t
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dout
=
dout_t
->
data
<
T
>
();
if
(
dx_t
!=
nullptr
)
{
auto
*
dx
=
dx_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
numel
;
i
++
)
{
dx
[
i
]
=
dout
[
i
]
*
(
cond_data
[
i
]
?
1.
:
0.
);
}
}
if
(
dy_t
!=
nullptr
)
{
auto
*
dy
=
dy_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
numel
;
i
++
)
{
dy
[
i
]
=
dout
[
i
]
*
(
cond_data
[
i
]
?
0.
:
1.
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/where_op_npu.cc
浏览文件 @
468a2a17
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/
operators/where_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace
paddle
{
...
...
paddle/fluid/operators/where_op_xpu.cc
浏览文件 @
468a2a17
...
...
@@ -14,7 +14,7 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/
operators/where_op
.h"
#include "paddle/fluid/
framework/op_registry
.h"
namespace
paddle
{
namespace
operators
{
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
468a2a17
...
...
@@ -306,8 +306,7 @@ void CrossInferMeta(const MetaTensor& x,
}
void
Atan2InferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
)
{
auto
in_dims
=
x
.
dims
();
out
->
set_dims
(
in_dims
);
out
->
share_meta
(
x
);
}
void
BCELossInferMeta
(
const
MetaTensor
&
input
,
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
468a2a17
...
...
@@ -133,4 +133,29 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out
->
share_lod
(
*
x
.
at
(
0
));
}
void
WhereInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
)
{
auto
cond_dims
=
condition
.
dims
();
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
PADDLE_ENFORCE_EQ
(
cond_dims
,
x_dims
,
phi
::
errors
::
InvalidArgument
(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]"
,
cond_dims
,
x_dims
));
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
phi
::
errors
::
InvalidArgument
(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]"
,
x_dims
,
y_dims
));
out
->
share_meta
(
x
);
}
}
// namespace phi
paddle/phi/infermeta/multiary.h
浏览文件 @
468a2a17
...
...
@@ -30,4 +30,8 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
void
WhereInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/atan2_grad_kernel.cc
浏览文件 @
468a2a17
...
...
@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
atan2_grad
,
CPU
,
...
...
paddle/phi/kernels/cpu/atan2_kernel.cc
浏览文件 @
468a2a17
...
...
@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL
(
atan2
,
CPU
,
...
...
paddle/phi/kernels/cpu/where_grad_kernel.cc
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_grad_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
WhereGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
)
{
const
auto
*
cond_data
=
condition
.
data
<
bool
>
();
auto
numel
=
condition
.
numel
();
auto
*
dout
=
out_grad
.
data
<
T
>
();
if
(
x_grad
!=
nullptr
)
{
auto
*
dx
=
ctx
.
template
Alloc
<
T
>(
x_grad
);
for
(
int
i
=
0
;
i
<
numel
;
i
++
)
{
dx
[
i
]
=
dout
[
i
]
*
(
cond_data
[
i
]
?
1.
:
0.
);
}
}
if
(
y_grad
!=
nullptr
)
{
auto
*
dy
=
ctx
.
template
Alloc
<
T
>(
y_grad
);
for
(
int
i
=
0
;
i
<
numel
;
i
++
)
{
dy
[
i
]
=
dout
[
i
]
*
(
cond_data
[
i
]
?
0.
:
1.
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
where_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
WhereGradKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/cpu/where_kernel.cc
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
WhereKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
)
{
const
bool
*
cond_data
=
condition
.
data
<
bool
>
();
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
y_data
=
y
.
data
<
T
>
();
auto
x_numel
=
x
.
numel
();
T
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
for
(
int
i
=
0
;
i
<
x_numel
;
i
++
)
{
out_data
[
i
]
=
cond_data
[
i
]
?
x_data
[
i
]
:
y_data
[
i
];
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
where
,
CPU
,
ALL_LAYOUT
,
phi
::
WhereKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/atan2_grad_kernel.cu
浏览文件 @
468a2a17
...
...
@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
atan2_grad
,
GPU
,
...
...
paddle/phi/kernels/gpu/atan2_kernel.cu
浏览文件 @
468a2a17
...
...
@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL
(
atan2
,
GPU
,
...
...
paddle/phi/kernels/gpu/where_grad_kernel.cu
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_grad_kernel.h"
namespace
phi
{
template
<
typename
T
>
__global__
void
WhereGradCUDAKernel
(
const
int
N
,
const
T
*
dout
,
const
bool
*
cond
,
T
*
dx
,
T
*
dy
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
dx
!=
nullptr
)
{
dx
[
idx
]
=
cond
[
idx
]
?
dout
[
idx
]
:
0.
;
}
if
(
dy
!=
nullptr
)
{
dy
[
idx
]
=
cond
[
idx
]
?
0.
:
dout
[
idx
];
}
}
}
template
<
typename
T
,
typename
Context
>
void
WhereGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
)
{
const
bool
*
cond_data
=
condition
.
data
<
bool
>
();
auto
numel
=
condition
.
numel
();
auto
*
dout
=
out_grad
.
data
<
T
>
();
T
*
dx
=
(
x_grad
!=
nullptr
)
?
ctx
.
template
Alloc
<
T
>(
x_grad
)
:
nullptr
;
T
*
dy
=
(
y_grad
!=
nullptr
)
?
ctx
.
template
Alloc
<
T
>(
y_grad
)
:
nullptr
;
auto
stream
=
ctx
.
stream
();
auto
config
=
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
numel
);
WhereGradCUDAKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
stream
>>>
(
numel
,
dout
,
cond_data
,
dx
,
dy
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
where_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
WhereGradKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/where_kernel.cu
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace
phi
{
// Cond
template
<
typename
T
>
struct
CondFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
bool
cond
,
const
T
x
,
const
T
y
)
const
{
return
cond
?
x
:
y
;
}
};
template
<
typename
T
,
typename
Context
>
void
WhereKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
)
{
std
::
vector
<
const
DenseTensor
*>
ins
=
{
&
condition
,
&
x
,
&
y
};
std
::
vector
<
DenseTensor
*>
outs
=
{
out
};
ctx
.
template
Alloc
<
T
>(
out
);
CondFunctor
<
T
>
func
;
funcs
::
BroadcastKernel
<
ElementwiseType
::
kTernary
,
T
,
T
>
(
ctx
,
ins
,
&
outs
,
-
1
,
func
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
where
,
GPU
,
ALL_LAYOUT
,
phi
::
WhereKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/impl/atan2_grad_kernel_impl.h
浏览文件 @
468a2a17
...
...
@@ -14,9 +14,10 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
...
...
paddle/phi/kernels/impl/atan2_kernel_impl.h
浏览文件 @
468a2a17
...
...
@@ -14,9 +14,10 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
>
...
...
paddle/phi/kernels/where_grad_kernel.h
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
WhereGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
y_grad
);
}
// namespace phi
paddle/phi/kernels/where_kernel.h
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
WhereKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/where_grad_sig.cc
0 → 100644
浏览文件 @
468a2a17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
WhereGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"where_grad"
,
{
"Condition"
,
"X"
,
"Y"
,
GradVarName
(
"Out"
)},
{},
{
GradVarName
(
"X"
),
GradVarName
(
"Y"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
where_grad
,
phi
::
WhereGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录