提交 26822bd7 编写于 作者: D dzhwinter

"add sequence kernel"

上级 4ee1c9e6
...@@ -21,48 +21,89 @@ namespace operators { ...@@ -21,48 +21,89 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
__global__ sequence_expand_kernel(const T* x_data, T* out_data, size_t* lod, __global__ void sequence_expand_kernel(const T* x_data, T* out_data,
const size_t* lod, size_t lod_size,
size_t element_len) { size_t element_len) {
int BLOCK_SIZE = 1024; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ T shm_lod[BLOCK_SIZE]; for (; tid_x < static_cast<int>(lod_size - 1);
for (int idx = threadIdx.x; idx < BLOCK_SIZE; ++idx) { tid_x += blockDim.x * gridDim.x) {
shm_lod[idx] = lod[idx]; int scale = lod[tid_x + 1] - lod[tid_x];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) {
int tid_z = blockIdx.z * blockDim.z + threadIdx.z;
int item_start = tid_x / element_len;
for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) {
out_data[item_start * scale + tid_z] = x_data[item_start + tid_z];
}
} }
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < lod.size();
idx += blockDim.x * gridDim.x) {
int scale = lod[i]
} }
} }
template <typename T> template <typename T>
void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()( __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data,
const platform::CPUDeviceContext& context, const LoDTensor& x, const size_t* lod, size_t lod_size,
LoDTensor* out) { size_t element_len,
x_dims = x.dims(); size_t dout_size) {
extern __shared__ T shm[];
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < static_cast<int>(lod_size - 1);
tid_x += blockDim.x * gridDim.x) {
int scale = lod[tid_x + 1] - lod[tid_x];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) {
int tid_z = blockIdx.z * blockDim.z + threadIdx.z;
int item_start = tid_x / element_len;
for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) {
shm[item_start + tid_z] += doutx_data[item_start * scale + tid_z];
}
}
}
// synchronize before write to dx
__syncthreads();
for (int idx = blockDimx * blockIdx.x + threadIdx.x;
idx < static_cast<int>(dout_size); idx += blockDim.x * gridDim.x) {
dx_data[idx] = shm[idx;]
}
}
template <typename T>
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const LoDTensor& x, LoDTensor* out) {
auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0]; size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back(); auto out_starts = out->lod().back();
const int kThreadsPerBlock = 1024; dim3 block_size(16, 32, element_len);
int block_cols = kThreadsPerBlock; dim3 grid_size(10, 10);
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
block_cols = ((out_cols + 31) >> 5) << 5; x.data<T>(), out->mutable_data<T>(context.GetPlace()),
out_starts.CUDAData(context.GetPlace()), out_starts.size(),
element_len);
} }
int block_rows = kThreadsPerBlock / block_cols; };
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount(); template <typename T>
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const LoDTensor& x,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
int grid_cols = dim3 block_size(16, 32, element_len);
std::min((out_cols + block_cols - 1) / block_cols, max_blocks); dim3 grid_size(10, 10);
int grid_rows = size_t out_size = framework::product(dx->dims());
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); sequence_expand_kernel<<<grid_size, block_size, out_size * sizeof(T),
dim3 grid_size = dim3(grid_cols, grid_rows, 1); context.stream()>>>(
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>( dout.data<T>(), dx->mutable_data<T>(context.GetPlace()),
x.data<T>(), out->mutable_data<T>(context.GetPlace()), out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len,
out_starts.CUDAData(context.GetPlace()), element_len); out_size);
} }
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -28,15 +28,19 @@ struct SequenceExpandFunctor { ...@@ -28,15 +28,19 @@ struct SequenceExpandFunctor {
void operator()(const DeviceContext& ctx, const LoDTensor& x, LoDTensor* out); void operator()(const DeviceContext& ctx, const LoDTensor& x, LoDTensor* out);
}; };
// template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
// struct SequenceExpandGradFunctor {}; struct SequenceExpandGradFunctor {
void operator()(const DeviceContext& ctx, const LoDTensor& x,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx);
};
template <typename T> template <typename T>
void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()( struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
const platform::CPUDeviceContext& context, const LoDTensor& x, void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x,
LoDTensor* out) { LoDTensor* out) {
x_dims = x.dims(); auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0]; size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back(); auto out_starts = out->lod().back();
...@@ -52,7 +56,8 @@ void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()( ...@@ -52,7 +56,8 @@ void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()(
x_data += element_len; x_data += element_len;
out_data += element_len * scale; out_data += element_len * scale;
} }
} }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> { class SequenceExpandKernel : public framework::OpKernel<T> {
...@@ -60,7 +65,6 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -60,7 +65,6 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y"); auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
...@@ -86,19 +90,14 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -86,19 +90,14 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* Grad(X).lod = Input(X).lod * Grad(X).lod = Input(X).lod
* *
* */ * */
template <typename DeviceContext, typename T> template <typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> { struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
public: void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x,
void Compute(const framework::ExecutionContext& context) const override { const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out")); auto out_last_level = out.lod().back();
auto* x = context.Input<LoDTensor>("X"); const T* d_out_data = d_out.data<T>();
auto* out = context.Input<LoDTensor>("Out");
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto out_last_level = out->lod().back();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace()); T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = d_out->numel() / d_out->dims()[0]; size_t element_len = d_out.numel() / d_out.dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) { for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i]; size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap< Eigen::TensorMap<
...@@ -106,14 +105,27 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> { ...@@ -106,14 +105,27 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
d_out_t(d_out_data, static_cast<int>(repeat), element_len); d_out_t(d_out_data, static_cast<int>(repeat), element_len);
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>> Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>
d_x_t(d_x_data, static_cast<int>(element_len)); d_x_t(d_x_data, static_cast<int>(element_len));
auto place = d_x_t.device(*context.eigen_device()) =
context.template device_context<DeviceContext>().eigen_device(); d_out_t.sum(Eigen::array<int, 1>({{0}}));
d_x_t.device(*place) = d_out_t.sum(Eigen::array<int, 1>({{0}}));
d_out_data += (repeat * element_len); d_out_data += (repeat * element_len);
d_x_data += element_len; d_x_data += element_len;
} }
} }
}; };
template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
d_x->set_lod(x->lod());
SequenceExpandGradFunctor(context.template device_context(), *x, *out,
d_out, d_x);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册