diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 1ca5010eaeb14948c5fab419e49d5114410f7c45..8d2d8a1141188ca86c7f904947422eb1bb23f72f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -42,7 +42,6 @@ function(op_library TARGET) endfunction() add_subdirectory(math) -add_subdirectory(functor) cc_test(gather_test SRCS gather_test.cc DEPS tensor) @@ -69,4 +68,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) -op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu DEPS math_functor) +op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index fd380ca8514b0ac50f39613368a4836bd485668b..969998ce2eae02b8ad057c6259703e51559bf98a 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel { auto* output = context.Output("Dst"); output->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*output); - t.device(context.GetEigenDevice()) = t.constant(T(0)); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); } }; diff --git a/paddle/operators/functor/CMakeLists.txt b/paddle/operators/functor/CMakeLists.txt deleted file mode 100644 index d3b39e5fc2533002218210f0b7a1665a6d4b280f..0000000000000000000000000000000000000000 --- a/paddle/operators/functor/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -if(WITH_GPU) - nv_library(math_functor SRCS math_functor.cc math_functor.cu DEPS device_context) -else() - cc_library(math_functor SRCS math_functor.cc DEPS device_context) -endif() diff --git a/paddle/operators/functor/math_functor.cc b/paddle/operators/functor/math_functor.cc deleted file mode 100644 index 1f2767f171a70d2df308c6e593c9e251b0337577..0000000000000000000000000000000000000000 --- a/paddle/operators/functor/math_functor.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/functor/math_functor.h" -#include "paddle/framework/eigen.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -struct Set { - void operator()(const T alpha, framework::Tensor* Y, - platform::DeviceContext* context) { - int N = product(Y->dims()); - T* YData = Y->mutable_data(context->GetPlace()); - if (alpha == static_cast(0)) { - memset(YData, 0, N * sizeof(T)); - } else { - framework::EigenVector::Flatten(*Y) - .setConstant(alpha); - } - } -}; - -template struct Set; -template struct Set; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/functor/math_functor.cu b/paddle/operators/functor/math_functor.cu deleted file mode 100644 index 6dc828c60ac710fd2cd0c35f838915b73a6fea83..0000000000000000000000000000000000000000 --- a/paddle/operators/functor/math_functor.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/operators/functor/math_functor.h" -#include "paddle/platform/cuda_helper.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -__global__ void SetKernel(const int N, const T alpha, T* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; } -} - -template -struct Set { - void operator()(const T alpha, framework::Tensor* Y, - platform::DeviceContext* context) { - int N = product(Y->dims()); - T* YData = Y->mutable_data(context->GetPlace()); - SetKernel<<<(N + 512 - 1) / 512, 512>>>(N, alpha, YData); - } -}; - -template struct Set; -template struct Set; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/functor/math_functor.h b/paddle/operators/functor/math_functor.h deleted file mode 100644 index d5c7bd368fee44f18bc55d127ea58647c944114c..0000000000000000000000000000000000000000 --- a/paddle/operators/functor/math_functor.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/tensor.h" -#include "paddle/platform/device_context.h" - -namespace paddle { -namespace operators { -namespace functor { - -template -struct Set { - void operator()(const T alpha, paddle::framework::Tensor* Y, - paddle::platform::DeviceContext* context); -}; - -} // namespace functor -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 99678ef681627d93c35aae724d97812fc24a15c1..27eee3436af8107cef2aa3577ea238be49edf1af 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/functor/math_functor.h" #include "paddle/platform/assert.h" #include "paddle/platform/cuda_helper.h" @@ -22,11 +22,11 @@ namespace operators { using Tensor = framework::Tensor; -template +template __global__ void LookupTable(T* output, const T* table, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDimX; + int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { int id = ids[idy]; @@ -34,18 +34,18 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids, PADDLE_ASSERT(id < N); T* out = output + idy * D; const T* tab = table + id * D; - for (int i = idx; i < D; i += blockDimX) { + for (int i = idx; i < D; i += BlockDimX) { out[i] = tab[i]; } - idy += blockDimY * gridDimX; + idy += BlockDimY * GridDimX; } } -template +template __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * gridDimX; + int idy = blockIdx.x + threadIdx.y * GridDimX; while (idy < K) { int id = ids[idy]; @@ -53,10 +53,10 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, PADDLE_ASSERT(id < N); const T* out = output + idy * D; T* tab = table + id * D; - for (int i = idx; i < D; i += blockDimX) { + for (int i = idx; i < D; i += BlockDimX) { paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } - idy += blockDimY * gridDimX; + idy += BlockDimY * GridDimX; } } @@ -96,10 +96,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); - auto* device_context = - const_cast(context.device_context_); - functor::Set()(static_cast(0), d_table_t, - device_context); + auto t = framework::EigenVector::Flatten(*d_table_t); + t.device(context.GetEigenDevice()) = + t.constant(static_cast(0)); + dim3 threads(128, 8); dim3 grids(8, 1); LookupTableGrad<<>>(d_table, d_output, ids, N, diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 9254e03a1b7d11b3003fe07b784152ddfa05d8c7..4da8079b91624c3510cae89fd599a7035a4c7477 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -14,8 +14,8 @@ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/functor/math_functor.h" namespace paddle { namespace operators { @@ -57,10 +57,10 @@ class LookupTableGradKernel : public framework::OpKernel { const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); - auto* device_context = - const_cast(context.device_context_); - functor::Set()(static_cast(0), d_table_t, - device_context); + auto t = framework::EigenVector::Flatten(*d_table_t); + t.device(context.GetEigenDevice()) = + t.constant(static_cast(0)); + for (size_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index 939c3713adb43b04d5e1ce279f20acc03cf82578..6feec0d7f8bd5d32d9e5eedee962fcbeff655f1c 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -18,10 +18,6 @@ limitations under the License. */ namespace paddle { namespace platform { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - #define CUDA_ATOMIC_WRAPPER(op, T) \ __device__ __forceinline__ T CudaAtomic##op(T* address, const T val) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 8b8e2f444be1169c23784321721c5d8154541fcf..06b82fa2e4f9fe0c81be145c6a9d4a884170802f 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -23,6 +23,10 @@ def grad_var_name(var_name): return var_name + "@GRAD" +def empty_var_name(): + return "@EMPTY@" + + def get_numeric_gradient(op, input_values, output_name, @@ -171,7 +175,7 @@ class GradientChecker(unittest.TestCase): ] return outs - def compare_grad(self, forward_op, input_value): + def compare_grad(self, forward_op, input_value, no_grad_set=None): """ Compare the input gradients between CPU and GPU for the given forward operator. @@ -179,15 +183,20 @@ class GradientChecker(unittest.TestCase): :type forward_op: Operator :param input_value: input values. :type input_value: dict{string:numpy.array} + :param no_grad_set: the set of variables names without gradients. + :type no_grad_set: a set of string :raises: AssertionError, there is different gradient value. """ - backward_op = core.Operator.backward(forward_op, set()) + if no_grad_set is None: + no_grad_set = set() + backward_op = core.Operator.backward(forward_op, no_grad_set) # return if not compile with GPU or not implementing GPU kernel if not (core.is_compile_gpu() and backward_op.support_gpu()): return outputs = backward_op.outputs() out_names = [item for k in outputs for item in outputs[k]] + out_names = filter(lambda x: x != empty_var_name(), out_names) cpu_grads = self.__get_gradient(forward_op, backward_op, input_value, out_names, core.CPUPlace()) gpu_grads = self.__get_gradient(forward_op, backward_op, input_value, diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py index 3056bf53e3d23cf004368bbbe9c1616d3a8efa58..19eb464baa555fb67a994f3cfb4d3ed628367c73 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table.py +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -21,6 +21,8 @@ class TestSigmoidGradOp(GradientChecker): table = np.random.random((17, 31)).astype('float32') ids = np.random.randint(0, 17, 4).astype('int32') inputs = {'W': table, 'Ids': ids} + # comapre gradients + self.compare_grad(op, inputs, set(['Ids'])) # check gradients self.check_grad(op, inputs, set('W'), 'Out')