提交 85914f7a 编写于 作者: S ShenLiang 提交者: Yi Liu

add gather_nd op and unit test (#19366)

* fixed the code for coverage

* fixed the document,test=document_preview test=develop
上级 ecd9f330
......@@ -194,6 +194,7 @@ paddle.fluid.layers.resize_bilinear (ArgSpec(args=['input', 'out_shape', 'scale'
paddle.fluid.layers.resize_trilinear (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners', 'align_mode'], varargs=None, keywords=None, defaults=(None, None, None, None, True, 1)), ('document', '4836e98a634f6fbea26d0cdaa303f867'))
paddle.fluid.layers.resize_nearest (ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape', 'align_corners'], varargs=None, keywords=None, defaults=(None, None, None, None, True)), ('document', '32ffc0e8818d7319ed1bf63a791e985d'))
paddle.fluid.layers.gather (ArgSpec(args=['input', 'index', 'overwrite'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f985c9b66e3aec96fa753a8eb44c991c'))
paddle.fluid.layers.gather_nd (ArgSpec(args=['input', 'index', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '3cc24f9cf135770aa6263dba25b457f9'))
paddle.fluid.layers.scatter (ArgSpec(args=['input', 'index', 'updates', 'name', 'overwrite'], varargs=None, keywords=None, defaults=(None, True)), ('document', '69b22affd4a6326502af166f04c095ab'))
paddle.fluid.layers.sequence_scatter (ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'abe3f714120117a5a3d3e639853932bf'))
paddle.fluid.layers.random_crop (ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,)), ('document', '042af0b8abea96b40c22f6e70d99e042'))
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -39,6 +43,27 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
}
}
template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
const IndexT* indices, T* output,
size_t remain_size, size_t slice_size,
size_t end_size) {
CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
auto index_value = indices[indices_i * end_size + j];
assert(index_value >= 0 && index_value < input_dims[j]);
gather_i += (index_value * temp);
temp *= input_dims[j];
}
IndexT input_i = gather_i + slice_i;
*(output + i) = *(input + input_i);
}
}
/**
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
......@@ -84,5 +109,56 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
p_src, p_index, p_output, index_size, slice_size);
}
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUGatherNd(const framework::ExecutionContext& context,
const Tensor& input, const Tensor& index, Tensor* output) {
const auto& ctx = context.template device_context<DeviceContext>();
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
auto cplace = platform::CPUPlace();
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
auto input_dims = input.dims();
auto input_dims_size = input_dims.size();
const T* p_input = input.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
// remain dim
auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = framework::product(remain_ddim);
// slice size
int64_t slice_size = 1;
for (int64_t i = end_size; i < input_dims_size; ++i) {
slice_size *= input_dims[i];
}
// source dim
std::vector<int> v_input_dims(input_dims_size);
for (int i = 0; i < input_dims_size; ++i) {
v_input_dims[i] = static_cast<int>(input_dims[i]);
}
auto& dev_ctx = context.cuda_device_context();
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = input_dims_size * sizeof(int);
auto p_input_dims = allocator.Allocate(bytes);
int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
ctx.stream());
int block = 512;
int n = slice_size * remain_numel;
int grid = (n + block - 1) / block;
GatherNdCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_input, g_input_dims, p_index, p_output, remain_numel, slice_size,
end_size);
}
} // namespace operators
} // namespace paddle
......@@ -60,5 +60,51 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
}
}
template <typename T, typename IndexT = int>
void CPUGatherNd(const platform::DeviceContext& ctx, const Tensor& input,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"It should be running on the CPU");
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
auto input_dims = input.dims();
auto input_dims_size = input_dims.size();
const T* p_input = input.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
// remain dim
auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = framework::product(remain_ddim);
// slice size
int64_t slice_size = 1;
for (int64_t i = end_size; i < input_dims_size; ++i) {
slice_size *= input_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < remain_numel; ++i) {
int64_t index_ = 0;
int64_t temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_LT(index_value, input_dims[j],
"Input(index[-1)] has wrong value, it is %d",
index_value);
PADDLE_ENFORCE_GE(index_value, 0UL,
"The value of Input(index) must be no less than 0");
index_ += (index_value * temp);
temp *= input_dims[j];
}
memcpy(p_output + i * slice_size, p_input + index_ * slice_size,
slice_bytes);
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_nd_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
namespace paddle {
namespace operators {
class GatherNdOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of GatherNdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
"Input(Index) of GatherNdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of GatherNdOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims_size = x_dims.size();
auto index_dims = ctx->GetInputDim("Index");
auto index_dims_size = index_dims.size();
PADDLE_ENFORCE_LE(index_dims[index_dims_size - 1], x_dims_size,
"Input(Index).shape[-1] <= Input(X).rank");
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
"The rank of Input(Index) should be greater than 1");
std::vector<int64_t> result_dims;
// The result dims is
// Index.shape[:-1] + X.shape[Index.shape[-1]:]
for (int i = 0; i < index_dims_size - 1; ++i) {
result_dims.emplace_back(index_dims[i]);
}
for (int i = index_dims[index_dims_size - 1]; i < x_dims_size; ++i) {
result_dims.emplace_back(x_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(result_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
class GatherNdGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
class GatherNdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The source input of gather_nd op");
AddInput("Index", "The index input of gather_nd op");
AddOutput("Out", "The output of gather_nd op");
AddComment(R"DOC(
Gather_Nd Operator.
This function is actually a high-dimensional extension of gather
and supports for simultaneous indexing by multiple axes. Out is
obtained by gathering slices from X into a tensor with shape
Index.shape[:-1] + X.shape[Index.shape[-1]:].
Example:
Given:
X = [[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
X.shape = (2, 3, 4)
*Case 1:
Index = [[1]]
we get:
Out =
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]
*Case 2:
Index = [[0,2]]
we get:
Out = [8, 9, 10, 11]
*Case 3:
Index = [[1, 2, 3]]
we get:
Out = [23]
)DOC");
}
};
class GatherNdGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("gather_nd_grad");
op->SetInput("Index", Input("Index"));
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GatherNdGradNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gather_nd, ops::GatherNdOp, ops::GatherNdOpMaker,
ops::GatherNdGradOpDescMaker);
REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
ops::GatherNdGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>,
ops::GatherNdOpKernel<int64_t>,
ops::GatherNdOpKernel<int>,
ops::GatherNdOpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,
ops::GatherNdGradOpKernel<double>,
ops::GatherNdGradOpKernel<int64_t>,
ops::GatherNdGradOpKernel<int>,
ops::GatherNdGradOpKernel<uint8_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/gather_nd_op.h"
#include "paddle/fluid/operators/scatter.cu.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class GatherNdOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUGatherNd<DeviceContext, T, int>(ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
GPUGatherNd<DeviceContext, T, int64_t>(ctx, *x, *index, output);
}
}
};
template <typename DeviceContext, typename T>
class GatherNdGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"This kernel only runs on GPU device.");
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterNdAdd<DeviceContext, T, int>(ctx, *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) {
GPUScatterNdAdd<DeviceContext, T, int64_t>(ctx, *dO, *index, dX);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
ops::GatherNdOpCUDAKernel<CUDA, double>,
ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
ops::GatherNdOpCUDAKernel<CUDA, int>,
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
REGISTER_OP_CUDA_KERNEL(gather_nd_grad,
ops::GatherNdGradOpCUDAKernel<CUDA, float>,
ops::GatherNdGradOpCUDAKernel<CUDA, double>,
ops::GatherNdGradOpCUDAKernel<CUDA, int64_t>,
ops::GatherNdGradOpCUDAKernel<CUDA, int>,
ops::GatherNdGradOpCUDAKernel<CUDA, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class GatherNdOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
CPUGatherNd<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
CPUGatherNd<T, int64_t>(ctx.device_context(), *x, *index, output);
}
}
};
template <typename T>
class GatherNdGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"This kernel only runs on CPU.");
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
ScatterNdAdd<T, int32_t>(ctx, *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) {
ScatterNdAdd<T, int64_t>(ctx, *dO, *index, dX);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <unordered_set>
#include <vector>
#include "math/math_function.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -57,6 +58,26 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
}
}
template <typename T, typename IndexT = int>
__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
T* output, const int* output_dims,
size_t remain_size, size_t slice_size,
size_t end_size) {
CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = indices[indices_i * end_size + j];
gather_i += (index_value * temp);
temp *= output_dims[j];
}
IndexT output_i = gather_i + slice_i;
paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
}
}
/**
* A thin wrapper on gpu tensor
* Return a new updated tensor from source tensor, scatter-assigned according to
......@@ -109,5 +130,59 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
p_src, p_index, p_output, index_size, slice_size, overwrite);
}
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context,
const Tensor& update, const Tensor& index,
Tensor* output) {
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
auto output_dims = output->dims();
auto output_dims_size = output_dims.size();
const T* p_update = update.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
// remain dim
auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = framework::product(remain_ddim);
// slice size
int64_t slice_size = 1;
for (int64_t i = end_size; i < output_dims_size; ++i) {
slice_size *= output_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T);
// put output_dims int CUDA
// gplace and cplace
const auto& ctx = context.template device_context<DeviceContext>();
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
auto cplace = platform::CPUPlace();
std::vector<int> v_output_dims(output_dims_size);
for (int i = 0; i < output_dims_size; ++i) {
v_output_dims[i] = static_cast<int>(output_dims[i]);
}
auto& dev_ctx = context.cuda_device_context();
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = output_dims_size * sizeof(int);
auto output_dims_ptr = allocator.Allocate(bytes);
int* g_output_dims = reinterpret_cast<int*>(output_dims_ptr->ptr());
memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
ctx.stream());
int block = 512;
int n = slice_size * remain_numel;
int grid = (n + block - 1) / block;
ScatterNdCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_update, p_index, p_output, g_output_dims, remain_numel, slice_size,
end_size);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -144,5 +144,49 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
}
template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
true, "It should be running on the CPU");
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
auto index_dims = index.dims();
auto index_dims_size = index_dims.size();
auto output_dims = output->dims();
auto output_dims_size = output_dims.size();
const T* p_update = update.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* result_p_output = output->data<T>();
const T* p_output = output->data<T>();
// final dim
int64_t end_size = index_dims[index_dims_size - 1];
// remain dim
auto remain_ddim = framework::slice_ddim(index_dims, 0, index_dims_size - 1);
int64_t remain_numel = framework::product(remain_ddim);
// slice size
int64_t slice_size = 1;
for (int64_t i = end_size; i < output_dims_size; ++i) {
slice_size *= output_dims[i];
}
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < remain_numel; ++i) {
IndexT index_ = 0;
IndexT temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j];
index_ += (index_value * temp);
temp *= output_dims[j];
}
elementwise_inner_add<T, IndexT>(ctx, p_update, p_output, result_p_output,
update, output, i, index_, slice_size,
slice_bytes);
}
}
} // namespace operators
} // namespace paddle
......@@ -122,6 +122,7 @@ __all__ = [
'resize_trilinear',
'resize_nearest',
'gather',
'gather_nd',
'scatter',
'sequence_scatter',
'random_crop',
......@@ -8449,6 +8450,91 @@ def gather(input, index, overwrite=True):
return out
def gather_nd(input, index, name=None):
"""
**Gather Nd Layer**
This function is actually a high-dimensional extension of :code:`gather`
and supports for simultaneous indexing by multiple axes. :attr:`index` is a
K-dimensional integer tensor, which is regarded as a (K-1)-dimensional
tensor of :attr:`index` into :attr:`input`, where each element defines
a slice of params:
.. math::
output[(i_0, ..., i_{K-2})] = input[index[(i_0, ..., i_{K-2})]]
Obviously, :code:`index.shape[-1] <= input.rank` . And, the output tensor has
shape :code:`index.shape[:-1] + input.shape[index.shape[-1]:]` .
.. code-block:: text
Given:
input = [[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
input.shape = (2, 3, 4)
* Case 1:
index = [[1]]
gather_nd(input, index)
= [input[1, :, :]]
= [[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]
* Case 2:
index = [[0,2]]
gather_nd(input, index)
= [input[0, 2, :]]
= [8, 9, 10, 11]
* Case 3:
index = [[1, 2, 3]]
gather_nd(input, index)
= [input[1, 2, 3]]
= [23]
Args:
input (Variable): The source input
index (Variable): The index input with rank > 1, index.shape[-1] <= input.rank
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically
Returns:
output (Variable): A tensor with the shape index.shape[:-1] + input.shape[index.shape[-1]:]
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[3, 4, 5], dtype='float32')
index = fluid.layers.data(name='index', shape=[2, 2], dtype='int32')
output = fluid.layers.gather_nd(x, index)
"""
helper = LayerHelper('gather_nd', **locals())
dtype = helper.input_dtype()
if name is None:
output = helper.create_variable_for_type_inference(dtype)
else:
output = helper.create_variable(
name=name, dtype=dtype, persistable=False)
helper.append_op(
type="gather_nd",
inputs={"X": input,
"Index": index},
outputs={"Out": output})
return output
def scatter(input, index, updates, name=None, overwrite=True):
"""
**Scatter Layer**
......
......@@ -185,6 +185,7 @@ set(TEST_OPS_WITH_GC
test_fill_constant_batch_size_like_op
test_fill_zeros_like2_op
test_gather_op
test_gather_nd_op
test_gaussian_random_batch_size_like_op
test_linear_chain_crf_op
test_lod_reset_op
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
class TestGatherNdOpWithEmptyIndex(OpTest):
"""
Index has empty element, which means copy entire tensor
"""
def setUp(self):
self.op_type = "gather_nd"
xnp = np.array(
[[65, 17, 2], [-14, -25, -1], [76, 22, 3]]).astype("float32")
self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
self.outputs = {
'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestGatherNdOpWithLowIndex(OpTest):
"""
Index has low rank, X has high rank
"""
def setUp(self):
self.op_type = "gather_nd"
xnp = np.array(
[[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float32")
index = np.array([[1], [2]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]} #[[14, 25, 1], [76, 22, 3]]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestGatherNdOpWithSameIndexAsX(OpTest):
"""
Index has same rank as X's rank
"""
def setUp(self):
self.op_type = "gather_nd"
xnp = np.array(
[[65, 17, 2], [14, 25, 1], [76, 22, 3]]).astype("float64")
index = np.array([[1, 1], [2, 1]]).astype("int64")
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]} #[25, 22]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestGatherNdOpWithHighRankSame(OpTest):
"""
Both Index and X have high rank, and Rank(Index) = Rank(X)
"""
def setUp(self):
self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31)
xnp = np.random.rand(*shape)
index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestGatherNdOpWithHighRankDiff(OpTest):
"""
Both Index and X have high rank, and Rank(Index) < Rank(X)
"""
def setUp(self):
self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31)
xnp = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T
index_re = index.reshape([10, 5, 20, 5])
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
#Test Python API
class TestGatherNdOpAPI(OpTest):
def test_case1(self):
x1 = fluid.layers.data(
name='x1', shape=[30, 40, 50, 60], dtype='float32')
index1 = fluid.layers.data(name='index1', shape=[2, 4], dtype='int32')
output1 = fluid.layers.gather_nd(x1, index1)
def test_case2(self):
x2 = fluid.layers.data(name='x2', shape=[30, 40, 50], dtype='float32')
index2 = fluid.layers.data(name='index2', shape=[2, 2], dtype='int64')
output2 = fluid.layers.gather_nd(x2, index2)
def test_case3(self):
x3 = fluid.layers.data(name='x3', shape=[3, 4, 5], dtype='float32')
index3 = fluid.layers.data(name='index3', shape=[2, 1], dtype='int32')
output3 = fluid.layers.gather_nd(x3, index3, name="gather_nd_layer")
#Test Raise Index Error
class TestGatherNdOpRaise(OpTest):
def test_check_raise(self):
def check_raise_is_test():
try:
x = fluid.layers.data(
name='x', shape=[3, 4, 5], dtype='float32')
index = fluid.layers.data(
name='index', shape=[2, 10], dtype='int32')
output = fluid.layers.gather_nd(x, index)
except Exception as e:
t = \
"Input(Index).shape[-1] <= Input(X).rank"
if t in str(e):
raise IndexError
self.assertRaises(IndexError, check_raise_is_test)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册