From c068512f345fc19199f9069f7e08ac147ac090b6 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Wed, 8 Apr 2020 17:14:55 +0800 Subject: [PATCH] Implement a new C++ operator where and API tensor.where (#23220) --- paddle/fluid/operators/where_op.cc | 159 ++++++++++++++++ paddle/fluid/operators/where_op.cu | 122 ++++++++++++ paddle/fluid/operators/where_op.h | 73 ++++++++ .../fluid/tests/unittests/test_where_op.py | 173 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/search.py | 75 +++++++- 6 files changed, 600 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/where_op.cc create mode 100644 paddle/fluid/operators/where_op.cu create mode 100644 paddle/fluid/operators/where_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_where_op.py diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_op.cc new file mode 100644 index 00000000000..bdb3fb24ded --- /dev/null +++ b/paddle/fluid/operators/where_op.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/where_op.h" + +namespace paddle { +namespace operators { + +class WhereOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Where"); + + auto cond_dims = ctx->GetInputDim("Condition"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ( + cond_dims, x_dims, + platform::errors::InvalidArgument( + "The dims of Inputs(Condition) and Inputs(X) should be same. " + "But received Condition's shape is [%s], X's shape is [%s]", + cond_dims, x_dims)); + PADDLE_ENFORCE_EQ(x_dims, y_dims, + platform::errors::InvalidArgument( + "The dims of Inputs(X) and Inputs(Y) should be same. " + "But received X's shape is [%s], Y's shape is [%s]", + x_dims, y_dims)); + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class WhereGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "Where"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +class WhereOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Condition", + "(Tensor) A bool tensor whose rank is at least 1. When Condition " + "is True, yield x, otherwise yield y"); + AddInput("X", + "(Tensor), The first input tensor of where op. When the " + "corresponding position of the condition is true, the output " + "takes the element of X."); + AddInput("Y", + "(Tensor), The second input tensor of where op. When the " + "corresponding position of condition is false, the output takes " + "the element of Y."); + AddOutput("Out", "(Tensor), The output tensor of mul op."); + AddComment(R"DOC( + Where Operator. + Return a tensor of elements selected from either $X$ or $Y$, depending on condition. + The equation is: + $$ + Out_i = + \begin{cases} + \X_i, \quad \text{if} \ cond_i is True \\ + \Y_i, \quad \text{if} \ cond_i is False \\ + \end{cases} + $$ +)DOC"); + } +}; + +template +class WhereOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad) const override { + grad->SetType("where_grad"); + grad->SetInput("Condition", this->Input("Condition")); + grad->SetInput("X", this->Input("X")); + grad->SetInput("Y", this->Input("Y")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInference, "X", + "Y"); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker, + ops::WhereOpGradMaker, + ops::WhereOpGradMaker); + +REGISTER_OPERATOR(where_grad, ops::WhereGradOp, + ops::WhereGradNoNeedBufferVarsInference); +REGISTER_OP_CPU_KERNEL( + where, ops::WhereKernel, + ops::WhereKernel, + ops::WhereKernel, + ops::WhereKernel); +REGISTER_OP_CPU_KERNEL( + where_grad, ops::WhereGradKernel, + ops::WhereGradKernel, + ops::WhereGradKernel, + ops::WhereGradKernel); diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_op.cu new file mode 100644 index 00000000000..0ec1a3c6fa6 --- /dev/null +++ b/paddle/fluid/operators/where_op.cu @@ -0,0 +1,122 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" + +namespace platform = paddle::platform; + +namespace paddle { +namespace operators { + +template +__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x, + const T* y, T* out) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + out[idx] = cond[idx] ? x[idx] : y[idx]; + } +} + +template +__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond, + T* x, T* y) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < N; idx += blockDim.x * gridDim.x) { + if (x != nullptr) { + x[idx] = out[idx] * (cond[idx] ? 1. : 0.); + } + if (y != nullptr) { + y[idx] = out[idx] * (cond[idx] ? 0. : 1.); + } + } +} + +template +class WhereKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(context.GetPlace()), true, + platform::errors::PermissionDenied("It must use CUDAPlace.")); + auto* condition = context.Input("Condition"); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* out = context.Output("Out"); + auto numel = condition->numel(); + + // TODO(GaaoWei8): Input of where can be broadcast + const bool* cond_data = condition->data(); + const T* x_data = X->data(); + const T* y_data = Y->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + auto stream = context.cuda_device_context().stream(); + auto& dev_ctx = + context.template device_context(); + auto config = GetGpuLaunchConfig1D(dev_ctx, numel); + WhereCUDAKernel< + T><<>>( + numel, cond_data, x_data, y_data, out_data); + } +}; + +template +class WhereGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(context.GetPlace()), true, + platform::errors::PermissionDenied("It must use CUDAPlace.")); + + auto* condition = context.Input("Condition"); + const bool* cond_data = condition->data(); + auto numel = condition->numel(); + + auto* dout_t = + context.Input(framework::GradVarName("Out")); + auto* dx_t = context.Output(framework::GradVarName("X")); + auto* dy_t = context.Output(framework::GradVarName("Y")); + auto* dout = dout_t->data(); + T* dx = + (dx_t != nullptr) ? dx_t->mutable_data(context.GetPlace()) : nullptr; + T* dy = + (dy_t != nullptr) ? dy_t->mutable_data(context.GetPlace()) : nullptr; + + auto stream = context.cuda_device_context().stream(); + auto& dev_ctx = + context.template device_context(); + auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel()); + WhereGradCUDAKernel< + T><<>>( + numel, dout, cond_data, dx, dy); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + where, paddle::operators::WhereKernel, + paddle::operators::WhereKernel, + paddle::operators::WhereKernel, + paddle::operators::WhereKernel); +REGISTER_OP_CUDA_KERNEL( + where_grad, + paddle::operators::WhereGradKernel, + paddle::operators::WhereGradKernel, + paddle::operators::WhereGradKernel, + paddle::operators::WhereGradKernel); diff --git a/paddle/fluid/operators/where_op.h b/paddle/fluid/operators/where_op.h new file mode 100644 index 00000000000..fdb65858eff --- /dev/null +++ b/paddle/fluid/operators/where_op.h @@ -0,0 +1,73 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class WhereKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* out = context.Output("Out"); + + const bool* cond_data = condition->data(); + const T* x_data = X->data(); + const T* y_data = Y->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + auto x_numel = X->numel(); + for (int i = 0; i < x_numel; i++) { + out_data[i] = cond_data[i] ? x_data[i] : y_data[i]; + } + } +}; + +template +class WhereGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* condition = context.Input("Condition"); + const auto* cond_data = condition->data(); + auto numel = condition->numel(); + + auto* dout_t = + context.Input(framework::GradVarName("Out")); + auto* dx_t = context.Output(framework::GradVarName("X")); + auto* dy_t = context.Output(framework::GradVarName("Y")); + + auto* dout = dout_t->data(); + if (dx_t != nullptr) { + auto* dx = dx_t->mutable_data(context.GetPlace()); + for (int i = 0; i < numel; i++) { + dx[i] = dout[i] * (cond_data[i] ? 1. : 0.); + } + } + if (dy_t != nullptr) { + auto* dy = dy_t->mutable_data(context.GetPlace()); + for (int i = 0; i < numel; i++) { + dy[i] = dout[i] * (cond_data[i] ? 0. : 1.); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_where_op.py b/python/paddle/fluid/tests/unittests/test_where_op.py new file mode 100644 index 00000000000..1ae311a9c46 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_where_op.py @@ -0,0 +1,173 @@ +#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.tensor as tensor +import paddle.fluid.core as core +from op_test import OpTest +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.op import Operator +from paddle.fluid.backward import append_backward + + +class TestWhereOp(OpTest): + def setUp(self): + self.op_type = "where" + self.init_config() + self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y} + self.outputs = {'Out': np.where(self.cond, self.x, self.y)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out') + + def init_config(self): + self.x = np.random.uniform(-3, 5, (100)).astype("float64") + self.y = np.random.uniform(-3, 5, (100)).astype("float64") + self.cond = np.zeros((100)).astype("bool") + + +class TestWhereOp2(TestWhereOp): + def init_config(self): + self.x = np.random.uniform(-5, 5, (60, 2)).astype("float64") + self.y = np.random.uniform(-5, 5, (60, 2)).astype("float64") + self.cond = np.ones((60, 2)).astype("bool") + + +class TestWhereOp3(TestWhereOp): + def init_config(self): + self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64") + self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64") + self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool) + + +class TestWhereAPI(unittest.TestCase): + def test_api(self, use_cuda=False): + main_program = Program() + with fluid.program_guard(main_program): + x = fluid.layers.data(name='x', shape=[4], dtype='float32') + y = fluid.layers.data(name='y', shape=[4], dtype='float32') + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") + cond_i = np.array([False, False, True, True]).astype("bool") + result = tensor.where(x > 1, X=x, Y=y) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result]) + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + + def test_grad(self, use_cuda=False): + main_program = Program() + for x_stop_gradient, y_stop_gradient in [[False, False], [True, False], + [False, True]]: + with fluid.program_guard(main_program): + x = fluid.layers.data(name='x', shape=[4], dtype='float32') + y = fluid.layers.data(name='y', shape=[4], dtype='float32') + x.stop_gradient = x_stop_gradient + y.stop_gradient = y_stop_gradient + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32") + cond_i = np.array([False, False, True, True]).astype("bool") + result = tensor.where(x > 1, X=x, Y=y) + x_mean = layers.mean(x) + append_backward(x_mean) + y_mean = layers.mean(y) + append_backward(y_mean) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result, x.grad_name, y.grad_name]) + x_grad = [0.25] * 4 + y_grad = [0.25] * 4 + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + assert np.array_equal(out[1], x_grad) + assert np.array_equal(out[2], y_grad) + + def test_api_broadcast(self, use_cuda=False): + main_program = Program() + with fluid.program_guard(main_program): + x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32') + y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32') + x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32") + y_i = np.array( + [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32") + cond_i = np.array([[False, False, True, True], + [False, False, True, True]]).astype("bool") + result = tensor.where(x > 1, X=x, Y=y) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + out = exe.run(fluid.default_main_program(), + feed={'x': x_i, + 'y': y_i}, + fetch_list=[result]) + assert np.array_equal(out[0], np.where(cond_i, x_i, y_i)) + + def test_fw_bw(self): + if core.is_compiled_with_cuda(): + self.test_api(use_cuda=True) + self.test_api_broadcast(use_cuda=True) + self.test_grad(use_cuda=True) + + +class TestWhereDygraphAPI(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(): + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") + cond_i = np.array([False, False, True, True]).astype("bool") + x = fluid.dygraph.to_variable(x_i) + y = fluid.dygraph.to_variable(y_i) + cond = fluid.dygraph.to_variable(cond_i) + out = tensor.where(cond, x, y) + assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i)) + + +class TestWhereOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") + cond_i = np.array([False, False, True, True]).astype("bool") + + def test_Variable(): + tensor.where(cond_i, x_i, y_i) + + self.assertRaises(TypeError, test_Variable) + + def test_type(): + x = fluid.layers.data(name='x', shape=[4], dtype='bool') + y = fluid.layers.data(name='y', shape=[4], dtype='float16') + cond = fluid.layers.data(name='cond', shape=[4], dtype='int32') + tensor.where(cond, x, y) + + self.assertRaises(TypeError, test_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 99b5453435f..64b4e307891 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -165,7 +165,7 @@ from .search import argmax #DEFINE_ALIAS # from .search import has_nan #DEFINE_ALIAS # from .search import masked_select #DEFINE_ALIAS # from .search import topk #DEFINE_ALIAS -# from .search import where #DEFINE_ALIAS +from .search import where #DEFINE_ALIAS # from .search import index_select #DEFINE_ALIAS from .search import index_sample # DEFINE_ALIAS # from .search import nonzero #DEFINE_ALIAS diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 3a4ca7ab7bd..c66effd1d75 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -12,10 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function +import numpy as np +import warnings +import six +import os +import inspect from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype - -# TODO: define searching & indexing functions of a tensor +from ..fluid.initializer import Normal, Constant, NumpyArrayInitializer +from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program +from ..fluid import dygraph_utils +from ..fluid.param_attr import ParamAttr +from ..fluid import unique_name +from ..fluid import core, layers + +# TODO: define searching & indexing functions of a tensor __all__ = [ 'argmax', # 'argmin', @@ -24,7 +35,7 @@ __all__ = [ # 'has_nan', # 'masked_select', # 'topk', - # 'where', + 'where', # 'index_select', # 'nonzero', 'sort', @@ -213,6 +224,64 @@ def sort(input, axis=-1, descending=False, out=None, name=None): return out, ids +def where(Condition, X, Y): + """ + Return a tensor of elements selected from either $X$ or $Y$, depending on $Condition$. + Args: + Condition(Variable): A bool tensor with rank at least 1, the data type is bool. + X(Variable): X is a Tensor Variable. + Y(Variable): Y is a Tensor Variable. + Returns: + out : The tensor. + Examples: + .. code-block:: python + + import numpy as np + import paddle as paddle + import paddle.fluid as fluid + + with fluid.dygraph.guard(): + x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64") + y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64") + x = fluid.dygraph.to_variable(x_i) + y = fluid.dygraph.to_variable(y_i) + out = paddle.where(x>1, x, y) + print(out.numpy()) + #out: [1.0, 1.0, 3.2, 1.2] + """ + if not in_dygraph_mode(): + check_variable_and_dtype(Condition, 'Condition', ['bool'], 'where') + check_variable_and_dtype( + X, 'X', ['float32', 'float64', 'int32', 'int64'], 'where') + check_variable_and_dtype( + Y, 'Y', ['float32', 'float64', 'int32', 'int64'], 'where') + + X_shape = list(X.shape) + Y_shape = list(Y.shape) + if X_shape == Y_shape: + if in_dygraph_mode(): + return core.ops.where(Condition, X, Y) + else: + helper = LayerHelper("where", **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type='where', + inputs={'Condition': Condition, + 'X': X, + 'Y': Y}, + outputs={'Out': [out]}) + return out + else: + cond_int = layers.cast(Condition, X.dtype) + cond_not_int = layers.cast(layers.logical_not(Condition), X.dtype) + out1 = layers.elementwise_mul(X, cond_int) + out2 = layers.elementwise_mul(Y, cond_not_int) + out = layers.elementwise_add(out1, out2) + return out + + def index_sample(x, index): """ **IndexSample Layer** -- GitLab