From 20eed5401a5d7d29b1397def1c6ce9952d985076 Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Tue, 31 Mar 2020 10:23:13 +0800 Subject: [PATCH] =?UTF-8?q?Change=20fluid.layers.where=E2=80=98s=20C++=20o?= =?UTF-8?q?perator=20name=20(#23250)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../{where_op.cc => where_index_op.cc} | 30 +++++++++++-------- .../{where_op.cu => where_index_op.cu} | 10 +++---- .../{where_op.h => where_index_op.h} | 12 ++++---- python/paddle/fluid/layers/nn.py | 6 ++-- .../{test_where.py => test_where_index.py} | 26 ++++++++++++---- 5 files changed, 53 insertions(+), 31 deletions(-) rename paddle/fluid/operators/{where_op.cc => where_index_op.cc} (62%) rename paddle/fluid/operators/{where_op.cu => where_index_op.cu} (89%) rename paddle/fluid/operators/{where_op.h => where_index_op.h} (88%) rename python/paddle/fluid/tests/unittests/{test_where.py => test_where_index.py} (75%) diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_index_op.cc similarity index 62% rename from paddle/fluid/operators/where_op.cc rename to paddle/fluid/operators/where_index_op.cc index 3b53ebec0b2..c02afe51e31 100644 --- a/paddle/fluid/operators/where_op.cc +++ b/paddle/fluid/operators/where_index_op.cc @@ -12,23 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/operators/where_index_op.h" namespace paddle { namespace operators { -class WhereOp : public framework::OperatorWithKernel { +class WhereIndexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Condition"), - "Input(Condition) of WhereOp should not be null."); - PADDLE_ENFORCE( - ctx->GetInputDim("Condition").size() >= 1, - "Input(Condition) should have number of dimension at least 1"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(OUt) of WhereOp should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Condition"), true, + platform::errors::NotFound( + "Input(Condition) of layers.where should not be null.")); + PADDLE_ENFORCE_GE( + ctx->GetInputDim("Condition").size(), 1UL, + platform::errors::InvalidArgument( + "Input(Condition) should have number of dimension at least 1")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of layers.where should not be null.")); + ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()}); } @@ -40,7 +45,7 @@ class WhereOp : public framework::OperatorWithKernel { } }; -class WhereOpMaker : public framework::OpProtoAndCheckerMaker { +class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Condition", "A bool tensor whose rank is at least 1"); @@ -54,5 +59,6 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker); -REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel); +REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp, + ops::WhereIndexOpMaker); +REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel); diff --git a/paddle/fluid/operators/where_op.cu b/paddle/fluid/operators/where_index_op.cu similarity index 89% rename from paddle/fluid/operators/where_op.cu rename to paddle/fluid/operators/where_index_op.cu index 27682f869c7..7a40932b016 100644 --- a/paddle/fluid/operators/where_op.cu +++ b/paddle/fluid/operators/where_index_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/where_op.h" +#include "paddle/fluid/operators/where_index_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/for_range.h" @@ -25,7 +25,7 @@ namespace operators { using CUDADeviceContext = paddle::platform::CUDADeviceContext; template -class CUDAWhereKernel : public framework::OpKernel { +class CUDAWhereIndexKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* condition = context.Input("Condition"); @@ -67,8 +67,8 @@ class CUDAWhereKernel : public framework::OpKernel { int* ptr_stride = thrust::raw_pointer_cast(d_stride.data()); auto& dev_ctx = context.template device_context(); - WhereFunctor functor(ptr_true_index, true_num, ptr_stride, rank, - out_ptr); + WhereIndexFunctor functor(ptr_true_index, true_num, ptr_stride, rank, + out_ptr); platform::ForRange for_range(dev_ctx, true_num); for_range(functor); } @@ -78,4 +78,4 @@ class CUDAWhereKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel); +REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel); diff --git a/paddle/fluid/operators/where_op.h b/paddle/fluid/operators/where_index_op.h similarity index 88% rename from paddle/fluid/operators/where_op.h rename to paddle/fluid/operators/where_index_op.h index 6a161a2668f..e327120e0ab 100644 --- a/paddle/fluid/operators/where_op.h +++ b/paddle/fluid/operators/where_index_op.h @@ -24,9 +24,9 @@ namespace paddle { namespace operators { template -struct WhereFunctor { - WhereFunctor(const T& true_index, int true_num, const T& stride, int rank, - int64_t* out) +struct WhereIndexFunctor { + WhereIndexFunctor(const T& true_index, int true_num, const T& stride, + int rank, int64_t* out) : true_index_(true_index), true_num_(true_num), stride_(stride), @@ -51,7 +51,7 @@ struct WhereFunctor { using CPUDeviceContext = paddle::platform::CPUDeviceContext; template -class CPUWhereKernel : public framework::OpKernel { +class CPUWhereIndexKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* condition = context.Input("Condition"); @@ -84,8 +84,8 @@ class CPUWhereKernel : public framework::OpKernel { } auto& dev_ctx = context.template device_context(); - WhereFunctor functor(true_index.data(), true_num, stride.data(), rank, - out_ptr); + WhereIndexFunctor functor(true_index.data(), true_num, stride.data(), + rank, out_ptr); platform::ForRange for_range(dev_ctx, true_num); for_range(functor); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 584aad997c5..f7f1c867b66 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12976,13 +12976,15 @@ def where(condition): out = layers.where(condition) # [[]] """ - helper = LayerHelper("where", **locals()) + helper = LayerHelper("where_index", **locals()) out = helper.create_variable_for_type_inference( dtype=core.VarDesc.VarType.INT64) helper.append_op( - type='where', inputs={'Condition': condition}, outputs={'Out': [out]}) + type='where_index', + inputs={'Condition': condition}, + outputs={'Out': [out]}) return out diff --git a/python/paddle/fluid/tests/unittests/test_where.py b/python/paddle/fluid/tests/unittests/test_where_index.py similarity index 75% rename from python/paddle/fluid/tests/unittests/test_where.py rename to python/paddle/fluid/tests/unittests/test_where_index.py index ee0fa161309..05528caf986 100644 --- a/python/paddle/fluid/tests/unittests/test_where.py +++ b/python/paddle/fluid/tests/unittests/test_where_index.py @@ -19,11 +19,13 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard -class TestWhereOp(OpTest): +class TestWhereIndexOp(OpTest): def setUp(self): - self.op_type = "where" + self.op_type = "where_index" self.init_config() def test_check_output(self): @@ -37,7 +39,7 @@ class TestWhereOp(OpTest): class TestAllFalse(unittest.TestCase): def setUp(self): - self.op_type = "where" + self.op_type = "where_index" self.init_config() def check_with_place(self, place): @@ -48,7 +50,7 @@ class TestAllFalse(unittest.TestCase): out = scope.var("Out").get_tensor() out.set(np.full(self.shape, 0).astype('int64'), place) - op = Operator("where", Condition="Condition", Out="Out") + op = Operator("where_index", Condition="Condition", Out="Out") op.run(scope, place) out_array = np.array(out) @@ -66,14 +68,14 @@ class TestAllFalse(unittest.TestCase): self.check_with_place(core.CUDAPlace(0)) -class TestRank2(TestWhereOp): +class TestRank2(TestWhereIndexOp): def init_config(self): self.inputs = {'Condition': np.array([[True, False], [False, True]]), } self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} -class TestRank3(TestWhereOp): +class TestRank3(TestWhereIndexOp): def init_config(self): self.inputs = { 'Condition': np.array([[[True, False], [False, True]], @@ -88,5 +90,17 @@ class TestRank3(TestWhereOp): } +class TestWhereOpError(unittest.TestCase): + def test_api(self): + with program_guard(Program(), Program()): + cond = fluid.layers.data(name='cond', shape=[4], dtype='bool') + result = fluid.layers.where(cond) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + cond_i = np.array([True, False, False, False]).astype("bool") + out = exe.run(fluid.default_main_program(), feed={'cond': cond_i}) + + if __name__ == "__main__": unittest.main() -- GitLab