未验证 提交 cd28cddb 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Move set_value kernel to phi (#40195)

* save code

* fix bug of set_value

* add coverage test
上级 63fb0347
...@@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const { ...@@ -539,6 +539,20 @@ bool ExecutionContext::HasInput(const std::string& name) const {
return var != nullptr; return var != nullptr;
} }
bool ExecutionContext::HasInputs(const std::string& name) const {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (const auto* input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool ExecutionContext::HasOutput(const std::string& name) const { bool ExecutionContext::HasOutput(const std::string& name) const {
auto* var = OutputVar(name); auto* var = OutputVar(name);
return var != nullptr; return var != nullptr;
...@@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2189,6 +2203,51 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
}
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]); auto& attr = Attrs().at(attr_names[i]);
...@@ -2212,6 +2271,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2212,6 +2271,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
......
...@@ -295,6 +295,8 @@ class ExecutionContext { ...@@ -295,6 +295,8 @@ class ExecutionContext {
virtual bool HasInput(const std::string& name) const; virtual bool HasInput(const std::string& name) const;
virtual bool HasInputs(const std::string& name) const;
virtual bool HasOutput(const std::string& name) const; virtual bool HasOutput(const std::string& name) const;
virtual size_t InputSize(const std::string& name) const { virtual size_t InputSize(const std::string& name) const {
...@@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -449,7 +451,7 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
: ctx_(ctx) {} : ctx_(ctx) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name); return ctx_.HasInputs(name);
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
......
...@@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -133,6 +133,11 @@ class DygraphExecutionContext : public framework::ExecutionContext {
return (it != var_map_in_.end() && it->second.size() > 0); return (it != var_map_in_.end() && it->second.size() > 0);
} }
bool HasInputs(const std::string& name) const override {
auto it = var_map_in_.find(name);
return (it != var_map_in_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
auto it = var_map_out_.find(name); auto it = var_map_out_.find(name);
return (it != var_map_out_.end() && it->second.size() > 0); return (it != var_map_out_.end() && it->second.size() > 0);
......
...@@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext( ...@@ -332,6 +332,7 @@ void BuildDygraphPhiKernelContext(
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
if (attrs.find(attr_names[i]) != if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute attrs.end()) { // shape is in the attribute
...@@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext( ...@@ -409,6 +410,60 @@ void BuildDygraphPhiKernelContext(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var()))); experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
} }
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<bool>))) {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
}
} else { } else {
// TODO(chenweihang): support other attrs later // TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
...@@ -432,6 +487,10 @@ void BuildDygraphPhiKernelContext( ...@@ -432,6 +487,10 @@ void BuildDygraphPhiKernelContext(
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) == if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) { std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
......
...@@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker, ...@@ -241,13 +241,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
ops::SetValueGradMaker<paddle::imperative::OpBase>, ops::SetValueGradMaker<paddle::imperative::OpBase>,
ops::SetValueOpInplaceInferer); ops::SetValueOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueKernel<plat::CPUDeviceContext, float>,
ops::SetValueKernel<plat::CPUDeviceContext, double>,
ops::SetValueKernel<plat::CPUDeviceContext, bool>);
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad); REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -16,13 +16,6 @@ ...@@ -16,13 +16,6 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CUDADeviceContext, int>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
set_value_grad, set_value_grad,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>,
......
...@@ -121,201 +121,6 @@ inline void CheckIsDimsMatch(const framework::DDim first, ...@@ -121,201 +121,6 @@ inline void CheckIsDimsMatch(const framework::DDim first,
"of target shape: %d, but now shape is %d.", "of target shape: %d, but now shape is %d.",
second.to_str(), first.to_str())); second.to_str(), first.to_str()));
} }
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const int rank = ctx.Input<framework::LoDTensor>("Input")->dims().size();
// TODO(liym27): A more elegent code to do this. C++ has to make template
// integer as constant, but we had better have alternative writing in the
// future.
switch (rank) {
case 1:
SetValueCompute<1>(ctx);
break;
case 2:
SetValueCompute<2>(ctx);
break;
case 3:
SetValueCompute<3>(ctx);
break;
case 4:
SetValueCompute<4>(ctx);
break;
case 5:
SetValueCompute<5>(ctx);
break;
case 6:
SetValueCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
private:
template <size_t D>
void SetValueCompute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::LoDTensor>("Input");
auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto starts_tensor_list =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
auto ends_tensor_list = ctx.MultiInput<framework::Tensor>("EndsTensorList");
auto steps_tensor_list =
ctx.MultiInput<framework::Tensor>("StepsTensorList");
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
auto starts = ctx.Attr<std::vector<int64_t>>("starts");
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto decrease_axes = ctx.Attr<std::vector<int64_t>>("decrease_axes");
auto none_axes = ctx.Attr<std::vector<int64_t>>("none_axes");
if (!starts_tensor_list.empty()) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
if (!ends_tensor_list.empty()) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
if (!steps_tensor_list.empty()) {
steps = GetDataFromTensorList<int64_t>(steps_tensor_list);
}
auto in_dims = in->dims();
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
auto place = ctx.GetPlace();
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
paddle::framework::TensorCopy(*in, place, out);
Tensor slice_tensor(in->dtype()), pad_tensor(in->dtype());
slice_tensor.mutable_data<T>(slice_dims, place);
pad_tensor.mutable_data<T>(in_dims, place);
auto pad_e = framework::EigenTensor<T, D>::From(pad_tensor, in_dims);
auto out_e = framework::EigenTensor<T, D>::From(*out);
auto slice_e = framework::EigenTensor<T, D>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0));
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = starts[i];
ends_indices[axis_index] = ends[i];
strides_indices[axis_index] = steps[i];
if (starts[i] == ends[i]) { // slice is empty, data will not be changed
return;
}
}
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value_tensor, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value_tensor
// x's shape = [3, 4], value_tensor's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
if (value_tensor != nullptr) {
CheckIsDimsMatch(slice_dims_for_assign, value_tensor->dims());
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, value_tensor, -1, SubFunctor<T>(), &slice_tensor);
} else {
Tensor value_t(in->dtype());
auto value_dims = phi::make_ddim(shape);
CheckIsDimsMatch(slice_dims_for_assign, value_dims);
value_t.mutable_data<T>(value_dims, place);
auto value_name =
GetValueName(framework::TransToProtoVarType(in->dtype()));
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
}
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0));
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 3: Set out tensor with value_tensor
out_e.device(eigen_place) = out_e - pad_e;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SetValueGradKernel : public framework::OpKernel<T> { class SetValueGradKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -252,6 +252,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -252,6 +252,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);
PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<Scalar>&);
/* Output Helpers */ /* Output Helpers */
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
PD_REGISTER_KERNEL(set_value,
CPU,
ALL_LAYOUT,
phi::SetValueKernel,
float,
double,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
CPU,
ALL_LAYOUT,
phi::SetTensorValueKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_value_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_kernel_impl.h"
PD_REGISTER_KERNEL(set_value,
GPU,
ALL_LAYOUT,
phi::SetValueKernel,
float,
double,
int,
int64_t,
bool) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
GPU,
ALL_LAYOUT,
phi::SetTensorValueKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/slice_utils.h"
namespace phi {
// check whether the tensor with dimension of second can assign to the
// tensor with dimension of first
inline void CheckIsDimsMatch(const DDim& first, const DDim& second) {
int ignore_axis1 = 0, ignore_axis2 = 0;
for (; ignore_axis1 < first.size(); ++ignore_axis1) {
if (first[ignore_axis1] != 1) {
break;
}
}
for (; ignore_axis2 < second.size(); ++ignore_axis2) {
if (second[ignore_axis2] != 1) {
break;
}
}
if (second.size() == ignore_axis2) {
// second tensor has only one value
return;
}
if (first.size() - ignore_axis1 >= second.size() - ignore_axis2) {
auto idx1 = first.size() - 1;
auto idx2 = second.size() - 1;
bool is_match = true;
for (; idx2 >= ignore_axis2; idx2--) {
if (first[idx1--] != second[idx2] && second[idx2] != 1) {
is_match = false;
break;
}
}
if (is_match) {
return;
}
}
PADDLE_THROW(errors::InvalidArgument(
"The shape of tensor assigned value must match the shape "
"of target shape: %d, but now shape is %d.",
second.to_str(),
first.to_str()));
}
template <typename T, typename Context, size_t RANK>
void SetValueImpl(const Context& dev_ctx,
const DenseTensor& in,
const DenseTensor& value,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
auto in_dims = in.dims();
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
paddle::operators::CheckAndUpdateSliceAttrs(
in_dims, axes, &starts_local, &ends_local, &steps_local);
auto slice_dims = paddle::operators::GetSliceDims(
in_dims, axes, starts_local, ends_local, &steps_local);
auto decrease_slice_dims =
paddle::operators::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
std::vector<int64_t> slice_dims_with_none;
size_t none_axes_cur = 0, decrease_axes_cur = 0;
for (int i = 0; i < slice_dims.size(); ++i) {
while (none_axes_cur < none_axes.size() &&
none_axes[none_axes_cur] <= i) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
if (decrease_axes_cur < decrease_axes.size() &&
decrease_axes[decrease_axes_cur] == i) {
decrease_axes_cur++;
} else {
slice_dims_with_none.push_back(slice_dims[i]);
}
}
while (none_axes_cur < none_axes.size()) {
slice_dims_with_none.push_back(1);
none_axes_cur++;
}
slice_dims_for_assign = phi::make_ddim(slice_dims_with_none);
}
auto place = dev_ctx.GetPlace();
auto& eigen_place = *dev_ctx.eigen_device();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
Copy(dev_ctx, in, place, false, out);
DenseTensor slice_tensor =
Empty<T>(dev_ctx, ScalarArray{slice_dims.Get(), slice_dims.size()});
DenseTensor pad_tensor =
Empty<T>(dev_ctx, ScalarArray{in_dims.Get(), in_dims.size()});
auto pad_e = EigenTensor<T, RANK>::From(pad_tensor, in_dims);
auto out_e = EigenTensor<T, RANK>::From(*out);
auto slice_e = EigenTensor<T, RANK>::From(slice_tensor, slice_dims);
// Step 1: Set the value of out at `_index` to zero
slice_e.device(eigen_place) = slice_e.constant(T(0));
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
for (size_t i = 0; i < RANK; ++i) {
starts_indices[i] = 0;
ends_indices[i] = slice_dims[i];
strides_indices[i] = 1;
}
for (size_t i = 0; i < axes.size(); i++) {
int axis_index = axes[i];
starts_indices[axis_index] = starts_local[i];
ends_indices[axis_index] = ends_local[i];
strides_indices[axis_index] = steps_local[i];
if (starts_local[i] ==
ends_local[i]) { // slice is empty, data will not be changed
return;
}
}
out_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value, and data out of '_index' to zero
// - Step 2.1 Set slice tensor with value
// NOTE(liym27): [ Why resize slice_tensor here? ]
// A: When do broadcasting on slice_tensor and value, the shape of
// slice_tensor should be decreased dims.
// e.g.
// x[:,0] = value
// x's shape = [3, 4], value's shape = [3]
// We get slice_dims = [3, 1], decrease_slice_dims = [3]
// If do broadcasting on Tensor with shape [3, 1] and [3], the result's
// shape is [3, 3], which cross the border;
// If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign);
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
// ElementwiseComputeEx can do broadcasting
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx,
slice_tensor,
value,
-1,
funcs::SubtractFunctor<T>(),
&slice_tensor);
slice_tensor.Resize(slice_dims);
// - Step 2.2 Pad slice tensor with 0
pad_e.device(eigen_place) = pad_e.constant(T(0));
pad_e.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(eigen_place) = slice_e;
// Step 3: Set out tensor with value
out_e.device(eigen_place) = out_e - pad_e;
}
template <typename T, typename Context>
void SetTensorValueKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& value,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
const int rank = x.dims().size();
switch (rank) {
case 1:
SetValueImpl<T, Context, 1>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 2:
SetValueImpl<T, Context, 2>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 3:
SetValueImpl<T, Context, 3>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 4:
SetValueImpl<T, Context, 4>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 5:
SetValueImpl<T, Context, 5>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
case 6:
SetValueImpl<T, Context, 6>(dev_ctx,
x,
value,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
break;
default:
PADDLE_THROW(errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
template <typename T, typename Context>
void SetValueKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
const std::vector<int64_t>& shape,
const std::vector<Scalar>& values,
DenseTensor* out) {
std::vector<T> assgin_values;
assgin_values.reserve(values.size());
for (const auto& val : values) {
assgin_values.push_back(val.to<T>());
}
DenseTensor value_tensor = Empty<T>(dev_ctx, shape);
paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor);
value_tensor.Resize(phi::make_ddim(shape));
SetTensorValueKernel<T, Context>(dev_ctx,
x,
value_tensor,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
template <typename T, typename Context>
void SetTensorValueKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& value,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out);
template <typename T, typename Context>
void SetValueKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
const std::vector<int64_t>& shape,
const std::vector<Scalar>& values,
DenseTensor* out);
} // namespace phi
此差异已折叠。
...@@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) { ...@@ -114,5 +114,375 @@ TEST(ARG_MAP, fill_constant) {
ASSERT_EQ(signature9.name, "full_sr"); ASSERT_EQ(signature9.name, "full_sr");
} }
TEST(ARG_MAP, set_value) {
TestArgumentMappingContext arg_case(
{"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"},
{},
{{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case).name,
"set_value");
TestArgumentMappingContext arg_case1(
{"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case1).name,
"set_value");
TestArgumentMappingContext arg_case2(
{"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case2).name,
"set_value");
TestArgumentMappingContext arg_case3(
{"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case3).name,
"set_value");
TestArgumentMappingContext arg_case4(
{"Input", "StartsTensorList", "EndsTensorList", "StepsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case4).name,
"set_value");
TestArgumentMappingContext arg_case5(
{"Input", "StartsTensorList", "EndsTensorList", "ValueTensor"},
{},
{},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case5).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case6(
{"Input", "StartsTensorList", "EndsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case6).name,
"set_value");
TestArgumentMappingContext arg_case7(
{"Input", "StartsTensorList", "EndsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case7).name,
"set_value");
TestArgumentMappingContext arg_case8(
{"Input", "StartsTensorList", "EndsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case8).name,
"set_value");
TestArgumentMappingContext arg_case9(
{"Input", "StartsTensorList", "EndsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case9).name,
"set_value");
TestArgumentMappingContext arg_case10(
{"Input", "StartsTensorList", "StepsTensorList", "ValueTensor"},
{},
{},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case10).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case11(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case11).name,
"set_value");
TestArgumentMappingContext arg_case12(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case12).name,
"set_value");
TestArgumentMappingContext arg_case13(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case13).name,
"set_value");
TestArgumentMappingContext arg_case14(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case14).name,
"set_value");
TestArgumentMappingContext arg_case15(
{"Input", "StartsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case15).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case16(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case16).name,
"set_value");
TestArgumentMappingContext arg_case17(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case17).name,
"set_value");
TestArgumentMappingContext arg_case18(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case18).name,
"set_value");
TestArgumentMappingContext arg_case19(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case19).name,
"set_value");
TestArgumentMappingContext arg_case20(
{"Input", "StartsTensorList", "StepsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case20).name,
"set_value");
TestArgumentMappingContext arg_case21(
{"Input", "EndsTensorList", "StepsTensorList", "ValueTensor"},
{},
{},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case21).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case22(
{"Input", "EndsTensorList", "StepsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case22).name,
"set_value");
TestArgumentMappingContext arg_case23(
{"Input", "EndsTensorList", "StepsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case23).name,
"set_value");
TestArgumentMappingContext arg_case24(
{"Input", "EndsTensorList", "StepsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case24).name,
"set_value");
TestArgumentMappingContext arg_case25(
{"Input", "EndsTensorList", "StepsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case25).name,
"set_value");
TestArgumentMappingContext arg_case26(
{"Input", "EndsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case26).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case27(
{"Input", "EndsTensorList"},
{},
{{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case27).name,
"set_value");
TestArgumentMappingContext arg_case28(
{"Input", "EndsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case28).name,
"set_value");
TestArgumentMappingContext arg_case29(
{"Input", "EndsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case29).name,
"set_value");
TestArgumentMappingContext arg_case30(
{"Input", "EndsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case30).name,
"set_value");
TestArgumentMappingContext arg_case31(
{"Input", "EndsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case31).name,
"set_value");
TestArgumentMappingContext arg_case32(
{"Input", "StepsTensorList", "ValueTensor"}, {}, {}, {"Out"}, {});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case32).name,
"set_value_with_tensor");
TestArgumentMappingContext arg_case33(
{"Input", "StepsTensorList"},
{},
{{"fp32_values", paddle::any{std::vector<float>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case33).name,
"set_value");
TestArgumentMappingContext arg_case34(
{"Input", "StepsTensorList"},
{},
{{"fp64_values", paddle::any{std::vector<double>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case34).name,
"set_value");
TestArgumentMappingContext arg_case35(
{"Input", "StepsTensorList"},
{},
{{"int32_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case35).name,
"set_value");
TestArgumentMappingContext arg_case36(
{"Input", "StepsTensorList"},
{},
{{"int64_values", paddle::any{std::vector<int64_t>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case36).name,
"set_value");
TestArgumentMappingContext arg_case37(
{"Input", "StepsTensorList"},
{},
{{"bool_values", paddle::any{std::vector<int>{1}}}},
{"Out"},
{});
ASSERT_EQ(
OpUtilsMap::Instance().GetArgumentMappingFn("set_value")(arg_case37).name,
"set_value");
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册