未验证 提交 3edd8331 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #5573 from qingqing01/cmake_speed

[Speed Compiling]: Reduce NVCC compiling files.
......@@ -9,6 +9,7 @@ function(op_library TARGET)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
set(cc_srcs)
set(cu_srcs)
set(cu_cc_srcs)
set(op_common_deps operator op_registry math_function)
set(options "")
set(oneValueArgs "")
......@@ -22,6 +23,9 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND cc_srcs ${TARGET}.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
......@@ -29,6 +33,8 @@ function(op_library TARGET)
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
list(APPEND cu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
......@@ -43,7 +49,7 @@ function(op_library TARGET)
endif()
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS}
......@@ -140,7 +146,9 @@ function(op_library TARGET)
# pybind USE_CPU_ONLY_OP
list(LENGTH cu_srcs cu_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
......@@ -160,11 +168,12 @@ set(DEPS_OPS
recurrent_op
dynamic_recurrent_op
softmax_with_cross_entropy_op
softmax_op
sequence_softmax_op
sum_op
pool_op
pool_with_index_op
conv_op
lstm_op
conv_transpose_op
nccl_op
sequence_conv_op
......@@ -182,6 +191,8 @@ set(DEPS_OPS
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
op_library(adagrad_op DEPS selected_rows_functor)
......@@ -225,6 +236,6 @@ cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc
DEPS dynamic_recurrent_op)
if(WITH_GPU)
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
......@@ -200,9 +200,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
math::set_constant(ctx.device_context(), input_grad, 0);
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, output_grad_data,
......@@ -214,9 +212,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward filter ---------------------
if (filter_grad) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*filter_grad);
t.device(ctx.GetEigenDevice<platform::GPUPlace>()) =
t.constant(static_cast<T>(0));
math::set_constant(ctx.device_context(), filter_grad, 0);
// Gradient with respect to the filter
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc,
......
......@@ -23,8 +23,6 @@ template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int64_t* label, const int N,
const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
int idx = i * D + label[i];
......
......@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_batch_size_like_op.h"
#include "paddle/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
......
......@@ -12,8 +12,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
#include "paddle/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
......
......@@ -12,7 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gru_op.h"
namespace ops = paddle::operators;
......
......@@ -27,10 +27,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
......@@ -57,19 +53,15 @@ class GRUKernel : public framework::OpKernel<T> {
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
auto& dev_ctx = context.device_context();
to_batch(dev_ctx, *input, *batch_gate, true, is_reverse);
int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
auto g = EigenMatrix<T>::From(*batch_gate);
auto place = context.GetEigenDevice<Place>();
if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = g +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
math::RowwiseAdd<Place, T> add_bias;
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
}
int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
......@@ -89,7 +81,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gateValue = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute(
context.device_context(), gru_value, frame_size, cur_batch_size,
dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue;
......@@ -97,7 +89,7 @@ class GRUKernel : public framework::OpKernel<T> {
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(context.device_context(), *batch_hidden, *hidden);
to_seq(dev_ctx, *batch_hidden, *hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
......@@ -138,15 +130,14 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
context.GetPlace());
math::SetConstant<Place, T> zero;
zero(context.device_context(), &batch_hidden_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_gate_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
auto& dev_ctx = context.device_context();
zero(dev_ctx, &batch_hidden_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse);
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
......@@ -157,7 +148,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
if (weight_grad) {
gru_grad.gateWeightGrad =
weight_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), weight_grad, static_cast<T>(0.0));
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad =
weight_grad->data<T>() + 2 * frame_size * frame_size;
} else {
......@@ -188,7 +179,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
......@@ -202,8 +193,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
}
math::GRUUnitGradFunctor<Place, T>::compute(
context.device_context(), gru_value, gru_grad, frame_size,
cur_batch_size,
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
}
......@@ -211,14 +201,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
input_grad->mutable_data<T>(context.GetPlace());
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_gate_grad.set_lod(batch_gate->lod());
to_seq(context.device_context(), batch_gate_grad, *input_grad);
to_seq(dev_ctx, batch_gate_grad, *input_grad);
}
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_g = EigenMatrix<T>::From(batch_gate_grad);
auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
math::ColwiseSum<Place, T> col_sum;
col_sum(dev_ctx, batch_gate_grad, bias_grad);
}
}
......
......@@ -12,7 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/lstm_op.h"
namespace ops = paddle::operators;
......
......@@ -24,10 +24,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
......@@ -65,16 +61,11 @@ class LSTMKernel : public framework::OpKernel<T> {
framework::DDim dims({in_dims[0], frame_size});
if (bias) {
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto b = EigenMatrix<T>::From(*bias);
auto gate = EigenMatrix<T>::From(*batch_gate);
gate.device(ctx.GetEigenDevice<Place>()) =
gate +
b.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}))
.broadcast(
Eigen::array<int, 2>({{static_cast<int>(in_dims[0]), 1}}));
Tensor b = *bias;
b.Resize({bias->numel(), 1});
Tensor gate_bias = b.Slice(0, 4 * frame_size);
math::RowwiseAdd<Place, T> add_bias;
add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
}
math::LstmMetaValue<T> lstm_value;
......@@ -350,16 +341,11 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if (bias && bias_g) {
/* backward bias */
int m = static_cast<int>(batch_gate_g.dims()[0]);
int n = static_cast<int>(batch_gate_g.dims()[1]);
Tensor ones;
ones.mutable_data<T>({m}, ctx.GetPlace());
math::SetConstant<Place, T> set;
set(device_ctx, &ones, static_cast<T>(1.0));
math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
ones.data<T>(), 0., bias_g->data<T>());
Tensor b_g = *bias_g;
b_g.Resize({bias_g->numel(), 1});
Tensor gate_bias_g = b_g.Slice(0, 4 * frame_size);
math::ColwiseSum<Place, T> col_sum;
col_sum(device_ctx, batch_gate_g, &gate_bias_g);
}
if (h0 && h0_g) {
......
add_subdirectory(detail)
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor)
nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(softmax SRCS softmax.cc DEPS device_context)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -24,9 +24,6 @@ namespace math {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
/*
* \brief Context projection concatenates features in adjacent time-steps in
......@@ -152,9 +149,7 @@ class ContextProjectFunctor {
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
out_t_sub.CopyFrom(w_sub, context.GetPlace(), context);
}
}
if (down_pad > 0) { // add down pad
......@@ -184,9 +179,7 @@ class ContextProjectFunctor {
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e;
out_t_sub.CopyFrom(w_sub, context.GetPlace(), context);
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
......@@ -265,10 +258,8 @@ class ContextProjectGradFunctor {
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
axpy<Place, T>(context, w_sub.numel(), static_cast<T>(1),
out_t_sub.data<T>(), w_sub.data<T>());
}
}
if (down_pad > 0) {
......@@ -299,10 +290,8 @@ class ContextProjectGradFunctor {
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(*context.GetEigenDevice<Place>()) =
w_sub_e + out_t_sub_e;
axpy<Place, T>(context, w_sub.numel(), static_cast<T>(1),
out_t_sub.data<T>(), w_sub.data<T>());
}
}
out_t.Resize({sequence_height, context_length * sequence_width});
......
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/hostdevice.h"
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/framework/data_type.h"
#include "paddle/operators/math/math_function_impl.h"
namespace paddle {
namespace operators {
......@@ -232,7 +233,34 @@ void gemv<platform::CPUPlace, double>(const platform::DeviceContext& context,
cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
void axpy<platform::CPUPlace, float>(const platform::DeviceContext& context,
const int n, const float alpha,
const float* x, float* y) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUPlace, double>(const platform::DeviceContext& context,
const int n, const double alpha,
const double* x, double* y) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template struct SetConstant<platform::CPUPlace, float>;
template struct SetConstant<platform::CPUPlace, double>;
template struct SetConstant<platform::CPUPlace, int>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUPlace, float, RANK>; \
template struct Transpose<platform::CPUPlace, double, RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
DEFINE_CPU_TRANS(3);
DEFINE_CPU_TRANS(4);
DEFINE_CPU_TRANS(5);
DEFINE_CPU_TRANS(6);
struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value)
......@@ -280,6 +308,11 @@ void set_constant(const platform::DeviceContext& context,
#endif
}
template struct RowwiseAdd<platform::CPUPlace, float>;
template struct RowwiseAdd<platform::CPUPlace, double>;
template struct ColwiseSum<platform::CPUPlace, float>;
template struct ColwiseSum<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/data_type.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/math_function_impl.h"
namespace paddle {
namespace operators {
......@@ -231,11 +233,44 @@ void gemv<platform::GPUPlace, double>(const platform::DeviceContext& context,
cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1));
}
template <>
void axpy<platform::GPUPlace, float>(const platform::DeviceContext& context,
const int n, const float alpha,
const float* x, float* y) {
PADDLE_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
n, &alpha, x, 1, y, 1));
}
template <>
void axpy<platform::GPUPlace, double>(const platform::DeviceContext& context,
const int n, const double alpha,
const double* x, double* y) {
PADDLE_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
n, &alpha, x, 1, y, 1));
}
template struct SetConstant<platform::GPUPlace, float>;
template struct SetConstant<platform::GPUPlace, double>;
template struct SetConstant<platform::GPUPlace, int>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::GPUPlace, float, RANK>; \
template struct Transpose<platform::GPUPlace, double, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);
struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context,
framework::Tensor* tensor, float value)
framework::Tensor* tensor, float value)
: context_(context), tensor_(tensor), value_(value) {}
template <typename T>
......@@ -257,6 +292,11 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU(context, tensor, value));
}
template struct RowwiseAdd<platform::GPUPlace, float>;
template struct RowwiseAdd<platform::GPUPlace, double>;
template struct ColwiseSum<platform::GPUPlace, float>;
template struct ColwiseSum<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -93,14 +93,21 @@ void gemv(const platform::DeviceContext& context, const bool trans_a,
const int M, const int N, const T alpha, const T* A, const T* B,
const T beta, T* C);
template <typename Place, typename T>
void axpy(const platform::DeviceContext& context, const int n, const T alpha,
const T* x, T* y);
template <typename Place, typename T, int Rank>
struct Transpose {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& in, framework::Tensor* out,
const std::vector<int>& axis);
};
template <typename Place, typename T>
struct SetConstant {
void operator()(const platform::DeviceContext& context,
framework::Tensor* tensor, T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.GetEigenDevice<Place>()) =
t.constant(static_cast<T>(num));
}
framework::Tensor* tensor, T num);
};
template <typename Place>
......@@ -110,6 +117,19 @@ void set_constant_with_place(const platform::DeviceContext& context,
void set_constant(const platform::DeviceContext& context,
framework::Tensor* tensor, float value);
template <typename Place, typename T>
struct RowwiseAdd {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& vec,
framework::Tensor* output);
};
template <typename Place, typename T>
struct ColwiseSum {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor* vec);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
void SetConstant<Place, T>::operator()(const platform::DeviceContext& context,
framework::Tensor* tensor, T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(num));
}
template <typename Place, typename T, int Rank>
void Transpose<Place, T, Rank>::operator()(
const platform::DeviceContext& context, const framework::Tensor& in,
framework::Tensor* out, const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out->dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(*out);
auto* dev = context.GetEigenDevice<Place>();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
template <typename Place, typename T>
void RowwiseAdd<Place, T>::operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& vector,
framework::Tensor* output) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenMatrix<T>::From(vector);
auto out = framework::EigenMatrix<T>::From(*output);
Eigen::array<int, 2> shape({{1, static_cast<int>(size)}});
Eigen::array<int, 2> bcast({{static_cast<int>(in_dims[0]), 1}});
out.device(*context.GetEigenDevice<Place>()) =
in + vec.reshape(shape).broadcast(bcast);
}
template <typename Place, typename T>
void ColwiseSum<Place, T>::operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), size);
auto vec = framework::EigenMatrix<T>::From(*vector);
auto in = framework::EigenMatrix<T>::From(input);
Eigen::array<int, 2> shape({{1, static_cast<int>(size)}});
vec.reshape(shape).device(*context.GetEigenDevice<Place>()) =
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/math/sequence2batch.h"
namespace paddle {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
......@@ -21,6 +22,10 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class CopyMatrixRowsFunctor {
public:
......
......@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/softmax.h"
#include "paddle/operators/math/softmax_impl.h"
namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUPlace, float>;
template class SoftmaxFunctor<platform::CPUPlace, double>;
template class SoftmaxGradFunctor<platform::CPUPlace, float>;
template class SoftmaxGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -15,13 +15,16 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/math/softmax.h"
#include "paddle/operators/math/softmax_impl.h"
namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>;
template class SoftmaxFunctor<platform::GPUPlace, double>;
template class SoftmaxGradFunctor<platform::GPUPlace, float>;
template class SoftmaxGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
......
......@@ -13,60 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct ValueClip {
HOSTDEVICE T operator()(const T& x) const {
const T kThreshold = -64.;
return x < kThreshold ? kThreshold : x;
}
};
template <typename Place, typename T>
class SoftmaxFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
const framework::Tensor* X, framework::Tensor* Y);
};
template <typename Place, typename T>
......@@ -74,29 +31,7 @@ class SoftmaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto dot = (softmax * softmax_grad)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
logits_grad.device(*context.GetEigenDevice<Place>()) =
(softmax_grad - dot) * softmax;
}
framework::Tensor* x_grad);
};
} // namespace math
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct ValueClip {
HOSTDEVICE T operator()(const T& x) const {
const T kThreshold = -64.;
return x < kThreshold ? kThreshold : x;
}
};
template <typename Place, typename T>
void SoftmaxFunctor<Place, T>::operator()(
const platform::DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
template <typename Place, typename T>
void SoftmaxGradFunctor<Place, T>::operator()(
const platform::DeviceContext& context, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto dot = (softmax * softmax_grad)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
logits_grad.device(*context.GetEigenDevice<Place>()) =
(softmax_grad - dot) * softmax;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -15,8 +15,8 @@
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/matmul.h"
#include "paddle/operators/transpose_op.h"
namespace paddle {
namespace operators {
......@@ -76,7 +76,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context,
if (in_dims.size() == 3) {
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2});
std::vector<int> axis = {1, 0, 2};
math::Transpose<Place, T, 3> trans;
trans(context.device_context(), input, &output, axis);
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
} else {
output.ShareDataWith(input);
......
......@@ -81,22 +81,21 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
auto& device_ctx = context.device_context();
math::set_constant(device_ctx, in_x_grad, 0);
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *out_grad, *mask, ksize,
strides, paddings, in_x_grad);
pool2d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *out_grad, *mask, ksize,
strides, paddings, in_x_grad);
pool3d_backward(device_ctx, *out_grad, *mask, ksize, strides,
paddings, in_x_grad);
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
}
......
......@@ -12,8 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_conv_op.h"
namespace ops = paddle::operators;
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/context_project.h"
#include "paddle/operators/math/math_function.h"
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h>
#include <iostream>
namespace paddle {
namespace operators {
......
......@@ -14,27 +14,44 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T, int Rank>
void EigenTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
template <typename Place, typename T>
inline void TransCompute(const int dim, const platform::DeviceContext& dev_ctx,
const framework::Tensor& in, framework::Tensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 1:
math::Transpose<Place, T, 1> trans1;
trans1(dev_ctx, in, out, axis);
break;
case 2:
math::Transpose<Place, T, 2> trans2;
trans2(dev_ctx, in, out, axis);
break;
case 3:
math::Transpose<Place, T, 3> trans3;
trans3(dev_ctx, in, out, axis);
break;
case 4:
math::Transpose<Place, T, 4> trans4;
trans4(dev_ctx, in, out, axis);
break;
case 5:
math::Transpose<Place, T, 5> trans5;
trans5(dev_ctx, in, out, axis);
break;
case 6:
math::Transpose<Place, T, 6> trans6;
trans6(dev_ctx, in, out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
auto in_dim = in.dims();
auto out_dim = out.dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(out);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute);
}
template <typename Place, typename T>
......@@ -47,28 +64,8 @@ class TransposeKernel : public framework::OpKernel<T> {
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *x, *out, axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *x, *out, axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *x, *out, axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *x, *out, axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *x, *out, axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *x, *out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
auto& dev_ctx = context.device_context();
TransCompute<Place, T>(ndims, dev_ctx, *x, out, axis);
}
};
......@@ -80,47 +77,19 @@ class TransposeGradKernel : public framework::OpKernel<T> {
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
if (!x_grad) return;
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *out_grad, *x_grad,
reversed_axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
auto& dev_ctx = context.device_context();
TransCompute<Place, T>(ndims, dev_ctx, *out_grad, x_grad, reversed_axis);
}
};
......
......@@ -62,6 +62,8 @@ extern void *cublas_dso_handle;
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSaxpy_v2); \
__macro(cublasDaxpy_v2); \
__macro(cublasSgemv_v2); \
__macro(cublasDgemv_v2); \
__macro(cublasSgemm_v2); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册