未验证 提交 c068512f 编写于 作者: G GaoWei8 提交者: GitHub

Implement a new C++ operator where and API tensor.where (#23220)

上级 9b82e4c1
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
namespace paddle {
namespace operators {
class WhereOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Where");
auto cond_dims = ctx->GetInputDim("Condition");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(
cond_dims, x_dims,
platform::errors::InvalidArgument(
"The dims of Inputs(Condition) and Inputs(X) should be same. "
"But received Condition's shape is [%s], X's shape is [%s]",
cond_dims, x_dims));
PADDLE_ENFORCE_EQ(x_dims, y_dims,
platform::errors::InvalidArgument(
"The dims of Inputs(X) and Inputs(Y) should be same. "
"But received X's shape is [%s], Y's shape is [%s]",
x_dims, y_dims));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class WhereGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "Where");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Condition",
"(Tensor) A bool tensor whose rank is at least 1. When Condition "
"is True, yield x, otherwise yield y");
AddInput("X",
"(Tensor), The first input tensor of where op. When the "
"corresponding position of the condition is true, the output "
"takes the element of X.");
AddInput("Y",
"(Tensor), The second input tensor of where op. When the "
"corresponding position of condition is false, the output takes "
"the element of Y.");
AddOutput("Out", "(Tensor), The output tensor of mul op.");
AddComment(R"DOC(
Where Operator.
Return a tensor of elements selected from either $X$ or $Y$, depending on condition.
The equation is:
$$
Out_i =
\begin{cases}
\X_i, \quad \text{if} \ cond_i is True \\
\Y_i, \quad \text{if} \ cond_i is False \\
\end{cases}
$$
)DOC");
}
};
template <typename T>
class WhereOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("where_grad");
grad->SetInput("Condition", this->Input("Condition"));
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInference, "X",
"Y");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(where, ops::WhereOp, ops::WhereOpMaker,
ops::WhereOpGradMaker<paddle::framework::OpDesc>,
ops::WhereOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(where_grad, ops::WhereGradOp,
ops::WhereGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
where, ops::WhereKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, double>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, int>,
ops::WhereKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
where_grad, ops::WhereGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::WhereGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
namespace platform = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
const T* y, T* out) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
out[idx] = cond[idx] ? x[idx] : y[idx];
}
}
template <typename T>
__global__ void WhereGradCUDAKernel(const int N, const T* out, const bool* cond,
T* x, T* y) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
if (x != nullptr) {
x[idx] = out[idx] * (cond[idx] ? 1. : 0.);
}
if (y != nullptr) {
y[idx] = out[idx] * (cond[idx] ? 0. : 1.);
}
}
}
template <typename T>
class WhereKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::PermissionDenied("It must use CUDAPlace."));
auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
auto numel = condition->numel();
// TODO(GaaoWei8): Input of where can be broadcast
const bool* cond_data = condition->data<bool>();
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
WhereCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, cond_data, x_data, y_data, out_data);
}
};
template <typename T>
class WhereGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::PermissionDenied("It must use CUDAPlace."));
auto* condition = context.Input<framework::Tensor>("Condition");
const bool* cond_data = condition->data<bool>();
auto numel = condition->numel();
auto* dout_t =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* dout = dout_t->data<T>();
T* dx =
(dx_t != nullptr) ? dx_t->mutable_data<T>(context.GetPlace()) : nullptr;
T* dy =
(dy_t != nullptr) ? dy_t->mutable_data<T>(context.GetPlace()) : nullptr;
auto stream = context.cuda_device_context().stream();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, condition->numel());
WhereGradCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, dout, cond_data, dx, dy);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
where, paddle::operators::WhereKernel<platform::CUDADeviceContext, float>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, double>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, int>,
paddle::operators::WhereKernel<platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
where_grad,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, float>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, double>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int>,
paddle::operators::WhereGradKernel<platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class WhereKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
const bool* cond_data = condition->data<bool>();
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_numel = X->numel();
for (int i = 0; i < x_numel; i++) {
out_data[i] = cond_data[i] ? x_data[i] : y_data[i];
}
}
};
template <typename DeviceContext, typename T>
class WhereGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::LoDTensor>("Condition");
const auto* cond_data = condition->data<bool>();
auto numel = condition->numel();
auto* dout_t =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx_t = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy_t = context.Output<framework::Tensor>(framework::GradVarName("Y"));
auto* dout = dout_t->data<T>();
if (dx_t != nullptr) {
auto* dx = dx_t->mutable_data<T>(context.GetPlace());
for (int i = 0; i < numel; i++) {
dx[i] = dout[i] * (cond_data[i] ? 1. : 0.);
}
}
if (dy_t != nullptr) {
auto* dy = dy_t->mutable_data<T>(context.GetPlace());
for (int i = 0; i < numel; i++) {
dy[i] = dout[i] * (cond_data[i] ? 0. : 1.);
}
}
}
};
} // namespace operators
} // namespace paddle
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.tensor as tensor
import paddle.fluid.core as core
from op_test import OpTest
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.op import Operator
from paddle.fluid.backward import append_backward
class TestWhereOp(OpTest):
def setUp(self):
self.op_type = "where"
self.init_config()
self.inputs = {'Condition': self.cond, 'X': self.x, 'Y': self.y}
self.outputs = {'Out': np.where(self.cond, self.x, self.y)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
def init_config(self):
self.x = np.random.uniform(-3, 5, (100)).astype("float64")
self.y = np.random.uniform(-3, 5, (100)).astype("float64")
self.cond = np.zeros((100)).astype("bool")
class TestWhereOp2(TestWhereOp):
def init_config(self):
self.x = np.random.uniform(-5, 5, (60, 2)).astype("float64")
self.y = np.random.uniform(-5, 5, (60, 2)).astype("float64")
self.cond = np.ones((60, 2)).astype("bool")
class TestWhereOp3(TestWhereOp):
def init_config(self):
self.x = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
self.y = np.random.uniform(-3, 5, (20, 2, 4)).astype("float64")
self.cond = np.array(np.random.randint(2, size=(20, 2, 4)), dtype=bool)
class TestWhereAPI(unittest.TestCase):
def test_api(self, use_cuda=False):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
def test_grad(self, use_cuda=False):
main_program = Program()
for x_stop_gradient, y_stop_gradient in [[False, False], [True, False],
[False, True]]:
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4], dtype='float32')
y = fluid.layers.data(name='y', shape=[4], dtype='float32')
x.stop_gradient = x_stop_gradient
y.stop_gradient = y_stop_gradient
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float32")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float32")
cond_i = np.array([False, False, True, True]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y)
x_mean = layers.mean(x)
append_backward(x_mean)
y_mean = layers.mean(y)
append_backward(y_mean)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'x': x_i,
'y': y_i},
fetch_list=[result, x.grad_name, y.grad_name])
x_grad = [0.25] * 4
y_grad = [0.25] * 4
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
assert np.array_equal(out[1], x_grad)
assert np.array_equal(out[2], y_grad)
def test_api_broadcast(self, use_cuda=False):
main_program = Program()
with fluid.program_guard(main_program):
x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32')
y = fluid.layers.data(name='y', shape=[4, 2], dtype='float32')
x_i = np.array([[0.9383, 0.1983, 3.2, 1.2]]).astype("float32")
y_i = np.array(
[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]).astype("float32")
cond_i = np.array([[False, False, True, True],
[False, False, True, True]]).astype("bool")
result = tensor.where(x > 1, X=x, Y=y)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
out = exe.run(fluid.default_main_program(),
feed={'x': x_i,
'y': y_i},
fetch_list=[result])
assert np.array_equal(out[0], np.where(cond_i, x_i, y_i))
def test_fw_bw(self):
if core.is_compiled_with_cuda():
self.test_api(use_cuda=True)
self.test_api_broadcast(use_cuda=True)
self.test_grad(use_cuda=True)
class TestWhereDygraphAPI(unittest.TestCase):
def test_api(self):
with fluid.dygraph.guard():
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
cond_i = np.array([False, False, True, True]).astype("bool")
x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i)
cond = fluid.dygraph.to_variable(cond_i)
out = tensor.where(cond, x, y)
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
class TestWhereOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
cond_i = np.array([False, False, True, True]).astype("bool")
def test_Variable():
tensor.where(cond_i, x_i, y_i)
self.assertRaises(TypeError, test_Variable)
def test_type():
x = fluid.layers.data(name='x', shape=[4], dtype='bool')
y = fluid.layers.data(name='y', shape=[4], dtype='float16')
cond = fluid.layers.data(name='cond', shape=[4], dtype='int32')
tensor.where(cond, x, y)
self.assertRaises(TypeError, test_type)
if __name__ == '__main__':
unittest.main()
...@@ -165,7 +165,7 @@ from .search import argmax #DEFINE_ALIAS ...@@ -165,7 +165,7 @@ from .search import argmax #DEFINE_ALIAS
# from .search import has_nan #DEFINE_ALIAS # from .search import has_nan #DEFINE_ALIAS
# from .search import masked_select #DEFINE_ALIAS # from .search import masked_select #DEFINE_ALIAS
# from .search import topk #DEFINE_ALIAS # from .search import topk #DEFINE_ALIAS
# from .search import where #DEFINE_ALIAS from .search import where #DEFINE_ALIAS
# from .search import index_select #DEFINE_ALIAS # from .search import index_select #DEFINE_ALIAS
from .search import index_sample # DEFINE_ALIAS from .search import index_sample # DEFINE_ALIAS
# from .search import nonzero #DEFINE_ALIAS # from .search import nonzero #DEFINE_ALIAS
......
...@@ -12,10 +12,21 @@ ...@@ -12,10 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import numpy as np
import warnings
import six
import os
import inspect
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
from ..fluid.initializer import Normal, Constant, NumpyArrayInitializer
# TODO: define searching & indexing functions of a tensor from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, dygraph_only, _dygraph_tracer, default_main_program
from ..fluid import dygraph_utils
from ..fluid.param_attr import ParamAttr
from ..fluid import unique_name
from ..fluid import core, layers
# TODO: define searching & indexing functions of a tensor
__all__ = [ __all__ = [
'argmax', 'argmax',
# 'argmin', # 'argmin',
...@@ -24,7 +35,7 @@ __all__ = [ ...@@ -24,7 +35,7 @@ __all__ = [
# 'has_nan', # 'has_nan',
# 'masked_select', # 'masked_select',
# 'topk', # 'topk',
# 'where', 'where',
# 'index_select', # 'index_select',
# 'nonzero', # 'nonzero',
'sort', 'sort',
...@@ -213,6 +224,64 @@ def sort(input, axis=-1, descending=False, out=None, name=None): ...@@ -213,6 +224,64 @@ def sort(input, axis=-1, descending=False, out=None, name=None):
return out, ids return out, ids
def where(Condition, X, Y):
"""
Return a tensor of elements selected from either $X$ or $Y$, depending on $Condition$.
Args:
Condition(Variable): A bool tensor with rank at least 1, the data type is bool.
X(Variable): X is a Tensor Variable.
Y(Variable): Y is a Tensor Variable.
Returns:
out : The tensor.
Examples:
.. code-block:: python
import numpy as np
import paddle as paddle
import paddle.fluid as fluid
with fluid.dygraph.guard():
x_i = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
y_i = np.array([1.0, 1.0, 1.0, 1.0]).astype("float64")
x = fluid.dygraph.to_variable(x_i)
y = fluid.dygraph.to_variable(y_i)
out = paddle.where(x>1, x, y)
print(out.numpy())
#out: [1.0, 1.0, 3.2, 1.2]
"""
if not in_dygraph_mode():
check_variable_and_dtype(Condition, 'Condition', ['bool'], 'where')
check_variable_and_dtype(
X, 'X', ['float32', 'float64', 'int32', 'int64'], 'where')
check_variable_and_dtype(
Y, 'Y', ['float32', 'float64', 'int32', 'int64'], 'where')
X_shape = list(X.shape)
Y_shape = list(Y.shape)
if X_shape == Y_shape:
if in_dygraph_mode():
return core.ops.where(Condition, X, Y)
else:
helper = LayerHelper("where", **locals())
dtype = helper.input_dtype()
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='where',
inputs={'Condition': Condition,
'X': X,
'Y': Y},
outputs={'Out': [out]})
return out
else:
cond_int = layers.cast(Condition, X.dtype)
cond_not_int = layers.cast(layers.logical_not(Condition), X.dtype)
out1 = layers.elementwise_mul(X, cond_int)
out2 = layers.elementwise_mul(Y, cond_not_int)
out = layers.elementwise_add(out1, out2)
return out
def index_sample(x, index): def index_sample(x, index):
""" """
**IndexSample Layer** **IndexSample Layer**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册