提交 2817ca03 编写于 作者: Z Zhuoyuan 提交者: GitHub

Merge pull request #4483 from zchen0211/develop

gather, scatter with gpu support, passed python test
...@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet( ...@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0]; dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace()); tensor_child->mutable_data<float>(dim, platform::CPUPlace());
Gather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], CPUGather<float>(dev_ctx, *tensor_parent, index_tensors[i], tensor_child);
tensor_child);
} }
} }
...@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output); Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child); PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>(); auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterUpdate<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], ScatterAssign<float>(dev_ctx, *tensor_child, index_tensors[i],
tensor_parent); tensor_parent);
} }
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Place;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
size_t index_size, size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int gather_i = indices[indices_i];
int params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
/**
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
GatherCUDAKernel<T><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
}
} // namespace operators
} // namespace paddle
...@@ -24,49 +24,40 @@ limitations under the License. */ ...@@ -24,49 +24,40 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// Implementation of CPU copy using framework::Tensor;
template <typename T>
void CPUGather(const T* src, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
}
// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/** /**
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index * Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D) * input[index]: type-int index Tensor (1-D)
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void Gather(const platform::Place& place, const paddle::framework::Tensor* src, void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const paddle::framework::Tensor* index, const Tensor& index, Tensor* output) {
paddle::framework::Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); framework::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering const size_t slice_bytes = slice_size * sizeof(T);
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size, for (int i = 0; i < index_size; ++i) {
output->data<T>()); int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
} }
} }
......
...@@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GatherOp should not be null."); "Output(Out) of GatherOp should not be null.");
auto index_dims = ctx->GetInputDim("Index");
PADDLE_ENFORCE(index_dims.size() == 1);
int batch_size = ctx->GetInputDim("Index")[0]; int batch_size = ctx->GetInputDim("Index")[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx->GetInputDim("X")); framework::DDim output_dims(ctx->GetInputDim("X"));
...@@ -79,8 +81,5 @@ Out = X[Index] ...@@ -79,8 +81,5 @@ Out = X[Index]
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp); ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>);
ops::GatherOpKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gather.cu.h"
#include "paddle/framework/eigen.h"
#include "paddle/operators/gather_op.h"
#include "scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
GPUGather<T>(ctx.device_context(), *x, *index, output);
}
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *x = ctx.Input<Tensor>("X");
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
...@@ -23,29 +23,40 @@ namespace operators { ...@@ -23,29 +23,40 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename T>
class GatherOpKernel : public framework::OpKernel<T> { class GatherOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
auto *Index = ctx.Input<Tensor>("Index"); "This kernel only runs on CPU.");
auto *Y = ctx.Output<Tensor>("Out");
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
Y->mutable_data<T>(ctx.GetPlace()); CPUGather<T>(ctx.device_context(), *x, *index, output);
Gather<T>(ctx.GetPlace(), X, Index, Y);
} }
}; };
template <typename Place, typename T> template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> { class GatherGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
} }
}; };
......
...@@ -41,7 +41,9 @@ TEST(Gather, GatherData) { ...@@ -41,7 +41,9 @@ TEST(Gather, GatherData) {
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace()); int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
Gather<int>(CPUPlace(), src, index, output); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, *src, *index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void ScatterCUDAKernel(const T* params, const int* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int scatter_i = indices[indices_i];
int out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = *(params + i);
}
}
/**
* A thin wrapper on gpu tensor
* Return a new updated tensor from source tensor, scatter-assigned according to
* index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
ScatterCUDAKernel<T><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
}
} // namespace operators
} // namespace paddle
...@@ -24,63 +24,42 @@ namespace paddle { ...@@ -24,63 +24,42 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
// Implementation of CPU copy
template <typename T>
void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index,
const size_t index_size,
paddle::framework::Tensor* output) {
paddle::framework::DDim output_dims = output->dims();
for (size_t i = 0; i < index_size; ++i) {
int index_ = index[i];
paddle::framework::Tensor src_ = *src;
paddle::framework::Tensor output_ = *output;
if (index_size > 1) src_ = src->Slice<T>(i, i + 1);
if (output_dims[0] > 1) output_ = output->Slice<T>(index_, index_ + 1);
auto X = EigenVector<T>::Flatten(src_);
auto Y = EigenVector<T>::Flatten(output_);
Y = X + Y;
}
}
// Implementation of GPU scatter:
template <typename T>
void GPUScatterUpdate(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/** /**
* Return a updated tensor from source tensor, scattered according to index: * Return a updated tensor from source tensor, scattered according to index:
* dst[i] += src[index[i]] * dst[i] = src[index[i]]
* input[src]: type-T source Tensor * input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D) * input[index]: type-int index Tensor (1-D)
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void ScatterUpdate(const platform::Place& place, void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const paddle::framework::Tensor* src, const Tensor& index, Tensor* output) {
const paddle::framework::Tensor* index, PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
paddle::framework::Tensor* output) {
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
// check src shape and dst shape should match // check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++) for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
if (platform::is_cpu_place(place)) { // slice size
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output); size_t slice_size = 1;
} else { for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes);
} }
} }
......
...@@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates ...@@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
ops::ScatterGradOp); ops::ScatterGradOp);
REGISTER_OP_CPU_KERNEL(scatter, REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>);
ops::ScatterOpKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
scatter_grad,
ops::ScatterGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gather.cu.h"
#include "paddle/operators/gather_op.h"
#include "scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
Out->ShareDataWith<T>(*Ref);
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
}
};
template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index]
GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel<float>);
...@@ -23,10 +23,12 @@ namespace operators { ...@@ -23,10 +23,12 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename T>
class ScatterOpKernel : public framework::OpKernel<T> { class ScatterOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Ref = ctx.Input<Tensor>("Ref"); auto *Ref = ctx.Input<Tensor>("Ref");
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *Updates = ctx.Input<Tensor>("Updates"); auto *Updates = ctx.Input<Tensor>("Updates");
...@@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates // In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref); Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:] // Apply ScatterUpdate: Out[index] += Updates[:]
ScatterUpdate<T>(ctx.GetPlace(), Updates, Index, Out); ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
} }
}; };
template <typename Place, typename T> template <typename T>
class ScatterGradientOpKernel : public framework::OpKernel<T> { class ScatterGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
...@@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates); CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
} }
}; };
......
...@@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) { ...@@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) {
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace()); float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
ScatterUpdate<float>(CPUPlace(), src, index, output); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
ScatterAssign<float>(ctx, *src, *index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
...@@ -10,7 +10,7 @@ class TestScatterOp(OpTest): ...@@ -10,7 +10,7 @@ class TestScatterOp(OpTest):
index_np = np.array([1, 2]).astype("int32") index_np = np.array([1, 2]).astype("int32")
updates_np = np.random.random((2, 3)).astype("float32") updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np) output_np = np.copy(ref_np)
output_np[index_np] += updates_np output_np[index_np] = updates_np
self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
...@@ -18,7 +18,7 @@ class TestScatterOp(OpTest): ...@@ -18,7 +18,7 @@ class TestScatterOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates', 'Ref'], 'Out', in_place=True) self.check_grad(['Updates'], 'Out', in_place=True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册