未验证 提交 c1bf06f9 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9289 from dzhwinter/speed/sequence_expand

Speed/sequence expand
...@@ -84,12 +84,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -84,12 +84,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
} }
} }
out_dims[0] = out_first_dim; out_dims[0] = out_first_dim;
ctx->SetOutputDim("Out", out_dims);
} else { } else {
out_dims[0] = -1; out_dims[0] = -1;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
......
...@@ -12,8 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #include <algorithm>
#include "paddle/fluid/operators/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_expand_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename T>
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
const size_t* ref_lod,
const size_t* offset,
const size_t lod_size,
/* default=1,
the instance length*/
const int x_item_length, T* out_data) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
int x_item_count = x_lod[bid + 1] - x_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = static_cast<int>(offset[bid]);
int x_offset = x_lod[bid];
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
for (int tid_x = threadIdx.x; tid_x < x_item_length;
tid_x += blockDim.x) {
out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length +
tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x];
}
}
}
}
template <typename T>
__global__ void sequence_expand_grad_kernel(
const T* dout_data, const size_t* ref_lod, const size_t* dx_lod,
const size_t* offset, const size_t lod_size,
/* default=1,
the instance length*/
const int x_item_length, T* dx_data) {
int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = static_cast<int>(offset[bid]);
int x_offset = dx_lod[bid];
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
for (int tid_x = threadIdx.x; tid_x < x_item_length;
tid_x += blockDim.x) {
platform::CudaAtomicAdd(
&dx_data[(x_offset + tid_y) * x_item_length + tid_x],
dout_data[(out_offset + tid_z * x_item_count + tid_y) *
x_item_length +
tid_x]);
}
}
}
}
void GetOutputOffset(const framework::Vector<size_t>& x_lod,
const framework::Vector<size_t>& ref_lod,
framework::Vector<size_t>* out_offset) {
size_t offset = 0;
int lod_size = static_cast<int>(x_lod.size());
for (int i = 0; i < static_cast<int>(x_lod.size()); ++i) {
(*out_offset)[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
}
}
}
template <typename T>
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
void operator()(
const platform::CUDADeviceContext& context, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) {
int x_item_length = x.numel() / x.dims()[0];
framework::Vector<size_t> out_offset(x_lod.size());
GetOutputOffset(x_lod, ref_lod, &out_offset);
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
dim3 grid_size(block_x, 1);
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
x.data<T>(), x_lod.CUDAData(context.GetPlace()),
ref_lod.CUDAData(context.GetPlace()),
out_offset.CUDAData(context.GetPlace()), x_lod.size(), x_item_length,
out->mutable_data<T>(context.GetPlace()));
}
};
template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand based lod*/
LoDTensor* dx) {
int x_item_length = framework::product(dx->dims()) / dx->dims()[0];
framework::Vector<size_t> out_offset(x_lod.size());
GetOutputOffset(x_lod, ref_lod, &out_offset);
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
dim3 grid_size(block_x, 1);
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
x_lod.CUDAData(context.GetPlace()),
out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
dx->mutable_data<T>(context.GetPlace()));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <numeric> // std::iota
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
...@@ -26,6 +27,57 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,6 +27,57 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
struct SequenceExpandFunctor {
void operator()(
const DeviceContext& ctx, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out);
};
template <typename DeviceContext, typename T>
struct SequenceExpandGradFunctor {
void operator()(
const DeviceContext& ctx, const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx);
};
template <typename T>
struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
void operator()(
const platform::CPUDeviceContext& context, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) {
int out_offset = 0;
auto& eigen_place = *context.eigen_device();
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
int x_start = x_lod[i - 1];
int x_end = x_lod[i];
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x.Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (out->lod().size() == 1) {
out_start = out->lod()[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
out_offset += repeat_num;
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> { class SequenceExpandKernel : public framework::OpKernel<T> {
public: public:
...@@ -47,45 +99,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -47,45 +99,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
return; return;
} }
auto& out_lod = *out->mutable_lod(); // x lod level is at most 1.
framework::Vector<size_t> out_lod;
if (x_lod.size() == 1) { if (x_lod.size() == 1) {
out_lod.resize(1); out_lod.push_back(0);
out_lod[0] = {0}; int out_offset = 0;
} for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int out_offset = 0; int x_start = x_lod[0][i - 1];
auto& eigen_place = int x_end = x_lod[0][i];
*context.template device_context<DeviceContext>().eigen_device(); int x_seq_len = x_end - x_start;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { for (int j = 0; j < repeat_num; ++j) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; out_lod.push_back(out_lod.back() + x_seq_len);
int x_start = i - 1; out_offset++;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x->Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (x_lod.size() == 1) {
out_start = out_lod[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
for (int j = 0; j < repeat_num; ++j) {
if (x_lod.size() == 1) {
out_lod[0].push_back(out_lod[0].back() + x_seq_len);
} }
out_offset++;
} }
// write lod to out if x has lod
auto& ref_lod = *out->mutable_lod();
ref_lod[0] = out_lod;
} }
framework::Vector<size_t> ref_x_lod;
if (x->lod().size() == 1) {
ref_x_lod = x->lod()[0];
} else {
// x_lod doesn't has lod, use fake x lod, level = 0
ref_x_lod.resize(x->dims()[0] + 1);
std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0);
}
SequenceExpandFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *x, ref_x_lod,
y_lod[ref_level], out);
} }
}; };
...@@ -101,6 +144,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -101,6 +144,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* Grad(X).lod = Input(X).lod * Grad(X).lod = Input(X).lod
* *
* */ * */
template <typename T>
struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
void operator()(
const platform::CPUDeviceContext& context, const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
set_zero(context, dx, static_cast<T>(0));
int dout_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
if (repeat_num > 0) {
int x_start = x_lod[i - 1];
int x_end = x_lod[i];
int x_seq_len = x_end - x_start;
auto dx_sub = dx->Slice(x_start, x_end);
dx_sub.Resize(flatten_to_1d(dx_sub.dims()));
int dout_end = dout_offset + repeat_num * x_seq_len;
auto dout_sub = dout.Slice(dout_offset, dout_end);
dout_sub.Resize({repeat_num, dx_sub.dims()[0]});
math::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
col_sum(context, dout_sub, &dx_sub);
dout_offset += repeat_num * x_seq_len;
}
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> { class SequenceExpandGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -114,43 +187,26 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> { ...@@ -114,43 +187,26 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
g_x->mutable_data<T>(context.GetPlace()); g_x->mutable_data<T>(context.GetPlace());
g_x->set_lod(x->lod()); g_x->set_lod(x->lod());
auto& x_lod = x->lod();
auto& y_lod = y->lod(); auto& y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1; if (ref_level == -1) ref_level = y_lod.size() - 1;
// just copy the gradient // just copy the gradient
if (y_lod[ref_level].size() <= 1) { if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*g_out, context.GetPlace(), g_x); framework::TensorCopy(*g_out, context.GetPlace(), g_x);
return; return;
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); framework::Vector<size_t> ref_x_lod;
framework::Vector<size_t> ref_lod = y_lod[ref_level];
math::SetConstant<DeviceContext, T> set_zero; if (x->lod().size() == 1) {
set_zero(dev_ctx, g_x, static_cast<T>(0)); ref_x_lod = x->lod()[0];
} else {
int g_out_offset = 0; // x_lod doesn't has lod, use fake x lod, level = 0
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) { ref_x_lod.resize(x->dims()[0] + 1);
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1]; std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0);
if (repeat_num > 0) {
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
auto g_x_sub = g_x->Slice(x_start, x_end);
g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
int g_out_end = g_out_offset + repeat_num * x_seq_len;
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
math::ColwiseSum<DeviceContext, T> col_sum;
col_sum(dev_ctx, g_out_sub, &g_x_sub);
g_out_offset += repeat_num * x_seq_len;
}
} }
SequenceExpandGradFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *g_out, ref_x_lod,
ref_lod, g_x);
} }
}; };
......
...@@ -34,6 +34,8 @@ function(py_test_modules TARGET_NAME) ...@@ -34,6 +34,8 @@ function(py_test_modules TARGET_NAME)
endif() endif()
endfunction() endfunction()
list(REMOVE_ITEM TEST_OPS test_sequence_expand)
# test time consuming OPs in a separate process for expliot parallism # test time consuming OPs in a separate process for expliot parallism
list(REMOVE_ITEM TEST_OPS test_parallel_executor) list(REMOVE_ITEM TEST_OPS test_parallel_executor)
list(REMOVE_ITEM TEST_OPS test_warpctc_op) list(REMOVE_ITEM TEST_OPS test_warpctc_op)
...@@ -70,6 +72,8 @@ else() ...@@ -70,6 +72,8 @@ else()
endforeach(TEST_OP) endforeach(TEST_OP)
endif(WITH_FAST_BUNDLE_TEST) endif(WITH_FAST_BUNDLE_TEST)
#
py_test_modules(test_sequence_expand MODULES test_sequence_expand)
# tests with high overhead # tests with high overhead
py_test_modules(test_parallel_executor MODULES test_parallel_executor) py_test_modules(test_parallel_executor MODULES test_parallel_executor)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR}) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR})
......
...@@ -47,8 +47,10 @@ class TestSequenceExpand(OpTest): ...@@ -47,8 +47,10 @@ class TestSequenceExpand(OpTest):
x_len = x_idx[i] - x_idx[i - 1] x_len = x_idx[i] - x_idx[i - 1]
if repeat_num > 0: if repeat_num > 0:
x_sub = x_data[x_idx[i - 1]:x_idx[i], :] x_sub = x_data[x_idx[i - 1]:x_idx[i], :]
x_sub = np.repeat(x_sub, repeat_num, axis=0) stacked_x_sub = x_sub
out = np.vstack((out, x_sub)) for r in range(repeat_num - 1):
stacked_x_sub = np.vstack((stacked_x_sub, x_sub))
out = np.vstack((out, stacked_x_sub))
if x_lod is not None: if x_lod is not None:
for j in xrange(repeat_num): for j in xrange(repeat_num):
out_lod[0].append(out_lod[0][-1] + x_len) out_lod[0].append(out_lod[0][-1] + x_len)
...@@ -101,11 +103,11 @@ class TestSequenceExpandCase3(TestSequenceExpand): ...@@ -101,11 +103,11 @@ class TestSequenceExpandCase3(TestSequenceExpand):
class TestSequenceExpandCase4(TestSequenceExpand): class TestSequenceExpandCase4(TestSequenceExpand):
def set_data(self): def set_data(self):
data = [0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3] data = np.random.uniform(0.1, 1, [5 * 2, 1])
x_data = np.array(data).reshape([5, 2]).astype('float32') x_data = np.array(data).reshape([5, 2]).astype('float32')
x_lod = [[0, 2, 5]] x_lod = [[0, 2, 5]]
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32') y_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32')
y_lod = [[0, 1, 2], [0, 1, 2]] y_lod = [[0, 1, 3], [0, 1, 3]]
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)} self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册