提交 eef55ca7 编写于 作者: Z Zhuoyuan

remodify

上级 2b35fca1
......@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include "paddle/framework/ddim.h"
/**
* Return a new tensor from source tensor, gathered according to index
......@@ -27,7 +27,7 @@ limitations under the License. */
template <typename Place, typename T>
Tensor* Gather(Tensor* src, Tensor* index) {
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size()==1);
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
// Source shape
......@@ -41,61 +41,67 @@ Tensor* Gather(Tensor* src, Tensor* index) {
/* slice size */
int slice_size = 1;
for(unsigned int i = 0; i < src_dims.size(); ++i)
slice_size *= src_dims[i];
for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
/* Gathering */
if (place == CPUPlace()) {
// init for CPU
output = New_tensor.mutable_data<T>(output_dims, CPUPlace());
CPUGather(src->data(), index->data(), slice_size, new_tensor->mutable_data());
} else { // GPU
// init for GPU
output = New_tensor.mutable_data<T>(output_dims, GPUPlace());
/* how to specialize device??*/
GPUGather(d, src->data(), index->data(), slice_size, new_tensor->mutable_data());
// init for CPU
output = New_tensor.mutable_data<T>(output_dims, CPUPlace());
CPUGather(
src->data(), index->data(), slice_size, new_tensor->mutable_data());
} else { // GPU
// init for GPU
output = New_tensor.mutable_data<T>(output_dims, GPUPlace());
/* how to specialize device??*/
GPUGather(
d, src->data(), index->data(), slice_size, new_tensor->mutable_data());
}
return New_tensor;
}
/* Implementation of CPU copy */
template<typename T>
void CPUGather(const T* params, const int* indices,
const int slice_size, const int index_size,
T* output) {
template <typename T>
void CPUGather(const T* params,
const int* indices,
const int slice_size,
const int index_size,
T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for(int i = 0; i < index_size; ++i)
int index_ = indices[i];
/* copy src[index_] to output[i] */
memcpy(output + i * slice_bytes,
params + index_ * slice_bytes,
slice_bytes);
for (size_t i = 0; i < index_size; ++i) {
int index_ = indices[i];
/* copy src[index_] to output[i] */
memcpy(
output + i * slice_bytes, params + index_ * slice_bytes, slice_bytes);
}
}
/* Implementation of GPU copy:
I suppose the GPUDevice& d, contains gpu_id and thread_id
d = cuda_stream(gpu_id_, stream_id_);
*/
template<typename T>
template <typename T>
void GPUGather(const GPUDevice& d,
const T* src, const int* index,
const int slice_size, const int index_size,
T* output) {
const T* src,
const int* index,
const int slice_size,
const int index_size,
T* output) {
int block_count = slice_size * index_size;
int thread_per_block = 1024;
GatherOpKernel<T>
<<<block_count, thread_per_block, 0, d.stream()>>>(
src, index, output, slice_size,
indices_size, slice_size, out_size);
GatherOpKernel<T><<<block_count, thread_per_block, 0, d.stream()>>>(
src, index, output, slice_size, indices_size, slice_size, out_size);
}
template <typename T>
__global__ void GatherOpKernel(const T* params, const int* indices, T* out,
__global__ void GatherOpKernel(const T* params,
const int* indices,
T* out,
int64 indices_size,
int64 slice_size, int64 out_size) {
/* I suppose we have the following macro,
int64 slice_size,
int64 out_size) {
/* I suppose we have the following macro,
which I strongly suggest that we should put in cuda:
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
......@@ -103,9 +109,9 @@ __global__ void GatherOpKernel(const T* params, const int* indices, T* out,
*/
CUDA_1D_KERNEL_LOOP(i, out_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int slice_i = i - indices_i * slice_size; // offset inside the slice
int gather_i = indices[indices_i];
int params_i = gather_i * slice_size + slice_i;
out[i] = *(params + params_i);
}
}
}
......@@ -14,96 +14,93 @@ limitations under the License. */
#pragma once
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include "paddle/framework/ddim.h"
/**
* Return a updated tensor from source tensor, scattered according to index:
* dst[i] += src[index[i]]
* input[src]: type-T source Tensor
* input[Index]: type-int index Tensor (1-D)
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename place, typename T>
void ScatterUpdate_func(Tensor* Src, Tensor* Dst, Tensor* Index) {
// assert index is an int-type tensor
assert(Index->istype(int));
// Source shape
auto src_dims = Src->dims();
auto dst_dims = Dst->dims();
DDim output_dims(dims_src);
// check Src shape and Dst shape should match
for(int i = 1; i < src_dims.size(); i++)
assert(src_dims[i]==dst_dims[i]);
int index_size = Index->dims()[0];
/* slice size */
int slice_size = 1;
for(unsigned int i = 0; i < src_dims.size(); ++i)
slice_size *= src_dims[i];
if (place == CPUPlace()) {
// init
output = new_tensor.mutable_data<T>(output_dims, CPUPlace());
CPUScatterUpdate(src->data(), index->data(), slice_size, new_tensor->mutable_data());
} else { // GPU
// init
output = new_tensor.mutable_data<T>(output_dims, GPUPlace());
/* how to specialize device??*/
GPUScatterUpdate(d, src->data(), index->data(), slice_size, new_tensor->mutable_data());
}
template <typename Place, typename T>
void ScatterUpdate(Tensor* src, Tensor* dst, Tensor* index) {
// Source shape
auto src_dims = src->dims();
auto dst_dims = dst->dims();
DDim output_dims(dims_src);
// check src shape and dst shape should match
for (size_t i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
int index_size = index->dims()[0];
/* slice size */
int slice_size = 1;
for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (place == CPUPlace()) {
// init
output = new_tensor.mutable_data<T>(output_dims, CPUPlace());
CPUScatterUpdate(
src->data(), index->data(), slice_size, new_tensor->mutable_data());
} else { // GPU
// init
output = new_tensor.mutable_data<T>(output_dims, GPUPlace());
/* how to specialize device??*/
GPUScatterUpdate(
d, src->data(), index->data(), slice_size, new_tensor->mutable_data());
}
}
/* Implementation of CPU copy */
template<typename T>
void CPUScatterUpdate(const T* src, const int* Index,
const int slice_size, const int index_size,
T* output) {
//const size_t slice_bytes = slice_size * sizeof(T);
for(int i = 0; i < index_size; ++i)
int index_ = index[i];
/* dst[index_] += src[index_]
add operation size: slice_size
*/
math::vAdd<T>(slice_size, src + index_ * slice_bytes,
output + i * slice_bytes,
output + i * slice_bytes);
/* Scatter update, not just assign
memcpy(output + i * slice_bytes,
src + index_ * slice_bytes,
slice_bytes);
*/
template <typename T>
void CPUScatterUpdate(const T* src,
const int* index,
const int slice_size,
const int index_size,
T* output) {
// const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) {
int index_ = index[i];
math::vAdd<T>(slice_size,
src + index_ * slice_bytes,
output + i * slice_bytes,
output + i * slice_bytes);
}
}
/* Implementation of GPU scatter:
I suppose the GPUDevice& d, contains gpu_id and thread_id
d = cuda_stream(gpu_id_, stream_id_);
*/
template<typename T>
template <typename T>
void GPUScatterUpdate(const GPUDevice& d,
const T* src, const int* Index,
const int slice_size, const int index_size,
T* output) {
int block_count = slice_size * index_size;
int thread_per_block = 1024;
ScatterOpKernel<T>
<<<block_count, thread_per_block, 0, d.stream()>>>(
src, Index, output, slice_size,
indices_size, slice_size, out_size);
const T* src,
const int* index,
const int slice_size,
const int index_size,
T* output) {
int block_count = slice_size * index_size;
int thread_per_block = 1024;
ScatterOpKernel<T><<<block_count, thread_per_block, 0, d.stream()>>>(
src, index, output, slice_size, indices_size, slice_size, out_size);
}
template <typename T>
__global__ void ScatterOpKernel(const T* params, const int* indices, T* out,
int64 indices_size,
int64 slice_size, int64 out_size) {
/* I suppose we have the following macro,
__global__ void ScatterOpKernel(const T* params,
const int* indices,
T* out,
int64 indices_size,
int64 slice_size,
int64 out_size) {
/* I suppose we have the following macro,
which I strongly suggest that we should put in cuda:
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
......@@ -111,9 +108,9 @@ __global__ void ScatterOpKernel(const T* params, const int* indices, T* out,
*/
CUDA_1D_KERNEL_LOOP(i, out_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int slice_i = i - indices_i * slice_size; // offset inside the slice
int scatter_i = indices[indices_i];
int params_i = scatter_i * slice_size + slice_i;
out[i] += *(params + params_i);
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册