提交 586671ea 编写于 作者: P phlrain

fix error

上级 d35f5882
......@@ -28,10 +28,103 @@ using Variable = framework::Variable;
using LoDTensorArray = framework::LoDTensorArray;
using DDim = framework::DDim;
inline void DealTensorArray(const framework::ExecutionContext& ctx,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
bool out_is_array) {
auto in_array = ctx.Input<LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
if (starts[0] == -1 && end == 0) {
end = start + 1;
}
PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received end = %d, start = %d.",
ends[0], starts[0]));
int64_t out_size = end - start;
if (out_is_array) {
auto out_array = ctx.Output<LoDTensorArray>("Out");
out_array->resize(out_size);
for (int i = 0; i < out_size; ++i) {
auto* out_tensor = &out_array->at(i);
auto in_tensor = in_array->at(i + start);
out_tensor->set_lod(in_tensor.lod());
if (in_tensor.memory_size() > 0) {
paddle::framework::TensorCopy(in_tensor, ctx.GetPlace(), out_tensor);
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
<< i << "].";
}
}
} else {
auto out = ctx.Output<Tensor>("Out");
auto in_tensor = in_array->at(start);
paddle::framework::TensorCopy(in_tensor, ctx.GetPlace(), out);
}
}
template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
Variable* out_var = ctx.OutputVar("Out");
bool input_is_array = input_var->IsType<LoDTensorArray>();
bool out_is_array = out_var->IsType<LoDTensorArray>();
auto axes_int = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> axes(axes_int.begin(), axes_int.end());
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = ctx.Attr<std::vector<int>>("infer_flags");
// Step 1: Get the accurate attribute value of starts and ends
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
PADDLE_ENFORCE_EQ(
starts.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(
ends.size(), axes.size(),
platform::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
if (input_is_array) {
DealTensorArray(ctx, starts, ends, out_is_array);
return;
}
}
private:
};
......@@ -39,7 +132,73 @@ class SliceKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
// Get the accurate attribute value of starts and ends
auto starts_tensor_list = ctx.MultiInput<Tensor>("StartsTensorList");
if (ctx.HasInput("StartsTensor")) {
starts = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("StartsTensor"));
} else if (starts_tensor_list.size() > 0) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
auto ends_tensor_list = ctx.MultiInput<Tensor>("EndsTensorList");
if (ctx.HasInput("EndsTensor")) {
ends = GetDataFromTensor<int64_t>(ctx.Input<Tensor>("EndsTensor"));
} else if (ends_tensor_list.size() > 0) {
ends = GetDataFromTensorList<int64_t>(ends_tensor_list);
}
Variable* d_input_var = ctx.OutputVar(framework::GradVarName("Input"));
const Variable* d_out_var = ctx.InputVar(framework::GradVarName("Out"));
bool d_input_is_array = d_input_var->IsType<LoDTensorArray>();
bool d_out_is_array = d_out_var->IsType<LoDTensorArray>();
if (d_input_is_array) {
auto* input_array = ctx.Input<LoDTensorArray>("Input");
auto* d_in_arr =
ctx.Output<LoDTensorArray>(framework::GradVarName("Input"));
int64_t d_in_size = input_array->size();
d_in_arr->resize(d_in_size);
// If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts.
int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
start = std::max(start, static_cast<int64_t>(0));
// set zero
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> functor;
for (int i = 0; i < d_in_size; ++i) {
auto dim = input_array->at(i).dims();
d_in_arr->at(i).Resize(dim);
d_in_arr->at(i).mutable_data<T>(ctx.GetPlace());
functor(reinterpret_cast<const DeviceContext&>(dev_ctx),
&d_in_arr->at(i), static_cast<T>(0));
}
if (d_out_is_array) {
auto* d_out_arr =
ctx.Input<LoDTensorArray>(framework::GradVarName("Out"));
int d_out_size = d_out_arr->size();
for (int i = 0; i < d_out_size; ++i) {
paddle::framework::TensorCopy(d_out_arr->at(i), ctx.GetPlace(),
&(d_in_arr->at(start + i)));
}
} else {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
paddle::framework::TensorCopy(*d_out, ctx.GetPlace(),
&(d_in_arr->at(start)));
}
return;
}
}
private:
};
......
......@@ -29,4 +29,5 @@ PD_REGISTER_KERNEL(slice_grad,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -30,6 +30,8 @@ void LaunchEigenPadding(
const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) {
auto& place = *context.template eigen_device();
LOG(ERROR) << D << "\t" << in_dims;
LOG(ERROR) << out_dims;
auto d_in_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input, in_dims);
auto d_out_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
......@@ -150,12 +152,12 @@ void EigenPaddingCompute(
// the second dimension do not need padding, set padding[1] zero
reshaped_padding[1].first = reshaped_padding[1].second = 0;
LaunchEigenPadding<T, Context>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
LaunchEigenPadding<T, Context, 2>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
} else {
// other dimension need padding
// reshape the dimension of tensor in 3:
......@@ -190,12 +192,13 @@ void EigenPaddingCompute(
// the third dimension do not need padding, set padding[2] zero
reshaped_padding[2].first = reshaped_padding[2].second = 0;
LaunchEigenPadding<T, Context>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
LOG(ERROR) << "run here";
LaunchEigenPadding<T, Context, 3>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
}
} else {
// need padding at many dimension, cannot reduce dimension
......@@ -270,14 +273,18 @@ void SliceGradCompute(const Context& ctx,
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const ScalarArray& starts_arr,
const ScalarArray& ends_arr,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad) {
size_t rank = out_grad.dims().size();
size_t rank = input.dims().size();
auto& starts = starts_arr.GetData();
auto& ends = ends_arr.GetData();
switch (rank) {
case 1:
......
......@@ -110,13 +110,16 @@ template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const ScalarArray& starts_arr,
const ScalarArray& ends_arr,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out) {
int rank = input.dims().size();
auto& starts = starts_arr.GetData();
auto& ends = ends_arr.GetData();
switch (rank) {
case 1:
SliceCompute<T, Context, 1>(
......
......@@ -14,16 +14,18 @@
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -22,8 +23,8 @@ template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
......
......@@ -17,19 +17,155 @@
namespace phi {
KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{"Out"});
if (ctx.HasInput("StartsTensor")) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensor",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensor",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
} else if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensorList",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensorList",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
} else {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
}
}
KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"slice_grad",
{GradVarName("Out")},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
if (ctx.HasInput("StartsTensor")) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensor",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensor",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
} else if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensorList",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensorList",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
} else {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
}
}
} // namespace phi
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
......@@ -22,8 +23,8 @@ template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册