未验证 提交 3149e399 编写于 作者: Z zyfncg 提交者: GitHub

[PHI] Move set_value_grad kernel form fluid to phi (#40478)

* move set_value_grad kernel form fluid to phi

* add unittest for passing coverage ci
上级 3f219160
......@@ -243,14 +243,6 @@ REGISTER_OPERATOR(set_value, ops::SetValue, ops::SetValueMaker,
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
REGISTER_OP_CPU_KERNEL(
set_value_grad,
ops::SetValueGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueGradKernel<plat::CPUDeviceContext, int64_t>,
ops::SetValueGradKernel<plat::CPUDeviceContext, float>,
ops::SetValueGradKernel<plat::CPUDeviceContext, double>,
ops::SetValueGradKernel<plat::CPUDeviceContext, bool>);
REGISTER_OP_VERSION(set_value)
.AddCheckpoint(
R"ROC(
......
......@@ -19,14 +19,10 @@
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -36,23 +32,6 @@ namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
inline void GetOffsets(const DDim& big_dim, const DDim& small_dim,
DDim start_offset, int cur_dim,
std::vector<DDim>* offsets) {
if (cur_dim == big_dim.size()) {
offsets->push_back(start_offset);
return;
}
if (small_dim[cur_dim] == big_dim[cur_dim]) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
} else {
for (int i = 0; i < big_dim[cur_dim]; i++) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
start_offset[cur_dim] += 1;
}
}
}
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
std::string value_name;
switch (data_type) {
......@@ -121,253 +100,6 @@ inline void CheckIsDimsMatch(const framework::DDim first,
"of target shape: %d, but now shape is %d.",
second.to_str(), first.to_str()));
}
template <typename DeviceContext, typename T>
class SetValueGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int rank = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
SetValueGradCompute<1>(ctx);
break;
case 2:
SetValueGradCompute<2>(ctx);
break;
case 3:
SetValueGradCompute<3>(ctx);
break;
case 4:
SetValueGradCompute<4>(ctx);
break;
case 5:
SetValueGradCompute<5>(ctx);
break;
case 6:
SetValueGradCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of set_value_grad's input should be less than 7, but "
"received %d.",
rank));
}
}
private:
template <size_t D>
void SetValueGradCompute(const framework::ExecutionContext& context) const {
auto starts = context.Attr<std::vector<int64_t>>("starts");
auto ends = context.Attr<std::vector<int64_t>>("ends");
auto steps = context.Attr<std::vector<int64_t>>("steps");
auto axes_int64 = context.Attr<std::vector<int64_t>>("axes");
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto steps_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_steps_tensor =
context.MultiInput<framework::Tensor>("StepsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
}
if (list_new_steps_tensor.size() > 0) {
steps = GetDataFromTensorList<int64_t>(list_new_steps_tensor);
}
auto in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
in->IsInitialized(), true,
platform::errors::PermissionDenied(
"The input of `set_value_grad`(%s) has not been initialized",
framework::GradVarName("Out")));
auto grad_value = context.Output<framework::Tensor>(
framework::GradVarName("ValueTensor"));
auto grad_input =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto in_dims = in->dims();
auto decrease_axis_int64 =
context.Attr<std::vector<int64_t>>("decrease_axes");
std::vector<int> decrease_axis(decrease_axis_int64.begin(),
decrease_axis_int64.end());
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, steps, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), steps.data(), axes.data(),
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
steps_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
steps_indices[axis_index] = steps[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto& dev_ctx = context.template device_context<DeviceContext>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
phi::funcs::SetConstant<DeviceContext, T> set_zero;
if (grad_input) {
// Set gradient of `Input`
paddle::framework::TensorCopy(*in, context.GetPlace(), grad_input);
auto grad_input_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*grad_input);
framework::Tensor tmp(grad_input->dtype());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
grad_input_t.stridedSlice(starts_indices, ends_indices, steps_indices)
.device(place) = tmp_t;
}
if (grad_value) {
grad_value->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, grad_value, static_cast<T>(0));
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*in);
if (grad_value->dims() == out_dims) {
auto grad_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*grad_value);
if (need_reverse) {
framework::Tensor tmp(grad_value->dtype());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
grad_value_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
grad_value_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
}
} else {
int out_dims_size = out_dims.size();
auto grad_value_dims = grad_value->dims();
auto fake_grad_value_dims = out_dims;
// Create an extented shape according to the rules of broadcast.
auto grad_value_dims_size = grad_value_dims.size();
int num_decrease = 0;
int decrease_axis_size = decrease_axis.size();
for (int i = 0; i < out_dims_size; i++) {
if (decrease_axis.end() !=
std::find(decrease_axis.begin(), decrease_axis.end(), i)) {
fake_grad_value_dims[i] = 1;
num_decrease++;
} else if (i < out_dims_size - (grad_value_dims_size +
decrease_axis_size - num_decrease)) {
fake_grad_value_dims[i] = 1;
} else {
auto index_grad =
i - (out_dims_size - (grad_value_dims_size +
decrease_axis_size - num_decrease));
fake_grad_value_dims[i] = grad_value_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == grad_value_dims[index_grad]) ||
(grad_value_dims[index_grad] == 1),
true,
platform::errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
framework::GradVarName("ValueTensor"),
out_dims, grad_value_dims));
}
}
VLOG(3) << "Dimensions of " << framework::GradVarName("ValueTensor")
<< "([" << grad_value_dims << "])is broadcasted into ["
<< fake_grad_value_dims << "].";
auto extent = Eigen::DSizes<Eigen::DenseIndex, D>();
auto offset = out_dims;
for (int i = 0; i < out_dims_size; i++) {
offset[i] = 0;
extent[i] = fake_grad_value_dims[i];
}
std::vector<DDim> offsets;
GetOffsets(out_dims, fake_grad_value_dims, offset, 0, &offsets);
auto grad_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::
From(*grad_value, fake_grad_value_dims);
framework::Tensor tmp(grad_value->dtype());
tmp.mutable_data<T>(out_dims, context.GetPlace());
set_zero(dev_ctx, &tmp, static_cast<T>(0));
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
// accumulate gradient
for (auto offset : offsets) {
grad_value_t.device(place) =
grad_value_t +
tmp_t.slice(framework::EigenDim<D>::From(offset), extent);
}
if (need_reverse) {
framework::Tensor tmp_value(grad_value->dtype());
tmp_value.mutable_data<T>(fake_grad_value_dims, context.GetPlace());
auto tmp_value_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp_value);
tmp_value_t.device(place) = grad_value_t.reverse(reverse_axis);
grad_value_t.device(place) = tmp_value_t;
}
}
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,14 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/set_value_op.h"
#include "paddle/phi/kernels/set_value_grad_kernel.h"
namespace ops = paddle::operators;
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL(
set_value_grad,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, bool>);
PD_REGISTER_KERNEL(set_value_grad,
CPU,
ALL_LAYOUT,
phi::SetValueGradKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/set_value_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/set_value_grad_kernel_impl.h"
PD_REGISTER_KERNEL(set_value_grad,
GPU,
ALL_LAYOUT,
phi::SetValueGradKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/operators/strided_slice_op.h"
namespace phi {
inline void GetOffsets(const DDim& big_dim,
const DDim& small_dim,
DDim start_offset,
int cur_dim,
std::vector<DDim>* offsets) {
if (cur_dim == big_dim.size()) {
offsets->push_back(start_offset);
return;
}
if (small_dim[cur_dim] == big_dim[cur_dim]) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
} else {
for (int i = 0; i < big_dim[cur_dim]; i++) {
GetOffsets(big_dim, small_dim, start_offset, cur_dim + 1, offsets);
start_offset[cur_dim] += 1;
}
}
}
template <typename T, typename Context, size_t RANK>
void SetValueGradImpl(const Context& dev_ctx,
const DenseTensor& out_grad,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
out_grad.IsInitialized(),
true,
errors::PermissionDenied(
"The input of `set_value_grad`(out_grad) has not been initialized"));
auto in_dims = out_grad.dims();
std::vector<int> decrease_axis_int32(decrease_axes.begin(),
decrease_axes.end());
std::vector<int> axes_int32(axes.begin(), axes.end());
std::vector<int> infer_flags(axes.size(), 1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
paddle::operators::StridedSliceOutDims(starts_local,
ends_local,
steps_local,
axes_int32,
infer_flags,
in_dims,
decrease_axis_int32,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts_local.size(), 0);
paddle::operators::StridedSliceFunctor(starts_local.data(),
ends_local.data(),
steps_local.data(),
axes_int32.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis_int32,
starts_local.size());
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto steps_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto reverse_axis = Eigen::array<bool, RANK>();
for (size_t axis = 0; axis < RANK; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
steps_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts_local[axis];
ends_indices[axis_index] = ends_local[axis];
steps_indices[axis_index] = steps_local[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
auto& place = *dev_ctx.eigen_device();
phi::funcs::SetConstant<Context, T> set_zero;
if (x_grad) {
// Set gradient of `Input`
Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
auto x_grad_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(*x_grad);
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
auto tmp_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
x_grad_t.stridedSlice(starts_indices, ends_indices, steps_indices)
.device(place) = tmp_t;
}
if (value_grad) {
dev_ctx.template Alloc<T>(value_grad);
set_zero(dev_ctx, value_grad, static_cast<T>(0));
auto in_t = EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
out_grad);
if (value_grad->dims() == out_dims) {
auto value_grad_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
*value_grad);
if (need_reverse) {
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
auto tmp_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
value_grad_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
value_grad_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
}
} else {
int out_dims_size = out_dims.size();
auto value_grad_dims = value_grad->dims();
auto fake_value_grad_dims = out_dims;
// Create an extented shape according to the rules of broadcast.
auto value_grad_dims_size = value_grad_dims.size();
int num_decrease = 0;
int decrease_axis_size = decrease_axes.size();
for (int i = 0; i < out_dims_size; i++) {
if (decrease_axes.end() !=
std::find(decrease_axes.begin(), decrease_axes.end(), i)) {
fake_value_grad_dims[i] = 1;
num_decrease++;
} else if (i < out_dims_size - (value_grad_dims_size +
decrease_axis_size - num_decrease)) {
fake_value_grad_dims[i] = 1;
} else {
auto index_grad =
i - (out_dims_size -
(value_grad_dims_size + decrease_axis_size - num_decrease));
fake_value_grad_dims[i] = value_grad_dims[index_grad];
PADDLE_ENFORCE_EQ((out_dims[i] == value_grad_dims[index_grad]) ||
(value_grad_dims[index_grad] == 1),
true,
errors::InvalidArgument(
"An error occurred while calculating %s: "
"[%s] can not be accumulated into [%s].",
paddle::framework::GradVarName("ValueTensor"),
out_dims,
value_grad_dims));
}
}
VLOG(3) << "Dimensions of "
<< paddle::framework::GradVarName("ValueTensor") << "(["
<< value_grad_dims << "])is broadcasted into ["
<< fake_value_grad_dims << "].";
auto extent = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto offset = out_dims;
for (int i = 0; i < out_dims_size; i++) {
offset[i] = 0;
extent[i] = fake_value_grad_dims[i];
}
std::vector<DDim> offsets;
GetOffsets(out_dims, fake_value_grad_dims, offset, 0, &offsets);
auto value_grad_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
*value_grad, fake_value_grad_dims);
DenseTensor tmp = Full<T>(dev_ctx, out_dims_vector, static_cast<T>(0));
auto tmp_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, steps_indices);
// accumulate gradient
for (auto offset : offsets) {
value_grad_t.device(place) =
value_grad_t + tmp_t.slice(EigenDim<RANK>::From(offset), extent);
}
if (need_reverse) {
DenseTensor tmp_value =
Full<T>(dev_ctx,
{fake_value_grad_dims.Get(), fake_value_grad_dims.size()},
static_cast<T>(0));
auto tmp_value_t =
EigenTensor<T, RANK, Eigen::RowMajor, Eigen::DenseIndex>::From(
tmp_value);
tmp_value_t.device(place) = value_grad_t.reverse(reverse_axis);
value_grad_t.device(place) = tmp_value_t;
}
}
}
}
template <typename T, typename Context>
void SetValueGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad,
DenseTensor* value_grad) {
const int rank = out_grad.dims().size();
switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
value_grad);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_grad's input should be less than 7, but "
"received %d.",
rank));
}
}
} // namespace phi
......@@ -25,7 +25,6 @@
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/slice_utils.h"
namespace phi {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SetValueGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad,
DenseTensor* value_grad);
} // namespace phi
......@@ -731,6 +731,108 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature SetValueGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("StartsTensorList")) {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"StartsTensorList",
"EndsTensorList",
"StepsTensorList",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
} else {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"StartsTensorList",
"EndsTensorList",
"steps",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"StartsTensorList",
"ends",
"StepsTensorList",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
} else {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"StartsTensorList",
"ends",
"steps",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
}
}
} else {
if (ctx.HasInput("EndsTensorList")) {
if (ctx.HasInput("StepsTensorList")) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"starts",
"EndsTensorList",
"StepsTensorList",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
} else {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"starts",
"EndsTensorList",
"steps",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
}
} else {
if (ctx.HasInput("StepsTensorList")) {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"starts",
"ends",
"StepsTensorList",
"axes",
"decrease_axes",
"none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
} else {
return KernelSignature(
"set_value_grad",
{GradVarName("Out")},
{"starts", "ends", "steps", "axes", "decrease_axes", "none_axes"},
{GradVarName("Input"), GradVarName("ValueTensor")});
}
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(set_value, phi::SetValueOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(set_value_grad, phi::SetValueGradOpArgumentMapping);
......@@ -484,6 +484,71 @@ TEST(ARG_MAP, set_value) {
"set_value");
}
TEST(ARG_MAP, set_value_grad) {
TestArgumentMappingContext arg_case(
{"Out@GRAD", "StartsTensorList", "EndsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case)
.name,
"set_value_grad");
TestArgumentMappingContext arg_case1(
{"Out@GRAD", "StartsTensorList", "StepsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case1)
.name,
"set_value_grad");
TestArgumentMappingContext arg_case2({"Out@GRAD", "StartsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case2)
.name,
"set_value_grad");
TestArgumentMappingContext arg_case3(
{"Out@GRAD", "EndsTensorList", "StepsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case3)
.name,
"set_value_grad");
TestArgumentMappingContext arg_case4({"Out@GRAD", "EndsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case4)
.name,
"set_value_grad");
TestArgumentMappingContext arg_case5({"Out@GRAD", "StepsTensorList"},
{},
{},
{"Input@GRAD", "ValueTensor@GRAD"},
{});
ASSERT_EQ(OpUtilsMap::Instance()
.GetArgumentMappingFn("set_value_grad")(arg_case5)
.name,
"set_value_grad");
}
TEST(ARG_MAP, allclose) {
TestArgumentMappingContext arg_case1(
{"Input", "Other", "Rtol"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册