提交 9dde5640 编写于 作者: L Liufang Sang 提交者: whs

change var name padding_num to padding_value (#19498)

上级 5b5379b3
......@@ -57,7 +57,7 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"merge repeated elements between two blanks. ")
.SetDefault(true);
// add attr padding number for tensor input
AddAttr<int>("padding_num",
AddAttr<int>("padding_value",
"(int, default: 0), padding number "
"use to padding tensor. ")
.SetDefault(0);
......@@ -89,7 +89,7 @@ or Given:
And:
blank = 0
merge_repeated = True
padding_num = 0
padding_value = 0
Then:
Output.data = [[1, 2, 4, 0, 0, 0],
......
......@@ -46,7 +46,7 @@ template <typename T>
__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
const T* tokens, const int blank,
const int merge_repeated,
const int padding_num,
const int padding_value,
const int64_t batch_size,
T* output) {
int ind = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -62,7 +62,7 @@ __global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
prev_token = tokens[i];
}
for (int i = output_idx; i < ind * num_token + num_token; i++) {
output[i] = padding_num;
output[i] = padding_value;
}
}
......@@ -82,13 +82,13 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
// tensor input which has no lod
if (input->lod().empty()) {
const int padding_num = ctx.Attr<int>("padding_num");
const int padding_value = ctx.Attr<int>("padding_value");
auto input_dims = input->dims();
T* output_data = output->mutable_data<T>({input_dims[0], input_dims[1]},
ctx.GetPlace());
PaddingMergeAndDelCudaKernel<
T><<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>(
input_dims[1], tokens, blank, merge_repeated, padding_num,
input_dims[1], tokens, blank, merge_repeated, padding_value,
input_dims[0], output_data);
} else {
const size_t level = 0;
......
......@@ -39,7 +39,8 @@ class CTCAlignKernel : public framework::OpKernel<T> {
// support tensor input, no lod information
if (input->lod().empty()) {
size_t padding_num = static_cast<size_t>(ctx.Attr<int>("padding_num"));
size_t padding_value =
static_cast<size_t>(ctx.Attr<int>("padding_value"));
for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0];
batch_id++) {
T prev_token = -1;
......@@ -55,7 +56,7 @@ class CTCAlignKernel : public framework::OpKernel<T> {
prev_token = input_data[input_ind];
}
for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++)
output_data[batch_id * input_dims[1] + j] = padding_num;
output_data[batch_id * input_dims[1] + j] = padding_value;
}
} else {
const size_t level = 0;
......
......@@ -109,7 +109,7 @@ class TestCTCAlignPaddingOp(OpTest):
self.op_type = "ctc_align"
self.input_lod = []
self.blank = 0
self.padding_num = 0
self.padding_value = 0
self.merge_repeated = True
self.input = np.array([[0, 2, 4, 4, 0, 6, 3, 6, 6, 0, 0],
[1, 1, 3, 0, 0, 4, 5, 6, 0, 0, 0]]).reshape(
......@@ -118,13 +118,13 @@ class TestCTCAlignPaddingOp(OpTest):
def setUp(self):
self.config()
output = CTCAlign(self.input, self.input_lod, self.blank,
self.merge_repeated, self.padding_num)
self.merge_repeated, self.padding_value)
self.inputs = {"Input": (self.input, self.input_lod), }
self.outputs = {"Output": output}
self.attrs = {
"blank": self.blank,
"merge_repeated": self.merge_repeated,
"padding_num": self.padding_num
"padding_value": self.padding_value
}
def test_check_output(self):
......@@ -138,7 +138,7 @@ class TestCTCAlignOpCase3(TestCTCAlignPaddingOp):
self.blank = 0
self.input_lod = []
self.merge_repeated = True
self.padding_num = 0
self.padding_value = 0
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
......@@ -146,7 +146,7 @@ class TestCTCAlignOpCase3(TestCTCAlignPaddingOp):
class TestCTCAlignOpCase4(TestCTCAlignPaddingOp):
'''
# test tensor input which has attr input padding_num
# test tensor input which has attr input padding_value
'''
def config(self):
......@@ -154,7 +154,7 @@ class TestCTCAlignOpCase4(TestCTCAlignPaddingOp):
self.blank = 0
self.input_lod = []
self.merge_repeated = False
self.padding_num = 0
self.padding_value = 0
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
......@@ -166,7 +166,7 @@ class TestCTCAlignOpCase5(TestCTCAlignPaddingOp):
self.blank = 0
self.input_lod = []
self.merge_repeated = False
self.padding_num = 1
self.padding_value = 1
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 1, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册