未验证 提交 7e214b49 编写于 作者: Y Yu Yang 提交者: GitHub

Speed up ColwiseSum in CPU (#6834)

* Remove unnecessary reshape in ColwiseSum

Speed up 12s -> 10s.

* Hand write ColwiseAdd in CPU
上级 1a3d4b0d
...@@ -67,18 +67,45 @@ void RowwiseAdd<DeviceContext, T>::operator()(const DeviceContext& context, ...@@ -67,18 +67,45 @@ void RowwiseAdd<DeviceContext, T>::operator()(const DeviceContext& context,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context, void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor* vector) { framework::Tensor* out) {
auto in_dims = input.dims(); auto in_dims = input.dims();
auto size = input.numel() / in_dims[0]; auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), size); PADDLE_ENFORCE_EQ(out->numel(), size);
auto vec = framework::EigenMatrix<T>::From(*vector);
auto in = framework::EigenMatrix<T>::From(input); auto in = framework::EigenMatrix<T>::From(input);
Eigen::array<int, 2> shape({{1, static_cast<int>(size)}}); auto vec = framework::EigenVector<T>::Flatten(*out);
vec.reshape(shape).device(*context.eigen_device()) =
in.sum(Eigen::array<int, 1>({{0}})).reshape(shape); vec.device(*context.eigen_device()) = in.sum(Eigen::array<int, 1>({{0}}));
} }
// Specialize for CPU, since Eigen implement a general reduce. However,
// colwise-sum can be easily implemented. General reduce has a huge overhead in
// CPU
template <typename T>
class ColwiseSum<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
auto& in_dims = input.dims();
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < size; ++j) {
if (i == 0) {
out_buf[j] = in_buf[i * size + j];
} else {
out_buf[j] += in_buf[i * size + j];
}
}
}
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册