diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 530b319a44eac915f0d49eb55bfe5929908eab26..6212e39dfde33c5943958adbd1a0a052262e119e 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -23,8 +23,6 @@ template __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const int64_t* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { int idx = i * D + label[i]; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 57b995f36dda9b9d0627e1b30b6c0d78245e723d..6daec3797e9b91b67c95a878256275d4ead237de 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -240,7 +240,7 @@ void axpy(const platform::DeviceContext& context, PADDLE_ENFORCE(platform::dynload::cublasSaxpy( reinterpret_cast(context) .cublas_handle(), - n, alpha, x, 1, y, 1)); + n, &alpha, x, 1, y, 1)); } template <> @@ -250,7 +250,7 @@ void axpy(const platform::DeviceContext& context, PADDLE_ENFORCE(platform::dynload::cublasDaxpy( reinterpret_cast(context) .cublas_handle(), - n, alpha, x, 1, y, 1)); + n, &alpha, x, 1, y, 1)); } template struct SetConstant; @@ -270,7 +270,7 @@ DEFINE_GPU_TRANS(6); struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, - framework::Tensor* tensor, float value) + framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} template diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index 5e7f4f7daf718669cd9637123bf699e9ac6d4f7b..312c9153946d47c12c6592ba780636985e50fcf7 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -65,10 +65,8 @@ class SequenceConvKernel : public framework::OpKernel { padding_trainable, context_start, context_length, context_stride, up_pad, down_pad); - context.device_context().Finish(); math::matmul(context.device_context(), col, false, filter, false, static_cast(1.0), out, static_cast(0.0)); - context.device_context().Finish(); } }; diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index 5c817ba03caefb24756f786ca3728ccfa9018bdc..77f062e8c8870ec9cc56c9566108abe74665ae30 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -180,7 +180,6 @@ class TestLstmOp(OpTest): ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) -""" class TestLstmOpHasInitial(TestLstmOp): def set_argument(self): self.lod = [[0, 2, 5, 7]] @@ -281,7 +280,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): self.has_initial_state = False self.is_reverse = True self.use_peepholes = False -""" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_seq_conv.py b/python/paddle/v2/framework/tests/test_seq_conv.py index 65292a1a20acaf27f64bd30dfc1429b1c2469ab1..14edc5f953022ca05f5620c28bd7276d961dd4d0 100644 --- a/python/paddle/v2/framework/tests/test_seq_conv.py +++ b/python/paddle/v2/framework/tests/test_seq_conv.py @@ -122,7 +122,7 @@ class TestSeqProject(OpTest): max_relative_error=0.05, no_grad_set=set(['X', 'Filter'])) - def not_test_check_grad_Filter(self): + def test_check_grad_Filter(self): self.check_grad( ['Filter'], 'Out', @@ -165,33 +165,34 @@ class TestSeqProject(OpTest): self.output_represention = 8 # output feature size -#class TestSeqProjectCase1(TestSeqProject): -# def init_test_case(self): -# self.input_row = 11 -# self.context_start = -1 -# self.context_length = 3 -# self.padding_trainable = True -# self.context_stride = 1 -# -# self.input_size = [self.input_row, 23] -# self.lod = [[0, 4, 5, 8, self.input_row]] -# self.output_represention = 8 # output feature size -# -# -#class TestSeqProjectCase2(TestSeqProject): -# def init_test_case(self): -# self.input_row = 25 -# self.context_start = 2 -# self.context_length = 3 -# self.padding_trainable = True -# self.context_stride = 1 -# -# self.input_size = [self.input_row, 23] -# idx = range(self.input_size[0]) -# del idx[0] -# self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + -# [self.input_size[0]]] -# self.output_represention = 8 # output feature size +class TestSeqProjectCase1(TestSeqProject): + def init_test_case(self): + self.input_row = 11 + self.context_start = -1 + self.context_length = 3 + self.padding_trainable = True + self.context_stride = 1 + + self.input_size = [self.input_row, 23] + self.lod = [[0, 4, 5, 8, self.input_row]] + self.output_represention = 8 # output feature size + + +class TestSeqProjectCase2(TestSeqProject): + def init_test_case(self): + self.input_row = 25 + self.context_start = 2 + self.context_length = 3 + self.padding_trainable = True + self.context_stride = 1 + + self.input_size = [self.input_row, 23] + idx = range(self.input_size[0]) + del idx[0] + self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + + [self.input_size[0]]] + self.output_represention = 8 # output feature size + if __name__ == '__main__': unittest.main()