From 39f14f1dd6fd6810472fd100ad59a1d1cdb661f1 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 9 Aug 2017 15:24:32 -0700 Subject: [PATCH] scatter update implemented --- paddle/operators/CMakeLists.txt | 2 + paddle/operators/scatter.h | 92 ++++++++++++++++++++++++++++++++ paddle/operators/scatter_test.cc | 52 ++++++++++++++++++ 3 files changed, 146 insertions(+) create mode 100644 paddle/operators/scatter.h create mode 100644 paddle/operators/scatter_test.cc diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e018a112a4a..7ba9384fa80 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,6 +43,8 @@ endfunction() cc_test(gather_test SRCS gather_test.cc DEPS tensor) +cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) + cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h new file mode 100644 index 00000000000..714c022c02d --- /dev/null +++ b/paddle/operators/scatter.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +// Implementation of CPU copy +template +void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, + const size_t index_size, + paddle::framework::Tensor* output) { + paddle::framework::DDim output_dims = output->dims(); + + for (size_t i = 0; i < index_size; ++i) { + int index_ = index[i]; + + paddle::framework::Tensor src_ = *src; + paddle::framework::Tensor output_ = *output; + if (index_size > 1) src_ = src->Slice(i, i + 1); + if (output_dims[0] > 1) output_ = output->Slice(index_, index_ + 1); + + auto X = EigenVector::Flatten(src_); + auto Y = EigenVector::Flatten(output_); + + Y = X + Y; + } +} + +// Implementation of GPU scatter: +template +void GPUScatterUpdate(const T* src, const int* index, const int slice_size, + const int index_size, T* output); + +/** + * Return a updated tensor from source tensor, scattered according to index: + * dst[i] += src[index[i]] + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void ScatterUpdate(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + auto dst_dims = output->dims(); + + // check src shape and dst shape should match + for (size_t i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); + + // slice size + size_t slice_size = 1; + for (size_t i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + if (platform::is_cpu_place(place)) { + CPUScatterUpdate(src, index->data(), index_size, output); + } else { + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc new file mode 100644 index 00000000000..4449ce65643 --- /dev/null +++ b/paddle/operators/scatter_test.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/scatter.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +#include +#include +#include + +TEST(scatter, ScatterUpdate) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + Tensor* src = new Tensor(); + Tensor* index = new Tensor(); + Tensor* output = new Tensor(); + + float* p_src = nullptr; + int* p_index = nullptr; + p_src = src->mutable_data(make_ddim({1, 4}), CPUPlace()); + p_index = index->mutable_data(make_ddim({1}), CPUPlace()); + + for (size_t i = 0; i < 4; ++i) p_src[i] = float(i); + p_index[0] = 1; + + float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); + + ScatterUpdate(CPUPlace(), src, index, output); + + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); + for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], float(i - 4)); + for (size_t i = 4; i < 8; ++i) + EXPECT_EQ(output->data()[i], float(i - 4)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data()[i], float(0)); +} -- GitLab