未验证 提交 1945b729 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix CPPLint issues with math/sequence_padding (#10317)

* Fix cpplint issues in sequence_padding

* Fix typo in cu file

* Fix dependencies of sequence_padding

* Add include
上级 9bcd9f66
......@@ -22,7 +22,7 @@ template <typename T>
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
const framework::LoDTensor& seq, framework::Tensor* padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
......@@ -37,7 +37,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
auto padding_dims = padding->dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
......@@ -58,7 +58,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"width of sequence in LoDTensor seq.");
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
T* padding_data = padding->data<T>();
for (int64_t i = 0; i < max_sequence_length; ++i) {
for (int64_t j = 0; j < num_sequences; ++j) {
int64_t start_pos = abs_offset_lod[level][j];
......@@ -84,16 +84,16 @@ template <typename T>
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
framework::LoDTensor* seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
auto lod = seq->lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
auto seq_dims = seq->dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
......@@ -114,13 +114,13 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq->numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
T* seq_data = seq->data<T>();
for (int64_t i = 0; i < num_sequences; ++i) {
int64_t start_pos = abs_offset_lod[level][i];
int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace paddle {
......@@ -61,7 +62,7 @@ template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
const framework::LoDTensor& seq, framework::Tensor* padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
......@@ -76,7 +77,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
auto padding_dims = padding->dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
......@@ -97,8 +98,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
TensorCopy(seq, context.GetPlace(), context, &padding);
padding.Resize(padding_dims);
TensorCopy(seq, context.GetPlace(), context, padding);
padding->Resize(padding_dims);
return;
}
......@@ -117,7 +118,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
dim3 grid(grid_dim_x, grid_dim_y);
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
T* padding_data = padding->data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data),
......@@ -136,16 +137,16 @@ template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
framework::LoDTensor* seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
auto lod = seq->lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
auto seq_dims = seq->dims();
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
......@@ -166,14 +167,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const int64_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq->numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
TensorCopy(padding, context.GetPlace(), context, &seq);
seq.Resize(seq_dims);
TensorCopy(padding, context.GetPlace(), context, seq);
seq->Resize(seq_dims);
return;
}
......@@ -192,7 +193,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
dim3 grid(grid_dim_x, grid_dim_y);
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
T* seq_data = seq->data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data,
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -64,13 +65,13 @@ template <typename DeviceContext, typename T>
class PaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
framework::Tensor& padding, bool norm_by_times);
framework::Tensor* padding, bool norm_by_times);
};
template <typename DeviceContext, typename T>
class UnpaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
void operator()(const DeviceContext& context, framework::LoDTensor* seq,
const framework::Tensor& padding, bool norm_by_times);
};
......
......@@ -54,12 +54,12 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
static_cast<int64_t>(sequence_width)});
padding.mutable_data<T>(padding_dims, *place);
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq, padding, false);
*context, seq, &padding, false);
seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, *place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq_back, padding, false);
*context, &seq_back, padding, false);
if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back;
......
......@@ -162,7 +162,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits, warpctc_logits,
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
false);
const T* warpctc_logits_data = warpctc_logits.data<T>();
......@@ -217,7 +217,7 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
ctx.template device_context<DeviceContext>(), logits_grad,
*warpctc_grad, norm_by_times);
const T* loss_grad_data = loss_grad->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册