未验证 提交 20eed540 编写于 作者: G GaoWei8 提交者: GitHub

Change fluid.layers.where‘s C++ operator name (#23250)

上级 12355ccc
...@@ -12,23 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/where_op.h" #include "paddle/fluid/operators/where_index_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class WhereOp : public framework::OperatorWithKernel { class WhereIndexOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Condition"), PADDLE_ENFORCE_EQ(
"Input(Condition) of WhereOp should not be null."); ctx->HasInput("Condition"), true,
PADDLE_ENFORCE( platform::errors::NotFound(
ctx->GetInputDim("Condition").size() >= 1, "Input(Condition) of layers.where should not be null."));
"Input(Condition) should have number of dimension at least 1"); PADDLE_ENFORCE_GE(
PADDLE_ENFORCE(ctx->HasOutput("Out"), ctx->GetInputDim("Condition").size(), 1UL,
"Output(OUt) of WhereOp should not be null."); platform::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of layers.where should not be null."));
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()}); ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
} }
...@@ -40,7 +45,7 @@ class WhereOp : public framework::OperatorWithKernel { ...@@ -40,7 +45,7 @@ class WhereOp : public framework::OperatorWithKernel {
} }
}; };
class WhereOpMaker : public framework::OpProtoAndCheckerMaker { class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Condition", "A bool tensor whose rank is at least 1"); AddInput("Condition", "A bool tensor whose rank is at least 1");
...@@ -54,5 +59,6 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,5 +59,6 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker); REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel<int64_t>); ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>);
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/where_op.h" #include "paddle/fluid/operators/where_index_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using CUDADeviceContext = paddle::platform::CUDADeviceContext; using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T> template <typename T>
class CUDAWhereKernel : public framework::OpKernel<T> { class CUDAWhereIndexKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition"); auto* condition = context.Input<framework::Tensor>("Condition");
...@@ -67,8 +67,8 @@ class CUDAWhereKernel : public framework::OpKernel<T> { ...@@ -67,8 +67,8 @@ class CUDAWhereKernel : public framework::OpKernel<T> {
int* ptr_stride = thrust::raw_pointer_cast(d_stride.data()); int* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
WhereFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank, WhereIndexFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
out_ptr); out_ptr);
platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num); platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num);
for_range(functor); for_range(functor);
} }
...@@ -78,4 +78,4 @@ class CUDAWhereKernel : public framework::OpKernel<T> { ...@@ -78,4 +78,4 @@ class CUDAWhereKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel<int64_t>); REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>);
...@@ -24,9 +24,9 @@ namespace paddle { ...@@ -24,9 +24,9 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
struct WhereFunctor { struct WhereIndexFunctor {
WhereFunctor(const T& true_index, int true_num, const T& stride, int rank, WhereIndexFunctor(const T& true_index, int true_num, const T& stride,
int64_t* out) int rank, int64_t* out)
: true_index_(true_index), : true_index_(true_index),
true_num_(true_num), true_num_(true_num),
stride_(stride), stride_(stride),
...@@ -51,7 +51,7 @@ struct WhereFunctor { ...@@ -51,7 +51,7 @@ struct WhereFunctor {
using CPUDeviceContext = paddle::platform::CPUDeviceContext; using CPUDeviceContext = paddle::platform::CPUDeviceContext;
template <typename T> template <typename T>
class CPUWhereKernel : public framework::OpKernel<T> { class CPUWhereIndexKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition"); auto* condition = context.Input<framework::Tensor>("Condition");
...@@ -84,8 +84,8 @@ class CPUWhereKernel : public framework::OpKernel<T> { ...@@ -84,8 +84,8 @@ class CPUWhereKernel : public framework::OpKernel<T> {
} }
auto& dev_ctx = context.template device_context<CPUDeviceContext>(); auto& dev_ctx = context.template device_context<CPUDeviceContext>();
WhereFunctor<int*> functor(true_index.data(), true_num, stride.data(), rank, WhereIndexFunctor<int*> functor(true_index.data(), true_num, stride.data(),
out_ptr); rank, out_ptr);
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num); platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
for_range(functor); for_range(functor);
} }
......
...@@ -12976,13 +12976,15 @@ def where(condition): ...@@ -12976,13 +12976,15 @@ def where(condition):
out = layers.where(condition) # [[]] out = layers.where(condition) # [[]]
""" """
helper = LayerHelper("where", **locals()) helper = LayerHelper("where_index", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64) dtype=core.VarDesc.VarType.INT64)
helper.append_op( helper.append_op(
type='where', inputs={'Condition': condition}, outputs={'Out': [out]}) type='where_index',
inputs={'Condition': condition},
outputs={'Out': [out]})
return out return out
......
...@@ -19,11 +19,13 @@ import numpy as np ...@@ -19,11 +19,13 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestWhereOp(OpTest): class TestWhereIndexOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "where" self.op_type = "where_index"
self.init_config() self.init_config()
def test_check_output(self): def test_check_output(self):
...@@ -37,7 +39,7 @@ class TestWhereOp(OpTest): ...@@ -37,7 +39,7 @@ class TestWhereOp(OpTest):
class TestAllFalse(unittest.TestCase): class TestAllFalse(unittest.TestCase):
def setUp(self): def setUp(self):
self.op_type = "where" self.op_type = "where_index"
self.init_config() self.init_config()
def check_with_place(self, place): def check_with_place(self, place):
...@@ -48,7 +50,7 @@ class TestAllFalse(unittest.TestCase): ...@@ -48,7 +50,7 @@ class TestAllFalse(unittest.TestCase):
out = scope.var("Out").get_tensor() out = scope.var("Out").get_tensor()
out.set(np.full(self.shape, 0).astype('int64'), place) out.set(np.full(self.shape, 0).astype('int64'), place)
op = Operator("where", Condition="Condition", Out="Out") op = Operator("where_index", Condition="Condition", Out="Out")
op.run(scope, place) op.run(scope, place)
out_array = np.array(out) out_array = np.array(out)
...@@ -66,14 +68,14 @@ class TestAllFalse(unittest.TestCase): ...@@ -66,14 +68,14 @@ class TestAllFalse(unittest.TestCase):
self.check_with_place(core.CUDAPlace(0)) self.check_with_place(core.CUDAPlace(0))
class TestRank2(TestWhereOp): class TestRank2(TestWhereIndexOp):
def init_config(self): def init_config(self):
self.inputs = {'Condition': np.array([[True, False], [False, True]]), } self.inputs = {'Condition': np.array([[True, False], [False, True]]), }
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')} self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
class TestRank3(TestWhereOp): class TestRank3(TestWhereIndexOp):
def init_config(self): def init_config(self):
self.inputs = { self.inputs = {
'Condition': np.array([[[True, False], [False, True]], 'Condition': np.array([[[True, False], [False, True]],
...@@ -88,5 +90,17 @@ class TestRank3(TestWhereOp): ...@@ -88,5 +90,17 @@ class TestRank3(TestWhereOp):
} }
class TestWhereOpError(unittest.TestCase):
def test_api(self):
with program_guard(Program(), Program()):
cond = fluid.layers.data(name='cond', shape=[4], dtype='bool')
result = fluid.layers.where(cond)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
cond_i = np.array([True, False, False, False]).astype("bool")
out = exe.run(fluid.default_main_program(), feed={'cond': cond_i})
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册