提交 4ee1c9e6 编写于 作者: D dzhwinter

"add sequence expand kernel"

上级 b3f076a6
...@@ -15,6 +15,58 @@ limitations under the License. */ ...@@ -15,6 +15,58 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_expand_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ sequence_expand_kernel(const T* x_data, T* out_data, size_t* lod,
size_t element_len) {
int BLOCK_SIZE = 1024;
__shared__ T shm_lod[BLOCK_SIZE];
for (int idx = threadIdx.x; idx < BLOCK_SIZE; ++idx) {
shm_lod[idx] = lod[idx];
}
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < lod.size();
idx += blockDim.x * gridDim.x) {
int scale = lod[i]
}
}
template <typename T>
void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()(
const platform::CPUDeviceContext& context, const LoDTensor& x,
LoDTensor* out) {
x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
x.data<T>(), out->mutable_data<T>(context.GetPlace()),
out_starts.CUDAData(context.GetPlace()), element_len);
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sequence_expand, sequence_expand,
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,22 +24,18 @@ namespace operators { ...@@ -24,22 +24,18 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> { struct SequenceExpandFunctor {
public: void operator()(const DeviceContext& ctx, const LoDTensor& x, LoDTensor* out);
void Compute(const framework::ExecutionContext& context) const override { };
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); // template <typename DeviceContext, typename T>
const T* x_data = x->data<T>(); // struct SequenceExpandGradFunctor {};
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y"); template <typename T>
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); void SequenceExpandFunctor<platform::CPUDeviceContext, T>::operator()(
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]), const platform::CPUDeviceContext& context, const LoDTensor& x,
y->lod().back().size() - 1, LoDTensor* out) {
"The size of last lod level in Input(Y)" x_dims = x.dims();
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
auto* place =
context.template device_context<DeviceContext>().eigen_device();
size_t element_len = framework::product(x_dims) / x_dims[0]; size_t element_len = framework::product(x_dims) / x_dims[0];
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back(); auto out_starts = out->lod().back();
...@@ -52,10 +48,29 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -52,10 +48,29 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>> Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>>
out_t(out_data, scale, element_len); out_t(out_data, scale, element_len);
Eigen::array<int, 2> cast({{scale, 1}}); Eigen::array<int, 2> cast({{scale, 1}});
out_t.device(*place) = x_t.broadcast(cast); out_t.device(*context.eigen_device()) = x_t.broadcast(cast);
x_data += element_len; x_data += element_len;
out_data += element_len * scale; out_data += element_len * scale;
} }
}
template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
const T* x_data = x->data<T>();
auto x_dims = x->dims();
auto* y = context.Input<LoDTensor>("Y");
PADDLE_ENFORCE(!y->lod().empty(), "y should have lod");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims[0]),
y->lod().back().size() - 1,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X).");
out->set_lod(y->lod());
SequenceExpandFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *x, out);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册