diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d7145798dde815eb79542b26b9323781f8750d46..a719da2560291dbc7e98aadfae41d4692d8afcad 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -9,6 +9,7 @@ function(op_library TARGET) set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE) set(cc_srcs) set(cu_srcs) + set(cu_cc_srcs) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -22,6 +23,9 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) list(APPEND cc_srcs ${TARGET}.cc) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_cc_srcs ${TARGET}.cu.cc) + endif() if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() @@ -29,6 +33,8 @@ function(op_library TARGET) foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) + elseif(${src} MATCHES ".*\\.cu.cc$") + list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) else() @@ -43,7 +49,7 @@ function(op_library TARGET) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} @@ -140,7 +146,9 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0) + list(LENGTH cu_cc_srcs cu_cc_srcs_len) + + if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -160,11 +168,12 @@ set(DEPS_OPS recurrent_op dynamic_recurrent_op softmax_with_cross_entropy_op + softmax_op + sequence_softmax_op sum_op pool_op pool_with_index_op conv_op - lstm_op conv_transpose_op nccl_op sequence_conv_op @@ -182,6 +191,8 @@ set(DEPS_OPS op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) +op_library(softmax_op DEPS softmax) +op_library(sequence_softmax_op DEPS softmax) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(adagrad_op DEPS selected_rows_functor) @@ -225,6 +236,6 @@ cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc rnn/recurrent_op_utils.cc DEPS dynamic_recurrent_op) if(WITH_GPU) - nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context) + cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/operators/batch_norm_op.cu b/paddle/operators/batch_norm_op.cu.cc similarity index 100% rename from paddle/operators/batch_norm_op.cu rename to paddle/operators/batch_norm_op.cu.cc diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu.cc similarity index 100% rename from paddle/operators/concat_op.cu rename to paddle/operators/concat_op.cu.cc diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv2d_transpose_cudnn_op.cu.cc similarity index 96% rename from paddle/operators/conv2d_transpose_cudnn_op.cu rename to paddle/operators/conv2d_transpose_cudnn_op.cu.cc index 694526ec01214acf2ec6a3d68d3cf072739ac185..eff058afc6cc5dacf2a054a33f352824865c1924 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv2d_transpose_cudnn_op.cu.cc @@ -200,9 +200,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { T alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + math::set_constant(ctx.device_context(), input_grad, 0); PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_output_desc, output_grad_data, @@ -214,9 +212,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*filter_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + math::set_constant(ctx.device_context(), filter_grad, 0); + // Gradient with respect to the filter PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu.cc similarity index 100% rename from paddle/operators/conv_cudnn_op.cu rename to paddle/operators/conv_cudnn_op.cu.cc diff --git a/paddle/operators/conv_op.cu b/paddle/operators/conv_op.cu.cc similarity index 100% rename from paddle/operators/conv_op.cu rename to paddle/operators/conv_op.cu.cc diff --git a/paddle/operators/conv_transpose_op.cu b/paddle/operators/conv_transpose_op.cu.cc similarity index 100% rename from paddle/operators/conv_transpose_op.cu rename to paddle/operators/conv_transpose_op.cu.cc diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 530b319a44eac915f0d49eb55bfe5929908eab26..6212e39dfde33c5943958adbd1a0a052262e119e 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -23,8 +23,6 @@ template __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, const int64_t* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { int idx = i * D + label[i]; diff --git a/paddle/operators/fill_constant_batch_size_like_op.cu b/paddle/operators/fill_constant_batch_size_like_op.cu.cc similarity index 100% rename from paddle/operators/fill_constant_batch_size_like_op.cu rename to paddle/operators/fill_constant_batch_size_like_op.cu.cc index 298c196f1dfef388640e34153264986bd518a11a..87e3697e2832e7c60a4293fe7126ae4c9c053e4d 100644 --- a/paddle/operators/fill_constant_batch_size_like_op.cu +++ b/paddle/operators/fill_constant_batch_size_like_op.cu.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/fill_constant_batch_size_like_op.h" +#include "paddle/framework/op_registry.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu.cc similarity index 100% rename from paddle/operators/fill_zeros_like_op.cu rename to paddle/operators/fill_zeros_like_op.cu.cc index a6d4ba64bde534ea76867c456537b130a45b9496..2adb40cf90b42a5ba608302f7985346c949ff6ed 100644 --- a/paddle/operators/fill_zeros_like_op.cu +++ b/paddle/operators/fill_zeros_like_op.cu.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/framework/op_registry.h" #include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( diff --git a/paddle/operators/gru_op.cu b/paddle/operators/gru_op.cu.cc similarity index 97% rename from paddle/operators/gru_op.cu rename to paddle/operators/gru_op.cu.cc index 35538c74b4bf678f8068999bfadb2589a1671be0..0ceff94ec3ddaadbd5f0ca4f5a4eebe6cb8ee3a9 100644 --- a/paddle/operators/gru_op.cu +++ b/paddle/operators/gru_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/gru_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index ba90ec9816c40a6a49065ac6efcee6b93dffce90..55e9cc4a98bd6d36ce5d6bb4116039d0ec18b485 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -27,10 +27,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenMatrix = framework::EigenMatrix; - template class GRUKernel : public framework::OpKernel { public: @@ -57,19 +53,15 @@ class GRUKernel : public framework::OpKernel { bool is_reverse = context.Attr("is_reverse"); math::LoDTensor2BatchFunctor to_batch; - to_batch(context.device_context(), *input, *batch_gate, true, is_reverse); + auto& dev_ctx = context.device_context(); + to_batch(dev_ctx, *input, *batch_gate, true, is_reverse); - int frame_size = hidden_dims[1]; - int batch_size = hidden_dims[0]; - auto g = EigenMatrix::From(*batch_gate); - auto place = context.GetEigenDevice(); if (bias) { - auto b = EigenMatrix::From(*bias); - g.device(place) = g + - b.reshape(Eigen::array({{1, frame_size * 3}})) - .broadcast(Eigen::array({{batch_size, 1}})); + math::RowwiseAdd add_bias; + add_bias(dev_ctx, *batch_gate, *bias, batch_gate); } + int frame_size = hidden_dims[1]; math::hl_gru_value gru_value; gru_value.gateWeight = const_cast(weight_data); gru_value.stateWeight = @@ -89,7 +81,7 @@ class GRUKernel : public framework::OpKernel { gru_value.gateValue = gate_t.data(); gru_value.resetOutputValue = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( - context.device_context(), gru_value, frame_size, cur_batch_size, + dev_ctx, gru_value, frame_size, cur_batch_size, math::ActiveType(context.Attr("activation")), math::ActiveType(context.Attr("gate_activation"))); gru_value.prevOutValue = gru_value.outputValue; @@ -97,7 +89,7 @@ class GRUKernel : public framework::OpKernel { math::Batch2LoDTensorFunctor to_seq; batch_hidden->set_lod(batch_gate->lod()); - to_seq(context.device_context(), *batch_hidden, *hidden); + to_seq(dev_ctx, *batch_hidden, *hidden); } void Compute(const framework::ExecutionContext& context) const override { @@ -138,15 +130,14 @@ class GRUGradKernel : public framework::OpKernel { batch_reset_hidden_prev_grad.mutable_data(hidden_dims, context.GetPlace()); math::SetConstant zero; - zero(context.device_context(), &batch_hidden_grad, static_cast(0.0)); - zero(context.device_context(), &batch_gate_grad, static_cast(0.0)); - zero(context.device_context(), &batch_reset_hidden_prev_grad, - static_cast(0.0)); + auto& dev_ctx = context.device_context(); + zero(dev_ctx, &batch_hidden_grad, static_cast(0.0)); + zero(dev_ctx, &batch_gate_grad, static_cast(0.0)); + zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); bool is_reverse = context.Attr("is_reverse"); batch_hidden_grad.set_lod(batch_hidden->lod()); - to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false, - is_reverse); + to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); math::hl_gru_value gru_value; gru_value.gateWeight = const_cast(weight_data); @@ -157,7 +148,7 @@ class GRUGradKernel : public framework::OpKernel { if (weight_grad) { gru_grad.gateWeightGrad = weight_grad->mutable_data(context.GetPlace()); - zero(context.device_context(), weight_grad, static_cast(0.0)); + zero(dev_ctx, weight_grad, static_cast(0.0)); gru_grad.stateWeightGrad = weight_grad->data() + 2 * frame_size * frame_size; } else { @@ -188,7 +179,7 @@ class GRUGradKernel : public framework::OpKernel { gru_value.prevOutValue = const_cast(h0_data); if (h0_grad) { T* h0_grad_data = h0_grad->mutable_data(context.GetPlace()); - zero(context.device_context(), h0_grad, static_cast(0.0)); + zero(dev_ctx, h0_grad, static_cast(0.0)); gru_grad.prevOutGrad = h0_grad_data; } else { gru_grad.prevOutGrad = nullptr; @@ -202,8 +193,7 @@ class GRUGradKernel : public framework::OpKernel { } math::GRUUnitGradFunctor::compute( - context.device_context(), gru_value, gru_grad, frame_size, - cur_batch_size, + dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, math::ActiveType(context.Attr("activation")), math::ActiveType(context.Attr("gate_activation"))); } @@ -211,14 +201,12 @@ class GRUGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); math::Batch2LoDTensorFunctor to_seq; batch_gate_grad.set_lod(batch_gate->lod()); - to_seq(context.device_context(), batch_gate_grad, *input_grad); + to_seq(dev_ctx, batch_gate_grad, *input_grad); } if (bias_grad) { bias_grad->mutable_data(context.GetPlace()); - auto d_b = EigenMatrix::From(*bias_grad); - auto d_g = EigenMatrix::From(batch_gate_grad); - auto place = context.GetEigenDevice(); - d_b.device(place) = d_g.sum(Eigen::array({{0}})); + math::ColwiseSum col_sum; + col_sum(dev_ctx, batch_gate_grad, bias_grad); } } diff --git a/paddle/operators/lstm_op.cu b/paddle/operators/lstm_op.cu.cc similarity index 97% rename from paddle/operators/lstm_op.cu rename to paddle/operators/lstm_op.cu.cc index 9ad56941553bf19a56c25f41f76fe20dfa3a106f..610cbb03e890203407b1489800bc17f1a196d12c 100644 --- a/paddle/operators/lstm_op.cu +++ b/paddle/operators/lstm_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/lstm_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index fca84e2d8fa832a3780eab7e0fa2facceb4d613b..721aa42c92f2926aabbc13d0a9027b2b4e573225 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -24,10 +24,6 @@ namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; - template inline void ReorderInitState(const platform::DeviceContext& ctx, const framework::Tensor& src, const size_t* index, @@ -65,16 +61,11 @@ class LSTMKernel : public framework::OpKernel { framework::DDim dims({in_dims[0], frame_size}); if (bias) { - Eigen::array extents({{1, 4 * frame_size}}); - Eigen::array offsets({{0, 0}}); - auto b = EigenMatrix::From(*bias); - auto gate = EigenMatrix::From(*batch_gate); - gate.device(ctx.GetEigenDevice()) = - gate + - b.slice(offsets, extents) - .reshape(Eigen::array({{1, frame_size * 4}})) - .broadcast( - Eigen::array({{static_cast(in_dims[0]), 1}})); + Tensor b = *bias; + b.Resize({bias->numel(), 1}); + Tensor gate_bias = b.Slice(0, 4 * frame_size); + math::RowwiseAdd add_bias; + add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); } math::LstmMetaValue lstm_value; @@ -350,16 +341,11 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g) { /* backward bias */ - int m = static_cast(batch_gate_g.dims()[0]); - int n = static_cast(batch_gate_g.dims()[1]); - - Tensor ones; - ones.mutable_data({m}, ctx.GetPlace()); - math::SetConstant set; - set(device_ctx, &ones, static_cast(1.0)); - - math::gemv(device_ctx, true, m, n, 1., batch_gate_g.data(), - ones.data(), 0., bias_g->data()); + Tensor b_g = *bias_g; + b_g.Resize({bias_g->numel(), 1}); + Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size); + math::ColwiseSum col_sum; + col_sum(device_ctx, batch_gate_g, &gate_bias_g); } if (h0 && h0_g) { diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ab7f23f57043844d45c36acc475422613164bee1..b9417f1d7fdc663fff751328d18239af3dbb1216 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,28 +1,28 @@ add_subdirectory(detail) if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) - nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) - nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) + nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) - nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) + nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) - cc_library(softmax SRCS softmax.cc DEPS operator) - cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(softmax SRCS softmax.cc DEPS device_context) + cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) - cc_library(context_project SRCS context_project.cc DEPS device_context) + cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index 845de82bbcb33d52184f04ae1594738cb4776eca..72f4202bace4461d2597204feaa2a21e355bd1ac 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/lod_tensor.h" #include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -24,9 +24,6 @@ namespace math { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -template -using EigenMatrix = framework::EigenMatrix; /* * \brief Context projection concatenates features in adjacent time-steps in @@ -152,9 +149,7 @@ class ContextProjectFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data.Slice(k, k + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(*context.GetEigenDevice()) = w_sub_e; + out_t_sub.CopyFrom(w_sub, context.GetPlace(), context); } } if (down_pad > 0) { // add down pad @@ -184,9 +179,7 @@ class ContextProjectFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data.Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - out_t_sub_e.device(*context.GetEigenDevice()) = w_sub_e; + out_t_sub.CopyFrom(w_sub, context.GetPlace(), context); } } out_t.Resize({sequence_height, context_length * sequence_width}); @@ -265,10 +258,8 @@ class ContextProjectGradFunctor { Tensor out_t_sub = out_t.Slice(k * context_length, k * context_length + padding_size); Tensor w_sub = padding_data->Slice(k, k + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(*context.GetEigenDevice()) = - w_sub_e + out_t_sub_e; + axpy(context, w_sub.numel(), static_cast(1), + out_t_sub.data(), w_sub.data()); } } if (down_pad > 0) { @@ -299,10 +290,8 @@ class ContextProjectGradFunctor { (down_pad_begin_row + t) * context_length); Tensor w_sub = padding_data->Slice( up_pad + padding_idx, up_pad + padding_idx + padding_size); - auto out_t_sub_e = EigenMatrix::From(out_t_sub); - auto w_sub_e = EigenMatrix::From(w_sub); - w_sub_e.device(*context.GetEigenDevice()) = - w_sub_e + out_t_sub_e; + axpy(context, w_sub.numel(), static_cast(1), + out_t_sub.data(), w_sub.data()); } } out_t.Resize({sequence_height, context_length * sequence_width}); diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h index 0ab6827ffa8f8b90b432a801607a97206e010cf4..70ed9ddd551bb8cb7989727c02fea870186c9f2e 100644 --- a/paddle/operators/math/cross_entropy.h +++ b/paddle/operators/math/cross_entropy.h @@ -14,7 +14,6 @@ #pragma once #include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" #include "paddle/platform/hostdevice.h" diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1b0d4c8bdc683b5203a4bc4b3838560cffe00bc8..5ee091788687133f6eaef7229d9f95e2025a2daf 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/operators/math/math_function.h" #include "paddle/framework/data_type.h" +#include "paddle/operators/math/math_function_impl.h" namespace paddle { namespace operators { @@ -232,7 +233,34 @@ void gemv(const platform::DeviceContext& context, cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1); } +template <> +void axpy(const platform::DeviceContext& context, + const int n, const float alpha, + const float* x, float* y) { + cblas_saxpy(n, alpha, x, 1, y, 1); +} + +template <> +void axpy(const platform::DeviceContext& context, + const int n, const double alpha, + const double* x, double* y) { + cblas_daxpy(n, alpha, x, 1, y, 1); +} + template struct SetConstant; +template struct SetConstant; +template struct SetConstant; + +#define DEFINE_CPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; + +DEFINE_CPU_TRANS(1); +DEFINE_CPU_TRANS(2); +DEFINE_CPU_TRANS(3); +DEFINE_CPU_TRANS(4); +DEFINE_CPU_TRANS(5); +DEFINE_CPU_TRANS(6); struct TensorSetConstantCPU { TensorSetConstantCPU(framework::Tensor* tensor, float value) @@ -280,6 +308,11 @@ void set_constant(const platform::DeviceContext& context, #endif } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 817deec94314bdfd2ed7e4b0ba5212c72b813455..38c04b97f9d07b9cca938b09f46ea81328a35322 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/framework/data_type.h" #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/math_function_impl.h" namespace paddle { namespace operators { @@ -231,11 +233,44 @@ void gemv(const platform::DeviceContext& context, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1)); } +template <> +void axpy(const platform::DeviceContext& context, + const int n, const float alpha, + const float* x, float* y) { + PADDLE_ENFORCE(platform::dynload::cublasSaxpy( + reinterpret_cast(context) + .cublas_handle(), + n, &alpha, x, 1, y, 1)); +} + +template <> +void axpy(const platform::DeviceContext& context, + const int n, const double alpha, + const double* x, double* y) { + PADDLE_ENFORCE(platform::dynload::cublasDaxpy( + reinterpret_cast(context) + .cublas_handle(), + n, &alpha, x, 1, y, 1)); +} + template struct SetConstant; +template struct SetConstant; +template struct SetConstant; + +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; + +DEFINE_GPU_TRANS(1); +DEFINE_GPU_TRANS(2); +DEFINE_GPU_TRANS(3); +DEFINE_GPU_TRANS(4); +DEFINE_GPU_TRANS(5); +DEFINE_GPU_TRANS(6); struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, - framework::Tensor* tensor, float value) + framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} template @@ -257,6 +292,11 @@ void set_constant_with_place( TensorSetConstantGPU(context, tensor, value)); } +template struct RowwiseAdd; +template struct RowwiseAdd; +template struct ColwiseSum; +template struct ColwiseSum; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c2aaa1d7b7e920c3e6fd9ae4424eae725c3b7c0e..ffb99f53808c4316ede96b04e57aec4dae4134de 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -93,14 +93,21 @@ void gemv(const platform::DeviceContext& context, const bool trans_a, const int M, const int N, const T alpha, const T* A, const T* B, const T beta, T* C); +template +void axpy(const platform::DeviceContext& context, const int n, const T alpha, + const T* x, T* y); + +template +struct Transpose { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis); +}; + template struct SetConstant { void operator()(const platform::DeviceContext& context, - framework::Tensor* tensor, T num) { - auto t = framework::EigenVector::Flatten(*tensor); - t.device(*context.GetEigenDevice()) = - t.constant(static_cast(num)); - } + framework::Tensor* tensor, T num); }; template @@ -110,6 +117,19 @@ void set_constant_with_place(const platform::DeviceContext& context, void set_constant(const platform::DeviceContext& context, framework::Tensor* tensor, float value); +template +struct RowwiseAdd { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, const framework::Tensor& vec, + framework::Tensor* output); +}; + +template +struct ColwiseSum { + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..4dc17a4e525c52b8f696277274a7ad00a6b00a08 --- /dev/null +++ b/paddle/operators/math/math_function_impl.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/data_type.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void SetConstant::operator()(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); +} + +template +void Transpose::operator()( + const platform::DeviceContext& context, const framework::Tensor& in, + framework::Tensor* out, const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto in_dim = in.dims(); + auto out_dim = out->dims(); + + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(*out); + auto* dev = context.GetEigenDevice(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); +} + +template +void RowwiseAdd::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& vector, + framework::Tensor* output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector.numel(), size); + PADDLE_ENFORCE_EQ(output->dims(), in_dims); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenMatrix::From(vector); + auto out = framework::EigenMatrix::From(*output); + Eigen::array shape({{1, static_cast(size)}}); + Eigen::array bcast({{static_cast(in_dims[0]), 1}}); + out.device(*context.GetEigenDevice()) = + in + vec.reshape(shape).broadcast(bcast); +} + +template +void ColwiseSum::operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), size); + + auto vec = framework::EigenMatrix::From(*vector); + auto in = framework::EigenMatrix::From(input); + Eigen::array shape({{1, static_cast(size)}}); + vec.reshape(shape).device(*context.GetEigenDevice()) = + in.sum(Eigen::array({{0}})).reshape(shape); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index 8d04653832d58aa048f73e53b8349a08da3145a4..c5d968aeb216bbb3e0e17f138b9e891494d99f75 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/operators/math/sequence2batch.h" namespace paddle { diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 794c7d43973924d470124baf8c0c3de66e4ba087..73295ddbcb73fe80be08e732790f0ec75e94b415 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -21,6 +22,10 @@ namespace paddle { namespace operators { namespace math { +template +using EigenMatrix = framework::EigenMatrix; + template class CopyMatrixRowsFunctor { public: diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index 0ba8197ab8b64649c8adcf67771ba01eca7f1d10..3e2f15d6c27f58818128f32fab0bd4c5f36b0050 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/softmax.h" +#include "paddle/operators/math/softmax_impl.h" namespace paddle { namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu index 99f988d51e4b16c3f3bfd9c76b411bb53619603e..4dbab51d46bdaaa506a6c242d0958c73687f4eb9 100644 --- a/paddle/operators/math/softmax.cu +++ b/paddle/operators/math/softmax.cu @@ -15,13 +15,16 @@ limitations under the License. */ #define EIGEN_USE_GPU #include "paddle/operators/math/softmax.h" +#include "paddle/operators/math/softmax_impl.h" namespace paddle { namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxFunctor; template class SoftmaxGradFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index b7f627eee7f8fe68a83595a3390a55d438c97afb..fe1074650234c5beb5889e7efd713164769ad740 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -13,60 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" namespace paddle { namespace operators { namespace math { -template -using EigenMatrix = framework::EigenMatrix; - -template -struct ValueClip { - HOSTDEVICE T operator()(const T& x) const { - const T kThreshold = -64.; - return x < kThreshold ? kThreshold : x; - } -}; - template class SoftmaxFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor* X, framework::Tensor* Y) { - auto logits = EigenMatrix::From(*X); - auto softmax = EigenMatrix::From(*Y); - - const int kBatchDim = 0; - const int kClassDim = 1; - - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto shifted_logits = (logits - - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)) - .unaryExpr(ValueClip()); - - softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(*context.GetEigenDevice()) = - (softmax * - softmax.sum(along_class) - .inverse() - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); - } + const framework::Tensor* X, framework::Tensor* Y); }; template @@ -74,29 +31,7 @@ class SoftmaxGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor* y, const framework::Tensor* y_grad, - framework::Tensor* x_grad) { - auto softmax = EigenMatrix::From(*y); - auto softmax_grad = EigenMatrix::From(*y_grad); - auto logits_grad = EigenMatrix::From(*x_grad); - - const int kBatchDim = 0; - const int kClassDim = 1; - - const int batch_size = softmax.dimension(kBatchDim); - const int num_classes = softmax.dimension(kClassDim); - - Eigen::DSizes along_class(kClassDim); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, num_classes); - - auto dot = (softmax * softmax_grad) - .sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); - logits_grad.device(*context.GetEigenDevice()) = - (softmax_grad - dot) * softmax; - } + framework::Tensor* x_grad); }; } // namespace math diff --git a/paddle/operators/math/softmax_impl.h b/paddle/operators/math/softmax_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..05793eeb3eeafaf36c301236197555b7b35e5803 --- /dev/null +++ b/paddle/operators/math/softmax_impl.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +using EigenMatrix = framework::EigenMatrix; + +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = -64.; + return x < kThreshold ? kThreshold : x; + } +}; + +template +void SoftmaxFunctor::operator()( + const platform::DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y) { + auto logits = EigenMatrix::From(*X); + auto softmax = EigenMatrix::From(*Y); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + + softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(*context.GetEigenDevice()) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); +} + +template +void SoftmaxGradFunctor::operator()( + const platform::DeviceContext& context, const framework::Tensor* y, + const framework::Tensor* y_grad, framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto dot = (softmax * softmax_grad) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + logits_grad.device(*context.GetEigenDevice()) = + (softmax_grad - dot) * softmax; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/matmul_op.cu b/paddle/operators/matmul_op.cu.cc similarity index 100% rename from paddle/operators/matmul_op.cu rename to paddle/operators/matmul_op.cu.cc diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 4f565946d596b5e5fbf90f16c0c13c780c36886c..1e4aa48b7018d8e3d6f02591fbca2877ddbd3c5d 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -15,8 +15,8 @@ #pragma once #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/matmul.h" -#include "paddle/operators/transpose_op.h" namespace paddle { namespace operators { @@ -76,7 +76,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, if (in_dims.size() == 3) { output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.mutable_data(context.GetPlace()); - EigenTranspose(context, input, output, {1, 0, 2}); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(context.device_context(), input, &output, axis); + std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); } else { output.ShareDataWith(input); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu.cc similarity index 100% rename from paddle/operators/mul_op.cu rename to paddle/operators/mul_op.cu.cc diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu.cc similarity index 100% rename from paddle/operators/nccl_op.cu rename to paddle/operators/nccl_op.cu.cc diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu.cc similarity index 100% rename from paddle/operators/nccl_op_test.cu rename to paddle/operators/nccl_op_test.cu.cc diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu.cc similarity index 100% rename from paddle/operators/pool_cudnn_op.cu rename to paddle/operators/pool_cudnn_op.cu.cc diff --git a/paddle/operators/pool_op.cu b/paddle/operators/pool_op.cu.cc similarity index 100% rename from paddle/operators/pool_op.cu rename to paddle/operators/pool_op.cu.cc diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu.cc similarity index 100% rename from paddle/operators/pool_with_index_op.cu rename to paddle/operators/pool_with_index_op.cu.cc diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index c0e3b117dc3ea351b9edfed4d1823de0db27d30a..a081607edce335f0265388ab01238d584bcf3ead 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -81,22 +81,21 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*in_x_grad); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); + auto& device_ctx = context.device_context(); + math::set_constant(device_ctx, in_x_grad, 0); switch (ksize.size()) { case 2: { paddle::operators::math::MaxPool2dWithIndexGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *out_grad, *mask, ksize, - strides, paddings, in_x_grad); + pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides, + paddings, in_x_grad); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *out_grad, *mask, ksize, - strides, paddings, in_x_grad); + pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides, + paddings, in_x_grad); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } diff --git a/paddle/operators/reshape_op.cu b/paddle/operators/reshape_op.cu.cc similarity index 100% rename from paddle/operators/reshape_op.cu rename to paddle/operators/reshape_op.cu.cc diff --git a/paddle/operators/sequence_concat_op.cu b/paddle/operators/sequence_concat_op.cu.cc similarity index 100% rename from paddle/operators/sequence_concat_op.cu rename to paddle/operators/sequence_concat_op.cu.cc diff --git a/paddle/operators/sequence_conv_op.cu b/paddle/operators/sequence_conv_op.cu.cc similarity index 97% rename from paddle/operators/sequence_conv_op.cu rename to paddle/operators/sequence_conv_op.cu.cc index 4c0c673a517c4b05c3abd8bf6b5cf5bbb19cfae0..6106b0e46c0ab96e01dfc344055f23dbf4a1a2c3 100644 --- a/paddle/operators/sequence_conv_op.cu +++ b/paddle/operators/sequence_conv_op.cu.cc @@ -12,8 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include "paddle/operators/sequence_conv_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index adee8d760e1d6bed852236ecee0951656c458901..b8fbe2647c4338a2fa16aa655ebab64dd8d5417d 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/context_project.h" #include "paddle/operators/math/math_function.h" diff --git a/paddle/operators/sequence_softmax_op.cu b/paddle/operators/sequence_softmax_op.cu.cc similarity index 100% rename from paddle/operators/sequence_softmax_op.cu rename to paddle/operators/sequence_softmax_op.cu.cc diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu.cc similarity index 100% rename from paddle/operators/softmax_op.cu rename to paddle/operators/softmax_op.cu.cc diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index ed96e8cee5a78e63ea29ed383d06c1258abdc328..3dbb62d2e571eb92025c1b3fc0a6653c7cda007a 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/operators/softmax_with_cross_entropy_op.h" #include -#include namespace paddle { namespace operators { diff --git a/paddle/operators/split_op.cu b/paddle/operators/split_op.cu.cc similarity index 100% rename from paddle/operators/split_op.cu rename to paddle/operators/split_op.cu.cc diff --git a/paddle/operators/transpose_op.cu b/paddle/operators/transpose_op.cu.cc similarity index 100% rename from paddle/operators/transpose_op.cu rename to paddle/operators/transpose_op.cu.cc diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index aaa3f47ab5545accd4d1108e0ad6f5a3062186d0..e296032f4147f9f8338148f9e4fef100c7cf816f 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -14,27 +14,44 @@ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { -template -void EigenTranspose(const framework::ExecutionContext& context, - const framework::Tensor& in, framework::Tensor& out, - std::vector axis) { - Eigen::array permute; - for (int i = 0; i < Rank; i++) { - permute[i] = axis[i]; +template +inline void TransCompute(const int dim, const platform::DeviceContext& dev_ctx, + const framework::Tensor& in, framework::Tensor* out, + const std::vector& axis) { + switch (dim) { + case 1: + math::Transpose trans1; + trans1(dev_ctx, in, out, axis); + break; + case 2: + math::Transpose trans2; + trans2(dev_ctx, in, out, axis); + break; + case 3: + math::Transpose trans3; + trans3(dev_ctx, in, out, axis); + break; + case 4: + math::Transpose trans4; + trans4(dev_ctx, in, out, axis); + break; + case 5: + math::Transpose trans5; + trans5(dev_ctx, in, out, axis); + break; + case 6: + math::Transpose trans6; + trans6(dev_ctx, in, out, axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); } - auto in_dim = in.dims(); - auto out_dim = out.dims(); - - auto eigen_in = framework::EigenTensor::From(in); - auto eigen_out = framework::EigenTensor::From(out); - auto& dev = context.GetEigenDevice(); - eigen_out.device(dev) = eigen_in.shuffle(permute); } template @@ -47,28 +64,8 @@ class TransposeKernel : public framework::OpKernel { std::vector axis = context.Attr>("axis"); int ndims = axis.size(); - switch (ndims) { - case 1: - EigenTranspose(context, *x, *out, axis); - break; - case 2: - EigenTranspose(context, *x, *out, axis); - break; - case 3: - EigenTranspose(context, *x, *out, axis); - break; - case 4: - EigenTranspose(context, *x, *out, axis); - break; - case 5: - EigenTranspose(context, *x, *out, axis); - break; - case 6: - EigenTranspose(context, *x, *out, axis); - break; - default: - PADDLE_THROW("Tensors with rank at most 6 are supported"); - } + auto& dev_ctx = context.device_context(); + TransCompute(ndims, dev_ctx, *x, out, axis); } }; @@ -80,47 +77,19 @@ class TransposeGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); auto* x_grad = context.Output(framework::GradVarName("X")); - if (x_grad) { - x_grad->mutable_data(context.GetPlace()); - - std::vector axis = context.Attr>("axis"); - std::vector reversed_axis(axis); + if (!x_grad) return; - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } - - int ndims = axis.size(); + x_grad->mutable_data(context.GetPlace()); + std::vector axis = context.Attr>("axis"); + std::vector reversed_axis(axis); - switch (ndims) { - case 1: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 2: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 3: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 4: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 5: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - case 6: - EigenTranspose(context, *out_grad, *x_grad, - reversed_axis); - break; - default: - PADDLE_THROW("Tensors with rank at most 6 are supported"); - } + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; } + + int ndims = axis.size(); + auto& dev_ctx = context.device_context(); + TransCompute(ndims, dev_ctx, *out_grad, x_grad, reversed_axis); } }; diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 6b64539b0a9a4d535a53447fbcc0e458f3ac9129..61a22d9db3e07cbe6fbca0e0b09fedcba232ff6c 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,6 +62,8 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasSaxpy_v2); \ + __macro(cublasDaxpy_v2); \ __macro(cublasSgemv_v2); \ __macro(cublasDgemv_v2); \ __macro(cublasSgemm_v2); \