未验证 提交 c5fc413a 编写于 作者: G GGBond8488 提交者: GitHub

【inplace api】Batch add inplace api gt_, ge_, lt_, le_, eq_, not_equal_,...

【inplace api】Batch add inplace api gt_, ge_, lt_, le_, eq_, not_equal_, logical_and_, logical_or_, logical_xor_, logical_not_, divide_, floor_divide_, bitwise_and_ , bitwise_or_, bitwise_xor_, bitwise_not_ (#55509)

* tmp commit

* add atan2

* add inplace api

* fix error

* add inpalce divide

* add inplace api

* add more inplace

* add more inpalce

* fix logical_not error

* support sinh and cosh in cpu

* support asin, acos, atan, asinh, acosh, atanh in cpu

* fix typro

* fix typro

* mv out atan2 ldexp

* mv out atan2 ldexp

* support sinh and cosh in gpu

* support asin, acos, atan, asinh, acosh, atanh in gpu

* fix ge error

* fix dygraph commpare error

* fix dygraph commpare error

* check complex in python

* fix cast inpalce error

* open inplace test

* fix ops.yaml error

* mv cast inpalce to python

* fix coverage ci

* add last inplace

* fix inplace error

* fix cast error

* fix error

* add nan_to_num_

* fix typro

* fix sparse cast error

* remove gpu 4

* fix static cast error

* tmp commit

* add atan2

* add inplace api

* fix error

* add inpalce divide

* add inplace api

* add more inplace

* add more inpalce

* fix logical_not error

* fix typro

* fix typro

* mv out atan2 ldexp

* mv out atan2 ldexp

* fix ge error

* fix dygraph commpare error

* fix dygraph commpare error

* fix cast inpalce error

* open inplace test

* fix ops.yaml error

* mv cast inpalce to python

* fix coverage ci

* add last inplace

* fix inplace error

* fix cast error

* fix error

* add nan_to_num_

* fix typro

* fix sparse cast error

* remove gpu 4

* fix static cast error

* fix cast error

* fix

* Revert "check complex in python"

This reverts commit c822064261d774dd58ad46a4f90ba8b467700a05.

* add renorm , fix error

* add coverage

* fix cumsum inpalce version error

* add cast inpalce impl

* rm test.log

* fix multiply_dyfunction and add multiply_backward test

* add and use is_same_tensor

* fix typro

* fix sone error

* fix typro

---------
Co-authored-by: NScotty <jmhgchn@gmail.com>
Co-authored-by: NScotty <527407973@qq.com>
上级 f9c51e8c
......@@ -262,6 +262,32 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, x_autograd_meta, y_autograd_meta);
// Node Declaration
std::shared_ptr<MultiplyGradNode> grad_node;
// Set grad_node before API Call
if (require_any_grad) {
paddle::platform::RecordEvent node_creation_record_event(
"multiply node_creation",
paddle::platform::TracerEventType::OperatorInner,
1);
grad_node = std::shared_ptr<MultiplyGradNode>(new MultiplyGradNode(1, 2));
// Set for forward trace
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(egr::Controller::Instance().GetPythonStack());
}
// SetAttributes if needed
grad_node->SetAttributeaxis(-1);
// Set TensorWrappers for Forward Inputs if needed
auto x_clone = paddle::experimental::assign(x);
grad_node->SetTensorWrapperx(x_clone);
grad_node->SetTensorWrappery(y);
}
// Forward API Call
auto& api_result = paddle::experimental::multiply_(x, y);
// Check NaN and Inf if needed
......@@ -275,10 +301,6 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
// Get Output AutoGradMeta
egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out);
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, x_autograd_meta, y_autograd_meta);
// Check Inplace if needed
egr::EagerUtils::CheckInplace(x, x_autograd_meta, require_any_grad);
......@@ -289,25 +311,7 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
// Node Creation
if (require_any_grad) {
paddle::platform::RecordEvent node_creation_record_event(
"multiply node_creation",
paddle::platform::TracerEventType::OperatorInner,
1);
egr::EagerUtils::PassStopGradient(false, out_autograd_meta);
// Node Construction
auto grad_node =
std::shared_ptr<MultiplyGradNode>(new MultiplyGradNode(1, 2));
// Set for forward trace
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(egr::Controller::Instance().GetPythonStack());
}
// SetAttributes if needed
grad_node->SetAttributeaxis(-1);
// Set TensorWrappers for Forward Inputs if needed
grad_node->SetTensorWrapperx(x);
grad_node->SetTensorWrappery(y);
// SetGradOutMeta & SetEdges
grad_node->SetGradOutMeta(x, 0);
grad_node->SetGradOutMeta(y, 1);
......@@ -429,7 +433,6 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
input_str += input_y_str;
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}
// Forward API Call
auto api_result = paddle::experimental::sparse::multiply(x, y);
// Check NaN and Inf if needed
......
......@@ -144,13 +144,14 @@
- op : cast
args : (Tensor x, DataType dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : CastInferMeta
kernel :
func : cast
param : [x, dtype]
data_type : x
inplace: (x -> out)
backward : cast_grad
- op : channel_shuffle
......@@ -232,11 +233,12 @@
- op : divide
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : divide
inplace: (x -> out)
backward : divide_grad
- op : dropout
......@@ -334,6 +336,7 @@
func : CompareInferMeta
kernel :
func : equal
inplace: (x -> out)
- op : exponential_
args : (Tensor x, float lam)
......@@ -365,6 +368,7 @@
func : ElementwiseInferMeta
kernel :
func : floor_divide
inplace: (x -> out)
- op : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
......@@ -499,6 +503,7 @@
func : CompareInferMeta
kernel :
func : greater_equal
inplace: (x -> out)
- op : greater_than
args : (Tensor x, Tensor y)
......@@ -507,6 +512,7 @@
func : CompareInferMeta
kernel :
func : greater_than
inplace: (x -> out)
- op : hardswish
args : (Tensor x)
......@@ -545,6 +551,7 @@
func : CompareInferMeta
kernel :
func : less_equal
inplace: (x -> out)
- op : less_than
args : (Tensor x, Tensor y)
......@@ -553,6 +560,7 @@
func : CompareInferMeta
kernel :
func : less_than
inplace: (x -> out)
- op : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place)
......@@ -721,6 +729,7 @@
func : CompareInferMeta
kernel :
func : not_equal
inplace: (x -> out)
- op : one_hot
args : (Tensor x, Scalar(int) num_classes)
......
......@@ -330,6 +330,7 @@
kernel :
func : bitwise_and
backend : x
inplace: (x -> out)
- op : bitwise_not
args : (Tensor x)
......@@ -339,6 +340,7 @@
kernel :
func : bitwise_not
backend : x
inplace: (x -> out)
- op : bitwise_or
args : (Tensor x, Tensor y)
......@@ -348,6 +350,7 @@
kernel :
func : bitwise_or
backend : x
inplace: (x -> out)
- op : bitwise_xor
args : (Tensor x, Tensor y)
......@@ -357,6 +360,7 @@
kernel :
func : bitwise_xor
backend : x
inplace: (x -> out)
- op : bmm
args : (Tensor x, Tensor y)
......@@ -618,6 +622,7 @@
func : UnchangedInferMetaCheckAxis
kernel :
func : cumprod
inplace: (x -> out)
backward : cumprod_grad
- op : cumsum
......@@ -628,6 +633,7 @@
kernel :
func : cumsum
data_type : x
inplace: (x -> out)
backward : cumsum_grad
- op : data
......@@ -1514,6 +1520,7 @@
func : logical_and
data_type : x
backend : x
inplace: (x -> out)
- op : logical_not
args : (Tensor x)
......@@ -1524,6 +1531,7 @@
func : logical_not
data_type : x
backend : x
inplace: (x -> out)
- op : logical_or
args : (Tensor x, Tensor y)
......@@ -1534,6 +1542,7 @@
func : logical_or
data_type : x
backend : x
inplace: (x -> out)
- op : logical_xor
args : (Tensor x, Tensor y)
......@@ -1544,6 +1553,7 @@
func : logical_xor
data_type : x
backend : x
inplace: (x -> out)
- op : logit
args : (Tensor x, float eps = 1e-6f)
......@@ -2054,12 +2064,13 @@
- op : renorm
args : (Tensor x, float p, int axis, float max_norm)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : renorm
inplace: (x -> out)
backward : renorm_grad
- op : reverse
......@@ -2779,11 +2790,12 @@
- op : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : WhereInferMeta
kernel :
func : where
inplace: (x -> out)
backward : where_grad
- op : yolo_box
......
......@@ -210,6 +210,10 @@ bool MetaTensor::is_selected_rows() const {
}
bool MetaTensor::is_tensor_array() const { return false; }
bool MetaTensor::is_same_tensor(const MetaTensor& meta_tensor) const {
return tensor_ != nullptr && tensor_ == meta_tensor.tensor();
}
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
ValidCheck(*this);
bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
......
......@@ -86,6 +86,8 @@ class MetaTensor {
// and it will be deleted in the future.
virtual bool is_tensor_array() const;
virtual bool is_same_tensor(const MetaTensor& meta_tensor) const;
virtual operator unspecified_bool_type() const {
return tensor_ == nullptr ? 0 : unspecified_bool_true;
}
......
......@@ -380,8 +380,9 @@ void CompareRawInferMeta(const MetaTensor& x,
out->set_dims(make_ddim(out_dims_array));
out->share_lod(x);
}
out->set_dtype(DataType::BOOL);
if (!out->is_same_tensor(x)) {
out->set_dtype(DataType::BOOL);
}
}
void CompareInferMeta(const MetaTensor& x,
......
......@@ -384,9 +384,14 @@ void BatchSizeLikeInferMeta(const MetaTensor& x,
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(out_dtype);
out->set_layout(x.layout());
out->share_lod(x);
// In inpalce case, setting the dtype of out will reset the dtype of x at the
// same time, which will cause bugs, so move the dtype setting of out to the
// kernel
if (!(out->is_same_tensor(x))) {
out->set_dtype(out_dtype);
}
}
void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
......
......@@ -26,7 +26,8 @@ void CastGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PD_VISIT_ALL_TYPES(x.dtype(), "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, out_grad, x_grad);
CastKernelImpl<T, data_t>(
dev_ctx, out_grad, x_grad->dtype(), x_grad);
}));
}
......
......@@ -29,12 +29,35 @@ struct CastOpTransformFunctor {
template <typename InT, typename OutT>
void CastKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
auto* in_begin = x.data<InT>();
auto numel = x.numel();
auto* in_end = in_begin + numel;
auto* out_begin = dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);
phi::Transform<CPUContext> trans;
trans(dev_ctx,
in_begin,
in_end,
out_begin,
CastOpTransformFunctor<InT, OutT>());
}
template <typename InT, typename OutT>
void CastInplaceKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
auto x_origin = x;
auto* in_begin = x_origin.data<InT>();
auto numel = x_origin.numel();
auto* in_end = in_begin + numel;
auto* out_begin = dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);
phi::Transform<CPUContext> trans;
trans(dev_ctx,
......
......@@ -25,9 +25,16 @@ void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out);
}));
if (out->IsSharedWith(x)) {
PD_VISIT_ALL_TYPES(out_dtype, "CastInplaceKernelImpl", ([&] {
CastInplaceKernelImpl<T, data_t>(
dev_ctx, x, out_dtype, out);
}));
} else {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out_dtype, out);
}));
}
}
} // namespace phi
......
......@@ -30,13 +30,22 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
}
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, InverseFunctor(), out, axis);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, T>(ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, T>(
ctx, x, y, InverseFunctor(), out, axis);
}
}
}
......@@ -83,21 +92,19 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
......
......@@ -24,15 +24,20 @@
namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, binary_func, out); \
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
if (out->IsSharedWith(x)) { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, T>( \
dev_ctx, x, y, binary_func, out); \
} else { \
funcs::ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, binary_func, out); \
} \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
......@@ -44,11 +49,19 @@ template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
phi::Transform<Context> trans;
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
if (!out->IsSharedWith(x)) {
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
} else {
trans(dev_ctx,
x.data<T>(),
x.data<T>() + x.numel(),
reinterpret_cast<T*>(out->data()),
unary_func);
}
}
} // namespace phi
......@@ -64,9 +77,7 @@ void LogicalNotKernel(const Context& dev_ctx,
int64_t, \
int, \
int8_t, \
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
int16_t) {}
REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
......
......@@ -26,7 +26,8 @@ void CastGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PD_VISIT_ALL_TYPES(x.dtype(), "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(dev_ctx, out_grad, x_grad);
CastCUDAKernelImpl<T, data_t>(
dev_ctx, out_grad, x_grad->dtype(), x_grad);
}));
}
......
......@@ -29,12 +29,31 @@ struct CastFunctor {
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);
phi::funcs::ElementwiseKernel<OutT>(
dev_ctx, inputs, &outputs, CastFunctor<InT, OutT>());
}
template <typename InT, typename OutT>
void CastInplaceCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
// inplace case
auto x_origin = x;
inputs.emplace_back(&x_origin);
outputs.emplace_back(out);
dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);
phi::funcs::ElementwiseKernel<OutT>(
dev_ctx, inputs, &outputs, CastFunctor<InT, OutT>());
}
......
......@@ -25,9 +25,17 @@ void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
}));
if (out->IsSharedWith(x)) {
PD_VISIT_ALL_TYPES(out_dtype, "CastInplaceCUDAKernelImpl", ([&] {
CastInplaceCUDAKernelImpl<T, data_t>(
dev_ctx, x, out_dtype, out);
}));
} else {
PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(
dev_ctx, x, out_dtype, out);
}));
}
}
} // namespace phi
......
......@@ -52,10 +52,16 @@ inline void CompareKernelImpl(const Context& ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
ctx.template Alloc<bool>(out);
if (!out->IsSharedWith(x)) {
ctx.template Alloc<bool>(out);
}
std::vector<const DenseTensor*> ins{&x, &y};
std::vector<DenseTensor*> outs{out};
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
if (!out->IsSharedWith(x)) {
funcs::BroadcastKernel<bool>(ctx, ins, &outs, Functor(), axis);
} else {
funcs::BroadcastKernel<T>(ctx, ins, &outs, Functor(), axis);
}
}
#ifndef PADDLE_WITH_XPU_KP
......@@ -128,21 +134,18 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
......@@ -25,17 +25,24 @@
namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
dev_ctx.template Alloc<bool>(out); \
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
if (!out->IsSharedWith(x)) { \
dev_ctx.template Alloc<bool>(out); \
} \
\
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
if (!out->IsSharedWith(x)) { \
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, binary_func); \
} else { \
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, binary_func); \
} \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
......@@ -47,11 +54,17 @@ template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<bool>(out);
if (!out->IsSharedWith(x)) {
dev_ctx.template Alloc<bool>(out);
}
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
if (!out->IsSharedWith(x)) {
funcs::BroadcastKernel<bool>(dev_ctx, ins, &outs, unary_func);
} else {
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, unary_func);
}
}
} // namespace phi
......@@ -84,9 +97,7 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
int64_t, \
int, \
int8_t, \
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
int16_t) {}
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
......
......@@ -155,7 +155,6 @@ void CastCooKernel(const Context& dev_ctx,
} else {
phi::MetaTensor meta(out_values);
meta.set_dims(x_values.dims());
meta.set_dtype(value_dtype);
phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
}
out->SetIndicesDict(x.GetIndicesDict());
......@@ -201,7 +200,6 @@ void CastCsrKernel(const Context& dev_ctx,
} else {
phi::MetaTensor meta(out_values);
meta.set_dims(x_values.dims());
meta.set_dtype(value_dtype);
phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
}
}
......
......@@ -136,26 +136,41 @@ from .tensor.linalg import histogram # noqa: F401
from .tensor.linalg import bincount # noqa: F401
from .tensor.linalg import mv # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import equal_ # noqa: F401
from .tensor.linalg import eigvalsh # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
from .tensor.logic import greater_equal_ # noqa: F401
from .tensor.logic import greater_than # noqa: F401
from .tensor.logic import greater_than_ # noqa: F401
from .tensor.logic import is_empty # noqa: F401
from .tensor.logic import less_equal # noqa: F401
from .tensor.logic import less_equal_ # noqa: F401
from .tensor.logic import less_than # noqa: F401
from .tensor.logic import less_than_ # noqa: F401
from .tensor.logic import logical_and # noqa: F401
from .tensor.logic import logical_and_ # noqa: F401
from .tensor.logic import logical_not # noqa: F401
from .tensor.logic import logical_not_ # noqa: F401
from .tensor.logic import logical_or # noqa: F401
from .tensor.logic import logical_or_ # noqa: F401
from .tensor.logic import logical_xor # noqa: F401
from .tensor.logic import logical_xor_ # noqa: F401
from .tensor.logic import bitwise_and # noqa: F401
from .tensor.logic import bitwise_and_ # noqa: F401
from .tensor.logic import bitwise_not # noqa: F401
from .tensor.logic import bitwise_not_ # noqa: F401
from .tensor.logic import bitwise_or # noqa: F401
from .tensor.logic import bitwise_or_ # noqa: F401
from .tensor.logic import bitwise_xor # noqa: F401
from .tensor.logic import bitwise_xor_ # noqa: F401
from .tensor.logic import not_equal # noqa: F401
from .tensor.logic import not_equal_ # noqa: F401
from .tensor.logic import allclose # noqa: F401
from .tensor.logic import isclose # noqa: F401
from .tensor.logic import equal_all # noqa: F401
from .tensor.logic import is_tensor # noqa: F401
from .tensor.manipulation import cast # noqa: F401
from .tensor.manipulation import cast_ # noqa: F401
from .tensor.manipulation import concat # noqa: F401
from .tensor.manipulation import broadcast_tensors # noqa: F401
from .tensor.manipulation import expand # noqa: F401
......@@ -225,9 +240,11 @@ from .tensor.math import tan_ # noqa: F401
from .tensor.math import cosh # noqa: F401
from .tensor.math import cosh_ # noqa: F401
from .tensor.math import cumsum # noqa: F401
from .tensor.math import cumsum_ # noqa: F401
from .tensor.math import cummax # noqa: F401
from .tensor.math import cummin # noqa: F401
from .tensor.math import cumprod # noqa: F401
from .tensor.math import cumprod_ # noqa: F401
from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401
from .tensor.math import logit_ # noqa: F401
......@@ -262,6 +279,7 @@ from .tensor.math import square_ # noqa: F401
from .tensor.math import stanh # noqa: F401
from .tensor.math import sum # noqa: F401
from .tensor.math import nan_to_num # noqa: F401
from .tensor.math import nan_to_num_ # noqa: F401
from .tensor.math import nansum # noqa: F401
from .tensor.math import nanmean # noqa: F401
from .tensor.math import count_nonzero # noqa: F401
......@@ -276,13 +294,19 @@ from .tensor.math import minimum # noqa: F401
from .tensor.math import amin # noqa: F401
from .tensor.math import mm # noqa: F401
from .tensor.math import divide # noqa: F401
from .tensor.math import divide_ # noqa: F401
from .tensor.math import floor_divide # noqa: F401
from .tensor.math import floor_divide_ # noqa: F401
from .tensor.math import remainder # noqa: F401
from .tensor.math import remainder_ # noqa: F401
from .tensor.math import mod # noqa: F401
from .tensor.math import mod_ # noqa: F401
from .tensor.math import floor_mod # noqa: F401
from .tensor.math import floor_mod_ # noqa: F401
from .tensor.math import multiply # noqa: F401
from .tensor.math import multiply_ # noqa: F401
from .tensor.math import renorm # noqa: F401
from .tensor.math import renorm_ # noqa: F401
from .tensor.math import add # noqa: F401
from .tensor.math import subtract # noqa: F401
from .tensor.math import logsumexp # noqa: F401
......@@ -323,7 +347,9 @@ from .tensor.math import erfinv # noqa: F401
from .tensor.math import rad2deg # noqa: F401
from .tensor.math import deg2rad # noqa: F401
from .tensor.math import gcd # noqa: F401
from .tensor.math import gcd_ # noqa: F401
from .tensor.math import lcm # noqa: F401
from .tensor.math import lcm_ # noqa: F401
from .tensor.math import diff # noqa: F401
from .tensor.math import angle # noqa: F401
from .tensor.math import fmax # noqa: F401
......@@ -337,6 +363,7 @@ from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
from .tensor.math import ldexp # noqa: F401
from .tensor.math import ldexp_ # noqa: F401
from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401
......@@ -368,6 +395,7 @@ from .tensor.search import bucketize # noqa: F401
from .tensor.search import masked_select # noqa: F401
from .tensor.search import topk # noqa: F401
from .tensor.search import where # noqa: F401
from .tensor.search import where_ # noqa: F401
from .tensor.search import index_select # noqa: F401
from .tensor.search import nonzero # noqa: F401
from .tensor.search import sort # noqa: F401
......@@ -491,9 +519,11 @@ __all__ = [ # noqa
'empty_like',
'eye',
'cumsum',
'cumsum_',
'cummax',
'cummin',
'cumprod',
'cumprod_',
'logaddexp',
'logcumsumexp',
'logit',
......@@ -502,12 +532,14 @@ __all__ = [ # noqa
'sign',
'is_empty',
'equal',
'equal_',
'equal_all',
'is_tensor',
'is_complex',
'is_integer',
'cross',
'where',
'where_',
'log1p',
'cos',
'cos_',
......@@ -536,8 +568,10 @@ __all__ = [ # noqa
'split',
'vsplit',
'logical_and',
'logical_and_',
'full_like',
'less_than',
'less_than_',
'kron',
'clip',
'Tensor',
......@@ -558,17 +592,25 @@ __all__ = [ # noqa
'isinf',
'uniform',
'floor_divide',
'floor_divide_',
'remainder',
'remainder_',
'floor_mod',
'floor_mod_',
'roll',
'batch',
'max',
'amax',
'logical_or',
'logical_or_',
'bitwise_and',
'bitwise_and_',
'bitwise_or',
'bitwise_or_',
'bitwise_xor',
'bitwise_xor_',
'bitwise_not',
'bitwise_not_',
'mm',
'flip',
'rot90',
......@@ -585,6 +627,7 @@ __all__ = [ # noqa
'reciprocal',
'rand',
'less_equal',
'less_equal_',
'triu',
'triu_',
'sin',
......@@ -605,6 +648,7 @@ __all__ = [ # noqa
'set_grad_enabled',
'is_grad_enabled',
'mod',
'mod_',
'abs',
'abs_',
'tril',
......@@ -645,6 +689,7 @@ __all__ = [ # noqa
'square',
'square_',
'divide',
'divide_',
'ceil',
'atan',
'atan_',
......@@ -652,12 +697,15 @@ __all__ = [ # noqa
'rad2deg',
'deg2rad',
'gcd',
'gcd_',
'lcm',
'lcm_',
'expand',
'broadcast_to',
'ones_like',
'index_sample',
'cast',
'cast_',
'grad',
'all',
'ones',
......@@ -668,6 +716,7 @@ __all__ = [ # noqa
'count_nonzero',
'tile',
'greater_equal',
'greater_equal_',
'isfinite',
'create_parameter',
'dot',
......@@ -679,6 +728,7 @@ __all__ = [ # noqa
'tolist',
'tensordot',
'greater_than',
'greater_than_',
'shard_index',
'argsort',
'tanh',
......@@ -695,6 +745,7 @@ __all__ = [ # noqa
'flatten',
'asin',
'multiply',
'multiply_',
'disable_static',
'masked_select',
'var',
......@@ -715,6 +766,7 @@ __all__ = [ # noqa
'nonzero',
'CUDAPinnedPlace',
'logical_not',
'logical_not_',
'add_n',
'minimum',
'scatter',
......@@ -755,9 +807,11 @@ __all__ = [ # noqa
'clone',
'kthvalue',
'renorm',
'renorm_',
'take_along_axis',
'put_along_axis',
'nan_to_num',
'nan_to_num_',
'heaviside',
'tril_indices',
'index_add',
......@@ -769,6 +823,7 @@ __all__ = [ # noqa
'take',
'frexp',
'ldexp',
'ldexp_',
'trapezoid',
'cumulative_trapezoid',
'polar',
......
......@@ -75,25 +75,40 @@ from .linalg import lu # noqa: F401
from .linalg import lu_unpack # noqa: F401
from .linalg import cdist # noqa: F401
from .logic import equal # noqa: F401
from .logic import equal_ # noqa: F401
from .logic import greater_equal # noqa: F401
from .logic import greater_equal_ # noqa: F401
from .logic import greater_than # noqa: F401
from .logic import greater_than_ # noqa: F401
from .logic import is_empty # noqa: F401
from .logic import less_equal # noqa: F401
from .logic import less_equal_ # noqa: F401
from .logic import less_than # noqa: F401
from .logic import less_than_ # noqa: F401
from .logic import logical_and # noqa: F401
from .logic import logical_and_ # noqa: F401
from .logic import logical_not # noqa: F401
from .logic import logical_not_ # noqa: F401
from .logic import logical_or # noqa: F401
from .logic import logical_or_ # noqa: F401
from .logic import logical_xor # noqa: F401
from .logic import logical_xor_ # noqa: F401
from .logic import bitwise_and # noqa: F401
from .logic import bitwise_and_ # noqa: F401
from .logic import bitwise_or # noqa: F401
from .logic import bitwise_or_ # noqa: F401
from .logic import bitwise_xor # noqa: F401
from .logic import bitwise_xor_ # noqa: F401
from .logic import bitwise_not # noqa: F401
from .logic import bitwise_not_ # noqa: F401
from .logic import not_equal # noqa: F401
from .logic import not_equal_ # noqa: F401
from .logic import allclose # noqa: F401
from .logic import isclose # noqa: F401
from .logic import equal_all # noqa: F401
from .logic import is_tensor # noqa: F401
from .manipulation import cast # noqa: F401
from .manipulation import cast_ # noqa: F401
from .manipulation import concat # noqa: F401
from .manipulation import expand # noqa: F401
from .manipulation import broadcast_to # noqa: F401
......@@ -163,9 +178,11 @@ from .math import tan_ # noqa: F401
from .math import cosh # noqa: F401
from .math import cosh_ # noqa: F401
from .math import cumsum # noqa: F401
from .math import cumsum_ # noqa: F401
from .math import cummax # noqa: F401
from .math import cummin # noqa: F401
from .math import cumprod # noqa: F401
from .math import cumprod_ # noqa: F401
from .math import logcumsumexp # noqa: F401
from .math import logit # noqa: F401
from .math import logit_ # noqa: F401
......@@ -199,6 +216,7 @@ from .math import square # noqa: F401
from .math import stanh # noqa: F401
from .math import sum # noqa: F401
from .math import nan_to_num # noqa: F401
from .math import nan_to_num_ # noqa: F401
from .math import nansum # noqa: F401
from .math import nanmean # noqa: F401
from .math import count_nonzero # noqa: F401
......@@ -213,11 +231,15 @@ from .math import amin # noqa: F401
from .math import minimum # noqa: F401
from .math import mm # noqa: F401
from .math import divide # noqa: F401
from .math import divide_ # noqa: F401
from .math import floor_divide # noqa: F401
from .math import floor_divide_ # noqa: F401
from .math import remainder # noqa: F401
from .math import remainder_ # noqa: F401
from .math import mod # noqa: F401
from .math import mod_ # noqa: F401
from .math import floor_mod # noqa: F401
from .math import floor_mod_ # noqa: F401
from .math import multiply # noqa: F401
from .math import multiply_ # noqa: F401
from .math import add # noqa: F401
......@@ -271,7 +293,9 @@ from .math import erfinv_ # noqa: F401
from .math import rad2deg # noqa: F401
from .math import deg2rad # noqa: F401
from .math import gcd # noqa: F401
from .math import gcd_ # noqa: F401
from .math import lcm # noqa: F401
from .math import lcm_ # noqa: F401
from .math import diff # noqa: F401
from .math import angle # noqa: F401
from .math import fmax # noqa: F401
......@@ -285,6 +309,7 @@ from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401
from .math import ldexp # noqa: F401
from .math import ldexp_ # noqa: F401
from .math import trapezoid # noqa: F401
from .math import cumulative_trapezoid # noqa: F401
from .math import sigmoid # noqa: F401
......@@ -318,6 +343,7 @@ from .search import searchsorted # noqa: F401
from .search import bucketize # noqa: F401
from .search import topk # noqa: F401
from .search import where # noqa: F401
from .search import where_ # noqa: F401
from .search import index_select # noqa: F401
from .search import nonzero # noqa: F401
from .search import sort # noqa: F401
......@@ -380,9 +406,11 @@ tensor_method_func = [ # noqa
'cos',
'cosh',
'cumsum',
'cumsum_',
'cummax',
'cummin',
'cumprod',
'cumprod_',
'logcumsumexp',
'logit',
'logit_',
......@@ -421,6 +449,7 @@ tensor_method_func = [ # noqa
'stanh',
'sum',
'nan_to_num',
'nan_to_num_',
'nansum',
'nanmean',
'count_nonzero',
......@@ -439,11 +468,15 @@ tensor_method_func = [ # noqa
'inner',
'outer',
'divide',
'divide_',
'floor_divide',
'floor_divide_',
'remainder',
'remainder_',
'mod',
'mod_',
'floor_mod',
'floor_mod_',
'multiply',
'multiply_',
'add',
......@@ -473,6 +506,7 @@ tensor_method_func = [ # noqa
'lgamma',
'lgamma_',
'equal',
'equal_',
'equal_all',
'greater_equal',
'greater_equal_',
......@@ -484,15 +518,20 @@ tensor_method_func = [ # noqa
'less_than',
'less_than_',
'logical_and',
'logical_and_',
'logical_not',
'logical_not_',
'logical_or',
'logical_or_',
'logical_xor',
'logical_xor_',
'not_equal',
'not_equal_',
'allclose',
'isclose',
'is_tensor',
'cast',
'cast_',
'concat',
'expand',
'broadcast_to',
......@@ -535,6 +574,7 @@ tensor_method_func = [ # noqa
'masked_select',
'topk',
'where',
'where_',
'index_select',
'nonzero',
'sort',
......@@ -562,9 +602,13 @@ tensor_method_func = [ # noqa
'frac',
'frac_',
'bitwise_and',
'bitwise_and_',
'bitwise_or',
'bitwise_or_',
'bitwise_xor',
'bitwise_xor_',
'bitwise_not',
'bitwise_not_',
'broadcast_tensors',
'eig',
'uniform_',
......@@ -583,7 +627,9 @@ tensor_method_func = [ # noqa
'rad2deg',
'deg2rad',
'gcd',
'gcd_',
'lcm',
'lcm_',
'diff',
"mode",
'lerp',
......@@ -607,6 +653,7 @@ tensor_method_func = [ # noqa
'sgn',
'frexp',
'ldexp',
'ldexp_',
'trapezoid',
'cumulative_trapezoid',
'polar',
......
......@@ -24,6 +24,8 @@ Tensor = paddle.fluid.framework.core.eager.Tensor
from paddle import _C_ops
from paddle.tensor.creation import full
from paddle.tensor.math import broadcast_shape
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
from ..framework import LayerHelper, in_dynamic_mode
......@@ -138,6 +140,23 @@ def logical_and(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def logical_and_(x, y, name=None):
r"""
Inplace version of ``logical_and`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_logical_and`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.logical_and_(x, y)
def logical_or(x, y, out=None, name=None):
"""
......@@ -182,6 +201,23 @@ def logical_or(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def logical_or_(x, y, name=None):
r"""
Inplace version of ``logical_or`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_logical_or`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.logical_or_(x, y)
def logical_xor(x, y, out=None, name=None):
r"""
......@@ -227,6 +263,23 @@ def logical_xor(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def logical_xor_(x, y, name=None):
r"""
Inplace version of ``logical_xor`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_logical_xor`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.logical_xor_(x, y)
def logical_not(x, out=None, name=None):
"""
......@@ -266,6 +319,16 @@ def logical_not(x, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def logical_not_(x, name=None):
r"""
Inplace version of ``logical_not`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_logical_not`.
"""
if in_dynamic_mode():
return _C_ops.logical_not_(x)
def is_empty(x, name=None):
"""
......@@ -506,6 +569,23 @@ def equal(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def equal_(x, y, name=None):
r"""
Inplace version of ``equal`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_equal`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.equal_(x, y)
@templatedoc()
def greater_equal(x, y, name=None):
"""
......@@ -575,6 +655,23 @@ def greater_equal(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def greater_equal_(x, y, name=None):
r"""
Inplace version of ``greater_equal`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_greater_equal`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.greater_equal_(x, y)
@templatedoc()
def greater_than(x, y, name=None):
"""
......@@ -644,6 +741,23 @@ def greater_than(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def greater_than_(x, y, name=None):
r"""
Inplace version of ``greater_than`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_greater_than`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.greater_than_(x, y)
@templatedoc()
def less_equal(x, y, name=None):
"""
......@@ -714,6 +828,23 @@ def less_equal(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def less_equal_(x, y, name=None):
r"""
Inplace version of ``less_equal`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_less_equal`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.less_equal_(x, y)
@templatedoc()
def less_than(x, y, name=None):
"""
......@@ -784,6 +915,23 @@ def less_than(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def less_than_(x, y, name=None):
r"""
Inplace version of ``less_than`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_less_than`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.less_than_(x, y)
@templatedoc()
def not_equal(x, y, name=None):
"""
......@@ -854,6 +1002,23 @@ def not_equal(x, y, name=None):
return out
@inplace_apis_in_dygraph_only
def not_equal_(x, y, name=None):
r"""
Inplace version of ``not_equal`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_not_equal`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.not_equal_(x, y)
def is_tensor(x):
"""
......@@ -965,6 +1130,23 @@ def bitwise_and(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def bitwise_and_(x, y, name=None):
r"""
Inplace version of ``bitwise_and`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_bitwise_and`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.bitwise_and_(x, y)
def bitwise_or(x, y, out=None, name=None):
r"""
......@@ -1003,6 +1185,23 @@ def bitwise_or(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def bitwise_or_(x, y, name=None):
r"""
Inplace version of ``bitwise_or`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_bitwise_or`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.bitwise_or_(x, y)
def bitwise_xor(x, y, out=None, name=None):
r"""
......@@ -1040,6 +1239,23 @@ def bitwise_xor(x, y, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def bitwise_xor_(x, y, name=None):
r"""
Inplace version of ``bitwise_xor`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_bitwise_xor`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
if in_dynamic_mode():
return _C_ops.bitwise_xor_(x, y)
def bitwise_not(x, out=None, name=None):
r"""
......@@ -1076,6 +1292,16 @@ def bitwise_not(x, out=None, name=None):
)
@inplace_apis_in_dygraph_only
def bitwise_not_(x, name=None):
r"""
Inplace version of ``bitwise_not`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_bitwise_not`.
"""
if in_dynamic_mode():
return _C_ops.bitwise_not_(x)
@templatedoc()
def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
r"""
......
......@@ -230,6 +230,18 @@ def cast(x, dtype):
return out
@inplace_apis_in_dygraph_only
def cast_(x, dtype):
"""
Inplace version of ``cast`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_cast`.
"""
if in_dynamic_mode():
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return _C_ops.cast_(x, dtype)
def slice(input, axes, starts, ends):
"""
This operator produces a slice of ``input`` along multiple axes. Similar to numpy:
......
......@@ -31,14 +31,13 @@ from ..fluid.data_feeder import (
)
from ..framework import (
LayerHelper,
_dygraph_tracer,
convert_np_dtype_to_dtype_,
core,
in_dynamic_mode,
)
from .creation import _complex_to_real_dtype
from .layer_function_generator import generate_layer_fn, templatedoc
from .manipulation import cast
from .manipulation import cast, cast_
from .ops import abs # noqa: F401
from .ops import abs_ # noqa: F401
from .ops import acos # noqa: F401
......@@ -882,6 +881,22 @@ def divide(x, y, name=None):
return _elementwise_op(LayerHelper('elementwise_div', **locals()))
@inplace_apis_in_dygraph_only
def divide_(x, y, name=None):
r"""
Inplace version of ``divide`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_divide`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
return _C_ops.divide_(x, y)
def floor_divide(x, y, name=None):
"""
Floor divide two tensors element-wise and rounds the quotinents to the nearest integer toward zero. The equation is:
......@@ -927,6 +942,22 @@ def floor_divide(x, y, name=None):
return _elementwise_op(LayerHelper('elementwise_floordiv', **locals()))
@inplace_apis_in_dygraph_only
def floor_divide_(x, y, name=None):
r"""
Inplace version of ``floor_divide`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_floor_divide`.
"""
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
return _C_ops.floor_divide_(x, y)
def remainder(x, y, name=None):
r"""
Mod two tensors element-wise. The equation is:
......@@ -986,6 +1017,16 @@ def remainder_(x, y, name=None):
mod = remainder # noqa: F841
floor_mod = remainder # noqa: F841
mod_ = remainder_ # noqa: F841
mod_.__doc__ = r"""
Inplace version of ``mod`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_mod`.
"""
floor_mod_ = remainder_ # noqa: F841
floor_mod_.__doc__ = r"""
Inplace version of ``floor_mod_`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_floor_mod_`.
"""
def multiply(x, y, name=None):
......@@ -1048,10 +1089,6 @@ def multiply_(x, y, name=None):
Please refer to :ref:`api_tensor_multiply`.
"""
assert (
_dygraph_tracer()._has_grad is False
), "The current inplace version of multiply_ needs to be used in the context of paddle.no_grad() since inplace multiply_grad is not yet supported."
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
......@@ -1592,6 +1629,36 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None, name=None):
return x
@inplace_apis_in_dygraph_only
def nan_to_num_(x, nan=0.0, posinf=None, neginf=None, name=None):
r"""
Inplace version of ``nan_to_num`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_nan_to_num`.
"""
# NOTE(tiancaishaonvjituizi): it seems that paddle handles the dtype of python float number
# incorrectly, so we have to explicitly contruct tensors here
posinf_value = paddle.full_like(x, float("+inf"))
neginf_value = paddle.full_like(x, float("-inf"))
nan = paddle.full_like(x, nan)
assert x.dtype in [paddle.float32, paddle.float64]
is_float32 = x.dtype == paddle.float32
if posinf is None:
posinf = (
np.finfo(np.float32).max if is_float32 else np.finfo(np.float64).max
)
posinf = paddle.full_like(x, posinf)
if neginf is None:
neginf = (
np.finfo(np.float32).min if is_float32 else np.finfo(np.float64).min
)
neginf = paddle.full_like(x, neginf)
x_not_nan = paddle.logical_not(paddle.isnan(x))
x = paddle.where_(x_not_nan, x, nan)
x = paddle.where_(x != posinf_value, x, posinf)
x = paddle.where_(x != neginf_value, x, neginf)
return x
def nansum(x, axis=None, dtype=None, keepdim=False, name=None):
"""
Computes the sum of tensor elements over the given axis, treating Not a Numbers (NaNs) as zero.
......@@ -2344,6 +2411,32 @@ def renorm(x, p, axis, max_norm):
return out
@inplace_apis_in_dygraph_only
def renorm_(x, p, axis, max_norm):
"""
Inplace version of ``renorm`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_renorm`.
"""
input_shape = x.shape
if not axis < len(input_shape):
raise ValueError(
"the axis:{} should be less then the shape's size {}:{}".format(
axis, len(input_shape), input_shape
)
)
if not axis >= 0:
if not axis >= -1 * len(input_shape):
raise ValueError(
"the axis:{} should not be less than -1 * length of input_shape:{}".format(
axis, -1 * len(input_shape)
)
)
axis = axis + len(input_shape)
if in_dynamic_mode():
out = _C_ops.renorm_(x, p, axis, max_norm)
return out
def inner(x, y, name=None):
"""
......@@ -3852,6 +3945,25 @@ def cumsum(x, axis=None, dtype=None, name=None):
return _cum_sum_(**kwargs)
@inplace_apis_in_dygraph_only
def cumsum_(x, axis=None, dtype=None, name=None):
r"""
Inplace version of ``cumprod`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_cumprod`.
"""
if axis is None:
flatten = True
else:
flatten = False
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast_(x, dtype)
if in_dynamic_mode():
if axis is None:
axis = -1
return _C_ops.cumsum_(x, axis, flatten, False, False)
def cummax(x, axis=None, dtype='int64', name=None):
"""
The cumulative max of the elements along a given axis.
......@@ -4196,6 +4308,19 @@ def cumprod(x, dim=None, dtype=None, name=None):
return out
@inplace_apis_in_dygraph_only
def cumprod_(x, dim=None, dtype=None, name=None):
r"""
Inplace version of ``cumprod`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_cumprod`.
"""
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast_(x, dtype)
if in_dynamic_mode():
return _C_ops.cumprod_(x, dim)
def isfinite(x, name=None):
"""
......@@ -5449,6 +5574,51 @@ def gcd(x, y, name=None):
return out
def gcd_(x, y, name=None):
r"""
Inplace version of ``gcd`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_gcd`.
"""
shape = paddle.broadcast_shape(x.shape, y.shape)
if shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
shape, x.shape
)
)
y = paddle.broadcast_to(y, shape)
x = paddle.abs_(x)
y = paddle.abs(y)
def _gcd_cond_fn(x, y):
return paddle.any(y != 0)
def _gcd_body_fn(x, y):
# paddle.mod will raise an error when any element of y is 0. To avoid
# that, we change those zeros to ones. Their values don't matter because
# they won't be used.
y_equal_0 = y == 0
y_safe = paddle.where(y_equal_0, paddle.ones(y.shape, y.dtype), y)
y, x = (
paddle.where(
y_equal_0,
paddle.zeros(y.shape, y.dtype),
paddle.mod(x, y_safe),
),
paddle.where_(y_equal_0, x, y),
)
return (
paddle.where(x < y, x, y),
paddle.where_(x >= y, x, y),
)
if in_dynamic_mode():
while _gcd_cond_fn(x, y):
y, x = _gcd_body_fn(x, y)
return x
def lcm(x, y, name=None):
"""
Computes the element-wise least common multiple (LCM) of input |x| and |y|.
......@@ -5509,6 +5679,25 @@ def lcm(x, y, name=None):
return out
def lcm_(x, y, name=None):
r"""
Inplace version of ``lcm`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_lcm`.
"""
d = paddle.gcd(x, y)
# paddle.mod will raise an error when any element of y is 0. To avoid
# that, we change those zeros to ones. Their values don't matter because
# they won't be used.
d_not_equal_0 = d != 0
d_safe = paddle.where(d_not_equal_0, d, paddle.ones(d.shape, d.dtype))
out = paddle.where_(
d_not_equal_0,
paddle.abs_(x.multiply_(y)).floor_divide_(d_safe),
paddle.zeros(d.shape, d.dtype),
)
return out
def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
r"""
Computes the n-th forward difference along the given axis.
......@@ -6651,6 +6840,7 @@ def polygamma(x, n, name=None):
return out
@inplace_apis_in_dygraph_only
def polygamma_(x, n, name=None):
r"""
Inplace version of ``polygamma`` API, the output Tensor will be inplaced with input ``x``.
......@@ -6722,3 +6912,22 @@ def ldexp(x, y, name=None):
y = paddle.cast(y, dtype=out_dtype)
two = paddle.to_tensor(2, dtype=out_dtype)
return paddle.multiply(x, paddle.pow(two, y))
def ldexp_(x, y, name=None):
r"""
Inplace version of ``polygamma`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_polygamma`.
"""
if not isinstance(x, (paddle.Tensor, Variable)):
raise TypeError(f"x must be tensor type, but got {type(x)}")
if not isinstance(y, (paddle.Tensor, Variable)):
raise TypeError(f"y must be tensor type, but got {type(y)}")
if x.dtype == paddle.float64 or y.dtype == paddle.float64:
out_dtype = paddle.float64
else:
out_dtype = paddle.get_default_dtype()
x = paddle.cast_(x, dtype=out_dtype)
y = paddle.cast(y, dtype=out_dtype)
two = paddle.to_tensor(2, dtype=out_dtype)
return paddle.multiply_(x, paddle.pow(two, y))
......@@ -19,6 +19,7 @@ import numpy as np
import paddle
from paddle import _C_ops
from paddle.common_ops_import import VarDesc, Variable
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
from ..fluid.data_feeder import check_dtype, check_variable_and_dtype
from ..framework import (
......@@ -708,6 +709,43 @@ def where(condition, x=None, y=None, name=None):
return out
@inplace_apis_in_dygraph_only
def where_(condition, x=None, y=None, name=None):
r"""
Inplace version of ``where`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_where`.
"""
if np.isscalar(x) or np.isscalar(y):
raise ValueError("either both or neither of x and y should be given")
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)
if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition)
zeros_like_condition = paddle.cast(zeros_like_condition, x.dtype)
cast_cond = paddle.cast(condition, x.dtype)
broadcast_zeros = paddle.add(zeros_like_x, zeros_like_y)
broadcast_zeros = paddle.add(broadcast_zeros, zeros_like_condition)
broadcast_x = x.add_(broadcast_zeros)
broadcast_y = paddle.add(y, broadcast_zeros)
broadcast_condition = paddle.add(cast_cond, broadcast_zeros)
broadcast_condition = paddle.cast(broadcast_condition, 'bool')
if in_dynamic_mode():
return _C_ops.where_(broadcast_condition, broadcast_x, broadcast_y)
def index_sample(x, index):
"""
**IndexSample Layer**
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册