未验证 提交 8425c2c8 编写于 作者: D dzhwinter 提交者: GitHub

Speed/sequence op1 (#9217)

* "add functors"

* "remove old code"

* "fix"

* "fix ci"

* "add details"

* "fix ci"

* "fix ci"

* "fix ci"

* "fix ci"

* "remove unused code"
上级 d21ab2e2
...@@ -19,8 +19,17 @@ namespace paddle { ...@@ -19,8 +19,17 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> { class MaxSeqPoolFunctor {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output, const framework::LoDTensor& input, framework::Tensor* output,
...@@ -60,7 +69,7 @@ class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> { ...@@ -60,7 +69,7 @@ class MaxSeqPoolFunctor<platform::CPUDeviceContext, T> {
}; };
template <typename T> template <typename T>
class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> { class MaxSeqPoolGradFunctor {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad, const framework::Tensor& out_grad,
...@@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -93,10 +102,101 @@ class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, float>; template <typename T>
template class MaxSeqPoolFunctor<platform::CPUDeviceContext, double>; class SequencePoolFunctor<platform::CPUDeviceContext, T> {
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, float>; public:
template class MaxSeqPoolGradFunctor<platform::CPUDeviceContext, double>; /* max pool has index output */
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* output,
framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<T> max_pool;
max_pool(context, input, output, index);
return;
}
auto lod = input.lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
Tensor out_t = output->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = input.numel() / input.dims()[0];
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template <typename T>
class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const std::string pooltype, const framework::Tensor& out_grad,
framework::LoDTensor* in_grad,
/* max pool has index */
const framework::Tensor* index = nullptr) {
if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<T> max_pool_grad;
max_pool_grad(context, out_grad, *index, in_grad);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0);
}
auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1]));
auto out_g_t = out_grad.Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t w = in_grad->numel() / in_grad->dims()[0];
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
}
}
};
template class SequencePoolFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolFunctor<platform::CPUDeviceContext, double>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, float>;
template class SequencePoolGradFunctor<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,113 +23,331 @@ namespace math { ...@@ -22,113 +23,331 @@ namespace math {
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
template <typename T> template <typename T>
__global__ void KeMaxSequencePool(const T* input, const size_t* starts, struct MaxPoolFunctor {
T* output, int* index, int64_t num_seq, HOSTDEVICE void operator()(const T* input, const size_t start,
int64_t dim) { const size_t end, const size_t item_dim, T* output,
int dim_idx = threadIdx.x; int* index) {
int seq_id = blockIdx.x; for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (seq_id >= num_seq) return;
size_t start = starts[seq_id];
size_t end = starts[seq_id + 1];
for (int64_t i = dim_idx; i < dim; i += blockDim.x) {
T max_val = static_cast<T>(-FLT_MAX); T max_val = static_cast<T>(-FLT_MAX);
int max_id = -1; int max_index = -1;
for (size_t step_id = start; step_id < end; step_id++) { for (int i = start; i < end; ++i) {
if (max_val < input[step_id * dim + i]) { if (max_val < input[item_dim * i + tid]) {
max_val = input[step_id * dim + i]; max_val = input[item_dim * i + tid];
max_id = step_id; max_index = i;
} }
} }
output[seq_id * dim + i] = max_val; output[tid] = max_val;
index[seq_id * dim + i] = max_id; index[tid] = max_index;
} }
}
};
template <typename T>
struct AvgPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / static_cast<T>(end - start);
}
}
};
template <typename T>
struct SumPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
output[tid] = val;
}
}
};
template <typename T>
struct SqrtPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
T val = static_cast<T>(0);
for (int i = start; i < end; ++i) {
val += input[item_dim * i + tid];
}
// end, start is lod, so end - start != 0
output[tid] = val / sqrt(end - start);
}
}
};
template <typename T>
struct LastPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * (end - 1) + tid];
}
}
};
template <typename T>
struct FirstPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start,
const size_t end, const size_t item_dim, T* output,
int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
output[tid] = input[item_dim * start + tid];
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_kernel(Range_OP op, const T* input,
const size_t* lod, const size_t lod_size,
const size_t item_dim, T* output,
int* index) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
size_t start = lod[bid];
size_t end = lod[bid + 1];
int* index_offset = nullptr;
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(input, start, end, item_dim, &output[bid * item_dim], index_offset);
} }
template <typename T> template <typename T>
class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> { class SequencePoolFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& input, framework::Tensor* output, const std::string pooltype, const framework::LoDTensor& input,
framework::Tensor* index) { framework::Tensor* output,
auto in_dims = input.dims(); framework::Tensor* index = nullptr) {
auto out_dims = output->dims(); auto lod = input.lod()[0];
auto idx_dims = index->dims(); const size_t item_dim = output->numel() / output->dims()[0];
PADDLE_ENFORCE_GT(in_dims.size(), static_cast<int64_t>(1)); dim3 threads(1024, 1);
PADDLE_ENFORCE_GT(out_dims.size(), 1); dim3 grid(lod.size(), 1);
for (int64_t i = 1; i < in_dims.size(); ++i) { if (pooltype == "MAX") {
PADDLE_ENFORCE_EQ(in_dims[i], out_dims[i]); sequence_pool_kernel<
} T, MaxPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
PADDLE_ENFORCE_EQ(idx_dims, out_dims); MaxPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
auto starts = input.lod()[0]; output->mutable_data<T>(context.GetPlace()), index->data<int>());
const T* in_data = input.data<T>(); } else if (pooltype == "AVERAGE") {
T* out_data = output->data<T>(); sequence_pool_kernel<
int* max_index = index->data<int>(); T, AvgPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
AvgPoolFunctor<T>(), input.data<T>(),
int64_t num_seq = out_dims[0]; lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
int64_t dim = output->numel() / num_seq; output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "SUM") {
dim3 threads(256, 1); sequence_pool_kernel<
dim3 grid(num_seq, 1); T, SumPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
auto stream = context.stream(); SumPoolFunctor<T>(), input.data<T>(),
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>( lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_data, starts.CUDAData(context.GetPlace()), out_data, max_index, output->mutable_data<T>(context.GetPlace()), nullptr);
num_seq, dim); } else if (pooltype == "SQRT") {
sequence_pool_kernel<
T, SqrtPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_kernel<
T, LastPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_kernel<
T, FirstPoolFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolFunctor<T>(), input.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
output->mutable_data<T>(context.GetPlace()), nullptr);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
} }
}; };
template <typename T> template <typename T>
__global__ void KeMaxSequencePoolGrad(const T* out_grad, const int* max_index, struct MaxPoolGradFunctor {
T* in_grad, int64_t num_seq, HOSTDEVICE void operator()(const T* out_grad, const size_t start,
int64_t dim) { const size_t end, const size_t item_dim,
int idx = threadIdx.x + blockIdx.x * blockDim.x; T* in_grad, const int* index) {
int col_idx = idx % dim; for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
if (idx < num_seq * dim) { for (int i = start; i < end; ++i) {
int step_id = max_index[idx]; if (i == index[tid]) {
in_grad[step_id * dim + col_idx] = out_grad[idx]; in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
} }
}
}
}
};
template <typename T>
struct AvgPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
}
}
}
};
template <typename T>
struct SumPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] = out_grad[tid];
}
}
}
};
template <typename T>
struct SqrtPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
in_grad[item_dim * i + tid] =
out_grad[tid] / (sqrt(static_cast<T>(end - start)));
}
}
}
};
template <typename T>
struct LastPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == end - 1) {
in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
}
}
}
}
};
template <typename T>
struct FirstPoolGradFunctor {
HOSTDEVICE void operator()(const T* out_grad, const size_t start,
const size_t end, const size_t item_dim,
T* in_grad, const int* index) {
for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
for (int i = start; i < end; ++i) {
if (i == start) {
in_grad[item_dim * i + tid] = out_grad[tid];
} else {
in_grad[item_dim * i + tid] = static_cast<T>(0);
}
}
}
}
};
template <typename T, typename Range_OP>
__global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad,
const size_t* lod,
const size_t lod_size,
const size_t item_dim, T* in_grad,
const int* index) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
size_t start = lod[bid];
size_t end = lod[bid + 1];
const int* index_offset = nullptr;
if (index != nullptr) {
index_offset = &index[bid * item_dim];
}
op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
} }
template <typename T> template <typename T>
class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, T> { class SequencePoolGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& out_grad, const std::string pooltype, const framework::Tensor& out_grad,
const framework::Tensor& index, framework::LoDTensor* in_grad,
framework::LoDTensor* in_grad) { /* max pool has index */
auto og_dims = out_grad.dims(); const framework::Tensor* index = nullptr) {
auto idx_dims = index.dims(); auto lod = in_grad->lod()[0];
auto ig_dims = in_grad->dims(); const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE_GT(og_dims.size(), static_cast<int64_t>(1)); dim3 threads(1024, 1);
PADDLE_ENFORCE_GT(ig_dims.size(), static_cast<int64_t>(1)); dim3 grid(lod.size(), 1);
for (int64_t i = 1; i < og_dims.size(); ++i) { if (pooltype == "MAX") {
PADDLE_ENFORCE_EQ(og_dims[i], ig_dims[i]); sequence_pool_grad_kernel<
} T, MaxPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
PADDLE_ENFORCE_EQ(idx_dims, og_dims); MaxPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
const T* og_data = out_grad.data<T>(); in_grad->mutable_data<T>(context.GetPlace()), index->data<int>());
const int* max_index = index.data<int>(); } else if (pooltype == "AVERAGE") {
T* ig_data = in_grad->data<T>(); sequence_pool_grad_kernel<
T, AvgPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SetConstant<platform::CUDADeviceContext, T> set_zero; AvgPoolGradFunctor<T>(), out_grad.data<T>(),
set_zero(context, in_grad, static_cast<T>(0.0)); lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
int64_t num_seq = og_dims[0]; in_grad->mutable_data<T>(context.GetPlace()), nullptr);
int64_t dim = out_grad.numel() / num_seq; } else if (pooltype == "SUM") {
sequence_pool_grad_kernel<
unsigned int blocks = (num_seq * dim + 128 - 1) / 128; T, SumPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
dim3 threads(128, 1); SumPoolGradFunctor<T>(), out_grad.data<T>(),
dim3 grid(blocks, 1); lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
auto stream = context.stream(); in_grad->mutable_data<T>(context.GetPlace()), nullptr);
KeMaxSequencePoolGrad<T><<<grid, threads, 0, stream>>>( } else if (pooltype == "SQRT") {
og_data, max_index, ig_data, num_seq, dim); sequence_pool_grad_kernel<
T, SqrtPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
SqrtPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "LAST") {
sequence_pool_grad_kernel<
T, LastPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
LastPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else if (pooltype == "FIRST") {
sequence_pool_grad_kernel<
T, FirstPoolGradFunctor<T>><<<grid, threads, 0, context.stream()>>>(
FirstPoolGradFunctor<T>(), out_grad.data<T>(),
lod.CUDAData(context.GetPlace()), lod.size(), item_dim,
in_grad->mutable_data<T>(context.GetPlace()), nullptr);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
} }
}; };
template class MaxSeqPoolFunctor<platform::CUDADeviceContext, float>; // sequence pooling
template class MaxSeqPoolFunctor<platform::CUDADeviceContext, double>; template class SequencePoolFunctor<platform::CUDADeviceContext, float>;
template class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, float>; template class SequencePoolFunctor<platform::CUDADeviceContext, double>;
template class MaxSeqPoolGradFunctor<platform::CUDADeviceContext, double>; template class SequencePoolGradFunctor<platform::CUDADeviceContext, float>;
template class SequencePoolGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -21,23 +21,23 @@ namespace paddle { ...@@ -21,23 +21,23 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#define FLT_MAX __FLT_MAX__
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MaxSeqPoolFunctor { class SequencePoolFunctor {
public: public:
void operator()(const DeviceContext& context, /* max pool has index output */
void operator()(const DeviceContext& context, const std::string pooltype,
const framework::LoDTensor& input, framework::Tensor* output, const framework::LoDTensor& input, framework::Tensor* output,
framework::Tensor* index); framework::Tensor* index = nullptr);
}; };
template <typename DeviceContext, class T> template <typename DeviceContext, typename T>
class MaxSeqPoolGradFunctor { class SequencePoolGradFunctor {
public: public:
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context, const std::string pooltype,
const framework::Tensor& out_grad, const framework::Tensor& out_grad,
const framework::Tensor& index, framework::LoDTensor* in_grad,
framework::LoDTensor* in_grad); /* max pool has index */
const framework::Tensor* index = nullptr);
}; };
} // namespace math } // namespace math
......
...@@ -23,12 +23,6 @@ namespace operators { ...@@ -23,12 +23,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequencePoolKernel : public framework::OpKernel<T> { class SequencePoolKernel : public framework::OpKernel<T> {
...@@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -37,11 +31,13 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<Tensor>("Out"); auto* out = context.Output<Tensor>("Out");
std::string pooltype = context.Attr<std::string>("pooltype"); std::string pooltype = context.Attr<std::string>("pooltype");
Tensor* index = nullptr;
if (pooltype == "MAX") {
index = context.Output<Tensor>("MaxIndex");
}
auto dims = in->dims(); auto dims = in->dims();
auto lod = in->lod(); auto lod = in->lod();
int64_t w = in->numel() / dims[0];
// InferShape by lod // InferShape by lod
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -50,45 +46,14 @@ class SequencePoolKernel : public framework::OpKernel<T> {
"The first dimension of Input(X) must be large than batch size."); "The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1; dims[0] = lod[0].size() - 1;
out->Resize({dims}); out->Resize({dims});
auto lod_level_0 = lod[0];
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
if (pooltype == "MAX") { if (pooltype == "MAX") {
math::MaxSeqPoolFunctor<DeviceContext, T> max_pool;
auto* index = context.Output<Tensor>("MaxIndex");
index->Resize({dims}); index->Resize({dims});
index->mutable_data<int>(context.GetPlace()); index->mutable_data<int>(context.GetPlace());
max_pool(dev_ctx, *in, out, index);
return;
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = in->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = out->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod_level_0[i + 1] - lod_level_0[i]);
auto in_e = EigenMatrix<T>::From(in_t, framework::make_ddim({h, w}));
auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
} }
math::SequencePoolFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *in, out,
index);
} }
}; };
...@@ -96,58 +61,17 @@ template <typename DeviceContext, typename T> ...@@ -96,58 +61,17 @@ template <typename DeviceContext, typename T>
class SequencePoolGradKernel : public framework::OpKernel<T> { class SequencePoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<Tensor>(framework::GradVarName("Out")); auto* out_g = context.Input<Tensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X")); auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
std::string pooltype = context.Attr<std::string>("pooltype"); std::string pooltype = context.Attr<std::string>("pooltype");
const Tensor* index = nullptr;
auto dims = in->dims();
auto lod = in->lod()[0];
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
if (pooltype == "MAX") { if (pooltype == "MAX") {
math::MaxSeqPoolGradFunctor<DeviceContext, T> max_pool_grad; index = context.Input<Tensor>("MaxIndex");
auto* index = context.Input<Tensor>("MaxIndex");
max_pool_grad(dev_ctx, *out_g, *index, in_g);
return;
}
if (pooltype == "LAST" || pooltype == "FIRST") {
// set X@Grad be zero at first when pooltype is LAST/FIRST
math::SetConstant<DeviceContext, T> functor;
functor(dev_ctx, in_g, 0);
}
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t =
in_g->Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
auto out_g_t = out_g->Slice(i, i + 1);
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
in_g_e.device(place) = (out_g_e).broadcast(bcast);
} else if (pooltype == "SQRT") {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
} }
in_g->mutable_data<T>(context.GetPlace());
math::SequencePoolGradFunctor<DeviceContext, T> pool;
pool(context.template device_context<DeviceContext>(), pooltype, *out_g,
in_g, index);
} }
}; };
......
...@@ -49,6 +49,61 @@ class TestSeqAvgPool(OpTest): ...@@ -49,6 +49,61 @@ class TestSeqAvgPool(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqAvgPool2D(TestSeqAvgPool): class TestSeqAvgPool2D(TestSeqAvgPool):
def set_data(self): def set_data(self):
self.op_type = 'sequence_pool' self.op_type = 'sequence_pool'
...@@ -68,14 +123,6 @@ class TestSeqAvgPool2D(TestSeqAvgPool): ...@@ -68,14 +123,6 @@ class TestSeqAvgPool2D(TestSeqAvgPool):
out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17))
class TestSeqSumPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x.sum(axis=0)
class TestSeqSumPool2D(TestSeqAvgPool2D): class TestSeqSumPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out): def compute(self, x, lod, out):
self.attrs = {'pooltype': "SUM"} self.attrs = {'pooltype': "SUM"}
...@@ -84,15 +131,6 @@ class TestSeqSumPool2D(TestSeqAvgPool2D): ...@@ -84,15 +131,6 @@ class TestSeqSumPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) out[i] = np.reshape(sub_x.sum(axis=0), (3, 17))
class TestSeqSqrtPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
len = lod[0][i + 1] - lod[0][i]
out[i] = sub_x.sum(axis=0) / np.sqrt(len)
class TestSeqSqrtPool2D(TestSeqAvgPool2D): class TestSeqSqrtPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out): def compute(self, x, lod, out):
self.attrs = {'pooltype': "SQRT"} self.attrs = {'pooltype': "SQRT"}
...@@ -108,28 +146,6 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): ...@@ -108,28 +146,6 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
self.check_grad(["X"], "Out", max_relative_error=0.06) self.check_grad(["X"], "Out", max_relative_error=0.06)
class TestSeqMaxPool(TestSeqAvgPool):
def set_data(self):
self.op_type = 'sequence_pool'
x = np.random.uniform(0.1, 1, [13, 23]).astype('float32')
lod = [[0, 4, 5, 8, 13]]
for i in range(4):
l = lod[0][i + 1] - lod[0][i]
x[lod[0][i] + np.random.randint(l), :] += 2.0
self.inputs = {'X': (x, lod)}
out = np.zeros((4, 23)).astype('float32')
self.outputs = {'Out': out}
return x, lod, out
def compute(self, x, lod, out):
self.attrs = {'pooltype': "MAX"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = np.amax(sub_x, axis=0)
class TestSeqMaxPool2D(TestSeqAvgPool2D): class TestSeqMaxPool2D(TestSeqAvgPool2D):
def set_data(self): def set_data(self):
self.op_type = 'sequence_pool' self.op_type = 'sequence_pool'
...@@ -151,14 +167,6 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D): ...@@ -151,14 +167,6 @@ class TestSeqMaxPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11)) out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 11))
class TestSeqLastPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[-1, :]
class TestSeqLastPool2D(TestSeqAvgPool2D): class TestSeqLastPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out): def compute(self, x, lod, out):
self.attrs = {'pooltype': "LAST"} self.attrs = {'pooltype': "LAST"}
...@@ -167,14 +175,6 @@ class TestSeqLastPool2D(TestSeqAvgPool2D): ...@@ -167,14 +175,6 @@ class TestSeqLastPool2D(TestSeqAvgPool2D):
out[i] = np.reshape(sub_x[-1, :], (3, 17)) out[i] = np.reshape(sub_x[-1, :], (3, 17))
class TestSeqFirstPool(TestSeqAvgPool):
def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"}
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
out[i] = sub_x[0, :]
class TestSeqFirstPool2D(TestSeqAvgPool2D): class TestSeqFirstPool2D(TestSeqAvgPool2D):
def compute(self, x, lod, out): def compute(self, x, lod, out):
self.attrs = {'pooltype': "FIRST"} self.attrs = {'pooltype': "FIRST"}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册