diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_op.cc index 92ed2bbdc33f55315b3dddf8dc106b7716e97a6f..0f10efefa137b698b59db23b67122df990cfa366 100644 --- a/paddle/fluid/operators/where_op.cc +++ b/paddle/fluid/operators/where_op.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/where_op.h" - +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -21,31 +23,6 @@ class WhereOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where"); - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Where"); - - auto cond_dims = ctx->GetInputDim("Condition"); - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ( - cond_dims, x_dims, - platform::errors::InvalidArgument( - "The dims of Inputs(Condition) and Inputs(X) should be same. " - "But received Condition's shape is [%s], X's shape is [%s]", - cond_dims, x_dims)); - PADDLE_ENFORCE_EQ(x_dims, y_dims, - platform::errors::InvalidArgument( - "The dims of Inputs(X) and Inputs(Y) should be same. " - "But received X's shape is [%s], Y's shape is [%s]", - x_dims, y_dims)); - - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -140,19 +117,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y"); } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(where, WhereInferShapeFunctor, + PT_INFER_META(phi::WhereInferMeta)); REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker, ops::WhereOpGradMaker, - ops::WhereOpGradMaker); + ops::WhereOpGradMaker, + WhereInferShapeFunctor); REGISTER_OPERATOR(where_grad, ops::WhereGradOp, ops::WhereGradNoNeedBufferVarsInferer); -REGISTER_OP_CPU_KERNEL( - where, ops::WhereKernel, - ops::WhereKernel, - ops::WhereKernel, - ops::WhereKernel); -REGISTER_OP_CPU_KERNEL( - where_grad, ops::WhereGradKernel, - ops::WhereGradKernel, - ops::WhereGradKernel, - ops::WhereGradKernel); diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu deleted file mode 100644 index 61a1691e4fe265035917ed2407d5e3e24aa6bd88..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/where_op.cu +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/where_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" - -namespace platform = paddle::platform; - -namespace paddle { -namespace operators { - -template -struct CondFunctor { - HOSTDEVICE inline CondFunctor() {} - - HOSTDEVICE inline T operator()(const bool cond, const T x, const T y) const { - return cond ? x : y; - } -}; - -template -__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, - const T* y, T* out) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < N; idx += blockDim.x * gridDim.x) { - out[idx] = cond[idx] ? x[idx] : y[idx]; - } -} - -template -__global__ void WhereGradCUDAKernel(const int N, const T* dout, - const bool* cond, T* dx, T* dy) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < N; idx += blockDim.x * gridDim.x) { - if (dx != nullptr) { - dx[idx] = cond[idx] ? dout[idx] : 0.; - } - if (dy != nullptr) { - dy[idx] = cond[idx] ? 0. : dout[idx]; - } - } -} - -template -class WhereKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* out = context.Output("Out"); - auto numel = condition->numel(); - - // TODO(GaaoWei8): Input of where can be broadcast - const bool* cond_data = condition->data(); - const T* x_data = X->data(); - const T* y_data = Y->data(); - T* out_data = out->mutable_data(context.GetPlace()); - - auto stream = context.cuda_device_context().stream(); - auto& dev_ctx = - context.template device_context(); - auto functor = CondFunctor(); - std::vector ins = {condition, X, Y}; - std::vector outs = {out}; - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); - } -}; - -template -class WhereGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - const bool* cond_data = condition->data(); - auto numel = condition->numel(); - - auto* dout_t = - context.Input(framework::GradVarName("Out")); - auto* dx_t = context.Output(framework::GradVarName("X")); - auto* dy_t = context.Output(framework::GradVarName("Y")); - auto* dout = dout_t->data(); - T* dx = - (dx_t != nullptr) ? dx_t->mutable_data(context.GetPlace()) : nullptr; - T* dy = - (dy_t != nullptr) ? dy_t->mutable_data(context.GetPlace()) : nullptr; - - auto stream = context.cuda_device_context().stream(); - auto& dev_ctx = - context.template device_context(); - auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel()); - WhereGradCUDAKernel< - T><<>>( - numel, dout, cond_data, dx, dy); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - where, paddle::operators::WhereKernel, - paddle::operators::WhereKernel, - paddle::operators::WhereKernel, - paddle::operators::WhereKernel); -REGISTER_OP_CUDA_KERNEL( - where_grad, - paddle::operators::WhereGradKernel, - paddle::operators::WhereGradKernel, - paddle::operators::WhereGradKernel, - paddle::operators::WhereGradKernel); diff --git a/paddle/fluid/operators/where_op.h b/paddle/fluid/operators/where_op.h deleted file mode 100644 index 5398ee024a2890e38e88fc981721872e1ba34d60..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/where_op.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class WhereKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* out = context.Output("Out"); - - const bool* cond_data = condition->data(); - const T* x_data = X->data(); - const T* y_data = Y->data(); - T* out_data = out->mutable_data(context.GetPlace()); - - auto x_numel = X->numel(); - for (int i = 0; i < x_numel; i++) { - out_data[i] = cond_data[i] ? x_data[i] : y_data[i]; - } - } -}; - -template -class WhereGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - const auto* cond_data = condition->data(); - auto numel = condition->numel(); - - auto* dout_t = - context.Input(framework::GradVarName("Out")); - auto* dx_t = context.Output(framework::GradVarName("X")); - auto* dy_t = context.Output(framework::GradVarName("Y")); - - auto* dout = dout_t->data(); - if (dx_t != nullptr) { - auto* dx = dx_t->mutable_data(context.GetPlace()); - for (int i = 0; i < numel; i++) { - dx[i] = dout[i] * (cond_data[i] ? 1. : 0.); - } - } - if (dy_t != nullptr) { - auto* dy = dy_t->mutable_data(context.GetPlace()); - for (int i = 0; i < numel; i++) { - dy[i] = dout[i] * (cond_data[i] ? 0. : 1.); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/where_op_npu.cc b/paddle/fluid/operators/where_op_npu.cc index d4294393daa34612aae815b0ebfab7d55f0b9f46..35508950941783753734a916aa7c2dcff7731181 100755 --- a/paddle/fluid/operators/where_op_npu.cc +++ b/paddle/fluid/operators/where_op_npu.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/where_op_xpu.cc b/paddle/fluid/operators/where_op_xpu.cc index 3a4875c07005119e90f5d5cb448a63bcf62a09a4..41232c8b5e8d88564e59e0343a26a4ae98d5ed90 100644 --- a/paddle/fluid/operators/where_op_xpu.cc +++ b/paddle/fluid/operators/where_op_xpu.cc @@ -14,7 +14,7 @@ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 1905e33bd03160d526f44addba1f614ff2ac3bd1..675e68af74339b508f589a55a9c3cf3aed37cecb 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -306,8 +306,7 @@ void CrossInferMeta(const MetaTensor& x, } void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { - auto in_dims = x.dims(); - out->set_dims(in_dims); + out->share_meta(x); } void BCELossInferMeta(const MetaTensor& input, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8857c2cf424e4441ef8f08de1e1a2fa4b44c84d6..7634e5e01aca4cdaf7fb46399f9594897f2d0e36 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -133,4 +133,29 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void WhereInferMeta(const MetaTensor& condition, + const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto cond_dims = condition.dims(); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + PADDLE_ENFORCE_EQ( + cond_dims, + x_dims, + phi::errors::InvalidArgument( + "The dims of Inputs(Condition) and Inputs(X) should be same. " + "But received Condition's shape is [%s], X's shape is [%s]", + cond_dims, + x_dims)); + PADDLE_ENFORCE_EQ(x_dims, + y_dims, + phi::errors::InvalidArgument( + "The dims of Inputs(X) and Inputs(Y) should be same. " + "But received X's shape is [%s], Y's shape is [%s]", + x_dims, + y_dims)); + out->share_meta(x); +} + } // namespace phi diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 473845c6e409f481a04ba73b5e418028eafa0116..2afb79daa355cc897e3bf4076003e9a41de8b96c 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -30,4 +30,8 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void WhereInferMeta(const MetaTensor& condition, + const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/atan2_grad_kernel.cc b/paddle/phi/kernels/cpu/atan2_grad_kernel.cc index 6ff7431f0c8c556770b54e1328251e5996850fc9..7a519aab0ad71e4cd20270b216bf65262cab8ba6 100644 --- a/paddle/phi/kernels/cpu/atan2_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/atan2_grad_kernel.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/atan2_grad_kernel.h" +#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" + #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" PD_REGISTER_KERNEL(atan2_grad, CPU, diff --git a/paddle/phi/kernels/cpu/atan2_kernel.cc b/paddle/phi/kernels/cpu/atan2_kernel.cc index eb38a6c90b7938ef16cf9d56dfdb93903cc3c6a1..df6f5f59ac0056f36749faec8a300c1b5a1da1c9 100644 --- a/paddle/phi/kernels/cpu/atan2_kernel.cc +++ b/paddle/phi/kernels/cpu/atan2_kernel.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/atan2_kernel.h" +#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" + #include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" PD_REGISTER_KERNEL(atan2, CPU, diff --git a/paddle/phi/kernels/cpu/where_grad_kernel.cc b/paddle/phi/kernels/cpu/where_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..67c8cee1038c7a990e5961a3fcd17e8d7c591207 --- /dev/null +++ b/paddle/phi/kernels/cpu/where_grad_kernel.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_grad_kernel.h" + +namespace phi { + +template +void WhereGradKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const auto* cond_data = condition.data(); + auto numel = condition.numel(); + auto* dout = out_grad.data(); + + if (x_grad != nullptr) { + auto* dx = ctx.template Alloc(x_grad); + for (int i = 0; i < numel; i++) { + dx[i] = dout[i] * (cond_data[i] ? 1. : 0.); + } + } + if (y_grad != nullptr) { + auto* dy = ctx.template Alloc(y_grad); + for (int i = 0; i < numel; i++) { + dy[i] = dout[i] * (cond_data[i] ? 0. : 1.); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(where_grad, + CPU, + ALL_LAYOUT, + phi::WhereGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/where_kernel.cc b/paddle/phi/kernels/cpu/where_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f624c13c262296964cef6b98f7d5d26dfc0b7d56 --- /dev/null +++ b/paddle/phi/kernels/cpu/where_kernel.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_kernel.h" + +namespace phi { + +template +void WhereKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + const bool* cond_data = condition.data(); + const T* x_data = x.data(); + const T* y_data = y.data(); + auto x_numel = x.numel(); + + T* out_data = ctx.template Alloc(out); + + for (int i = 0; i < x_numel; i++) { + out_data[i] = cond_data[i] ? x_data[i] : y_data[i]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + where, CPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu index 1cc3311c3639820ef9b6d3a29d9274ac93bb5963..6652d242de5ce44f3bf64d91e6fae16c648c2726 100644 --- a/paddle/phi/kernels/gpu/atan2_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_grad_kernel.cu @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" + #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/atan2_grad_kernel.h" -#include "paddle/phi/kernels/impl/atan2_grad_kernel_impl.h" PD_REGISTER_KERNEL(atan2_grad, GPU, diff --git a/paddle/phi/kernels/gpu/atan2_kernel.cu b/paddle/phi/kernels/gpu/atan2_kernel.cu index 702c959b78f75d0e52511d9bdc9d4330c6838aa4..dd0bba177defef7cdbd41ef7944110d126ca2d7c 100644 --- a/paddle/phi/kernels/gpu/atan2_kernel.cu +++ b/paddle/phi/kernels/gpu/atan2_kernel.cu @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" + #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/atan2_kernel.h" -#include "paddle/phi/kernels/impl/atan2_kernel_impl.h" PD_REGISTER_KERNEL(atan2, GPU, diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f21aca80e21b30de8931b4fcd4ae3922be959958 --- /dev/null +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_grad_kernel.h" + +namespace phi { + +template +__global__ void WhereGradCUDAKernel( + const int N, const T* dout, const bool* cond, T* dx, T* dy) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + if (dx != nullptr) { + dx[idx] = cond[idx] ? dout[idx] : 0.; + } + if (dy != nullptr) { + dy[idx] = cond[idx] ? 0. : dout[idx]; + } + } +} + +template +void WhereGradKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + const bool* cond_data = condition.data(); + auto numel = condition.numel(); + auto* dout = out_grad.data(); + + T* dx = (x_grad != nullptr) ? ctx.template Alloc(x_grad) : nullptr; + T* dy = (y_grad != nullptr) ? ctx.template Alloc(y_grad) : nullptr; + + auto stream = ctx.stream(); + auto config = backends::gpu::GetGpuLaunchConfig1D(ctx, numel); + WhereGradCUDAKernel< + T><<>>( + numel, dout, cond_data, dx, dy); +} + +} // namespace phi + +PD_REGISTER_KERNEL(where_grad, + GPU, + ALL_LAYOUT, + phi::WhereGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..03c24eea3a95af1ed57f5c8df42b01fd09af1fa2 --- /dev/null +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/where_kernel.h" + +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +namespace phi { + +// Cond +template +struct CondFunctor { + inline HOSTDEVICE T operator()(const bool cond, const T x, const T y) const { + return cond ? x : y; + } +}; + +template +void WhereKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector ins = {&condition, &x, &y}; + std::vector outs = {out}; + ctx.template Alloc(out); + + CondFunctor func; + funcs::BroadcastKernel( + ctx, ins, &outs, -1, func); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + where, GPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h index d0dd18298518ab351918aa2492eb48d11d3cf1d7..0eff1378f41de9b31a35375f86ca69a427d19f4f 100644 --- a/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h @@ -14,9 +14,10 @@ #pragma once -#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/atan2_grad_kernel.h" -#include "paddle/phi/kernels/funcs/for_range.h" + +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h index 2cae914e2f61555377f7a41b3d89cdbb2b589247..7653032f2113c6e181673c57feaec2efd6472838 100644 --- a/paddle/phi/kernels/impl/atan2_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -14,9 +14,10 @@ #pragma once -#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/atan2_kernel.h" -#include "paddle/phi/kernels/funcs/for_range.h" + +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/core/dense_tensor.h" namespace phi { template diff --git a/paddle/phi/kernels/where_grad_kernel.h b/paddle/phi/kernels/where_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1a3c66ee6ed8403d0b453ed38d21e4beed02661c --- /dev/null +++ b/paddle/phi/kernels/where_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void WhereGradKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/where_kernel.h b/paddle/phi/kernels/where_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..254271ac9c7238c66d09ffe41d12e29fe8f23237 --- /dev/null +++ b/paddle/phi/kernels/where_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void WhereKernel(const Context& ctx, + const DenseTensor& condition, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/where_grad_sig.cc b/paddle/phi/ops/compat/where_grad_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..71984a26d35afd841654d82480c263799bdbf181 --- /dev/null +++ b/paddle/phi/ops/compat/where_grad_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature WhereGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("where_grad", + {"Condition", "X", "Y", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(where_grad, phi::WhereGradOpArgumentMapping);