未验证 提交 6e1e036a 编写于 作者: J JamesLim 提交者: GitHub

Implement cuda kernel for index_sample. (#30380)

上级 666efc23
......@@ -12,7 +12,170 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_sample_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
T* out_data, size_t index_length,
size_t input_length, size_t batch_size) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;
if (index_i < index_length & index_j < batch_size) {
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
}
}
template <typename T, typename IndexT = int>
__global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size,
bool same_data_in_row = true) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;
if (index_i < index_length & index_j < batch_size) {
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[sample_idx];
}
}
}
template <typename T>
class IndexSampleKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("X");
auto* index = ctx.Input<LoDTensor>("Index");
auto* output = ctx.Output<LoDTensor>("Out");
const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
const auto* in_data = input->data<T>();
auto* out_data = output->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
auto input_dim = input->dims();
auto index_dim = index->dims();
size_t batch_size = input_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
auto block_width = platform::RoundToPowerOfTwo(index_length);
int block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
IndexSampleForward<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
IndexSampleForward<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, in_data, out_data, index_length, input_length,
batch_size);
}
}
};
template <typename T>
class IndexSampleGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* output_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = ctx.Input<LoDTensor>("Index");
const auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto stream =
ctx.template device_context<platform::CUDADeviceContext>().stream();
auto input_num = input_grad->numel();
auto input_dim = input_grad->dims();
auto index_dim = index->dims();
size_t batch_size = index_dim[0];
size_t input_length = input_dim[1];
size_t index_length = index_dim[1];
bool same_data_in_index_row = index_length == 1 ? false : true;
auto block_width = platform::RoundToPowerOfTwo(index_length);
auto block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
IndexSampleGrad<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
} else if (index_type == framework::proto::VarType::INT32) {
const int* index_data = index->data<int>();
IndexSampleGrad<T, int><<<grid_dim, block_dim, 0, stream>>>(
index_data, input_grad_data, output_grad_data, index_length,
input_length, batch_size, same_data_in_index_row);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......
......@@ -92,9 +92,9 @@ class TestCase4(TestIndexSampleOp):
"""
For int64 index type
"""
self.x_shape = (10, 100)
self.x_shape = (10, 128)
self.x_type = "float64"
self.index_shape = (10, 10)
self.index_shape = (10, 64)
self.index_type = "int64"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册