未验证 提交 20eed540 编写于 作者: G GaoWei8 提交者: GitHub

Change fluid.layers.where‘s C++ operator name (#23250)

上级 12355ccc
......@@ -12,23 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/operators/where_index_op.h"
namespace paddle {
namespace operators {
class WhereOp : public framework::OperatorWithKernel {
class WhereIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Condition"),
"Input(Condition) of WhereOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputDim("Condition").size() >= 1,
"Input(Condition) should have number of dimension at least 1");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(OUt) of WhereOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->HasInput("Condition"), true,
platform::errors::NotFound(
"Input(Condition) of layers.where should not be null."));
PADDLE_ENFORCE_GE(
ctx->GetInputDim("Condition").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of layers.where should not be null."));
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
}
......@@ -40,7 +45,7 @@ class WhereOp : public framework::OperatorWithKernel {
}
};
class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Condition", "A bool tensor whose rank is at least 1");
......@@ -54,5 +59,6 @@ class WhereOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(where, ops::WhereOp, ops::WhereOpMaker);
REGISTER_OP_CPU_KERNEL(where, ops::CPUWhereKernel<int64_t>);
REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>);
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/where_op.h"
#include "paddle/fluid/operators/where_index_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -25,7 +25,7 @@ namespace operators {
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T>
class CUDAWhereKernel : public framework::OpKernel<T> {
class CUDAWhereIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
......@@ -67,8 +67,8 @@ class CUDAWhereKernel : public framework::OpKernel<T> {
int* ptr_stride = thrust::raw_pointer_cast(d_stride.data());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
WhereFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
out_ptr);
WhereIndexFunctor<int*> functor(ptr_true_index, true_num, ptr_stride, rank,
out_ptr);
platform::ForRange<CUDADeviceContext> for_range(dev_ctx, true_num);
for_range(functor);
}
......@@ -78,4 +78,4 @@ class CUDAWhereKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where, ops::CUDAWhereKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>);
......@@ -24,9 +24,9 @@ namespace paddle {
namespace operators {
template <typename T>
struct WhereFunctor {
WhereFunctor(const T& true_index, int true_num, const T& stride, int rank,
int64_t* out)
struct WhereIndexFunctor {
WhereIndexFunctor(const T& true_index, int true_num, const T& stride,
int rank, int64_t* out)
: true_index_(true_index),
true_num_(true_num),
stride_(stride),
......@@ -51,7 +51,7 @@ struct WhereFunctor {
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
template <typename T>
class CPUWhereKernel : public framework::OpKernel<T> {
class CPUWhereIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
......@@ -84,8 +84,8 @@ class CPUWhereKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = context.template device_context<CPUDeviceContext>();
WhereFunctor<int*> functor(true_index.data(), true_num, stride.data(), rank,
out_ptr);
WhereIndexFunctor<int*> functor(true_index.data(), true_num, stride.data(),
rank, out_ptr);
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
for_range(functor);
}
......
......@@ -12976,13 +12976,15 @@ def where(condition):
out = layers.where(condition) # [[]]
"""
helper = LayerHelper("where", **locals())
helper = LayerHelper("where_index", **locals())
out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64)
helper.append_op(
type='where', inputs={'Condition': condition}, outputs={'Out': [out]})
type='where_index',
inputs={'Condition': condition},
outputs={'Out': [out]})
return out
......
......@@ -19,11 +19,13 @@ import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestWhereOp(OpTest):
class TestWhereIndexOp(OpTest):
def setUp(self):
self.op_type = "where"
self.op_type = "where_index"
self.init_config()
def test_check_output(self):
......@@ -37,7 +39,7 @@ class TestWhereOp(OpTest):
class TestAllFalse(unittest.TestCase):
def setUp(self):
self.op_type = "where"
self.op_type = "where_index"
self.init_config()
def check_with_place(self, place):
......@@ -48,7 +50,7 @@ class TestAllFalse(unittest.TestCase):
out = scope.var("Out").get_tensor()
out.set(np.full(self.shape, 0).astype('int64'), place)
op = Operator("where", Condition="Condition", Out="Out")
op = Operator("where_index", Condition="Condition", Out="Out")
op.run(scope, place)
out_array = np.array(out)
......@@ -66,14 +68,14 @@ class TestAllFalse(unittest.TestCase):
self.check_with_place(core.CUDAPlace(0))
class TestRank2(TestWhereOp):
class TestRank2(TestWhereIndexOp):
def init_config(self):
self.inputs = {'Condition': np.array([[True, False], [False, True]]), }
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
class TestRank3(TestWhereOp):
class TestRank3(TestWhereIndexOp):
def init_config(self):
self.inputs = {
'Condition': np.array([[[True, False], [False, True]],
......@@ -88,5 +90,17 @@ class TestRank3(TestWhereOp):
}
class TestWhereOpError(unittest.TestCase):
def test_api(self):
with program_guard(Program(), Program()):
cond = fluid.layers.data(name='cond', shape=[4], dtype='bool')
result = fluid.layers.where(cond)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
cond_i = np.array([True, False, False, False]).astype("bool")
out = exe.run(fluid.default_main_program(), feed={'cond': cond_i})
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册