未验证 提交 363825df 编写于 作者: HappyHeavyRain's avatar HappyHeavyRain 提交者: GitHub

support 'backend' in static ops (#50671)

* support 'backend' in static ops

* change bitwise_xx comment in python

* change bitwise_xxx comment in python

* change 'backend' and 'data_type' in GetExpectedKernelType
上级 92cae577
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class BinaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddInput("Y",
string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddOutput("Out",
string::Sprintf("Result of ``%s`` . It is a N-D Tensor with "
"the same data type of input Tensor.",
comment.type));
AddComment(string::Sprintf(R"DOC(
It operates ``%s`` on Tensor ``X`` and ``Y`` .
.. math::
%s
.. note::
``paddle.%s`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor.
)DOC",
comment.type,
comment.equation,
comment.type));
}
};
template <typename OpComment>
class UnaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X",
string::Sprintf(
"Input Tensor of ``%s`` . It is "
"a N-D Tensor of bool, uint8, int8, int16, int32, int64.",
comment.type));
AddOutput("Out",
string::Sprintf("Result of ``%s`` . It is a N-D Tensor with "
"the same data type of input Tensor.",
comment.type));
AddComment(string::Sprintf(R"DOC(
It operates ``%s`` on Tensor ``X`` .
.. math::
%s
)DOC",
comment.type,
comment.equation));
}
};
template <typename OpComment>
class UnaryBitwiseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt;
}
};
template <typename OpComment>
class BinaryBitwiseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
if (dim_x == dim_y) {
context->SetOutputDim("Out", dim_x);
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x,
dim_y,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
context->SetOutputDim("Out", phi::make_ddim(out_dims_array));
}
context->ShareLoD("X", "Out");
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
return kt;
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
#define REGISTER_BINARY_BITWISE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, \
ops::BinaryBitwiseOp<_##op_type##Comment>, \
ops::BinaryBitwiseOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_UNARY_BITWISE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, \
ops::UnaryBitwiseOp<_##op_type##Comment>, \
ops::UnaryBitwiseOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y");
REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y");
REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y");
REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X");
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput(
"X",
string::Sprintf("the left hand operand of %s operator", comment.type));
AddInput(
"Y",
string::Sprintf("the right hand operand of %s operator", comment.type));
AddAttr<int>(
"axis",
"The start dimension index for broadcasting Y onto X. [default -1]")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddAttr<bool>("force_cpu",
"Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device [default true].")
.SetDefault(false);
AddOutput("Out",
string::Sprintf("n-dim bool tensor. Each element is %s",
comment.equation));
AddComment(string::Sprintf(R"DOC(
It operates element-wise on X and Y, and returns the Out. Each of them is a
N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by $%s$
)DOC",
comment.equation));
}
};
template <typename OpComment>
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
phi::KernelKey kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
bool force_cpu = ctx.Attr<bool>("force_cpu");
if (force_cpu) {
kt.set_backend(phi::Backend::CPU);
} else {
if (ctx.Input<phi::DenseTensor>("X")->place().GetType() !=
phi::AllocationType::GPUPINNED) {
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("X")->place()));
} else {
kt.set_backend(phi::TransToPhiBackend(ctx.GetPlace()));
}
}
return kt;
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_OP_VERSION(op_type) \
REGISTER_OP_VERSION(op_type).AddCheckpoint( \
R"ROC(Upgrade compare ops, add a new attribute [force_cpu])ROC", \
paddle::framework::compatible::OpVersionDesc().ModifyAttr( \
"force_cpu", \
"In order to force fill output variable to gpu memory.", \
false));
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
DECLARE_INFER_SHAPE_FUNCTOR(op_type, \
op_type##_InferShapeFunctor, \
PD_INFER_META(phi::CompareRawInferMeta)); \
REGISTER_OPERATOR( \
op_type, \
::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, \
op_type##_InferShapeFunctor); \
REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
REGISTER_COMPARE_OP(equal, "Out = X == Y");
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
......@@ -176,10 +176,14 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
'layout': None,
'data_type': None,
'dispatch': {},
'force_backend': None,
}
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
if 'force_backend' in kernel_config:
kernel['force_backend'] = kernel_config["force_backend"]
if 'backend' in kernel_config:
kernel['backend'] = parse_candidates(kernel_config["backend"])
......@@ -328,7 +332,14 @@ def check_op_config(op_entry, op_name):
'composite',
)
infer_meta_key_set = ('func', 'param')
kernel_key_set = ('func', 'param', 'data_type', 'layout', 'backend')
kernel_key_set = (
'func',
'param',
'data_type',
'layout',
'backend',
'force_backend',
)
for key in op_entry.keys():
assert (
key in base_key_set
......
......@@ -12,7 +12,7 @@ class {{op_name | to_pascal_case}}OpMaker : public framework::OpProtoAndCheckerM
{{add_output(loop.index0, output, op_name)}};
{% endfor %}
{% for attr in op["attrs"] %}
{% if attr["fluid_name"] in op["kernel"]["param"] %}
{% if attr["fluid_name"] %}
{{add_attr(loop.index0, attr, op_name)}};
{% endif %}
{% endfor %}
......@@ -260,6 +260,7 @@ paddle::small_vector<const char*> outputs {
{% set kernel = op["kernel"] %}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
phi::KernelKey kt;
{%if kernel["data_type"] is not none %}{# data type ---------------------------------#}
{% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %}
......@@ -279,12 +280,29 @@ phi::KernelKey GetExpectedKernelType(
data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_args[1] | to_opmaker_name}});
}
{% endif %}
{% elif "complex_promote" in op and "forward" not in op%}
kt = phi::KernelKey(data_type, ctx.GetPlace());
{% elif "complex_promote" in op and "forward" not in op%} {# compext data promote #}
{% set inputs = op["complex_promote"]%}
auto data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "{{inputs[0]}}", "{{inputs[1]}}");
kt = phi::KernelKey(data_type, ctx.GetPlace());
{% endif -%}
{%- if kernel["backend"] is not none %}
{% if kernel["data_type"] is none %}
kt = OperatorWithKernel::GetExpectedKernelType(ctx);
{% endif %}
kt.set_backend(
phi::TransToPhiBackend(ctx.Input<phi::DenseTensor>("{{kernel["backend"]["candidates"][0]}}")->place()));
{% endif %}
return phi::KernelKey(data_type, ctx.GetPlace());
{% if "force_backend" in op["kernel"] and op["kernel"]["force_backend"] == "force_cpu" %}
{% if kernel["backend"] is none and kernel["data_type"] is none %} {# only force_cpu#}
kt = OperatorWithKernel::GetExpectedKernelType(ctx);
{% endif %}
if (ctx.Attr<bool>("force_cpu")) {
kt.set_backend(phi::Backend::CPU);
}
{% endif %}
return kt;
}
{% endmacro -%}
......@@ -292,6 +310,7 @@ phi::KernelKey GetExpectedKernelType(
{% set skip_args = none %}
{% if op["data_transform"] is not none%}
{% if "skip_transform" in op["data_transform"] %}
{# TODO:(lizhiyu) support skip_transform and support_trans_dtype at the same time#}
{% set skip_args = op["data_transform"]["skip_transform"] %}
{% elif "support_trans_dtype" in op["data_transform"] %}
{% set skip_args = op["data_transform"]["support_trans_dtype"] %}
......@@ -339,9 +358,12 @@ class {{op["op_name"] | to_pascal_case}}Op : public framework::OperatorWithKerne
using framework::OperatorWithKernel::OperatorWithKernel;
{# ----------- get expected kernel type function -------------------------- #}
{% set kernel = op["kernel"] %}
{% if kernel["data_type"] is not none or "complex_promote" in op or "data_transform" in op%}
{% if kernel["data_type"] is not none or kernel["backend"] is not none
or kernel["force_backend"] is not none
or "complex_promote" in op or "data_transform" in op %}
protected:
{% if kernel["data_type"] is not none or "complex_promote" in op %}
{% if kernel["data_type"] is not none or kernel["backend"] is not none
or kernel["force_backend"] is not none or "complex_promote" in op %}
{% filter indent(2, True)%}
{{get_expected_kernel(op)}}
{% endfilter %}
......@@ -437,6 +459,9 @@ REGISTER_OP_VERSION({{name}})
{% if "delete_attr" in action %}
.DeleteAttr("{{action["delete_attr"]}}", "{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
{% if "modify_attr" in action %}
.ModifyAttr("{{action["modify_attr"]}}", "{{action["comment"]}}", {{action["default"]}}){{")" if loop.last}}
{% endif %}
{% if "fix_bug" in action %}
.BugfixWithBehaviorChanged("{{action["comment"]}}"){{")" if loop.last}}
{% endif %}
......
......@@ -277,38 +277,6 @@
func: bincount
optional: weights
- op : bitwise_and
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_and
- op : bitwise_not
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : bitwise_not
- op : bitwise_or
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_or
- op : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_xor
- op : box_coder
args : (Tensor prior_box, Tensor prior_box_var, Tensor target_box, str code_type, bool box_normalized, int axis, float[] variance)
output : Tensor(output_box)
......
......@@ -171,6 +171,30 @@
extra :
attrs : [bool use_mkldnn = false]
- op : bitwise_and
inputs :
{x : X, y : Y}
outputs :
{out : Out}
- op : bitwise_not
inputs :
{x : X}
outputs :
{out : Out}
- op : bitwise_or
inputs :
{x : X, y : Y}
outputs :
{out : Out}
- op : bitwise_xor
inputs :
{x : X, y : Y}
outputs :
{out : Out}
- op : bmm
inputs :
{x : X, y : Y}
......@@ -489,6 +513,12 @@
int trainer_id = 0, int slot = 0, 'int64_t[] height_sections = {}', 'str[] epmap = {}',
'str[] table_names = {}']
- op : equal
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : equal_all
inputs :
{x : X, y : Y}
......@@ -701,6 +731,18 @@
attrs : [bool use_mkldnn = false, str x_data_format = "", str y_data_format = "", str mkldnn_data_type = "float32",
bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f]
- op : greater_equal
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : greater_than
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : grid_sample(grid_sampler)
backward : grid_sample_grad (grid_sampler_grad)
inputs :
......@@ -879,6 +921,18 @@
outputs :
out : Out
- op : less_equal
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : less_than
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : lgamma
inputs :
x : X
......@@ -1119,6 +1173,12 @@
outputs :
{out : Out, total_weight : Total_weight}
- op : not_equal
inputs :
{x : X, y : Y}
outputs :
out : Out
- op : numel(size)
inputs :
x : Input
......
......@@ -36,6 +36,14 @@
- add_input : Max
comment : Pass the mix, min value as input, not attribute. Max is dispensable.
- op : equal
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : flip
version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
......@@ -46,6 +54,22 @@
- delete_attr : dims
comment : The attr 'dims' is deleted.
- op : greater_equal
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : greater_than
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : grid_sample
version :
- checkpoint : Upgrade grid_sampler add a new attribute [mode]
......@@ -54,6 +78,30 @@
comment : In order to specify interpolation mode
default : std::string("bilinear")
- op : less_equal
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : less_than
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : not_equal
version :
- checkpoint : Upgrade compare ops, add a new attribute [force_cpu]
action :
- modify_attr : force_cpu
comment : In order to force fill output variable to gpu memory.
default : "false"
- op : pixel_shuffle
version :
- checkpoint : Compatible upgrade of pixel_shuffle, add a new attribute [data_format]
......
......@@ -125,6 +125,42 @@
kernel :
func : bernoulli
- op : bitwise_and
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_and
backend : x
- op : bitwise_not
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : bitwise_not
backend : x
- op : bitwise_or
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_or
backend : x
- op : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_xor
backend : x
- op : bmm
args : (Tensor x, Tensor y)
output : Tensor
......
......@@ -7,6 +7,18 @@
func: embedding_with_eltwise_add_xpu
data_type: tables
- op : equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : equal_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : fc_xpu
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha)
output : Tensor(out), Tensor(out_max)
......@@ -26,6 +38,54 @@
func : generate_sequence_xpu
data_type : dtype
- op : greater_equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : greater_equal_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : greater_than
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : greater_than_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : less_equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : less_equal_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : less_than
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : less_than_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : multi_encoder_xpu
args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx)
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
......@@ -36,6 +96,18 @@
data_type : x
optional : mask, x_fp16, out_fp16
- op : not_equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
infer_meta :
func : CompareRawInferMeta
param : [x, y, axis]
kernel :
func : not_equal_raw
param : [x, y, axis]
backend : x
force_backend : force_cpu
- op : share_buffer
args : (Tensor[] x, bool[] share_dims_and_dtype={})
output : Tensor[](out){x.size()}, Tensor[](xout){x.size()}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LessThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_than_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature LessEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature GreaterThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_than_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature GreaterEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature EqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("not_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(less_than, phi::LessThanArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(less_equal, phi::LessEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(greater_than, phi::GreaterThanArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(greater_equal, phi::GreaterEqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(equal, phi::EqualArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(not_equal, phi::NotEqualArgumentMapping);
......@@ -828,18 +828,26 @@ def _bitwise_op(op_name, x, y, out=None, name=None, binary_op=True):
return out
@templatedoc()
def bitwise_and(x, y, out=None, name=None):
"""
${comment}
r"""
Apply ``bitwise_and`` on Tensor ``X`` and ``Y`` .
.. math::
Out = X \& Y
.. note::
``paddle.bitwise_and`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor.
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
x (Tensor): Input Tensor of ``bitwise_and`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
y (Tensor): Input Tensor of ``bitwise_and`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
out(Tensor): Result of ``bitwise_and`` . It is a N-D Tensor with the same data type of input Tensor.
Returns:
Tensor: ${out_comment}
Tensor: Result of ``bitwise_and`` . It is a N-D Tensor with the same data type of input Tensor.
Examples:
.. code-block:: python
......@@ -857,18 +865,26 @@ def bitwise_and(x, y, out=None, name=None):
)
@templatedoc()
def bitwise_or(x, y, out=None, name=None):
"""
${comment}
r"""
Apply ``bitwise_or`` on Tensor ``X`` and ``Y`` .
.. math::
Out = X | Y
.. note::
``paddle.bitwise_or`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor.
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
x (Tensor): Input Tensor of ``bitwise_or`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
y (Tensor): Input Tensor of ``bitwise_or`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
out(Tensor): Result of ``bitwise_or`` . It is a N-D Tensor with the same data type of input Tensor.
Returns:
Tensor: ${out_comment}
Tensor: Result of ``bitwise_or`` . It is a N-D Tensor with the same data type of input Tensor.
Examples:
.. code-block:: python
......@@ -887,18 +903,26 @@ def bitwise_or(x, y, out=None, name=None):
)
@templatedoc()
def bitwise_xor(x, y, out=None, name=None):
"""
${comment}
r"""
Apply ``bitwise_xor`` on Tensor ``X`` and ``Y`` .
.. math::
Out = X ^\wedge Y
.. note::
``paddle.bitwise_xor`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor.
Args:
x (Tensor): ${x_comment}
y (Tensor): ${y_comment}
out(Tensor): ${out_comment}
x (Tensor): Input Tensor of ``bitwise_xor`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
y (Tensor): Input Tensor of ``bitwise_xor`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
out(Tensor): Result of ``bitwise_xor`` . It is a N-D Tensor with the same data type of input Tensor.
Returns:
Tensor: ${out_comment}
Tensor: Result of ``bitwise_xor`` . It is a N-D Tensor with the same data type of input Tensor.
Examples:
.. code-block:: python
......@@ -916,17 +940,25 @@ def bitwise_xor(x, y, out=None, name=None):
)
@templatedoc()
def bitwise_not(x, out=None, name=None):
"""
${comment}
r"""
Apply ``bitwise_not`` on Tensor ``X``.
.. math::
Out = \sim X
.. note::
``paddle.bitwise_not`` supports broadcasting. If you want know more about broadcasting, please refer to please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor.
Args:
x(Tensor): ${x_comment}
out(Tensor): ${out_comment}
x (Tensor): Input Tensor of ``bitwise_not`` . It is a N-D Tensor of bool, uint8, int8, int16, int32, int64.
out(Tensor): Result of ``bitwise_not`` . It is a N-D Tensor with the same data type of input Tensor.
Returns:
Tensor: ${out_comment}
Tensor: Result of ``bitwise_not`` . It is a N-D Tensor with the same data type of input Tensor.
Examples:
.. code-block:: python
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册