提交 2b35fca1 编写于 作者: Z Zhuoyuan

gather modify

上级 08021979
...@@ -21,44 +21,41 @@ limitations under the License. */ ...@@ -21,44 +21,41 @@ limitations under the License. */
/** /**
* Return a new tensor from source tensor, gathered according to index * Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[Index]: type-int index Tensor (1-D) * input[index]: type-int index Tensor (1-D)
* return: output tensor * return: output tensor
*/ */
template <typename place, typename T> template <typename Place, typename T>
Tensor* Gather_func(Tensor* Src, Tensor* Index) { Tensor* Gather(Tensor* src, Tensor* index) {
// assert index is an int-type tensor? // check index of shape 1-D
// assert(Index->istype(int)); PADDLE_ENFORCE(index->dims().size()==1);
int index_size = index->dims()[0];
// check index of shape 1-D // Source shape
assert(Index->dims().size()==1); auto src_dims = src->dims();
int index_size = Index->dims()[0]; DDim output_dims(dims_src);
// Create a tensor of shape [index_size, dim_src[1:]]
output_dims[0] = index_size;
// Source shape Tensor* New_tensor;
auto src_dims = Src->dims(); float* output = nullptr;
DDim output_dims(dims_src);
// Create a tensor of shape [index_size, dim_src[1:]]
output_dims[0] = index_size;
Tensor* New_tensor; /* slice size */
float* output = nullptr; int slice_size = 1;
for(unsigned int i = 0; i < src_dims.size(); ++i)
slice_size *= src_dims[i];
/* slice size */ /* Gathering */
int slice_size = 1; if (place == CPUPlace()) {
for(unsigned int i = 0; i < src_dims.size(); ++i) // init for CPU
slice_size *= src_dims[i]; output = New_tensor.mutable_data<T>(output_dims, CPUPlace());
CPUGather(src->data(), index->data(), slice_size, new_tensor->mutable_data());
/* Gathering */ } else { // GPU
if (place == CPUPlace()) { // init for GPU
// init for CPU output = New_tensor.mutable_data<T>(output_dims, GPUPlace());
output = New_tensor.mutable_data<T>(output_dims, CPUPlace()); /* how to specialize device??*/
CPUGather(Src->data(), Index->data(), slice_size, new_tensor->mutable_data()); GPUGather(d, src->data(), index->data(), slice_size, new_tensor->mutable_data());
} else { // GPU }
// init for GPU return New_tensor;
output = New_tensor.mutable_data<T>(output_dims, GPUPlace());
/* how to specialize device??*/
GPUGather(d, Src->data(), Index->data(), slice_size, new_tensor->mutable_data());
}
return New_tensor;
} }
/* Implementation of CPU copy */ /* Implementation of CPU copy */
...@@ -82,15 +79,15 @@ void CPUGather(const T* params, const int* indices, ...@@ -82,15 +79,15 @@ void CPUGather(const T* params, const int* indices,
*/ */
template<typename T> template<typename T>
void GPUGather(const GPUDevice& d, void GPUGather(const GPUDevice& d,
const T* src, const int* Index, const T* src, const int* index,
const int slice_size, const int index_size, const int slice_size, const int index_size,
T* output) { T* output) {
int block_count = slice_size * index_size; int block_count = slice_size * index_size;
int thread_per_block = 1024; int thread_per_block = 1024;
GatherOpKernel<T> GatherOpKernel<T>
<<<block_count, thread_per_block, 0, d.stream()>>>( <<<block_count, thread_per_block, 0, d.stream()>>>(
src, Index, output, slice_size, src, index, output, slice_size,
indices_size, slice_size, out_size); indices_size, slice_size, out_size);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册