From cd28cddbfb5f5643947291e9a640ecd414dc8dae Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 9 Mar 2022 20:11:10 +0800 Subject: [PATCH] [PHI] Move set_value kernel to phi (#40195) * save code * fix bug of set_value * add coverage test --- paddle/fluid/framework/operator.cc | 65 +- paddle/fluid/framework/operator.h | 4 +- paddle/fluid/imperative/execution_context.h | 5 + paddle/fluid/imperative/prepared_operator.h | 61 +- paddle/fluid/operators/set_value_op.cc | 7 - paddle/fluid/operators/set_value_op.cu | 7 - paddle/fluid/operators/set_value_op.h | 195 ----- paddle/phi/core/kernel_utils.h | 1 + paddle/phi/kernels/cpu/set_value_kernel.cc | 38 + paddle/phi/kernels/gpu/set_value_kernel.cu | 38 + .../phi/kernels/impl/set_value_kernel_impl.h | 337 ++++++++ paddle/phi/kernels/set_value_kernel.h | 49 ++ paddle/phi/ops/compat/set_value_sig.cc | 736 ++++++++++++++++++ paddle/phi/tests/ops/test_op_signature.cc | 370 +++++++++ 14 files changed, 1701 insertions(+), 212 deletions(-) create mode 100644 paddle/phi/kernels/cpu/set_value_kernel.cc create mode 100644 paddle/phi/kernels/gpu/set_value_kernel.cu create mode 100644 paddle/phi/kernels/impl/set_value_kernel_impl.h create mode 100644 paddle/phi/kernels/set_value_kernel.h create mode 100644 paddle/phi/ops/compat/set_value_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index eff6d9a910..f8e30c1ee2 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const { return var != nullptr; } +bool ExecutionContext::HasInputs(const std::string& name) const { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end() || it->second.empty()) { + return false; + } + for (const auto* input : it->second) { + if (input == nullptr) { + return false; + } + } + return true; +} + bool ExecutionContext::HasOutput(const std::string& name) const { auto* var = OutputVar(name); return var != nullptr; @@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = Attrs().at(attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); + } } else { // TODO(chenweihang): support other attrs later auto& attr = Attrs().at(attr_names[i]); @@ -2212,7 +2271,11 @@ void OperatorWithKernel::BuildPhiKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + std::type_index(typeid(std::vector))) { + pt_kernel_context->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e33d4feb82..1a1171f1db 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -295,6 +295,8 @@ class ExecutionContext { virtual bool HasInput(const std::string& name) const; + virtual bool HasInputs(const std::string& name) const; + virtual bool HasOutput(const std::string& name) const; virtual size_t InputSize(const std::string& name) const { @@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { : ctx_(ctx) {} bool HasInput(const std::string& name) const override { - return ctx_.HasInput(name); + return ctx_.HasInputs(name); } bool HasOutput(const std::string& name) const override { diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index fe5ac73b00..fbc47f81fd 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { return (it != var_map_in_.end() && it->second.size() > 0); } + bool HasInputs(const std::string& name) const override { + auto it = var_map_in_.find(name); + return (it != var_map_in_.end() && it->second.size() > 0); + } + bool HasOutput(const std::string& name) const override { auto it = var_map_out_.find(name); return (it != var_map_out_.end() && it->second.size() > 0); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 30dbe07d7a..d7c0c8cc54 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext( } for (size_t i = 0; i < attr_names.size(); ++i) { + VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i]; if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { if (attrs.find(attr_names[i]) != attrs.end()) { // shape is in the attribute @@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext( experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); } + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + const auto& vec = BOOST_GET_CONST(std::vector, attr); + std::vector scalar_list; + scalar_list.reserve(vec.size()); + for (const auto& val : vec) { + scalar_list.emplace_back(val); + } + kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported cast op attribute `%s` to vector when " + "construct KernelContext.", + attr_names[i])); + } } else { // TODO(chenweihang): support other attrs later auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); @@ -432,7 +487,11 @@ void BuildDygraphPhiKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index ec3e04e71f..7d0d782b83 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker, ops::SetValueGradMaker, ops::SetValueOpInplaceInferer); -REGISTER_OP_CPU_KERNEL( - set_value, ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel); - REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/set_value_op.cu b/paddle/fluid/operators/set_value_op.cu index f9701b0aca..9f291a863c 100644 --- a/paddle/fluid/operators/set_value_op.cu +++ b/paddle/fluid/operators/set_value_op.cu @@ -16,13 +16,6 @@ namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - set_value, ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel, - ops::SetValueKernel); - REGISTER_OP_CUDA_KERNEL( set_value_grad, ops::SetValueGradKernel, diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 9dd7279592..4d459f8c01 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -121,201 +121,6 @@ inline void CheckIsDimsMatch(const framework::DDim first, "of target shape: %d, but now shape is %d.", second.to_str(), first.to_str())); } - -template -class SetValueKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const int rank = ctx.Input("Input")->dims().size(); - - // TODO(liym27): A more elegent code to do this. C++ has to make template - // integer as constant, but we had better have alternative writing in the - // future. - switch (rank) { - case 1: - SetValueCompute<1>(ctx); - break; - case 2: - SetValueCompute<2>(ctx); - break; - case 3: - SetValueCompute<3>(ctx); - break; - case 4: - SetValueCompute<4>(ctx); - break; - case 5: - SetValueCompute<5>(ctx); - break; - case 6: - SetValueCompute<6>(ctx); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", rank)); - } - } - - private: - template - void SetValueCompute(const framework::ExecutionContext& ctx) const { - auto* in = ctx.Input("Input"); - auto* value_tensor = ctx.Input("ValueTensor"); - auto* out = ctx.Output("Out"); - - auto starts_tensor_list = - ctx.MultiInput("StartsTensorList"); - auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); - auto steps_tensor_list = - ctx.MultiInput("StepsTensorList"); - - auto axes = ctx.Attr>("axes"); - auto starts = ctx.Attr>("starts"); - auto ends = ctx.Attr>("ends"); - auto steps = ctx.Attr>("steps"); - auto shape = ctx.Attr>("shape"); - auto decrease_axes = ctx.Attr>("decrease_axes"); - auto none_axes = ctx.Attr>("none_axes"); - - if (!starts_tensor_list.empty()) { - starts = GetDataFromTensorList(starts_tensor_list); - } - if (!ends_tensor_list.empty()) { - ends = GetDataFromTensorList(ends_tensor_list); - } - if (!steps_tensor_list.empty()) { - steps = GetDataFromTensorList(steps_tensor_list); - } - - auto in_dims = in->dims(); - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps); - auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps); - auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); - - auto slice_dims_for_assign = decrease_slice_dims; - if (!none_axes.empty()) { - std::vector slice_dims_with_none; - - size_t none_axes_cur = 0, decrease_axes_cur = 0; - for (int i = 0; i < slice_dims.size(); ++i) { - while (none_axes_cur < none_axes.size() && - none_axes[none_axes_cur] <= i) { - slice_dims_with_none.push_back(1); - none_axes_cur++; - } - if (decrease_axes_cur < decrease_axes.size() && - decrease_axes[decrease_axes_cur] == i) { - decrease_axes_cur++; - } else { - slice_dims_with_none.push_back(slice_dims[i]); - } - } - while (none_axes_cur < none_axes.size()) { - slice_dims_with_none.push_back(1); - none_axes_cur++; - } - - slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); - } - - auto place = ctx.GetPlace(); - auto& eigen_place = - *ctx.template device_context().eigen_device(); - - // Here copy data from input to avoid data loss at PE and Graph level. - // TODO(liym27): Speed up in the future version. - // - Q: Why don't call ShareDataWith to speed up? - // - A: Because it's not supported to ShareDataWith on OP's input and output - // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP - // - Q: Why don't delete Input, after all, the input and output are the same - // Tensor at program level? - // - A: If deleting Input, the graph will be complex, such as there will - // be two ops points to the output in graph: op1 -> output <- set_value. - // In this case, we have to find a way to handle the running order of - // set_value is what we want. - paddle::framework::TensorCopy(*in, place, out); - - Tensor slice_tensor(in->dtype()), pad_tensor(in->dtype()); - slice_tensor.mutable_data(slice_dims, place); - pad_tensor.mutable_data(in_dims, place); - - auto pad_e = framework::EigenTensor::From(pad_tensor, in_dims); - auto out_e = framework::EigenTensor::From(*out); - auto slice_e = framework::EigenTensor::From(slice_tensor, slice_dims); - - // Step 1: Set the value of out at `_index` to zero - slice_e.device(eigen_place) = slice_e.constant(T(0)); - - auto starts_indices = Eigen::DSizes(); - auto ends_indices = Eigen::DSizes(); - auto strides_indices = Eigen::DSizes(); - - for (size_t i = 0; i < D; ++i) { - starts_indices[i] = 0; - ends_indices[i] = slice_dims[i]; - strides_indices[i] = 1; - } - for (size_t i = 0; i < axes.size(); i++) { - int axis_index = axes[i]; - starts_indices[axis_index] = starts[i]; - ends_indices[axis_index] = ends[i]; - strides_indices[axis_index] = steps[i]; - if (starts[i] == ends[i]) { // slice is empty, data will not be changed - return; - } - } - - out_e.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(eigen_place) = slice_e; - - // Step 2: Set a tensor with the same shape as out tensor. And its data at - // '_index' is the same as value_tensor, and data out of '_index' to zero - - // - Step 2.1 Set slice tensor with value - - // NOTE(liym27): [ Why resize slice_tensor here? ] - // A: When do broadcasting on slice_tensor and value_tensor, the shape of - // slice_tensor should be decreased dims. - // e.g. - // x[:,0] = value_tensor - // x's shape = [3, 4], value_tensor's shape = [3] - // We get slice_dims = [3, 1], decrease_slice_dims = [3] - // If do broadcasting on Tensor with shape [3, 1] and [3], the result's - // shape is [3, 3], which cross the border; - // If do broadcasting on Tensor with shape [3] and [3], the result's shape - // is [3], which is right. - - slice_tensor.Resize(slice_dims_for_assign); - if (value_tensor != nullptr) { - CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims()); - // ElementwiseComputeEx can do broadcasting - ElementwiseComputeEx, DeviceContext, T>( - ctx, &slice_tensor, value_tensor, -1, SubFunctor(), &slice_tensor); - } else { - Tensor value_t(in->dtype()); - auto value_dims = phi::make_ddim(shape); - CheckIsDimsMatch(slice_dims_for_assign, value_dims); - - value_t.mutable_data(value_dims, place); - auto value_name = - GetValueName(framework::TransToProtoVarType(in->dtype())); - CopyVecotorToTensor(value_name.c_str(), &value_t, ctx); - value_t.Resize(value_dims); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &slice_tensor, &value_t, -1, SubFunctor(), &slice_tensor); - } - slice_tensor.Resize(slice_dims); - - // - Step 2.2 Pad slice tensor with 0 - pad_e.device(eigen_place) = pad_e.constant(T(0)); - pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(eigen_place) = slice_e; - - // Step 3: Set out tensor with value_tensor - out_e.device(eigen_place) = out_e - pad_e; - } -}; - template class SetValueGradKernel : public framework::OpKernel { public: diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index baa549d7a6..2cc82772cf 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -252,6 +252,7 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/phi/kernels/cpu/set_value_kernel.cc b/paddle/phi/kernels/cpu/set_value_kernel.cc new file mode 100644 index 0000000000..dcf278cd94 --- /dev/null +++ b/paddle/phi/kernels/cpu/set_value_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/set_value_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" + +PD_REGISTER_KERNEL(set_value, + CPU, + ALL_LAYOUT, + phi::SetValueKernel, + float, + double, + int, + int64_t, + bool) {} +PD_REGISTER_KERNEL(set_value_with_tensor, + CPU, + ALL_LAYOUT, + phi::SetTensorValueKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/gpu/set_value_kernel.cu b/paddle/phi/kernels/gpu/set_value_kernel.cu new file mode 100644 index 0000000000..f788da010b --- /dev/null +++ b/paddle/phi/kernels/gpu/set_value_kernel.cu @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/set_value_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/set_value_kernel_impl.h" + +PD_REGISTER_KERNEL(set_value, + GPU, + ALL_LAYOUT, + phi::SetValueKernel, + float, + double, + int, + int64_t, + bool) {} +PD_REGISTER_KERNEL(set_value_with_tensor, + GPU, + ALL_LAYOUT, + phi::SetTensorValueKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h new file mode 100644 index 0000000000..5aebffe51b --- /dev/null +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -0,0 +1,337 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/slice_utils.h" + +namespace phi { + +// check whether the tensor with dimension of second can assign to the +// tensor with dimension of first +inline void CheckIsDimsMatch(const DDim& first, const DDim& second) { + int ignore_axis1 = 0, ignore_axis2 = 0; + for (; ignore_axis1 < first.size(); ++ignore_axis1) { + if (first[ignore_axis1] != 1) { + break; + } + } + for (; ignore_axis2 < second.size(); ++ignore_axis2) { + if (second[ignore_axis2] != 1) { + break; + } + } + + if (second.size() == ignore_axis2) { + // second tensor has only one value + return; + } + + if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) { + auto idx1 = first.size() - 1; + auto idx2 = second.size() - 1; + bool is_match = true; + for (; idx2 >= ignore_axis2; idx2--) { + if (first[idx1--] != second[idx2] && second[idx2] != 1) { + is_match = false; + break; + } + } + if (is_match) { + return; + } + } + PADDLE_THROW(errors::InvalidArgument( + "The shape of tensor assigned value must match the shape " + "of target shape: %d, but now shape is %d.", + second.to_str(), + first.to_str())); +} + +template +void SetValueImpl(const Context& dev_ctx, + const DenseTensor& in, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + auto in_dims = in.dims(); + std::vector starts_local = starts.GetData(); + std::vector ends_local = ends.GetData(); + std::vector steps_local = steps.GetData(); + paddle::operators::CheckAndUpdateSliceAttrs( + in_dims, axes, &starts_local, &ends_local, &steps_local); + auto slice_dims = paddle::operators::GetSliceDims( + in_dims, axes, starts_local, ends_local, &steps_local); + auto decrease_slice_dims = + paddle::operators::GetDecreasedDims(slice_dims, decrease_axes); + + auto slice_dims_for_assign = decrease_slice_dims; + if (!none_axes.empty()) { + std::vector slice_dims_with_none; + + size_t none_axes_cur = 0, decrease_axes_cur = 0; + for (int i = 0; i < slice_dims.size(); ++i) { + while (none_axes_cur < none_axes.size() && + none_axes[none_axes_cur] <= i) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + if (decrease_axes_cur < decrease_axes.size() && + decrease_axes[decrease_axes_cur] == i) { + decrease_axes_cur++; + } else { + slice_dims_with_none.push_back(slice_dims[i]); + } + } + while (none_axes_cur < none_axes.size()) { + slice_dims_with_none.push_back(1); + none_axes_cur++; + } + + slice_dims_for_assign = phi::make_ddim(slice_dims_with_none); + } + + auto place = dev_ctx.GetPlace(); + auto& eigen_place = *dev_ctx.eigen_device(); + + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + Copy(dev_ctx, in, place, false, out); + + DenseTensor slice_tensor = + Empty(dev_ctx, ScalarArray{slice_dims.Get(), slice_dims.size()}); + DenseTensor pad_tensor = + Empty(dev_ctx, ScalarArray{in_dims.Get(), in_dims.size()}); + + auto pad_e = EigenTensor::From(pad_tensor, in_dims); + auto out_e = EigenTensor::From(*out); + auto slice_e = EigenTensor::From(slice_tensor, slice_dims); + + // Step 1: Set the value of out at `_index` to zero + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + for (size_t i = 0; i < RANK; ++i) { + starts_indices[i] = 0; + ends_indices[i] = slice_dims[i]; + strides_indices[i] = 1; + } + for (size_t i = 0; i < axes.size(); i++) { + int axis_index = axes[i]; + starts_indices[axis_index] = starts_local[i]; + ends_indices[axis_index] = ends_local[i]; + strides_indices[axis_index] = steps_local[i]; + if (starts_local[i] == + ends_local[i]) { // slice is empty, data will not be changed + return; + } + } + + out_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value, and data out of '_index' to zero + + // - Step 2.1 Set slice tensor with value + + // NOTE(liym27): [ Why resize slice_tensor here? ] + // A: When do broadcasting on slice_tensor and value, the shape of + // slice_tensor should be decreased dims. + // e.g. + // x[:,0] = value + // x's shape = [3, 4], value's shape = [3] + // We get slice_dims = [3, 1], decrease_slice_dims = [3] + // If do broadcasting on Tensor with shape [3, 1] and [3], the result's + // shape is [3, 3], which cross the border; + // If do broadcasting on Tensor with shape [3] and [3], the result's shape + // is [3], which is right. + + slice_tensor.Resize(slice_dims_for_assign); + CheckIsDimsMatch(slice_dims_for_assign, value.dims()); + // ElementwiseComputeEx can do broadcasting + funcs::ElementwiseCompute, T>( + dev_ctx, + slice_tensor, + value, + -1, + funcs::SubtractFunctor(), + &slice_tensor); + + slice_tensor.Resize(slice_dims); + + // - Step 2.2 Pad slice tensor with 0 + pad_e.device(eigen_place) = pad_e.constant(T(0)); + pad_e.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(eigen_place) = slice_e; + + // Step 3: Set out tensor with value + out_e.device(eigen_place) = out_e - pad_e; +} + +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out) { + const int rank = x.dims().size(); + + switch (rank) { + case 1: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 2: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 3: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 4: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 5: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + case 6: + SetValueImpl(dev_ctx, + x, + value, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); + break; + default: + PADDLE_THROW(errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +template +void SetValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + const std::vector& shape, + const std::vector& values, + DenseTensor* out) { + std::vector assgin_values; + assgin_values.reserve(values.size()); + for (const auto& val : values) { + assgin_values.push_back(val.to()); + } + DenseTensor value_tensor = Empty(dev_ctx, shape); + paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor); + value_tensor.Resize(phi::make_ddim(shape)); + + SetTensorValueKernel(dev_ctx, + x, + value_tensor, + starts, + ends, + steps, + axes, + decrease_axes, + none_axes, + out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/set_value_kernel.h b/paddle/phi/kernels/set_value_kernel.h new file mode 100644 index 0000000000..271691b1a3 --- /dev/null +++ b/paddle/phi/kernels/set_value_kernel.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" + +namespace phi { + +template +void SetTensorValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& value, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + DenseTensor* out); + +template +void SetValueKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& steps, + const std::vector& axes, + const std::vector& decrease_axes, + const std::vector& none_axes, + const std::vector& shape, + const std::vector& values, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc new file mode 100644 index 0000000000..eacfff26d5 --- /dev/null +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -0,0 +1,736 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("Input")) { + if (ctx.HasInput("StartsTensorList")) { + if (ctx.HasInput("EndsTensorList")) { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } else { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"StartsTensorList", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } + } else { + if (ctx.HasInput("EndsTensorList")) { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "EndsTensorList", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } else { + if (ctx.HasInput("StepsTensorList")) { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "StepsTensorList", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } else { + if (ctx.HasInput("ValueTensor")) { + return KernelSignature("set_value_with_tensor", + {"Input", "ValueTensor"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes"}, + {"Out"}); + } else if (ctx.HasAttr("fp32_values") && + !paddle::any_cast>( + ctx.Attr("fp32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp32_values"}, + {"Out"}); + } else if (ctx.HasAttr("fp64_values") && + !paddle::any_cast>( + ctx.Attr("fp64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp64_values"}, + {"Out"}); + } else if (ctx.HasAttr("int32_values") && + !paddle::any_cast>( + ctx.Attr("int32_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int32_values"}, + {"Out"}); + } else if (ctx.HasAttr("int64_values") && + !paddle::any_cast>( + ctx.Attr("int64_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "int64_values"}, + {"Out"}); + } else if (ctx.HasAttr("bool_values") && + !paddle::any_cast>( + ctx.Attr("bool_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "bool_values"}, + {"Out"}); + } + } + } + } + } + return KernelSignature("unregistered", {}, {}, {}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(set_value, phi::SetValueOpArgumentMapping); diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index a6c9a27de7..88c9193a8f 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) { ASSERT_EQ(signature9.name, "full_sr"); } +TEST(ARG_MAP, set_value) { + TestArgumentMappingContext arg_case( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name, + "set_value"); + + TestArgumentMappingContext arg_case1( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name, + "set_value"); + + TestArgumentMappingContext arg_case2( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name, + "set_value"); + + TestArgumentMappingContext arg_case3( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name, + "set_value"); + + TestArgumentMappingContext arg_case4( + {"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name, + "set_value"); + + TestArgumentMappingContext arg_case5( + {"Input", "StartsTensorList", "EndsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case6( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name, + "set_value"); + + TestArgumentMappingContext arg_case7( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name, + "set_value"); + + TestArgumentMappingContext arg_case8( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name, + "set_value"); + + TestArgumentMappingContext arg_case9( + {"Input", "StartsTensorList", "EndsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name, + "set_value"); + + TestArgumentMappingContext arg_case10( + {"Input", "StartsTensorList", "StepsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case11( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name, + "set_value"); + + TestArgumentMappingContext arg_case12( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name, + "set_value"); + + TestArgumentMappingContext arg_case13( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name, + "set_value"); + + TestArgumentMappingContext arg_case14( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name, + "set_value"); + + TestArgumentMappingContext arg_case15( + {"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case16( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name, + "set_value"); + + TestArgumentMappingContext arg_case17( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name, + "set_value"); + + TestArgumentMappingContext arg_case18( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name, + "set_value"); + + TestArgumentMappingContext arg_case19( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name, + "set_value"); + + TestArgumentMappingContext arg_case20( + {"Input", "StartsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name, + "set_value"); + + TestArgumentMappingContext arg_case21( + {"Input", "EndsTensorList", "StepsTensorList", "ValueTensor"}, + {}, + {}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case22( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name, + "set_value"); + + TestArgumentMappingContext arg_case23( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name, + "set_value"); + + TestArgumentMappingContext arg_case24( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name, + "set_value"); + + TestArgumentMappingContext arg_case25( + {"Input", "EndsTensorList", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name, + "set_value"); + + TestArgumentMappingContext arg_case26( + {"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case27( + {"Input", "EndsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name, + "set_value"); + + TestArgumentMappingContext arg_case28( + {"Input", "EndsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name, + "set_value"); + + TestArgumentMappingContext arg_case29( + {"Input", "EndsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name, + "set_value"); + + TestArgumentMappingContext arg_case30( + {"Input", "EndsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name, + "set_value"); + + TestArgumentMappingContext arg_case31( + {"Input", "EndsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name, + "set_value"); + + TestArgumentMappingContext arg_case32( + {"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name, + "set_value_with_tensor"); + + TestArgumentMappingContext arg_case33( + {"Input", "StepsTensorList"}, + {}, + {{"fp32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name, + "set_value"); + + TestArgumentMappingContext arg_case34( + {"Input", "StepsTensorList"}, + {}, + {{"fp64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name, + "set_value"); + + TestArgumentMappingContext arg_case35( + {"Input", "StepsTensorList"}, + {}, + {{"int32_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name, + "set_value"); + + TestArgumentMappingContext arg_case36( + {"Input", "StepsTensorList"}, + {}, + {{"int64_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name, + "set_value"); + + TestArgumentMappingContext arg_case37( + {"Input", "StepsTensorList"}, + {}, + {{"bool_values", paddle::any{std::vector{1}}}}, + {"Out"}, + {}); + ASSERT_EQ( + OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name, + "set_value"); +} + } // namespace tests } // namespace phi -- GitLab