未验证 提交 abc17ef7 编写于 作者: S ShenLiang 提交者: GitHub

Fix gather infer shape using axis (#33413)

* fix gather shape bug

* fix None

* fix topo
上级 9d8d5317
...@@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out, ...@@ -202,12 +202,11 @@ __global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place, const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
...@@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -215,12 +214,8 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1,
platform::errors::InvalidArgument( int axis_index = axis;
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int index_dim_size = input_dim[axis_index]; int index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
...@@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -251,26 +246,19 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
index_size, index_dim_size, out_size); index_size, index_dim_size, out_size);
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place, const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
Tensor cpu_axis;
framework::TensorCopy(*axis, platform::CPUPlace(), &cpu_axis);
int axis_index = cpu_axis.data<V>()[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
......
...@@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input, ...@@ -126,24 +126,17 @@ void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
const Tensor* axis, Tensor* out, Tensor* out, const paddle::platform::Place& place) {
const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
int index_size = index->numel(); int index_size = index->numel();
int input_size = input->numel(); int input_size = input->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) { for (int i = 0; i < index_size; i++) {
...@@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, ...@@ -186,22 +179,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index,
} }
} }
template <typename T, typename U, typename V> template <typename T, typename U>
void GatherV2GradFunction(const Tensor* input, const Tensor* index, void GatherV2GradFunction(const Tensor* input, const Tensor* index,
const Tensor* axis, Tensor* out, const int axis, Tensor* out,
const paddle::platform::Place& place) { const paddle::platform::Place& place) {
auto* axis_data = axis->data<V>();
auto* index_data = index->data<U>(); auto* index_data = index->data<U>();
int axis_size = axis->numel();
auto input_dim = input->dims(); auto input_dim = input->dims();
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
if (input->numel() == 0) return; if (input->numel() == 0) return;
PADDLE_ENFORCE_EQ(axis_size, 1, int axis_index = axis;
platform::errors::InvalidArgument(
"Axis size should be 1, but received %d", axis_size));
int axis_index = axis_data[0];
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
int inner_dim_size = 1; int inner_dim_size = 1;
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -52,11 +53,29 @@ class GatherOp : public framework::OperatorWithKernel {
index_dims.size())); index_dims.size()));
} }
int batch_size = ctx->GetInputDim("Index")[0]; auto axis = ctx->Attrs().Get<int>("axis");
framework::DDim output_dims(ctx->GetInputDim("X")); auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = framework::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
protected: protected:
...@@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,6 +139,10 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, update the grad using the overwrite mode in same index," "If true, update the grad using the overwrite mode in same index,"
"If false, using the accumulate mode in same index.") "If false, using the accumulate mode in same index.")
.SetDefault(true); .SetDefault(true);
AddAttr<int>(
"axis",
"The Tensor which contains the axis that we do gather operation.")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
Gather Operator. Gather Operator.
......
...@@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -31,47 +31,33 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); Tensor cpu_axis;
const auto &index_type = index->type(); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis->type(); framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
auto place = ctx.GetPlace(); const auto &axis_type = axis_tensor->type();
if (index_type == framework::proto::VarType::INT32 && if (axis_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
GatherV2CUDAFunction<T, int32_t, int32_t>(x, index, axis, output, place, } else if (axis_type == framework::proto::VarType::INT64) {
ctx); axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} }
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2CUDAFunction<T, int32_t, int64_t>(x, index, axis, output, place,
ctx);
} }
if (index_type == framework::proto::VarType::INT64 && const auto &place = ctx.GetPlace();
axis_type == framework::proto::VarType::INT32) { const auto &index_type = index->type();
GatherV2CUDAFunction<T, int64_t, int32_t>(x, index, axis, output, place, if (axis != 0) {
ctx); if (index_type == framework::proto::VarType::INT32) {
} GatherV2CUDAFunction<T, int32_t>(x, index, axis, output, place, ctx);
if (index_type == framework::proto::VarType::INT64 && } else if (index_type == framework::proto::VarType::INT64) {
axis_type == framework::proto::VarType::INT64) { GatherV2CUDAFunction<T, int64_t>(x, index, axis, output, place, ctx);
GatherV2CUDAFunction<T, int64_t, int64_t>(x, index, axis, output, place,
ctx);
} }
return; return;
} }
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output); GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -91,30 +77,27 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); Tensor cpu_axis;
const auto &axis_type = axis->type(); framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
auto place = ctx.GetPlace(); const auto &axis_type = axis_tensor->type();
if (index_type == framework::proto::VarType::INT32 && if (axis_type == framework::proto::VarType::INT32) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
GatherV2GradCUDAFunction<T, int32_t, int32_t>(dO, index, axis, dX, } else if (axis_type == framework::proto::VarType::INT64) {
place, ctx); axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int32_t, int64_t>(dO, index, axis, dX,
place, ctx);
} }
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int64_t, int32_t>(dO, index, axis, dX,
place, ctx);
} }
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT64) { const auto &index_type = index->type();
GatherV2GradCUDAFunction<T, int64_t, int64_t>(dO, index, axis, dX, if (axis != 0) {
place, ctx); if (index_type == framework::proto::VarType::INT32) {
GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
ctx.GetPlace(), ctx);
} }
return; return;
} }
...@@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -125,19 +108,6 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
.eigen_device(); .eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx, *dO, *index, dX, GPUScatterAssign<T, int>(ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite")); ctx.Attr<bool>("overwrite"));
......
...@@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -35,45 +35,30 @@ class GatherOpKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index"); auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); const auto &axis_type = axis_tensor->type();
const auto &axis_type = axis->type(); if (axis_type == framework::proto::VarType::INT32) {
auto place = ctx.GetPlace(); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
if (index_type == framework::proto::VarType::INT32 && } else if (axis_type == framework::proto::VarType::INT64) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
GatherV2Function<T, int32_t, int32_t>(x, index, axis, output, place);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int32_t, int64_t>(x, index, axis, output, place);
} }
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2Function<T, int64_t, int32_t>(x, index, axis, output, place);
} }
if (index_type == framework::proto::VarType::INT64 && const auto &place = ctx.GetPlace();
axis_type == framework::proto::VarType::INT64) { const auto &index_type = index->type();
GatherV2Function<T, int64_t, int64_t>(x, index, axis, output, place); if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2Function<T, int32_t>(x, index, axis, output, place);
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2Function<T, int64_t>(x, index, axis, output, place);
} }
return; return;
} }
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return; if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output); CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
...@@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -94,26 +79,23 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
const Tensor *axis = ctx.Input<Tensor>("Axis"); const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &index_type = index->type(); const auto &axis_type = axis_tensor->type();
const auto &axis_type = axis->type(); if (axis_type == framework::proto::VarType::INT32) {
auto place = ctx.GetPlace(); axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
if (index_type == framework::proto::VarType::INT32 && } else if (axis_type == framework::proto::VarType::INT64) {
axis_type == framework::proto::VarType::INT32) { axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
GatherV2GradFunction<T, int32_t, int32_t>(dO, index, axis, dX, place);
}
if (index_type == framework::proto::VarType::INT32 &&
axis_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int32_t, int64_t>(dO, index, axis, dX, place);
} }
if (index_type == framework::proto::VarType::INT64 &&
axis_type == framework::proto::VarType::INT32) {
GatherV2GradFunction<T, int64_t, int32_t>(dO, index, axis, dX, place);
} }
if (index_type == framework::proto::VarType::INT64 && const auto &index_type = index->type();
axis_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int64_t, int64_t>(dO, index, axis, dX, place); if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
GatherV2GradFunction<T, int32_t>(dO, index, axis, dX, ctx.GetPlace());
} else if (index_type == framework::proto::VarType::INT64) {
GatherV2GradFunction<T, int64_t>(dO, index, axis, dX, ctx.GetPlace());
} }
return; return;
} }
...@@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -126,18 +108,6 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite"); bool overwrite = ctx.Attr<bool>("overwrite");
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s].",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
if (overwrite) { if (overwrite) {
ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX); ScatterAssign<T, int32_t>(ctx.device_context(), *dO, *index, dX);
......
...@@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1): ...@@ -182,6 +182,7 @@ class TestGatherOp4(TestGatherOp1):
self.index_type = "int64" self.index_type = "int64"
self.axis = [0] self.axis = [0]
self.axis_type = "int32" self.axis_type = "int32"
self.attrs = {'overwrite': False}
class API_TestGather(unittest.TestCase): class API_TestGather(unittest.TestCase):
......
...@@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None): ...@@ -862,34 +862,39 @@ def gather(x, index, axis=None, name=None):
""" """
if axis is None: if axis is None:
axis = 0 axis = 0
axis_tensor = axis
if not isinstance(axis, Variable) and axis == 0:
return paddle.fluid.layers.gather(input=x, index=index, overwrite=False)
if not isinstance(axis, Variable):
with device_guard("cpu"):
axis_tensor = fill_constant(
shape=[1], dtype='int64', value=axis, force_cpu=True)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor) axis = axis.item() if isinstance(axis, paddle.Tensor) else axis
return core.ops.gather(x, index, None, "axis", axis, "overwrite", False)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'], x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64', 'uint8'],
'gather') 'gather')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather') check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather')
if isinstance(axis, Variable): if isinstance(axis, Variable):
check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather') check_variable_and_dtype(axis, 'axis', ['int32', 'int64'], 'gather')
else:
check_type(axis, 'axis', (int), 'gather')
helper = LayerHelper('gather', **locals()) helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype('x') dtype = helper.input_dtype('x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
if not isinstance(axis, Variable):
helper.append_op(
type="gather",
inputs={"X": x,
"Index": index},
attrs={'axis': axis,
'overwrite': False},
outputs={"Out": out})
else:
helper.append_op( helper.append_op(
type="gather", type="gather",
inputs={"X": x, inputs={"X": x,
"Index": index, "Index": index,
"Axis": axis_tensor}, "Axis": axis},
attrs={"overwrite": False},
outputs={"Out": out}) outputs={"Out": out})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册