未验证 提交 abc17ef7 编写于 作者: S ShenLiang 提交者: GitHub

Fix gather infer shape using axis (#33413)

* fix gather shape bug

* fix None

* fix topo
上级 9d8d5317
......@@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
}
}
template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
......@@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto* index_data = index->data<U>();
if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int axis_index = axis;
int index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
......@@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
index_size, index_dim_size, out_size);
}
template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
......
......@@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
}
}
template <typename T, typename U, typename V>
void GatherV2Function(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
Tensor* out, const paddle::platform::Place& place) {
auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel();
int input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) {
......@@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index,
}
}
template <typename T, typename U, typename V>
template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out,
const int axis, Tensor* out,
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>();
int axis_size = axis->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel {
index_dims.size()));
}
int batch_size = ctx->GetInputDim("Index")[0];
framework::DDim output_dims(ctx->GetInputDim("X"));
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
auto axis = ctx->Attrs().Get<int>("axis");
auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = framework::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
......@@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.")
.SetDefault(true);
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC(
Gather Operator.
......
......@@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int32_t, int64_t>(x, index, axis, output, place,
ctx);
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int64_t, int32_t>(x, index, axis, output, place,
ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t, int64_t>(x, index, axis, output, place,
ctx);
}
const auto &place = ctx.GetPlace();
const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2CUDAFunction<T, int32_t>(x, index, axis, output, place, ctx);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, place, ctx);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t, int32_t>(dO, index, axis, dX,
place, ctx);
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
Tensor cpu_axis;
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int32_t, int64_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int64_t, int32_t>(dO, index, axis, dX,
place, ctx);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t, int64_t>(dO, index, axis, dX,
place, ctx);
}
const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
}
return;
}
......@@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
......
......@@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2Function<T, int32_t, int32_t>(x, index, axis, output, place);
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int32_t, int64_t>(x, index, axis, output, place);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2Function<T, int64_t, int32_t>(x, index, axis, output, place);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int64_t, int64_t>(x, index, axis, output, place);
}
const auto &place = ctx.GetPlace();
const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2Function<T, int32_t>(x, index, axis, output, place);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int64_t>(x, index, axis, output, place);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
......@@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type();
const auto &axis_type = axis->type();
auto place = ctx.GetPlace();
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradFunction<T, int32_t, int32_t>(dO, index, axis, dX, place);
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->type();
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int32_t, int64_t>(dO, index, axis, dX, place);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradFunction<T, int64_t, int32_t>(dO, index, axis, dX, place);
}
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int64_t, int64_t>(dO, index, axis, dX, place);
}
const auto &index_type = index->type();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2GradFunction<T, int32_t>(dO, index, axis, dX, ctx.GetPlace());
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int64_t>(dO, index, axis, dX, ctx.GetPlace());
}
return;
}
......@@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite");
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
......
......@@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1):
self.index_type = "int64"
self.axis = [0]
self.axis_type = "int32"
self.attrs = {'overwrite': False}
class API_TestGather(unittest.TestCase):
......
......@@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None):
"""
if axis is None:
axis = 0
axis_tensor = axis
if not isinstance(axis, Variable) and axis == 0:
return paddle.fluid.layers.gather(input=x, index=index, overwrite=False)
if not isinstance(axis, Variable):
with device_guard("cpu"):
axis_tensor = fill_constant(
shape=[1], dtype='int64', value=axis, force_cpu=True)
if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor)
axis = axis.item() if isinstance(axis, paddle.Tensor) else axis
return core.ops.gather(x, index, None, "axis", axis, "overwrite", False)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'gather')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')
if isinstance(axis, Variable):
check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather')
else:
check_type(axis, 'axis', (int), 'gather')
helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype('x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="gather",
inputs={"X": x,
"Index": index,
"Axis": axis_tensor},
outputs={"Out": out})
if not isinstance(axis, Variable):
helper.append_op(
type="gather",
inputs={"X": x,
"Index": index},
attrs={'axis': axis,
'overwrite': False},
outputs={"Out": out})
else:
helper.append_op(
type="gather",
inputs={"X": x,
"Index": index,
"Axis": axis},
attrs={"overwrite": False},
outputs={"Out": out})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册