diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index f0114b9e4908d65b3fddb493230777f9e500b4e1..68c5526bbb9ce39c2bb1cdbaac51eaa75d09e1d1 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_OP(lookup_table); namespace paddle { namespace framework { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e43df6173791bc54b3faffc034867f7d..1ca5010eaeb14948c5fab419e49d5114410f7c45 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -42,6 +42,8 @@ function(op_library TARGET) endfunction() add_subdirectory(math) +add_subdirectory(functor) + cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) @@ -66,5 +68,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) -op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu) +op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) +op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu DEPS math_functor) diff --git a/paddle/operators/functor/CMakeLists.txt b/paddle/operators/functor/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3b39e5fc2533002218210f0b7a1665a6d4b280f --- /dev/null +++ b/paddle/operators/functor/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_GPU) + nv_library(math_functor SRCS math_functor.cc math_functor.cu DEPS device_context) +else() + cc_library(math_functor SRCS math_functor.cc DEPS device_context) +endif() diff --git a/paddle/operators/functor/math_functor.cc b/paddle/operators/functor/math_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f2767f171a70d2df308c6e593c9e251b0337577 --- /dev/null +++ b/paddle/operators/functor/math_functor.cc @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/functor/math_functor.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Set { + void operator()(const T alpha, framework::Tensor* Y, + platform::DeviceContext* context) { + int N = product(Y->dims()); + T* YData = Y->mutable_data(context->GetPlace()); + if (alpha == static_cast(0)) { + memset(YData, 0, N * sizeof(T)); + } else { + framework::EigenVector::Flatten(*Y) + .setConstant(alpha); + } + } +}; + +template struct Set; +template struct Set; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/functor/math_functor.cu b/paddle/operators/functor/math_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..6dc828c60ac710fd2cd0c35f838915b73a6fea83 --- /dev/null +++ b/paddle/operators/functor/math_functor.cu @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/functor/math_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +__global__ void SetKernel(const int N, const T alpha, T* Y) { + CUDA_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; } +} + +template +struct Set { + void operator()(const T alpha, framework::Tensor* Y, + platform::DeviceContext* context) { + int N = product(Y->dims()); + T* YData = Y->mutable_data(context->GetPlace()); + SetKernel<<<(N + 512 - 1) / 512, 512>>>(N, alpha, YData); + } +}; + +template struct Set; +template struct Set; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/functor/math_functor.h b/paddle/operators/functor/math_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..d5c7bd368fee44f18bc55d127ea58647c944114c --- /dev/null +++ b/paddle/operators/functor/math_functor.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Set { + void operator()(const T alpha, paddle::framework::Tensor* Y, + paddle::platform::DeviceContext* context); +}; + +} // namespace functor +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f70458a87f5d014a692cd35455a997bdcc15776 --- /dev/null +++ b/paddle/operators/lookup_table_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lookup_table_op.h" + +namespace paddle { +namespace operators { + +class LookupTableOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &context) const override { + auto table_t = context.Input("W"); + auto ids_t = context.Input("Ids"); + auto output_t = context.Output("Out"); + + output_t->Resize({ids_t->dims()[0], table_t->dims()[1]}); + } +}; + +class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LookupTableOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("W", + "An input represents embedding tensors," + " which is a learnable parameter."); + AddInput("Ids", + "An input with type int32 or int64" + "contains the ids to be looked up in W.") + .NotInGradient(); + AddOutput("Out", "The lookup results, which have the same type with W."); + AddComment( + "This operator is used to perform lookups on the parameter W," + "then concatenated into a dense tensor."); + } +}; + +class LookupTableOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &context) const override { + context.Output(0)->Resize(context.Input(0)->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker, + lookup_table_grad, ops::LookupTableOpGrad); + +REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel); +REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel); diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..94b440e00e872e67cec9dab57034f088a26e5c0a --- /dev/null +++ b/paddle/operators/lookup_table_op.cu @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/functor/math_functor.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void LookupTable(T* output, const T* table, const uint32_t* ids, + const int N, const int K, const int D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDimX; + + while (idy < K) { + int id = ids[idy]; + PADDLE_ASSERT(id >= 0); + PADDLE_ASSERT(id < N); + T* out = output + idy; + const T* tab = table + id; + for (int i = idx; i < D; i += blockDimX) { + out[i] = tab[i]; + } + idy += blockDimY * gridDimX; + } +} + +template +__global__ void LookupTableGradKernel(T* table, const T* output, + const uint32_t* ids, const int N, + const int K, const int D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDimX; + + while (idy < K) { + int id = ids[idy]; + PADDLE_ASSERT(id >= 0); + PADDLE_ASSERT(id < N); + const T* out = output + idy; + T* tab = table + id; + for (int i = idx; i < D; i += blockDimX) { + paddle::platform::CudaAtomicAdd(tab + i, out[i]); + } + idy += blockDimY * gridDimX; + } +} + +template +class LookupTableCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto table_t = context.Input("W"); + auto ids_t = context.Input("Ids"); + auto output_t = context.Output("Out"); + + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + size_t K = product(ids_t->dims()); + auto ids = ids_t->data(); + auto table = table_t->data(); + auto output = output_t->mutable_data(context.GetPlace()); + + dim3 threads(128, 8); + dim3 grids(8, 1); + LookupTable<<>>(output, table, ids, N, K, D); + } +}; + +template +class LookupTableGrad : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ids_t = context.Input("Ids"); + auto d_output_t = context.Input(framework::GradVarName("Out")); + auto d_table_t = context.Output(framework::GradVarName("W")); + + int N = d_table_t->dims()[0]; + int D = d_table_t->dims()[1]; + int K = product(ids_t->dims()); + const uint32_t* ids = ids_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); + const T* d_output = d_output_t->data(); + + auto* device_context = + const_cast(context.device_context_); + functor::Set()(static_cast(0), d_table_t, + device_context); + dim3 threads(128, 8); + dim3 grids(8, 1); + LookupTableGradKernel<<>>(d_table, d_output, + ids, N, K, D); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel); +REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h new file mode 100644 index 0000000000000000000000000000000000000000..790ecab3c66ada68c48d3306a7565430b340f431 --- /dev/null +++ b/paddle/operators/lookup_table_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/functor/math_functor.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class LookupTableKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto table_t = context.Input("W"); // float tensor + auto ids_t = context.Input("Ids"); // int tensor + auto output_t = context.Output("Out"); // float tensor + + size_t N = table_t->dims()[0]; + size_t D = table_t->dims()[1]; + auto ids = ids_t->data(); + auto table = table_t->data(); + auto output = output_t->mutable_data(context.GetPlace()); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); + } + } +}; + +template +class LookupTableGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto ids_t = context.Input("Ids"); + auto d_output_t = context.Input(framework::GradVarName("Out")); + auto d_table_t = context.Output(framework::GradVarName("W")); + + size_t N = d_table_t->dims()[0]; + size_t D = d_table_t->dims()[1]; + auto ids = ids_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); + const T* d_output = d_output_t->data(); + + auto* device_context = + const_cast(context.device_context_); + functor::Set()(static_cast(0), d_table_t, + device_context); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { + PADDLE_ENFORCE_LT(ids[i], N); + PADDLE_ENFORCE_GE(ids[i], 0); + for (size_t j = 0; j < D; ++j) { + d_table[ids[i] * D + j] += d_output[i * D + j]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..4346291117861b497e648a9cffb8b5767f3f6eec --- /dev/null +++ b/paddle/platform/cuda_helper.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace platform { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +#define CUDA_ATOMIC_WRAPPER(op, T) \ + __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) + +#define USE_CUDA_ATOMIC(op, T) \ + CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } + +// For atomicAdd. +USE_CUDA_ATOMIC(Add, float); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +USE_CUDA_ATOMIC(Add, double); +#else +// Custom implementation of atomicAdd for double. +// This implementation is copied from CUDA manual. +CUDA_ATOMIC_WRAPPER(Add, double) { + unsigned long long int* address_as_ull = + reinterpret_cast(address); + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + return __longlong_as_double(old); +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a0713092723b6a99b2416e06ff1a436f043b..65c02f2cfb05a1674ecac5e26da1f34bca93e7d5 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -27,3 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) +py_test(test_lookup_table SRCS test_lookup_table.py) diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py new file mode 100644 index 0000000000000000000000000000000000000000..071069768bf754eff20a9ee2e67279a3a61a14fc --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -0,0 +1,31 @@ +import unittest +import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op + + +class TestSigmoidOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = 'lookup_table' + table = np.random.random((17, 31)).astype('float32') + ids = np.random.randint(0, 17, 4) + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids]} + + +class TestSigmoidGradOp(GradientChecker): + def test_grad(self): + op = create_op('lookup_table') + table = np.random.random((17, 31)).astype('float32') + ids = np.random.randint(0, 17, 4) + inputs = {'W': table, 'Ids': ids} + # compare gradients between cpu and gpu + self.compare_grad(op, inputs) + # check gradients + self.check_grad(op, inputs, set('W'), 'Out') + + +if __name__ == '__main__': + unittest.main()