提交 3663bd88 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3620 from qingqing01/lookup_table

Add a lookup table op and a CUDA helper.
......@@ -42,6 +42,7 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
......@@ -67,7 +68,7 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
......@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
auto* output = context.Output<framework::Tensor>("Dst");
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(T(0));
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lookup_table_op.h"
namespace paddle {
namespace operators {
class LookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
}
};
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors,"
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(
"This operator is used to perform lookups on the parameter W,"
"then concatenated into a dense tensor.");
}
};
class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table = context.Input<Tensor>("W");
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
lookup_table_grad, ops::LookupTableOpGrad);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
const int N, const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy * D;
const T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
out[i] = tab[i];
}
idy += BlockDimY * GridDimX;
}
}
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
const int N, const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
const T* out = output + idy * D;
T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
idy += BlockDimY * GridDimX;
}
}
template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = product(ids_t->dims());
auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D);
}
};
template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = product(ids_t->dims());
const int32_t* ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
K, D);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class LookupTableKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
}
};
template <typename T>
class LookupTableGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
size_t N = d_table_t->dims()[0];
size_t D = d_table_t->dims()[1];
auto ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(context.GetEigenDevice<platform::CPUPlace>()) =
t.constant(static_cast<T>(0));
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
for (size_t j = 0; j < D; ++j) {
d_table[ids[i] * D + j] += d_output[i * D + j];
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace paddle {
namespace platform {
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
#else
CUDA_ATOMIC_WRAPPER(Add, double) {
unsigned long long int* address_as_ull =
reinterpret_cast<unsigned long long int*>(address);
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
// Note: uses integer comparison to avoid hang in case of NaN
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
} // namespace platform
} // namespace paddle
......@@ -15,6 +15,7 @@ cc_library(paddle_pybind SHARED
uniform_random_op
gaussian_random_op
fill_zeros_like_op
lookup_table_op
scale_op
minus_op)
endif(WITH_PYTHON)
......@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(scale);
USE_OP_ITSELF(identity);
USE_OP(minus);
......
......@@ -28,5 +28,6 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
py_test(mnist SRCS mnist.py)
......@@ -23,6 +23,10 @@ def grad_var_name(var_name):
return var_name + "@GRAD"
def empty_var_name():
return "@EMPTY@"
def get_numeric_gradient(op,
input_values,
output_name,
......@@ -176,7 +180,7 @@ class GradientChecker(unittest.TestCase):
]
return outs
def compare_grad(self, forward_op, input_value):
def compare_grad(self, forward_op, input_value, no_grad_set=None):
""" Compare the input gradients between CPU and GPU for the given forward
operator.
......@@ -184,15 +188,20 @@ class GradientChecker(unittest.TestCase):
:type forward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:param no_grad_set: the set of variables names without gradients.
:type no_grad_set: a set of string
:raises: AssertionError, there is different gradient value.
"""
backward_op = core.Operator.backward(forward_op, set())
if no_grad_set is None:
no_grad_set = set()
backward_op = core.Operator.backward(forward_op, no_grad_set)
# return if not compile with GPU or not implementing GPU kernel
if not (core.is_compile_gpu() and backward_op.support_gpu()):
return
outputs = backward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]]
out_names = filter(lambda x: x != empty_var_name(), out_names)
cpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
out_names, core.CPUPlace())
gpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
......
import unittest
import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestSigmoidOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = 'lookup_table'
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4).astype('int32')
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
class TestSigmoidGradOp(GradientChecker):
def test_grad(self):
op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4).astype('int32')
inputs = {'W': table, 'Ids': ids}
# comapre gradients
self.compare_grad(op, inputs, set(['Ids']))
# check gradients
self.check_grad(op, inputs, set('W'), 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册