未验证 提交 615a8bfc 编写于 作者: L liym27 提交者: GitHub

Support int32 int64 and fix bug (#24407)

* Compatible int32 and int64 for attr in op slice/strided_slice. test=develop

* Polish code in nn.py  test=develop

* Fix bug: set the same dtype for the inputs of elementwise_add. test=develop

* Convert int32 to int64 in slice op to avoid data overflow. test=develop

* Convert int32 to int64 in strided_slice_op to avoid data overflow. test=develop
上级 f68d4fb3
......@@ -26,13 +26,13 @@ using platform::PADDLE_CUDA_NUM_THREADS;
template <size_t D>
__global__ void Padding(const paddle::platform::float16* d_out,
const int* out_dims, const int* in_dims,
const int* offsets, int64_t n,
const int64_t* out_dims, const int64_t* in_dims,
const int64_t* offsets, int64_t n,
paddle::platform::float16* d_in) {
int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (out_idx < n) {
int64_t out_idx_tmp = out_idx;
int coords[D] = {0};
int64_t coords[D] = {0};
for (int i = D - 1; i >= 0; --i) {
coords[i] = out_idx_tmp % out_dims[i];
out_idx_tmp /= out_dims[i];
......@@ -61,25 +61,26 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
auto out_dims = d_out->dims();
auto in_dims = d_in->dims();
int rank = out_dims.size();
std::vector<int> offsets(rank, 0);
std::vector<int64_t> offsets(rank, 0);
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto list_new_starts_tensor =
ctx.MultiInput<framework::Tensor>("StartsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (ctx.HasInput("StartsTensor")) {
auto* starts_tensor = ctx.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
for (size_t i = 0; i < starts.size(); ++i) {
if (starts[i] < 0) {
starts[i] += in_dims[axes[i]];
}
offsets[axes[i]] = std::max(starts[i], 0);
offsets[axes[i]] = std::max(starts[i], static_cast<int64_t>(0));
}
math::SetConstant<paddle::platform::CUDADeviceContext,
......@@ -94,14 +95,16 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
dim3 threads(PADDLE_CUDA_NUM_THREADS);
auto stream = ctx.cuda_device_context().stream();
auto out_shape = framework::vectorize<int>(out_dims);
thrust::device_vector<int> out_dims_vec(out_shape.begin(), out_shape.end());
auto in_shape = framework::vectorize<int>(in_dims);
thrust::device_vector<int> in_dims_vec(in_shape.begin(), in_shape.end());
thrust::device_vector<int> offsets_vec(offsets.begin(), offsets.end());
const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
auto out_shape = framework::vectorize<int64_t>(out_dims);
thrust::device_vector<int64_t> out_dims_vec(out_shape.begin(),
out_shape.end());
auto in_shape = framework::vectorize<int64_t>(in_dims);
thrust::device_vector<int64_t> in_dims_vec(in_shape.begin(),
in_shape.end());
thrust::device_vector<int64_t> offsets_vec(offsets.begin(), offsets.end());
const int64_t* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
const int64_t* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
const int64_t* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
switch (rank) {
case 1:
......
......@@ -18,43 +18,12 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int> get_new_data_from_tensorlist(
const std::vector<const Tensor*>& list_new_data_tensor) {
// get tensor from
std::vector<int> vec_new_data;
for (size_t i = 0; i < list_new_data_tensor.size(); ++i) {
auto tensor = list_new_data_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_data.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_data;
}
inline std::vector<int> get_new_data_from_tensor(
const Tensor* new_data_tensor) {
std::vector<int> vec_new_data;
auto* new_data = new_data_tensor->data<int>();
framework::Tensor cpu_starts_tensor;
if (platform::is_gpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int>();
}
vec_new_data =
std::vector<int>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
public:
......@@ -98,9 +67,11 @@ class SliceKernel : public framework::OpKernel<T> {
bool out_is_tensor_array = out_var->IsType<framework::LoDTensorArray>();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto starts_int = context.Attr<std::vector<int>>("starts");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto list_new_ends_tensor =
......@@ -118,15 +89,15 @@ class SliceKernel : public framework::OpKernel<T> {
if (need_infer) {
if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
starts = GetDataFromTensor<int64_t>(starts_tensor);
} else if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
}
if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
ends = GetDataFromTensor<int64_t>(ends_tensor);
} else if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
}
}
PADDLE_ENFORCE_EQ(
......@@ -140,11 +111,12 @@ class SliceKernel : public framework::OpKernel<T> {
if (input_is_tensor_array) {
auto in_array = context.Input<framework::LoDTensorArray>("Input");
// If the input is LoDTensorArray, the rank of input is 1.
int in_size = in_array->size();
int start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, 0);
end = std::max(end, 0);
int64_t in_size = in_array->size();
int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0];
int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0];
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size);
PADDLE_ENFORCE_GT(end, start,
......@@ -152,7 +124,7 @@ class SliceKernel : public framework::OpKernel<T> {
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.",
end, start));
int out_size = end - start;
int64_t out_size = end - start;
if (out_is_tensor_array) {
auto out_array = context.Output<framework::LoDTensorArray>("Out");
......@@ -187,7 +159,7 @@ class SliceKernel : public framework::OpKernel<T> {
auto in_dims = in->dims();
if (need_infer) {
out_dims = in_dims;
int dim_value, start, end;
int64_t dim_value, start, end;
for (size_t i = 0; i < axes.size(); ++i) {
dim_value = out_dims[axes[i]];
if (dim_value > 0) {
......@@ -202,17 +174,22 @@ class SliceKernel : public framework::OpKernel<T> {
start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
start = std::max(start, static_cast<int64_t>(0));
end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, "end should greater than start");
PADDLE_ENFORCE_GT(
end, start,
platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in "
"slice op. But received ends = %d, starts = %d.",
end, start));
out_dims[axes[i]] = end - start;
}
}
out->Resize(out_dims);
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
"decrease dim should be 1");
......@@ -260,19 +237,19 @@ class SliceKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace());
auto new_out_dims = out->dims();
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
auto offsets = Eigen::array<int64_t, D>();
auto extents = Eigen::array<int64_t, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = new_out_dims[i];
}
int start;
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
}
auto in_t =
......@@ -325,25 +302,30 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto axes = context.Attr<std::vector<int>>("axes");
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto starts_int = context.Attr<std::vector<int>>("starts");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
auto ends_int = context.Attr<std::vector<int>>("ends");
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
framework::Variable* d_input_var =
context.OutputVar(framework::GradVarName("Input"));
......@@ -358,12 +340,12 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto* d_input_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
int d_in_size = input_array->size();
int64_t d_in_size = input_array->size();
d_input_array->resize(d_in_size);
// If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts.
int start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
start = std::max(start, 0);
int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0];
start = std::max(start, static_cast<int64_t>(0));
// set zero
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
......@@ -432,22 +414,22 @@ class SliceGradKernel : public framework::OpKernel<T> {
}
}
auto offsets = Eigen::array<int, D>();
auto extents = Eigen::array<int, D>();
auto offsets = Eigen::array<int64_t, D>();
auto extents = Eigen::array<int64_t, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
int start;
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
start = std::max(start, 0);
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
}
Eigen::array<std::pair<int, int>, D> paddings;
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
......
......@@ -39,9 +39,15 @@ class StridedSliceOp : public framework::OperatorWithKernel {
"The dimension of StridedSlice operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
auto starts = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto starts_int = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends_int = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides_int = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
......@@ -109,7 +115,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
std::vector<int> out_dims_vector(in_dims.size(), -1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
......@@ -118,7 +124,7 @@ class StridedSliceOp : public framework::OperatorWithKernel {
framework::DDim out_dims(framework::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
......
......@@ -24,15 +24,15 @@ namespace paddle {
namespace operators {
static void StridedSliceOutDims(
const std::vector<int>& starts, const std::vector<int>& ends,
const std::vector<int>& strides, const std::vector<int>& axes,
const std::vector<int64_t>& starts, const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
const std::vector<int>& decrease_axis, int* out_dims_vector,
const std::vector<int>& decrease_axis, int64_t* out_dims_vector,
const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int stride_index, start_index, end_index;
int64_t stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
start_index = starts[i];
......@@ -57,7 +57,8 @@ static void StridedSliceOutDims(
PADDLE_ENFORCE_NE(stride_index, 0,
platform::errors::InvalidArgument(
"stride index in StridedSlice operator is 0."));
int axis_size = in_dims[axes_index];
int64_t axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
......@@ -83,22 +84,26 @@ static void StridedSliceOutDims(
platform::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));
int left = std::max(0, std::min(start_index, end_index));
int right = std::min(axis_size, std::max(start_index, end_index));
int step = std::abs(stride_index);
int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int* starts, int* ends, int* strides, int* axes,
int* reverse_axis, const framework::DDim dims,
static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
int64_t* strides, int* axes, int* reverse_axis,
const framework::DDim dims,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
const size_t size) {
for (size_t axis = 0; axis < size; axis++) {
int axis_size = dims[axes[axis]];
int64_t axis_size = dims[axes[axis]];
int axis_index = axis;
if (axis_size < 0) {
starts[axis_index] = 0;
......@@ -183,9 +188,14 @@ class StridedSliceKernel : public framework::OpKernel<T> {
auto out = context.Output<framework::Tensor>("Out");
auto in_dims = in->dims();
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
auto strides_int = context.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
......@@ -203,27 +213,27 @@ class StridedSliceKernel : public framework::OpKernel<T> {
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
std::vector<int> out_dims_vector(in_dims.size(), -1);
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
......@@ -250,7 +260,7 @@ class StridedSliceKernel : public framework::OpKernel<T> {
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int> new_out_shape;
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
......@@ -350,9 +360,15 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, d_out, static_cast<T>(0));
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto starts = context.Attr<std::vector<int>>("starts");
auto ends = context.Attr<std::vector<int>>("ends");
auto strides = context.Attr<std::vector<int>>("strides");
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
auto strides_int = context.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
......@@ -365,24 +381,24 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = get_new_data_from_tensorlist(list_new_starts_tensor);
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = get_new_data_from_tensor(starts_tensor);
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = get_new_data_from_tensorlist(list_new_ends_tensor);
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = get_new_data_from_tensor(ends_tensor);
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = get_new_data_from_tensorlist(list_new_strides_tensor);
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = get_new_data_from_tensor(strides_tensor);
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
......
......@@ -38,6 +38,7 @@ inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
// NOTE: Converting int64 to int32 may cause data overflow.
vec_new_data = std::vector<T>(data, data + x->numel());
} else {
PADDLE_THROW("The dtype of Tensor must be int32 or int64.");
......@@ -69,6 +70,7 @@ inline std::vector<T> GetDataFromTensorList(
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
// NOTE: Converting int64 to int32 may cause data overflow.
vec_new_data.push_back(static_cast<T>(*temp.data<int64_t>()));
} else {
vec_new_data.push_back(static_cast<T>(*tensor->data<int64_t>()));
......
......@@ -665,8 +665,7 @@ def _getitem_impl_(var, item):
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
},
stop_gradient=True)
})
out.stop_gradient = True
return out
......@@ -706,9 +705,9 @@ def _getitem_impl_(var, item):
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype='int32')
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype='int32')
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
......
......@@ -6083,19 +6083,6 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
helper = LayerHelper("reshape2", **locals())
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
......@@ -6133,7 +6120,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"but received %s." % len(shape))
attrs["shape"] = get_attr_shape(shape)
if utils._contain_var(shape):
inputs['ShapeTensor'] = get_new_shape_tensor(shape)
inputs['ShapeTensor'] = utils._convert_to_tensor_list(shape)
elif isinstance(actual_shape, Variable):
actual_shape.stop_gradient = True
inputs["Shape"] = actual_shape
......@@ -6258,19 +6245,6 @@ def unsqueeze(input, axes, name=None):
inputs = {"X": input}
attrs = {}
def _to_Variable_list(one_list):
Variable_list = []
for ele in one_list:
if isinstance(ele, Variable):
ele.stop_gradient = True
Variable_list.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out)
Variable_list.append(temp_out)
return Variable_list
if isinstance(axes, int):
axes = [axes]
if isinstance(axes, Variable):
......@@ -6278,7 +6252,7 @@ def unsqueeze(input, axes, name=None):
inputs["AxesTensor"] = axes
elif isinstance(axes, (list, tuple)):
if utils._contain_var(axes):
inputs["AxesTensorList"] = _to_Variable_list(axes)
inputs["AxesTensorList"] = utils._convert_to_tensor_list(axes)
else:
attrs["axes"] = axes
......@@ -10192,26 +10166,13 @@ def expand(x, expand_times, name=None):
"Each element given in expand_times must not be negative.")
return attrs_expand_times
def get_new_expand_times_tensor(list_expand_times):
new_expand_times_tensor = []
for ele in list_expand_times:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_expand_times_tensor.append(ele)
else:
assert (isinstance(ele, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', ele, force_cpu=True, out=temp_out)
new_expand_times_tensor.append(temp_out)
return new_expand_times_tensor
if isinstance(expand_times, Variable):
expand_times.stop_gradient = True
inputs['ExpandTimes'] = expand_times
elif isinstance(expand_times, (list, tuple)):
attrs['expand_times'] = get_attr_expand_times(expand_times)
if utils._contain_var(expand_times):
inputs['expand_times_tensor'] = get_new_expand_times_tensor(
inputs['expand_times_tensor'] = utils._convert_to_tensor_list(
expand_times)
dtype = helper.input_dtype(input_param_name='x')
......@@ -10784,19 +10745,6 @@ def slice(input, axes, starts, ends):
helper = LayerHelper('slice', **locals())
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': input}
attrs = {'axes': axes}
infer_flags = list(1 for i in range(len(axes)))
......@@ -10809,7 +10757,7 @@ def slice(input, axes, starts, ends):
elif isinstance(starts, (list, tuple)):
attrs['starts'] = []
if utils._contain_var(starts):
inputs['StartsTensorList'] = get_new_list_tensor(starts)
inputs['StartsTensorList'] = utils._convert_to_tensor_list(starts)
for i, dim in enumerate(starts):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
......@@ -10827,7 +10775,7 @@ def slice(input, axes, starts, ends):
elif isinstance(ends, (list, tuple)):
attrs['ends'] = []
if utils._contain_var(ends):
inputs['EndsTensorList'] = get_new_list_tensor(ends)
inputs['EndsTensorList'] = utils._convert_to_tensor_list(ends)
for i, dim in enumerate(ends):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
......
......@@ -327,3 +327,21 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type):
inputs['ShapeTensorList'] = _get_shape_tensor(shape)
return inputs
def _convert_to_tensor_list(old_list, dtype="int32"):
"""
Converts all elements of a list to Variable.
"""
from .tensor import fill_constant
new_list_tensor = []
for ele in old_list:
if isinstance(ele, Variable):
ele.stop_gradient = True
new_list_tensor.append(ele)
else:
assert (isinstance(ele, int))
temp_out = fill_constant([1], dtype, ele, force_cpu=True)
new_list_tensor.append(temp_out)
return new_list_tensor
......@@ -168,7 +168,7 @@ class TestSliceOp_starts_ListTensor(OpTest):
starts_tensor = []
for index, ele in enumerate(self.starts):
starts_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
(1)).astype('int64') * ele))
self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
self.outputs = {'Out': self.out}
......@@ -297,7 +297,7 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest):
self.inputs = {
'Input': self.input,
"StartsTensor": np.array(
self.starts, dtype="int32"),
self.starts, dtype="int64"),
"EndsTensor": np.array(
self.ends, dtype="int32")
}
......@@ -486,7 +486,7 @@ class TestSliceAPI(unittest.TestCase):
def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float64")
minus_1 = fluid.layers.fill_constant([1], "int32", -1)
minus_3 = fluid.layers.fill_constant([1], "int32", -3)
minus_3 = fluid.layers.fill_constant([1], "int64", -3)
starts = fluid.layers.data(
name='starts', shape=[1, 3], append_batch_size=False)
ends = fluid.layers.data(
......@@ -498,8 +498,11 @@ class TestSliceAPI(unittest.TestCase):
append_batch_size=False,
dtype="float64")
# value_int64 is greater than 2147483647 which is the max of int32
value_int64 = fluid.layers.fill_constant([1], "int64", 2147483648)
out_1 = fluid.layers.slice(
x, axes=[0, 1, 2], starts=[-3, 0, 2], ends=[3, 100, -1])
x, axes=[0, 1, 2], starts=[-3, 0, 2], ends=[value_int64, 100, -1])
out_2 = fluid.layers.slice(
x, axes=[0, 1, 3], starts=[minus_3, 0, 2], ends=[3, 100, -1])
out_3 = fluid.layers.slice(
......@@ -564,11 +567,17 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase):
self.sliced_arr = output = arr[0]
elif case_num == 2:
end = fluid.layers.array_length(arr) - 1
end = fluid.layers.cast(end, "int32")
end = fluid.layers.array_length(
arr) - 1 # dtype of end is int64
self.sliced_arr = slice_arr = arr[self.start:end]
output, _ = fluid.layers.tensor_array_to_tensor(
slice_arr, axis=self.axis, use_stack=True)
elif case_num == 3:
value_int64 = fluid.layers.fill_constant([1], "int64",
2147483648)
self.sliced_arr = slice_arr = arr[self.start:value_int64]
output, _ = fluid.layers.tensor_array_to_tensor(
slice_arr, axis=self.axis, use_stack=True)
loss = fluid.layers.reduce_sum(output)
fluid.backward.append_backward(loss)
......@@ -608,6 +617,22 @@ class TestSliceApiWithLoDTensorArray(unittest.TestCase):
self.assertTrue(np.array_equal(self.g_x1, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x2, np.zeros_like(self.data)))
def test_case_3(self):
main_program = fluid.Program()
self.set_program_and_run(main_program, 3)
self.assertTrue(
self.sliced_arr.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY)
self.assertEqual(self.sliced_arr.shape, self.shape)
self.assertTrue(
np.array_equal(
self.out,
np.stack(
[self.data, self.data, self.data], axis=self.axis)))
self.assertTrue(np.array_equal(self.g_x0, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x1, np.ones_like(self.data)))
self.assertTrue(np.array_equal(self.g_x2, np.ones_like(self.data)))
if __name__ == '__main__':
unittest.main()
......@@ -486,7 +486,7 @@ class TestStridedSliceAPI(unittest.TestCase):
feed={
"x": input,
'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 100, -1]).astype("int32"),
'ends': np.array([3, 2147483648, -1]).astype("int64"),
'strides': np.array([1, 1, 1]).astype("int32")
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册