提交 0f3b9e41 编写于 作者: D dangqingqing

lookup table op, cuda helper and set functor

1. finish lookup table CPU and GPU kernel
2. Add some cuda helper
3. Add some math funtor
上级 a683a56f
......@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(lookup_table);
namespace paddle {
namespace framework {
......
......@@ -42,6 +42,8 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
add_subdirectory(functor)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......@@ -66,5 +68,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu DEPS math_functor)
if(WITH_GPU)
nv_library(math_functor SRCS math_functor.cc math_functor.cu DEPS device_context)
else()
cc_library(math_functor SRCS math_functor.cc DEPS device_context)
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/functor/math_functor.h"
#include "paddle/framework/eigen.h"
namespace paddle {
namespace operators {
namespace functor {
template <typename T>
struct Set<platform::CPUPlace, T> {
void operator()(const T alpha, framework::Tensor* Y,
platform::DeviceContext* context) {
int N = product(Y->dims());
T* YData = Y->mutable_data<T>(context->GetPlace());
if (alpha == static_cast<T>(0)) {
memset(YData, 0, N * sizeof(T));
} else {
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten(*Y)
.setConstant(alpha);
}
}
};
template struct Set<platform::CPUPlace, float>;
template struct Set<platform::CPUPlace, double>;
} // namespace functor
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/functor/math_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace functor {
template <typename T>
__global__ void SetKernel(const int N, const T alpha, T* Y) {
CUDA_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; }
}
template <typename T>
struct Set<platform::GPUPlace, T> {
void operator()(const T alpha, framework::Tensor* Y,
platform::DeviceContext* context) {
int N = product(Y->dims());
T* YData = Y->mutable_data<T>(context->GetPlace());
SetKernel<<<(N + 512 - 1) / 512, 512>>>(N, alpha, YData);
}
};
template struct Set<platform::GPUPlace, float>;
template struct Set<platform::GPUPlace, double>;
} // namespace functor
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace functor {
template <typename Place, typename T>
struct Set {
void operator()(const T alpha, paddle::framework::Tensor* Y,
paddle::platform::DeviceContext* context);
};
} // namespace functor
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lookup_table_op.h"
namespace paddle {
namespace operators {
class LookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
}
};
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors,"
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.")
.NotInGradient();
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(
"This operator is used to perform lookups on the parameter W,"
"then concatenated into a dense tensor.");
}
};
class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &context) const override {
context.Output<Tensor>(0)->Resize(context.Input<Tensor>(0)->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
lookup_table_grad, ops::LookupTableOpGrad);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/functor/math_functor.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
const int N, const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX;
while (idy < K) {
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy;
const T* tab = table + id;
for (int i = idx; i < D; i += blockDimX) {
out[i] = tab[i];
}
idy += blockDimY * gridDimX;
}
}
template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTableGradKernel(T* table, const T* output,
const uint32_t* ids, const int N,
const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX;
while (idy < K) {
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
const T* out = output + idy;
T* tab = table + id;
for (int i = idx; i < D; i += blockDimX) {
paddle::platform::CudaAtomicAdd(tab + i, out[i]);
}
idy += blockDimY * gridDimX;
}
}
template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = product(ids_t->dims());
auto ids = ids_t->data<uint32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D);
}
};
template <typename T>
class LookupTableGrad : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = product(ids_t->dims());
const uint32_t* ids = ids_t->data<uint32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T* d_output = d_output_t->data<T>();
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
functor::Set<paddle::platform::GPUPlace, T>()(static_cast<T>(0), d_table_t,
device_context);
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGradKernel<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output,
ids, N, K, D);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/functor/math_functor.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class LookupTableKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto ids = ids_t->data<uint32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
};
template <typename T>
class LookupTableGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
size_t N = d_table_t->dims()[0];
size_t D = d_table_t->dims()[1];
auto ids = ids_t->data<uint32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T* d_output = d_output_t->data<T>();
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
functor::Set<paddle::platform::CPUPlace, T>()(static_cast<T>(0), d_table_t,
device_context);
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
for (size_t j = 0; j < D; ++j) {
d_table[ids[i] * D + j] += d_output[i * D + j];
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace paddle {
namespace platform {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
// Custom implementation of atomicAdd for double.
// This implementation is copied from CUDA manual.
CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int* address_as_ull =
reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
return __longlong_as_double(old);
#endif
}
} // namespace platform
} // namespace paddle
......@@ -27,3 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
import unittest
import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestSigmoidOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = 'lookup_table'
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4)
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
class TestSigmoidGradOp(GradientChecker):
def test_grad(self):
op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4)
inputs = {'W': table, 'Ids': ids}
# compare gradients between cpu and gpu
self.compare_grad(op, inputs)
# check gradients
self.check_grad(op, inputs, set('W'), 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册