未验证 提交 ecc05377 编写于 作者: Z Zhou Wei 提交者: GitHub

Add bitwise_and/or/xor/not OP/API and unittest (#33524)

上级 a50d1296
......@@ -208,7 +208,7 @@ function(op_library TARGET)
endif()
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op"
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
......
......@@ -7,8 +7,6 @@ set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.tmp CACHE INTE
set(pybind_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h)
file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n")
copy_if_different(${pybind_file} ${pybind_file_final})
add_subdirectory(math)
add_subdirectory(eigen)
add_subdirectory(controlflow)
......@@ -203,3 +201,5 @@ endif()
if (WITH_GPU OR WITH_ASCEND_CL)
cc_test(copy_cross_scope_test SRCS copy_cross_scope_test.cc DEPS op_registry copy_cross_scope_op scope device_context enforce executor)
endif()
copy_if_different(${pybind_file} ${pybind_file_final})
......@@ -19,4 +19,6 @@ else()
target_link_libraries(conditional_block_infer_op conditional_block_op)
endif()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n")
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/bitwise_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class BinaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddInput("Y", string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddOutput("Out",
string::Sprintf("Result of ``%s`` . It is a N-D Tensor with "
"the same data type of input Tensor.",
comment.type));
AddComment(string::Sprintf(R"DOC(
It operates ``%s`` on Tensor ``X`` and ``Y`` .
.. math::
%s
.. note::
``paddle.%s`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
)DOC",
comment.type, comment.equation, comment.type));
}
};
template <typename OpComment>
class UnaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddOutput("Out",
string::Sprintf("Result of ``%s`` . It is a N-D Tensor with "
"the same data type of input Tensor.",
comment.type));
AddComment(string::Sprintf(R"DOC(
It operates ``%s`` on Tensor ``X`` .
.. math::
%s
)DOC",
comment.type, comment.equation));
}
};
class BitwiseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
template <typename OpComment>
class UnaryBitwiseOp : public BitwiseOp {
public:
using BitwiseOp::BitwiseOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment>
class BinaryBitwiseOp : public BitwiseOp {
public:
using BitwiseOp::BitwiseOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
if (dim_x == dim_y) {
context->SetOutputDim("Out", dim_x);
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
context->SetOutputDim("Out", framework::make_ddim(out_dims_array));
}
context->ShareLoD("X", "Out");
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
#define REGISTER_BINARY_BITWISE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ops::BinaryBitwiseOp<_##op_type##Comment>, \
ops::BinaryBitwiseOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_UNARY_BITWISE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ops::UnaryBitwiseOp<_##op_type##Comment>, \
ops::UnaryBitwiseOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y");
REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y");
REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y");
REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X");
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CPU, ops::BitwiseAndFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CPU, ops::BitwiseOrFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CPU, ops::BitwiseXorFunctor);
REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CPU, ops::BitwiseNotFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/bitwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace paddle {
namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##CUDAFunctor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T* args) const { \
return args[0] expr args[1]; \
} \
}; \
\
template <> \
struct Bitwise##func##CUDAFunctor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool* args) const { \
return args[0] bool_expr args[1]; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
template <typename T>
struct BitwiseNotCUDAFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T* args) const { return ~args[0]; }
};
template <>
struct BitwiseNotCUDAFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool* args) const { return !args[0]; }
};
template <typename Functor>
class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
using T = typename Functor::ELEM_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
if (ins.size() == 1) {
LaunchElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
} else {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, functor);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform;
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorCUDAFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotCUDAFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \
return a bool_expr b; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
template <typename T>
struct BitwiseNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T& a) const { return ~a; }
};
template <>
struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool& a) const { return !a; }
};
template <typename DeviceContext, typename Functor>
class BinaryBitwiseOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto func = Functor();
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
ElementwiseComputeEx<Functor, DeviceContext, T>(context, x, y, -1, func,
out);
}
};
template <typename DeviceContext, typename Functor>
class UnaryBitwiseOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto func = Functor();
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), out->mutable_data<T>(context.GetPlace()),
func);
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform;
#define REGISTER_BINARY_BITWISE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<bool>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<uint8_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int8_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int16_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int64_t>>);
#define REGISTER_UNARY_BITWISE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<bool>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<uint8_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int8_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int16_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int64_t>>);
......@@ -12,9 +12,11 @@ register_unity_group(cc
fetch_op.cc
get_places_op.cc
logical_op.cc
bitwise_op.cc
tensor_array_read_write_op.cc
while_op.cc)
register_unity_group(cu
logical_op.cu
bitwise_op.cu
compare_op.cu
compare_all_op.cu)
......@@ -78,7 +78,7 @@ if not defined PYTHON_ROOT set PYTHON_ROOT=C:\Python37
rem -------set cache build directory-----------
rmdir build\python /s/q
rmdir build\paddle\third_party\externalError /s/q
rmdir build\paddle\fluid\pybind /s/q
rem rmdir build\paddle\fluid\pybind /s/q
rmdir build\paddle_install_dir /s/q
rmdir build\paddle_inference_install_dir /s/q
rmdir build\paddle_inference_c_install_dir /s/q
......
......@@ -108,6 +108,10 @@ from .tensor.logic import logical_and # noqa: F401
from .tensor.logic import logical_not # noqa: F401
from .tensor.logic import logical_or # noqa: F401
from .tensor.logic import logical_xor # noqa: F401
from .tensor.logic import bitwise_and # noqa: F401
from .tensor.logic import bitwise_not # noqa: F401
from .tensor.logic import bitwise_or # noqa: F401
from .tensor.logic import bitwise_xor # noqa: F401
from .tensor.logic import not_equal # noqa: F401
from .tensor.logic import allclose # noqa: F401
from .tensor.logic import equal_all # noqa: F401
......@@ -371,6 +375,10 @@ __all__ = [ # noqa
'max',
'norm',
'logical_or',
'bitwise_and',
'bitwise_or',
'bitwise_xor',
'bitwise_not',
'mm',
'flip',
'histogram',
......
......@@ -319,10 +319,13 @@ def monkey_patch_math_varbase():
else:
import paddle.tensor
# Tensor method from module paddle.tensor
tensor_methods = paddle.tensor.tensor_method_func
for method_name in tensor_methods:
for method_name in paddle.tensor.tensor_method_func:
if hasattr(core.VarBase, method_name): continue
method_impl = getattr(paddle.tensor, method_name, None)
if method_impl: setattr(core.VarBase, method_name, method_impl)
for magic_method, origin_method in paddle.tensor.magic_method_func:
impl = getattr(paddle.tensor, origin_method, None)
if impl: setattr(core.VarBase, magic_method, impl)
_already_patch_varbase = True
......@@ -364,10 +364,13 @@ def monkey_patch_variable():
setattr(Variable, method_name, method_impl)
else:
import paddle.tensor
variabel_methods = paddle.tensor.tensor_method_func
for method_name in variabel_methods:
for method_name in paddle.tensor.tensor_method_func:
if hasattr(Variable, method_name): continue
method_impl = getattr(paddle.tensor, method_name, None)
if method_impl: setattr(Variable, method_name, method_impl)
for magic_method, origin_method in paddle.tensor.magic_method_func:
impl = getattr(paddle.tensor, origin_method, None)
if impl: setattr(Variable, magic_method, impl)
_already_patch_variable = True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from op_test import OpTest
paddle.enable_static()
################## TEST OP: BitwiseAnd ##################
class TestBitwiseAnd(OpTest):
def setUp(self):
self.op_type = "bitwise_and"
self.init_dtype()
self.init_shape()
self.init_bound()
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_and(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
def init_dtype(self):
self.dtype = np.int32
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
def init_bound(self):
self.low = -100
self.high = 100
class TestBitwiseAndUInt8(TestBitwiseAnd):
def init_dtype(self):
self.dtype = np.uint8
def init_bound(self):
self.low = 0
self.high = 100
class TestBitwiseAndInt8(TestBitwiseAnd):
def init_dtype(self):
self.dtype = np.int8
def init_shape(self):
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseAndInt16(TestBitwiseAnd):
def init_dtype(self):
self.dtype = np.int16
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
class TestBitwiseAndInt64(TestBitwiseAnd):
def init_dtype(self):
self.dtype = np.int64
def init_shape(self):
self.x_shape = [1, 4, 1]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseAndBool(TestBitwiseAnd):
def setUp(self):
self.op_type = "bitwise_and"
self.init_shape()
x = np.random.choice([True, False], self.x_shape)
y = np.random.choice([True, False], self.y_shape)
out = np.bitwise_and(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
################## TEST OP: BitwiseOr ##################
class TestBitwiseOr(OpTest):
def setUp(self):
self.op_type = "bitwise_or"
self.init_dtype()
self.init_shape()
self.init_bound()
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_or(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
def init_dtype(self):
self.dtype = np.int32
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
def init_bound(self):
self.low = -100
self.high = 100
class TestBitwiseOrUInt8(TestBitwiseOr):
def init_dtype(self):
self.dtype = np.uint8
def init_bound(self):
self.low = 0
self.high = 100
class TestBitwiseOrInt8(TestBitwiseOr):
def init_dtype(self):
self.dtype = np.int8
def init_shape(self):
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseOrInt16(TestBitwiseOr):
def init_dtype(self):
self.dtype = np.int16
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
class TestBitwiseOrInt64(TestBitwiseOr):
def init_dtype(self):
self.dtype = np.int64
def init_shape(self):
self.x_shape = [1, 4, 1]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseOrBool(TestBitwiseOr):
def setUp(self):
self.op_type = "bitwise_or"
self.init_shape()
x = np.random.choice([True, False], self.x_shape)
y = np.random.choice([True, False], self.y_shape)
out = np.bitwise_or(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
################## TEST OP: BitwiseXor ##################
class TestBitwiseXor(OpTest):
def setUp(self):
self.op_type = "bitwise_xor"
self.init_dtype()
self.init_shape()
self.init_bound()
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
y = np.random.randint(
self.low, self.high, self.y_shape, dtype=self.dtype)
out = np.bitwise_xor(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
def init_dtype(self):
self.dtype = np.int32
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [2, 3, 4, 5]
def init_bound(self):
self.low = -100
self.high = 100
class TestBitwiseXorUInt8(TestBitwiseXor):
def init_dtype(self):
self.dtype = np.uint8
def init_bound(self):
self.low = 0
self.high = 100
class TestBitwiseXorInt8(TestBitwiseXor):
def init_dtype(self):
self.dtype = np.int8
def init_shape(self):
self.x_shape = [4, 5]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseXorInt16(TestBitwiseXor):
def init_dtype(self):
self.dtype = np.int16
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
class TestBitwiseXorInt64(TestBitwiseXor):
def init_dtype(self):
self.dtype = np.int64
def init_shape(self):
self.x_shape = [1, 4, 1]
self.y_shape = [2, 3, 4, 5]
class TestBitwiseXorBool(TestBitwiseXor):
def setUp(self):
self.op_type = "bitwise_xor"
self.init_shape()
x = np.random.choice([True, False], self.x_shape)
y = np.random.choice([True, False], self.y_shape)
out = np.bitwise_xor(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out}
################## TEST OP: BitwiseNot ##################
class TestBitwiseNot(OpTest):
def setUp(self):
self.op_type = "bitwise_not"
self.init_dtype()
self.init_shape()
self.init_bound()
x = np.random.randint(
self.low, self.high, self.x_shape, dtype=self.dtype)
out = np.bitwise_not(x)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
def init_dtype(self):
self.dtype = np.int32
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
def init_bound(self):
self.low = -100
self.high = 100
class TestBitwiseNotUInt8(TestBitwiseNot):
def init_dtype(self):
self.dtype = np.uint8
def init_bound(self):
self.low = 0
self.high = 100
class TestBitwiseNotInt8(TestBitwiseNot):
def init_dtype(self):
self.dtype = np.int8
def init_shape(self):
self.x_shape = [4, 5]
class TestBitwiseNotInt16(TestBitwiseNot):
def init_dtype(self):
self.dtype = np.int16
def init_shape(self):
self.x_shape = [2, 3, 4, 5]
self.y_shape = [4, 1]
class TestBitwiseNotInt64(TestBitwiseNot):
def init_dtype(self):
self.dtype = np.int64
def init_shape(self):
self.x_shape = [1, 4, 1]
class TestBitwiseNotBool(TestBitwiseNot):
def setUp(self):
self.op_type = "bitwise_not"
self.init_shape()
x = np.random.choice([True, False], self.x_shape)
out = np.bitwise_not(x)
self.inputs = {'X': x}
self.outputs = {'Out': out}
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,7 @@ from decorator_helper import prog_scope
import paddle
import paddle.fluid as fluid
import numpy
import numpy as np
class TestMathOpPatches(unittest.TestCase):
......@@ -270,6 +271,71 @@ class TestMathOpPatches(unittest.TestCase):
fetch_list=[b])
self.assertTrue(numpy.allclose(a_np.astype('float32'), b_np))
@prog_scope()
def test_bitwise_and(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
out_np = x_np & y_np
x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32")
y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32")
z = x & y
exe = fluid.Executor()
out = exe.run(fluid.default_main_program(),
feed={"x": x_np,
"y": y_np},
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
@prog_scope()
def test_bitwise_or(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
out_np = x_np | y_np
x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32")
y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32")
z = x | y
exe = fluid.Executor()
out = exe.run(fluid.default_main_program(),
feed={"x": x_np,
"y": y_np},
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
@prog_scope()
def test_bitwise_xor(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
out_np = x_np ^ y_np
x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32")
y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32")
z = x ^ y
exe = fluid.Executor()
out = exe.run(fluid.default_main_program(),
feed={"x": x_np,
"y": y_np},
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
@prog_scope()
def test_bitwise_not(self):
x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32")
out_np = ~x_np
x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32")
z = ~x
exe = fluid.Executor()
out = exe.run(fluid.default_main_program(),
feed={"x": x_np},
fetch_list=[z])
self.assertTrue(np.array_equal(out[0], out_np))
if __name__ == '__main__':
unittest.main()
......@@ -141,6 +141,31 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
res = a % b
self.assertTrue(np.array_equal(res.numpy(), a_np % b_np))
# for bitwise and/or/xor/not
def test_bitwise(self):
paddle.disable_static()
x_np = np.random.randint(-100, 100, [2, 3, 5])
y_np = np.random.randint(-100, 100, [2, 3, 5])
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
out_np = x_np & y_np
out = x & y
self.assertTrue(np.array_equal(out.numpy(), out_np))
out_np = x_np | y_np
out = x | y
self.assertTrue(np.array_equal(out.numpy(), out_np))
out_np = x_np ^ y_np
out = x ^ y
self.assertTrue(np.array_equal(out.numpy(), out_np))
out_np = ~x_np
out = ~x
self.assertTrue(np.array_equal(out.numpy(), out_np))
# for logical compare
def test_equal(self):
a_np = np.asarray([1, 2, 3, 4, 5])
......
......@@ -54,6 +54,10 @@ from .logic import logical_and # noqa: F401
from .logic import logical_not # noqa: F401
from .logic import logical_or # noqa: F401
from .logic import logical_xor # noqa: F401
from .logic import bitwise_and # noqa: F401
from .logic import bitwise_or # noqa: F401
from .logic import bitwise_xor # noqa: F401
from .logic import bitwise_not # noqa: F401
from .logic import not_equal # noqa: F401
from .logic import allclose # noqa: F401
from .logic import equal_all # noqa: F401
......@@ -352,4 +356,16 @@ tensor_method_func = [ #noqa
'imag',
'trunc'
'digamma'
'bitwise_and',
'bitwise_or',
'bitwise_xor',
'bitwise_not',
]
#this list used in math_op_patch.py for magic_method bind
magic_method_func = [
('__and__', 'bitwise_and'),
('__or__', 'bitwise_or'),
('__xor__', 'bitwise_xor'),
('__invert__', 'bitwise_not'),
]
......@@ -16,7 +16,7 @@ from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.layers.layer_function_generator import templatedoc
from .. import fluid
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import in_dygraph_mode, Variable
from ..framework import VarBase as Tensor
# TODO: define logic functions of a tensor
......@@ -437,3 +437,140 @@ def is_tensor(x):
"""
return isinstance(x, Tensor)
def _bitwise_op(op_name, x, y, out=None, name=None, binary_op=True):
if in_dygraph_mode():
op = getattr(core.ops, op_name)
if binary_op:
return op(x, y)
else:
return op(x)
check_variable_and_dtype(
x, "x", ["bool", "uint8", "int8", "int16", "int32", "int64"], op_name)
if y is not None:
check_variable_and_dtype(
y, "y", ["bool", "uint8", "int8", "int16", "int32", "int64"],
op_name)
if out is not None:
check_type(out, "out", Variable, op_name)
helper = LayerHelper(op_name, **locals())
if binary_op:
assert x.dtype == y.dtype
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if binary_op:
helper.append_op(
type=op_name, inputs={"X": x,
"Y": y}, outputs={"Out": out})
else:
helper.append_op(type=op_name, inputs={"X": x}, outputs={"Out": out})
return out
@templatedoc()
def bitwise_and(x, y, out=None, name=None):
"""
${comment}
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
Returns:
Tensor: ${out_comment}
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-5, -1, 1])
y = paddle.to_tensor([4, 2, -3])
res = paddle.bitwise_and(x, y)
print(res) # [0, 2, 1]
"""
return _bitwise_op(
op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True)
@templatedoc()
def bitwise_or(x, y, out=None, name=None):
"""
${comment}
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
Returns:
Tensor: ${out_comment}
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-5, -1, 1])
y = paddle.to_tensor([4, 2, -3])
res = paddle.bitwise_or(x, y)
print(res) # [-1, -1, -3]
"""
return _bitwise_op(
op_name="bitwise_or", x=x, y=y, name=name, out=out, binary_op=True)
@templatedoc()
def bitwise_xor(x, y, out=None, name=None):
"""
${comment}
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
Returns:
Tensor: ${out_comment}
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-5, -1, 1])
y = paddle.to_tensor([4, 2, -3])
res = paddle.bitwise_xor(x, y)
print(res) # [-1, -3, -4]
"""
return _bitwise_op(
op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True)
@templatedoc()
def bitwise_not(x, out=None, name=None):
"""
${comment}
Args:
x(Tensor): ${x_comment}
out(Tensor): ${out_comment}
Returns:
Tensor: ${out_comment}
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-5, -1, 1])
res = paddle.bitwise_not(x)
print(res) # [4, 0, -2]
"""
return _bitwise_op(
op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册