未验证 提交 468a2a17 编写于 作者: R ronnywang 提交者: GitHub

[phi] migrate where kernel into phi (#39811)

上级 4fbcf6f4
......@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -21,31 +23,6 @@ class WhereOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Where");
auto cond_dims = ctx->GetInputDim("Condition");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(
cond_dims, x_dims,
platform::errors::InvalidArgument(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]",
cond_dims, x_dims));
PADDLE_ENFORCE_EQ(x_dims, y_dims,
platform::errors::InvalidArgument(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
x_dims, y_dims));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -140,19 +117,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y");
} // namespace paddle
namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(where, WhereInferShapeFunctor,
PT_INFER_META(phi::WhereInferMeta));
REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker,
ops::WhereOpGradMaker<paddle::framework::OpDesc>,
ops::WhereOpGradMaker<paddle::imperative::OpBase>);
ops::WhereOpGradMaker<paddle::imperative::OpBase>,
WhereInferShapeFunctor);
REGISTER_OPERATOR(where_grad, ops::WhereGradOp,
ops::WhereGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, double>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, int>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
where_grad, ops::WhereGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
namespace platform = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CondFunctor {
HOSTDEVICE inline CondFunctor() {}
HOSTDEVICE inline T operator()(const bool cond, const T x, const T y) const {
return cond ? x : y;
}
};
template <typename T>
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
const T* y, T* out) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
out[idx] = cond[idx] ? x[idx] : y[idx];
}
}
template <typename T>
__global__ void WhereGradCUDAKernel(const int N, const T* dout,
const bool* cond, T* dx, T* dy) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
if (dx != nullptr) {
dx[idx] = cond[idx] ? dout[idx] : 0.;
}
if (dy != nullptr) {
dy[idx] = cond[idx] ? 0. : dout[idx];
}
}
}
template <typename T>
class WhereKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
auto numel = condition->numel();
// TODO(GaaoWei8): Input of where can be broadcast
const bool* cond_data = condition->data<bool>();
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto functor = CondFunctor<T>();
std::vector<const framework::Tensor*> ins = {condition, X, Y};
std::vector<framework::Tensor*> outs = {out};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
template <typename T>
class WhereGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
const bool* cond_data = condition->data<bool>();
auto numel = condition->numel();
auto* dout_t =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* dout = dout_t->data<T>();
T* dx =
(dx_t != nullptr) ? dx_t->mutable_data<T>(context.GetPlace()) : nullptr;
T* dy =
(dy_t != nullptr) ? dy_t->mutable_data<T>(context.GetPlace()) : nullptr;
auto stream = context.cuda_device_context().stream();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel());
WhereGradCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, dout, cond_data, dx, dy);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
where, paddle::operators::WhereKernel<platform::CUDADeviceContext, float>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, double>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, int>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
where_grad,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, float>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, double>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class WhereKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
const bool* cond_data = condition->data<bool>();
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_numel = X->numel();
for (int i = 0; i < x_numel; i++) {
out_data[i] = cond_data[i] ? x_data[i] : y_data[i];
}
}
};
template <typename DeviceContext, typename T>
class WhereGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::LoDTensor>("Condition");
const auto* cond_data = condition->data<bool>();
auto numel = condition->numel();
auto* dout_t =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* dout = dout_t->data<T>();
if (dx_t != nullptr) {
auto* dx = dx_t->mutable_data<T>(context.GetPlace());
for (int i = 0; i < numel; i++) {
dx[i] = dout[i] * (cond_data[i] ? 1. : 0.);
}
}
if (dy_t != nullptr) {
auto* dy = dy_t->mutable_data<T>(context.GetPlace());
for (int i = 0; i < numel; i++) {
dy[i] = dout[i] * (cond_data[i] ? 0. : 1.);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -14,7 +14,7 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......
......@@ -306,8 +306,7 @@ void CrossInferMeta(const MetaTensor& x,
}
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto in_dims = x.dims();
out->set_dims(in_dims);
out->share_meta(x);
}
void BCELossInferMeta(const MetaTensor& input,
......
......@@ -133,4 +133,29 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
auto cond_dims = condition.dims();
auto x_dims = x.dims();
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
cond_dims,
x_dims,
phi::errors::InvalidArgument(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]",
cond_dims,
x_dims));
PADDLE_ENFORCE_EQ(x_dims,
y_dims,
phi::errors::InvalidArgument(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
x_dims,
y_dims));
out->share_meta(x);
}
} // namespace phi
......@@ -30,4 +30,8 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out);
} // namespace phi
......@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
PD_REGISTER_KERNEL(atan2_grad,
CPU,
......
......@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL(atan2,
CPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void WhereGradKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const auto* cond_data = condition.data<bool>();
auto numel = condition.numel();
auto* dout = out_grad.data<T>();
if (x_grad != nullptr) {
auto* dx = ctx.template Alloc<T>(x_grad);
for (int i = 0; i < numel; i++) {
dx[i] = dout[i] * (cond_data[i] ? 1. : 0.);
}
}
if (y_grad != nullptr) {
auto* dy = ctx.template Alloc<T>(y_grad);
for (int i = 0; i < numel; i++) {
dy[i] = dout[i] * (cond_data[i] ? 0. : 1.);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(where_grad,
CPU,
ALL_LAYOUT,
phi::WhereGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_kernel.h"
namespace phi {
template <typename T, typename Context>
void WhereKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
const bool* cond_data = condition.data<bool>();
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
auto x_numel = x.numel();
T* out_data = ctx.template Alloc<T>(out);
for (int i = 0; i < x_numel; i++) {
out_data[i] = cond_data[i] ? x_data[i] : y_data[i];
}
}
} // namespace phi
PD_REGISTER_KERNEL(
where, CPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {}
......@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h"
PD_REGISTER_KERNEL(atan2_grad,
GPU,
......
......@@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/impl/atan2_kernel_impl.h"
PD_REGISTER_KERNEL(atan2,
GPU,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_grad_kernel.h"
namespace phi {
template <typename T>
__global__ void WhereGradCUDAKernel(
const int N, const T* dout, const bool* cond, T* dx, T* dy) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
if (dx != nullptr) {
dx[idx] = cond[idx] ? dout[idx] : 0.;
}
if (dy != nullptr) {
dy[idx] = cond[idx] ? 0. : dout[idx];
}
}
}
template <typename T, typename Context>
void WhereGradKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
const bool* cond_data = condition.data<bool>();
auto numel = condition.numel();
auto* dout = out_grad.data<T>();
T* dx = (x_grad != nullptr) ? ctx.template Alloc<T>(x_grad) : nullptr;
T* dy = (y_grad != nullptr) ? ctx.template Alloc<T>(y_grad) : nullptr;
auto stream = ctx.stream();
auto config = backends::gpu::GetGpuLaunchConfig1D(ctx, numel);
WhereGradCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, dout, cond_data, dx, dy);
}
} // namespace phi
PD_REGISTER_KERNEL(where_grad,
GPU,
ALL_LAYOUT,
phi::WhereGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
namespace phi {
// Cond
template <typename T>
struct CondFunctor {
inline HOSTDEVICE T operator()(const bool cond, const T x, const T y) const {
return cond ? x : y;
}
};
template <typename T, typename Context>
void WhereKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> ins = {&condition, &x, &y};
std::vector<DenseTensor*> outs = {out};
ctx.template Alloc<T>(out);
CondFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kTernary, T, T>(
ctx, ins, &outs, -1, func);
}
} // namespace phi
PD_REGISTER_KERNEL(
where, GPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {}
......@@ -14,9 +14,10 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_grad_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......
......@@ -14,9 +14,10 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/atan2_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void WhereGradKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void WhereKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature WhereGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("where_grad",
{"Condition", "X", "Y", GradVarName("Out")},
{},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(where_grad, phi::WhereGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册