Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5ad3228c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ad3228c
编写于
8月 03, 2022
作者:
Z
zhiboniu
提交者:
GitHub
8月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Phi edit distance (#44447)
* phi_edit_distance * fix
上级
7eb37a7e
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
493 addition
and
395 deletion
+493
-395
paddle/fluid/operators/edit_distance_op.cc
paddle/fluid/operators/edit_distance_op.cc
+9
-70
paddle/fluid/operators/edit_distance_op.cu
paddle/fluid/operators/edit_distance_op.cu
+0
-196
paddle/fluid/operators/edit_distance_op.h
paddle/fluid/operators/edit_distance_op.h
+0
-126
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+10
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+69
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+8
-0
paddle/phi/kernels/cpu/edit_distance_kernel.cc
paddle/phi/kernels/cpu/edit_distance_kernel.cc
+124
-0
paddle/phi/kernels/edit_distance_kernel.h
paddle/phi/kernels/edit_distance_kernel.h
+31
-0
paddle/phi/kernels/gpu/edit_distance_kernel.cu
paddle/phi/kernels/gpu/edit_distance_kernel.cu
+186
-0
paddle/phi/ops/compat/edit_distance_sig.cc
paddle/phi/ops/compat/edit_distance_sig.cc
+29
-0
python/paddle/fluid/tests/unittests/test_edit_distance_op.py
python/paddle/fluid/tests/unittests/test_edit_distance_op.py
+23
-3
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+4
-0
未找到文件。
paddle/fluid/operators/edit_distance_op.cc
浏览文件 @
5ad3228c
...
...
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/edit_distance_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -21,72 +23,6 @@ class EditDistanceOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Hyps"
),
"Input"
,
"Hyps"
,
"EditDistance"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Refs"
),
"Input"
,
"Refs"
,
"EditDistance"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"EditDistance"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SequenceNum"
),
"Output"
,
"SequenceNum"
,
"EditDistance"
);
auto
hyp_dims
=
ctx
->
GetInputDim
(
"Hyps"
);
auto
ref_dims
=
ctx
->
GetInputDim
(
"Refs"
);
if
(
ctx
->
HasInput
(
"HypsLength"
)
&&
ctx
->
HasInput
(
"RefsLength"
))
{
auto
hyp_length_dims
=
ctx
->
GetInputDim
(
"HypsLength"
);
auto
ref_length_dims
=
ctx
->
GetInputDim
(
"RefsLength"
);
PADDLE_ENFORCE_EQ
(
hyp_dims
.
size
()
==
2
&&
ref_dims
.
size
()
==
2
&&
hyp_dims
[
0
]
==
ref_dims
[
0
],
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Hyps) and Input(Refs) must be 2-D Tensors with "
"identical first dimension. But received Input(Hyps): "
"input rank %u, input shape [%s]; received Input(Refs): "
"input rank %u, input shape [%s]"
,
hyp_dims
.
size
(),
hyp_dims
,
ref_dims
.
size
(),
ref_dims
));
PADDLE_ENFORCE_EQ
(
hyp_length_dims
[
0
]
==
ref_length_dims
[
0
]
&&
hyp_length_dims
[
0
]
==
hyp_dims
[
0
],
true
,
platform
::
errors
::
InvalidArgument
(
"Input(HypsLength), Input(RefsLength) and Input(Hyps) "
"should have identical first dimension. But received "
"Input(HypsLength): input rank %u, input shape [%s]; "
"received Input(RefsLength): input rank %u, input shape "
"[%s]; received Input(Hyps): input rank %u, input shape "
"[%s]."
,
hyp_length_dims
.
size
(),
hyp_length_dims
,
ref_length_dims
.
size
(),
ref_length_dims
,
hyp_dims
.
size
(),
hyp_dims
));
}
else
{
PADDLE_ENFORCE_EQ
(
hyp_dims
.
size
()
==
2
&&
hyp_dims
[
1
]
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1. But received: input rank %u, input shape [%s]."
,
hyp_dims
.
size
(),
hyp_dims
));
PADDLE_ENFORCE_EQ
(
ref_dims
.
size
()
==
2
&&
ref_dims
[
1
]
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1. But received: input rank %u, input shape [%s]."
,
ref_dims
.
size
(),
ref_dims
));
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"Refs"
));
ctx
->
SetOutputDim
(
"SequenceNum"
,
{
1
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
...
@@ -153,6 +89,10 @@ will be divided by the length of reference string.
}
// namespace operators
}
// namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR
(
edit_distance
,
EditDistanceShapeFunctor
,
PD_INFER_META
(
phi
::
EditDistanceInferMeta
));
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
...
...
@@ -160,6 +100,5 @@ REGISTER_OPERATOR(
ops
::
EditDistanceOp
,
ops
::
EditDistanceOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
edit_distance
,
ops
::
EditDistanceKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
EditDistanceShapeFunctor
);
paddle/fluid/operators/edit_distance_op.cu
已删除
100644 → 0
浏览文件 @
7eb37a7e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/edit_distance_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
__global__
void
FillFirstRow
(
T
*
dist
,
const
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
<
N
+
1
)
{
dist
[
idx
]
=
idx
;
}
}
template
<
typename
T
>
__global__
void
FillFirstColumn
(
T
*
dist
,
const
int
M
,
const
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
<
M
+
1
)
{
dist
[
idx
*
(
N
+
1
)]
=
idx
;
}
}
template
<
typename
T
>
__global__
void
Levenshtein
(
T
*
dist
,
const
int64_t
*
x1
,
const
int64_t
*
x2
,
const
int
M
,
const
int
N
,
const
int
start
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
N
;
int
index
=
start
+
idx
*
offset
;
int
row
=
index
/
(
N
+
1
);
int
col
=
index
%
(
N
+
1
);
if
(
row
>
0
&&
col
>
0
&&
row
<
M
+
1
&&
col
<
N
+
1
)
{
int
cost
=
x1
[
row
-
1
]
==
x2
[
col
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
row
-
1
)
*
(
N
+
1
)
+
col
]
+
1
;
int
ins
=
dist
[
row
*
(
N
+
1
)
+
col
-
1
]
+
1
;
int
subs
=
dist
[(
row
-
1
)
*
(
N
+
1
)
+
(
col
-
1
)]
+
cost
;
dist
[
index
]
=
min
(
dels
,
min
(
ins
,
subs
));
}
}
template
<
typename
T
>
__global__
void
SetOutput
(
T
*
out
,
const
T
*
dist
,
const
int
M
,
const
int
N
,
bool
normalized
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
==
0
)
{
out
[
0
]
=
normalized
?
dist
[
M
*
(
N
+
1
)
+
N
]
/
N
:
dist
[
M
*
(
N
+
1
)
+
N
];
}
}
template
<
typename
Place
,
typename
T
>
class
EditDistanceGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Hyps"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Refs"
);
auto
*
sequence_num
=
ctx
.
Output
<
framework
::
Tensor
>
(
"SequenceNum"
);
sequence_num
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
auto
batch_size
=
x1_t
->
dims
()[
0
];
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
.
device_context
()).
stream
();
framework
::
Vector
<
size_t
>
hyp_lod
(
batch_size
+
1
);
framework
::
Vector
<
size_t
>
ref_lod
(
batch_size
+
1
);
bool
use_length
=
ctx
.
HasInput
(
"HypsLength"
);
if
(
use_length
)
{
// build lod when using padding
auto
*
hyp_length
=
ctx
.
Input
<
framework
::
Tensor
>
(
"HypsLength"
);
auto
*
ref_length
=
ctx
.
Input
<
framework
::
Tensor
>
(
"RefsLength"
);
framework
::
Tensor
hyp_length_cpu
;
framework
::
Tensor
ref_length_cpu
;
framework
::
TensorCopy
(
*
hyp_length
,
platform
::
CPUPlace
(),
&
hyp_length_cpu
);
framework
::
TensorCopy
(
*
ref_length
,
platform
::
CPUPlace
(),
&
ref_length_cpu
);
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
hyp_lod
[
i
+
1
]
=
hyp_lod
[
i
]
+
hyp_length_cpu
.
data
<
int64_t
>
()[
i
];
ref_lod
[
i
+
1
]
=
ref_lod
[
i
]
+
ref_length_cpu
.
data
<
int64_t
>
()[
i
];
}
}
else
{
hyp_lod
=
x1_t
->
lod
()[
0
];
ref_lod
=
x2_t
->
lod
()[
0
];
}
if
(
normalized
)
{
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
ref_lod
[
i
],
ref_lod
[
i
-
1
],
platform
::
errors
::
InvalidArgument
(
"Reference string %d is empty."
,
i
));
}
}
const
size_t
num_strs
=
hyp_lod
.
size
()
-
1
;
phi
::
funcs
::
SetConstant
<
phi
::
GPUContext
,
int64_t
>
set_constant
;
set_constant
(
ctx
.
template
device_context
<
phi
::
GPUContext
>(),
sequence_num
,
static_cast
<
int64_t
>
(
num_strs
));
out_t
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
T
distance
=
0.0
;
for
(
size_t
num
=
0
;
num
<
num_strs
;
num
++
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
if
(
m
==
0
||
n
==
0
)
{
distance
=
std
::
max
(
m
,
n
);
if
(
normalized
)
{
distance
=
distance
/
n
;
}
memory
::
Copy
(
ctx
.
GetPlace
(),
out
+
num
,
platform
::
CPUPlace
(),
&
distance
,
sizeof
(
T
),
stream
);
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
hyp_offset
=
use_length
?
num
*
x1_t
->
dims
()[
1
]
:
hyp_lod
[
num
];
auto
ref_offset
=
use_length
?
num
*
x2_t
->
dims
()[
1
]
:
ref_lod
[
num
];
auto
x1
=
x1_t
->
data
<
int64_t
>
()
+
hyp_offset
;
auto
x2
=
x2_t
->
data
<
int64_t
>
()
+
ref_offset
;
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
// Compute the elements of distance matrix in the anti-diagonal diretion
for
(
int64_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// number of elments in the same
// anti-diagonal line to update
// the start index at which computes from
int
start
=
slice
<
n
+
1
?
slice
:
(
z_n
+
1
)
*
(
n
+
1
)
-
1
;
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
n
,
start
);
}
SetOutput
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
out
+
num
,
dist
,
m
,
n
,
normalized
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
edit_distance
,
ops
::
EditDistanceGPUKernel
<
paddle
::
platform
::
CUDAPlace
,
float
>
);
paddle/fluid/operators/edit_distance_op.h
已删除
100644 → 0
浏览文件 @
7eb37a7e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
class
EditDistanceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Hyps"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Refs"
);
auto
*
sequence_num
=
ctx
.
Output
<
framework
::
Tensor
>
(
"SequenceNum"
);
int64_t
*
seq_num_data
=
sequence_num
->
mutable_data
<
int64_t
>
(
ctx
.
GetPlace
());
auto
batch_size
=
x1_t
->
dims
()[
0
];
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
framework
::
Vector
<
size_t
>
hyp_lod
(
batch_size
+
1
);
framework
::
Vector
<
size_t
>
ref_lod
(
batch_size
+
1
);
bool
use_length
=
ctx
.
HasInput
(
"HypsLength"
);
if
(
use_length
)
{
// build lod when using padding
auto
hyp_length_ptr
=
ctx
.
Input
<
framework
::
Tensor
>
(
"HypsLength"
)
->
data
<
int64_t
>
();
auto
ref_length_ptr
=
ctx
.
Input
<
framework
::
Tensor
>
(
"RefsLength"
)
->
data
<
int64_t
>
();
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
hyp_lod
[
i
+
1
]
=
hyp_lod
[
i
]
+
hyp_length_ptr
[
i
];
ref_lod
[
i
+
1
]
=
ref_lod
[
i
]
+
ref_length_ptr
[
i
];
}
}
else
{
hyp_lod
=
x1_t
->
lod
()[
0
];
ref_lod
=
x2_t
->
lod
()[
0
];
}
if
(
normalized
)
{
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
ref_lod
[
i
],
ref_lod
[
i
-
1
],
platform
::
errors
::
InvalidArgument
(
"Reference string %d is empty."
,
i
));
}
}
auto
num_strs
=
hyp_lod
.
size
()
-
1
;
*
seq_num_data
=
static_cast
<
int64_t
>
(
num_strs
);
out_t
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
out_t
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
T
distance
=
0.0
;
for
(
size_t
num
=
0
;
num
<
num_strs
;
++
num
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
if
(
m
==
0
)
{
distance
=
n
;
}
else
if
(
n
==
0
)
{
distance
=
m
;
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
hyp_offset
=
use_length
?
num
*
x1_t
->
dims
()[
1
]
:
hyp_lod
[
num
];
auto
ref_offset
=
use_length
?
num
*
x2_t
->
dims
()[
1
]
:
ref_lod
[
num
];
auto
x1
=
x1_t
->
data
<
int64_t
>
()
+
hyp_offset
;
auto
x2
=
x2_t
->
data
<
int64_t
>
()
+
ref_offset
;
for
(
int64_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
dist
[
i
*
(
n
+
1
)]
=
i
;
}
for
(
int64_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
dist
[
j
]
=
j
;
}
for
(
int64_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
int64_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
int
subs
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
(
j
-
1
)]
+
cost
;
dist
[
i
*
(
n
+
1
)
+
j
]
=
std
::
min
(
dels
,
std
::
min
(
ins
,
subs
));
}
}
distance
=
dist
[
m
*
(
n
+
1
)
+
n
];
}
if
(
normalized
)
{
PADDLE_ENFORCE_GT
(
n
,
0UL
,
platform
::
errors
::
InvalidArgument
(
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled."
,
n
));
distance
=
distance
/
n
;
}
out
[
num
]
=
distance
;
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
5ad3228c
...
...
@@ -654,6 +654,16 @@
optional
:
seed_tensor
backward
:
dropout_grad
-
api
:
edit_distance
args
:
(Tensor hyps, Tensor refs, Tensor hypslength, Tensor refslength, bool normalized =
false
)
output
:
Tensor(sequencenum), Tensor(out)
infer_meta
:
func
:
EditDistanceInferMeta
kernel
:
func
:
edit_distance
data_type
:
DataType::FLOAT32
optional
:
hypslength, refslength
# eigh
-
api
:
eigh
args
:
(Tensor x, str uplo)
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
5ad3228c
...
...
@@ -1019,6 +1019,75 @@ void DeformableConvInferMeta(const MetaTensor& x,
out
->
set_dtype
(
x
.
dtype
());
}
void
EditDistanceInferMeta
(
const
MetaTensor
&
hyps
,
const
MetaTensor
&
refs
,
const
MetaTensor
&
hypslength
,
const
MetaTensor
&
refslength
,
bool
normalized
,
MetaTensor
*
sequencenum
,
MetaTensor
*
out
)
{
auto
hyp_dims
=
hyps
.
dims
();
auto
ref_dims
=
refs
.
dims
();
if
(
hypslength
&&
refslength
)
{
auto
hyp_length_dims
=
hypslength
.
dims
();
auto
ref_length_dims
=
refslength
.
dims
();
PADDLE_ENFORCE_EQ
(
hyp_dims
.
size
()
==
2
&&
ref_dims
.
size
()
==
2
&&
hyp_dims
[
0
]
==
ref_dims
[
0
],
true
,
errors
::
InvalidArgument
(
"Input(hyps) and Input(refs) must be 2-D Tensors with "
"identical first dimension. But received Input(Hyps): "
"input rank %u, input shape [%s]; received Input(Refs): "
"input rank %u, input shape [%s]"
,
hyp_dims
.
size
(),
hyp_dims
,
ref_dims
.
size
(),
ref_dims
));
PADDLE_ENFORCE_EQ
(
hyp_length_dims
[
0
]
==
ref_length_dims
[
0
]
&&
hyp_length_dims
[
0
]
==
hyp_dims
[
0
],
true
,
errors
::
InvalidArgument
(
"Input(hypslength), Input(refslength) and Input(hyps) "
"should have identical first dimension. But received "
"Input(hypslength): input rank %u, input shape [%s]; "
"received Input(refslength): input rank %u, input shape "
"[%s]; received Input(hyps): input rank %u, input shape "
"[%s]."
,
hyp_length_dims
.
size
(),
hyp_length_dims
,
ref_length_dims
.
size
(),
ref_length_dims
,
hyp_dims
.
size
(),
hyp_dims
));
}
else
{
PADDLE_ENFORCE_EQ
(
hyp_dims
.
size
()
==
2
&&
hyp_dims
[
1
]
==
1
,
true
,
errors
::
InvalidArgument
(
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1. But received: input rank %u, input shape [%s]."
,
hyp_dims
.
size
(),
hyp_dims
));
PADDLE_ENFORCE_EQ
(
ref_dims
.
size
()
==
2
&&
ref_dims
[
1
]
==
1
,
true
,
errors
::
InvalidArgument
(
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1. But received: input rank %u, input shape [%s]."
,
ref_dims
.
size
(),
ref_dims
));
}
out
->
set_dims
(
refs
.
dims
());
out
->
set_dtype
(
DataType
::
FLOAT32
);
sequencenum
->
set_dims
(
phi
::
make_ddim
({
1
}));
sequencenum
->
set_dtype
(
DataType
::
FLOAT32
);
}
void
HierarchicalSigmoidInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
w
,
const
MetaTensor
&
label
,
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
5ad3228c
...
...
@@ -212,6 +212,14 @@ void DeformableConvInferMeta(const MetaTensor& x,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
void
EditDistanceInferMeta
(
const
MetaTensor
&
hyps
,
const
MetaTensor
&
refs
,
const
MetaTensor
&
hypslength
,
const
MetaTensor
&
refslength
,
bool
normalized
,
MetaTensor
*
sequencenum
,
MetaTensor
*
out
);
void
HierarchicalSigmoidInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
w
,
const
MetaTensor
&
label
,
...
...
paddle/phi/kernels/cpu/edit_distance_kernel.cc
0 → 100644
浏览文件 @
5ad3228c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/edit_distance_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
EditDistanceKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
hyps
,
const
DenseTensor
&
refs
,
const
paddle
::
optional
<
DenseTensor
>&
hypslength
,
const
paddle
::
optional
<
DenseTensor
>&
refslength
,
bool
normalized
,
DenseTensor
*
sequencenum
,
DenseTensor
*
out
)
{
int64_t
*
seq_num_data
=
ctx
.
template
Alloc
<
int64_t
>(
sequencenum
);
auto
batch_size
=
hyps
.
dims
()[
0
];
paddle
::
framework
::
Vector
<
size_t
>
hyp_lod
(
batch_size
+
1
);
paddle
::
framework
::
Vector
<
size_t
>
ref_lod
(
batch_size
+
1
);
bool
use_length
=
hypslength
.
get_ptr
()
!=
nullptr
;
if
(
use_length
)
{
// build lod when using padding
auto
hyp_length_ptr
=
hypslength
.
get_ptr
()
->
data
<
int64_t
>
();
auto
ref_length_ptr
=
refslength
.
get_ptr
()
->
data
<
int64_t
>
();
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
hyp_lod
[
i
+
1
]
=
hyp_lod
[
i
]
+
hyp_length_ptr
[
i
];
ref_lod
[
i
+
1
]
=
ref_lod
[
i
]
+
ref_length_ptr
[
i
];
}
}
else
{
hyp_lod
=
hyps
.
lod
()[
0
];
ref_lod
=
refs
.
lod
()[
0
];
}
if
(
normalized
)
{
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
ref_lod
[
i
],
ref_lod
[
i
-
1
],
errors
::
InvalidArgument
(
"Reference string %d is empty."
,
i
));
}
}
auto
num_strs
=
hyp_lod
.
size
()
-
1
;
*
seq_num_data
=
static_cast
<
int64_t
>
(
num_strs
);
out
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
ctx
.
template
Alloc
<
T
>(
out
);
auto
outdata
=
out
->
data
<
T
>
();
T
distance
=
0.0
;
for
(
size_t
num
=
0
;
num
<
num_strs
;
++
num
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
if
(
m
==
0
)
{
distance
=
n
;
}
else
if
(
n
==
0
)
{
distance
=
m
;
}
else
{
DenseTensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
ctx
.
template
Alloc
<
T
>(
&
dist_t
);
auto
dist
=
dist_t
.
data
<
T
>
();
auto
hyp_offset
=
use_length
?
num
*
hyps
.
dims
()[
1
]
:
hyp_lod
[
num
];
auto
ref_offset
=
use_length
?
num
*
refs
.
dims
()[
1
]
:
ref_lod
[
num
];
auto
x1
=
hyps
.
data
<
int64_t
>
()
+
hyp_offset
;
auto
x2
=
refs
.
data
<
int64_t
>
()
+
ref_offset
;
for
(
int64_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
dist
[
i
*
(
n
+
1
)]
=
i
;
}
for
(
int64_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
dist
[
j
]
=
j
;
}
for
(
int64_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
int64_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
int
subs
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
(
j
-
1
)]
+
cost
;
dist
[
i
*
(
n
+
1
)
+
j
]
=
std
::
min
(
dels
,
std
::
min
(
ins
,
subs
));
}
}
distance
=
dist
[
m
*
(
n
+
1
)
+
n
];
}
if
(
normalized
)
{
PADDLE_ENFORCE_GT
(
n
,
0UL
,
errors
::
InvalidArgument
(
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled."
,
n
));
distance
=
distance
/
n
;
}
outdata
[
num
]
=
distance
;
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
edit_distance
,
CPU
,
ALL_LAYOUT
,
phi
::
EditDistanceKernel
,
float
)
{}
paddle/phi/kernels/edit_distance_kernel.h
0 → 100644
浏览文件 @
5ad3228c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
EditDistanceKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
hyps
,
const
DenseTensor
&
refs
,
const
paddle
::
optional
<
DenseTensor
>&
hypslength
,
const
paddle
::
optional
<
DenseTensor
>&
refslength
,
bool
normalized
,
DenseTensor
*
sequencenum
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/gpu/edit_distance_kernel.cu
0 → 100644
浏览文件 @
5ad3228c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/edit_distance_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
>
__global__
void
FillFirstRow
(
T
*
dist
,
const
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
<
N
+
1
)
{
dist
[
idx
]
=
idx
;
}
}
template
<
typename
T
>
__global__
void
FillFirstColumn
(
T
*
dist
,
const
int
M
,
const
int
N
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
<
M
+
1
)
{
dist
[
idx
*
(
N
+
1
)]
=
idx
;
}
}
template
<
typename
T
>
__global__
void
Levenshtein
(
T
*
dist
,
const
int64_t
*
x1
,
const
int64_t
*
x2
,
const
int
M
,
const
int
N
,
const
int
start
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
N
;
int
index
=
start
+
idx
*
offset
;
int
row
=
index
/
(
N
+
1
);
int
col
=
index
%
(
N
+
1
);
if
(
row
>
0
&&
col
>
0
&&
row
<
M
+
1
&&
col
<
N
+
1
)
{
int
cost
=
x1
[
row
-
1
]
==
x2
[
col
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
row
-
1
)
*
(
N
+
1
)
+
col
]
+
1
;
int
ins
=
dist
[
row
*
(
N
+
1
)
+
col
-
1
]
+
1
;
int
subs
=
dist
[(
row
-
1
)
*
(
N
+
1
)
+
(
col
-
1
)]
+
cost
;
dist
[
index
]
=
min
(
dels
,
min
(
ins
,
subs
));
}
}
template
<
typename
T
>
__global__
void
SetOutput
(
T
*
out
,
const
T
*
dist
,
const
int
M
,
const
int
N
,
bool
normalized
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
==
0
)
{
out
[
0
]
=
normalized
?
dist
[
M
*
(
N
+
1
)
+
N
]
/
N
:
dist
[
M
*
(
N
+
1
)
+
N
];
}
}
template
<
typename
T
,
typename
Context
>
void
EditDistanceKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
hyps
,
const
DenseTensor
&
refs
,
const
paddle
::
optional
<
DenseTensor
>&
hypslength
,
const
paddle
::
optional
<
DenseTensor
>&
refslength
,
bool
normalized
,
DenseTensor
*
sequencenum
,
DenseTensor
*
out
)
{
ctx
.
template
Alloc
<
int64_t
>(
sequencenum
);
auto
batch_size
=
hyps
.
dims
()[
0
];
auto
stream
=
reinterpret_cast
<
const
phi
::
GPUContext
&>
(
ctx
).
stream
();
paddle
::
framework
::
Vector
<
size_t
>
hyp_lod
(
batch_size
+
1
);
paddle
::
framework
::
Vector
<
size_t
>
ref_lod
(
batch_size
+
1
);
bool
use_length
=
hypslength
.
get_ptr
()
!=
nullptr
;
if
(
use_length
)
{
DenseTensor
hyp_length_cpu
;
DenseTensor
ref_length_cpu
;
phi
::
Copy
(
ctx
,
*
(
hypslength
.
get_ptr
()),
phi
::
CPUPlace
(),
false
,
&
hyp_length_cpu
);
phi
::
Copy
(
ctx
,
*
(
refslength
.
get_ptr
()),
phi
::
CPUPlace
(),
false
,
&
ref_length_cpu
);
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
hyp_lod
[
i
+
1
]
=
hyp_lod
[
i
]
+
hyp_length_cpu
.
data
<
int64_t
>
()[
i
];
ref_lod
[
i
+
1
]
=
ref_lod
[
i
]
+
ref_length_cpu
.
data
<
int64_t
>
()[
i
];
}
}
else
{
hyp_lod
=
hyps
.
lod
()[
0
];
ref_lod
=
refs
.
lod
()[
0
];
}
if
(
normalized
)
{
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE_GT
(
ref_lod
[
i
],
ref_lod
[
i
-
1
],
errors
::
InvalidArgument
(
"Reference string %d is empty."
,
i
));
}
}
const
size_t
num_strs
=
hyp_lod
.
size
()
-
1
;
phi
::
funcs
::
SetConstant
<
GPUContext
,
int64_t
>
set_constant
;
set_constant
(
ctx
,
sequencenum
,
static_cast
<
int64_t
>
(
num_strs
));
out
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
ctx
.
template
Alloc
<
T
>(
out
);
auto
out_data
=
out
->
data
<
T
>
();
T
distance
=
0.0
;
for
(
size_t
num
=
0
;
num
<
num_strs
;
num
++
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
if
(
m
==
0
||
n
==
0
)
{
distance
=
std
::
max
(
m
,
n
);
if
(
normalized
)
{
distance
=
distance
/
n
;
}
paddle
::
memory
::
Copy
(
ctx
.
GetPlace
(),
out_data
+
num
,
CPUPlace
(),
&
distance
,
sizeof
(
T
),
stream
);
}
else
{
DenseTensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
ctx
.
template
Alloc
<
T
>(
&
dist_t
);
auto
dist
=
dist_t
.
data
<
T
>
();
auto
hyp_offset
=
use_length
?
num
*
hyps
.
dims
()[
1
]
:
hyp_lod
[
num
];
auto
ref_offset
=
use_length
?
num
*
refs
.
dims
()[
1
]
:
ref_lod
[
num
];
auto
x1
=
hyps
.
data
<
int64_t
>
()
+
hyp_offset
;
auto
x2
=
refs
.
data
<
int64_t
>
()
+
ref_offset
;
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
// Compute the elements of distance matrix in the anti-diagonal diretion
for
(
int64_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// number of elments in the same
// anti-diagonal line to update
// the start index at which computes from
int
start
=
slice
<
n
+
1
?
slice
:
(
z_n
+
1
)
*
(
n
+
1
)
-
1
;
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
n
,
start
);
}
SetOutput
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
out_data
+
num
,
dist
,
m
,
n
,
normalized
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
edit_distance
,
GPU
,
ALL_LAYOUT
,
phi
::
EditDistanceKernel
,
float
)
{}
paddle/phi/ops/compat/edit_distance_sig.cc
0 → 100644
浏览文件 @
5ad3228c
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
EditDistanceOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"edit_distance"
,
{
"Hyps"
,
"Refs"
,
"HypsLength"
,
"RefsLength"
},
{
"normalized"
},
{
"SequenceNum"
,
"Out"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
edit_distance
,
phi
::
EditDistanceOpArgumentMapping
);
python/paddle/fluid/tests/unittests/test_edit_distance_op.py
浏览文件 @
5ad3228c
...
...
@@ -17,6 +17,22 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
def
python_edit_distance
(
input
,
label
,
input_length
=
None
,
label_length
=
None
,
normalized
=
True
,
ignored_tokens
=
None
):
return
paddle
.
nn
.
functional
.
loss
.
edit_distance
(
input
,
label
,
normalized
=
normalized
,
ignored_tokens
=
ignored_tokens
,
input_length
=
input_length
,
label_length
=
label_length
)
def
Levenshtein
(
hyp
,
ref
):
...
...
@@ -54,6 +70,7 @@ class TestEditDistanceOp(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"edit_distance"
self
.
python_api
=
python_edit_distance
normalized
=
False
x1
=
np
.
array
([[
12
,
3
,
5
,
8
,
2
]]).
astype
(
"int64"
)
x2
=
np
.
array
([[
12
,
4
,
7
,
8
]]).
astype
(
"int64"
)
...
...
@@ -83,7 +100,7 @@ class TestEditDistanceOp(OpTest):
self
.
outputs
=
{
'Out'
:
distance
,
'SequenceNum'
:
sequence_num
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
class
TestEditDistanceOpNormalizedCase0
(
OpTest
):
...
...
@@ -96,6 +113,7 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"edit_distance"
self
.
python_api
=
python_edit_distance
normalized
=
True
self
.
x1
=
np
.
array
([[
10
,
3
,
6
,
5
,
8
,
2
]]).
astype
(
"int64"
)
self
.
x2
=
np
.
array
([[
10
,
4
,
6
,
7
,
8
]]).
astype
(
"int64"
)
...
...
@@ -132,7 +150,7 @@ class TestEditDistanceOpNormalizedCase0(OpTest):
self
.
post_config
()
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
class
TestEditDistanceOpNormalizedCase1
(
TestEditDistanceOpNormalizedCase0
):
...
...
@@ -159,6 +177,7 @@ class TestEditDistanceOpNormalizedTensor(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"edit_distance"
self
.
python_api
=
python_edit_distance
normalized
=
True
self
.
reset_config
()
...
...
@@ -184,8 +203,9 @@ class TestEditDistanceOpNormalizedTensor(OpTest):
self
.
outputs
=
{
'Out'
:
distance
,
'SequenceNum'
:
sequence_num
}
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
check_eager
=
True
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/nn/functional/loss.py
浏览文件 @
5ad3228c
...
...
@@ -532,6 +532,10 @@ def edit_distance(input,
attrs
=
{
"tokens"
:
ignored_tokens
})
label
=
erased_label
if
in_dygraph_mode
():
return
_C_ops
.
final_state_edit_distance
(
input
,
label
,
input_length
,
label_length
,
normalized
)
this_inputs
=
{
"Hyps"
:
[
input
],
"Refs"
:
[
label
]}
if
input_length
is
not
None
and
label_length
is
not
None
:
this_inputs
[
'HypsLength'
]
=
[
input_length
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录