提交 32d881be 编写于 作者: Q qingqing01

Optimize the rowwise add function.

上级 c3fd2c28
...@@ -302,8 +302,40 @@ void set_constant(const platform::DeviceContext& context, ...@@ -302,8 +302,40 @@ void set_constant(const platform::DeviceContext& context,
#endif #endif
} }
template <typename T>
struct RowwiseAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& vector, framework::Tensor* output) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
// auto in = framework::EigenMatrix<T>::From(input);
// auto vec = framework::EigenVector<T>::Flatten(vector);
// auto out = framework::EigenMatrix<T>::From(*output);
// for (int64_t i = 0; i < in_dims[0]; ++i) {
// out.chip(i, 0) = in.chip(i, 0) + vec;
// }
auto* in = input.data<T>();
auto* vec = vector.data<T>();
auto* out = output->data<T>();
int64_t h = in_dims[0];
int64_t w = in_dims[1];
for (int64_t i = 0; i < h; ++i) {
for (int64_t j = 0; j < w; ++j) {
out[i * w + j] = in[i * w + j] + vec[j];
}
}
}
};
template struct RowwiseAdd<platform::CPUDeviceContext, float>; template struct RowwiseAdd<platform::CPUDeviceContext, float>;
template struct RowwiseAdd<platform::CPUDeviceContext, double>; template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>; template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>; template struct ColwiseSum<platform::CPUDeviceContext, double>;
......
...@@ -273,6 +273,33 @@ void set_constant_with_place<platform::CUDAPlace>( ...@@ -273,6 +273,33 @@ void set_constant_with_place<platform::CUDAPlace>(
TensorSetConstantGPU(context, tensor, value)); TensorSetConstantGPU(context, tensor, value));
} }
template <typename T>
__global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int64_t height,
int64_t width) {
int64_t num = height * width;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
int h = i / width;
int w = i % width;
int idx = h * width + w;
c[idx] = a[idx] + b[w];
}
}
template <typename T>
struct RowwiseAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& vector, framework::Tensor* output) {
auto in_dims = input.dims();
int blocks = 512;
int grids = (input.numel() + blocks - 1) / blocks;
RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
input.data<T>(), vector.data<T>(), output->data<T>(), in_dims[0],
in_dims[1]);
}
};
template struct RowwiseAdd<platform::CUDADeviceContext, float>; template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>; template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>; template struct ColwiseSum<platform::CUDADeviceContext, float>;
......
...@@ -45,25 +45,6 @@ void Transpose<DeviceContext, T, Rank>::operator()( ...@@ -45,25 +45,6 @@ void Transpose<DeviceContext, T, Rank>::operator()(
eigen_out.device(*dev) = eigen_in.shuffle(permute); eigen_out.device(*dev) = eigen_in.shuffle(permute);
} }
template <typename DeviceContext, typename T>
void RowwiseAdd<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& vector,
framework::Tensor* output) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector.numel(), size);
PADDLE_ENFORCE_EQ(output->dims(), in_dims);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenMatrix<T>::From(vector);
auto out = framework::EigenMatrix<T>::From(*output);
Eigen::array<int, 2> shape({{1, static_cast<int>(size)}});
Eigen::array<int, 2> bcast({{static_cast<int>(in_dims[0]), 1}});
out.device(*context.eigen_device()) =
in + vec.reshape(shape).broadcast(bcast);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context, void ColwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册