提交 88a8eedd 编写于 作者: Z zchen0211

scatter gather gpu

gather scatter gpu
上级 184768e0
...@@ -169,8 +169,8 @@ void CondOp::Run(const Scope& scope, ...@@ -169,8 +169,8 @@ void CondOp::Run(const Scope& scope,
tensor_child->Resize(dim); tensor_child->Resize(dim);
tensor_child->mutable_data<float>(dim, platform::CPUPlace()); tensor_child->mutable_data<float>(dim, platform::CPUPlace());
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], CPUTGather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child); tensor_child);
} }
} }
...@@ -194,7 +194,7 @@ void CondOp::Run(const Scope& scope, ...@@ -194,7 +194,7 @@ void CondOp::Run(const Scope& scope,
PADDLE_ENFORCE_NOT_NULL(v); PADDLE_ENFORCE_NOT_NULL(v);
LoDTensor* tensor_child = v->GetMutable<LoDTensor>(); LoDTensor* tensor_child = v->GetMutable<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], ScatterAssign<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
tensor_parent); tensor_parent);
} }
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Place;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
size_t index_size, size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int gather_i = indices[indices_i];
int params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
// Implementation of GPU copy:
template <typename T>
struct GPUGather {
void operator()(const T* src, const int* index, const int slice_size,
const int index_size, T* output) {
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
GatherCUDAKernel<T><<<grid, block>>>(src, index, output, index_size,
slice_size);
}
};
/**
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void GPUTGather(const Place& place, const Tensor* src, const Tensor* index,
Tensor* output) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
auto src_dims = src->dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering
GPUGather<T> gather_functor;
gather_functor(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
}
} // namespace operators
} // namespace paddle
...@@ -26,31 +26,31 @@ namespace operators { ...@@ -26,31 +26,31 @@ namespace operators {
// Implementation of CPU copy // Implementation of CPU copy
template <typename T> template <typename T>
void CPUGather(const T* src, const int* indices, const int slice_size, struct CPUGather {
const int index_size, T* output) { void operator()(const T* src, const int* indices, const int slice_size,
const size_t slice_bytes = slice_size * sizeof(T); const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = indices[i]; int index_ = indices[i];
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
} }
} };
// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/** /**
* A thin wrapper on cpu tensor
* Return a new tensor from source tensor, gathered according to index * Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D) * input[index]: type-int index Tensor (1-D)
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void Gather(const platform::Place& place, const paddle::framework::Tensor* src, void CPUTGather(const platform::Place& place,
const paddle::framework::Tensor* index, const paddle::framework::Tensor* src,
paddle::framework::Tensor* output) { const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index->dims()[0];
...@@ -64,10 +64,9 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, ...@@ -64,10 +64,9 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering // Gathering
if (platform::is_cpu_place(place)) { CPUGather<T> gather_functor;
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size, gather_functor(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>()); output->data<T>());
}
} }
} // namespace operators } // namespace operators
......
...@@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null."); "Output(Out) of GatherOp should not be null.");
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1);
int batch_size = ctx->GetInputDim("Index")[0]; int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
...@@ -79,8 +81,5 @@ Out = X[Index] ...@@ -79,8 +81,5 @@ Out = X[Index]
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp); ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>);
ops::GatherOpKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gather.cu.h"
#include "paddle/framework/eigen.h"
#include "paddle/operators/gather_op.h"
#include "scatter.cu.h"
namespace paddle {
namespace operators {
// template <typename T>
__global__ void print_arr(const float *params, const int N) {
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
}
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
GPUTGather<T>(ctx.GetPlace(), x, index, output);
}
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
LOG(INFO) << "Gather grad here";
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *x = ctx.Input<Tensor>("X");
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUTScatter<T>(ctx.GetPlace(), dO, Index, dX);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
...@@ -23,29 +23,40 @@ namespace operators { ...@@ -23,29 +23,40 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename T>
class GatherOpKernel : public framework::OpKernel<T> { class GatherOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
auto *Index = ctx.Input<Tensor>("Index"); "This kernel only runs on CPU.");
auto *Y = ctx.Output<Tensor>("Out");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
Y->mutable_data<T>(ctx.GetPlace()); CPUTGather<T>(ctx.GetPlace(), x, index, output);
Gather<T>(ctx.GetPlace(), X, Index, Y);
} }
}; };
template <typename Place, typename T> template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> { class GatherGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.GetPlace(), dO, Index, dX);
} }
}; };
......
...@@ -41,7 +41,7 @@ TEST(Gather, GatherData) { ...@@ -41,7 +41,7 @@ TEST(Gather, GatherData) {
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace()); int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
Gather<int>(CPUPlace(), src, index, output); CPUTGather<int>(CPUPlace(), src, index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void ScatterCUDAKernel(const T* params, const int* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int scatter_i = indices[indices_i];
int out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = *(params + i);
}
}
// Implementation of GPU copy:
template <typename T>
struct GPUScatterAssign {
void operator()(const T* src, const int* index, const int slice_size,
const int index_size, T* output) {
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
// printf("grid, block: %d %d\n", grid, block);
ScatterCUDAKernel<T><<<grid, block>>>(src, index, output, index_size,
slice_size);
}
};
/**
* A thin wrapper on gpu tensor
* Return a new updated tensor from source tensor, scatter-assigned according to
* index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void GPUTScatter(const platform::Place& place,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
auto src_dims = src->dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Scatter Assign
GPUScatterAssign<T> scatter_functor;
scatter_functor(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
}
} // namespace operators
} // namespace paddle
...@@ -24,49 +24,33 @@ namespace paddle { ...@@ -24,49 +24,33 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
// Implementation of CPU copy // Implementation of CPU copy
template <typename T> template <typename T>
void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, void CPUScatterAssign(const T* src, const int* index, const int slice_size,
const size_t index_size, const int index_size, T* output) {
paddle::framework::Tensor* output) { // paddle::framework::DDim output_dims = output->dims();
paddle::framework::DDim output_dims = output->dims(); const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = index[i]; int index_ = index[i];
memcpy(output + index_ * slice_size, src + i * slice_size, slice_bytes);
paddle::framework::Tensor src_ = *src;
paddle::framework::Tensor output_ = *output;
if (index_size > 1) src_ = src->Slice<T>(i, i + 1);
if (output_dims[0] > 1) output_ = output->Slice<T>(index_, index_ + 1);
auto X = EigenVector<T>::Flatten(src_);
auto Y = EigenVector<T>::Flatten(output_);
Y = X + Y;
} }
} }
// Implementation of GPU scatter:
template <typename T>
void GPUScatterUpdate(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/** /**
* Return a updated tensor from source tensor, scattered according to index: * Return a updated tensor from source tensor, scattered according to index:
* dst[i] += src[index[i]] * dst[i] = src[index[i]]
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D) * input[index]: type-int index Tensor (1-D)
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void ScatterUpdate(const platform::Place& place, void ScatterAssign(const platform::Place& place,
const paddle::framework::Tensor* src, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index, const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) { paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index->dims()[0];
...@@ -74,18 +58,19 @@ void ScatterUpdate(const platform::Place& place, ...@@ -74,18 +58,19 @@ void ScatterUpdate(const platform::Place& place,
auto src_dims = src->dims(); auto src_dims = src->dims();
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const T* p_src = src->data<T>();
const int* p_index = index->data<int>();
T* p_output = output->data<T>();
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
// slice size // slice size
size_t slice_size = 1; size_t slice_size = 1;
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (platform::is_cpu_place(place)) { CPUScatterAssign<T>(p_src, p_index, slice_size, index_size, p_output);
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
} else {
}
} }
} // namespace operators } // namespace operators
......
...@@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates ...@@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
ops::ScatterGradOp); ops::ScatterGradOp);
REGISTER_OP_CPU_KERNEL(scatter, REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>);
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
scatter_grad,
ops::ScatterGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gather.cu.h"
#include "paddle/operators/gather_op.h"
#include "scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
Out->ShareDataWith<T>(*Ref);
GPUTScatter<T>(ctx.GetPlace(), Updates, Index, Out);
}
};
template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index]
GPUTGather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel<float>);
...@@ -23,10 +23,12 @@ namespace operators { ...@@ -23,10 +23,12 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename T>
class ScatterOpKernel : public framework::OpKernel<T> { class ScatterOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Ref = ctx.Input<Tensor>("Ref"); auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates // In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref); Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:] // Apply ScatterUpdate: Out[index] += Updates[:]
ScatterUpdate<T>(ctx.GetPlace(), Updates, Index, Out); ScatterAssign<T>(ctx.GetPlace(), Updates, Index, Out);
} }
}; };
template <typename Place, typename T> template <typename T>
class ScatterGradientOpKernel : public framework::OpKernel<T> { class ScatterGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
...@@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates); CPUTGather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
} }
}; };
......
...@@ -40,7 +40,7 @@ TEST(scatter, ScatterUpdate) { ...@@ -40,7 +40,7 @@ TEST(scatter, ScatterUpdate) {
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace()); float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
ScatterUpdate<float>(CPUPlace(), src, index, output); ScatterAssign<float>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
...@@ -10,7 +10,7 @@ class TestScatterOp(OpTest): ...@@ -10,7 +10,7 @@ class TestScatterOp(OpTest):
index_np = np.array([1, 2]).astype("int32") index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] += updates_np output_np[index_np] = updates_np
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
...@@ -18,7 +18,7 @@ class TestScatterOp(OpTest): ...@@ -18,7 +18,7 @@ class TestScatterOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates', 'Ref'], 'Out', in_place=True) self.check_grad(['Updates'], 'Out', in_place=True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册