未验证 提交 1dc1f927 编写于 作者: G GaoWei8 提交者: GitHub

Fix lod error of concat op for axis = 0 (#22538)

上级 660ff184
......@@ -76,8 +76,8 @@ template <typename DeviceContext, typename T>
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null.");
auto axis = ctx.Attr<int>("axis");
bool need_resize_out_dims = false;
......@@ -102,6 +102,32 @@ class ConcatKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && ins[0]->lod().size() > 0) {
size_t lod_size_0 = ins[0]->lod().size();
size_t lod_size = lod_size_0;
for (size_t i = 1; i < ins.size(); ++i) {
if (ins[i]->lod().size() > 0) {
PADDLE_ENFORCE_EQ(
ins[i]->lod().size(), lod_size_0,
platform::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
" it is not supported currently."));
} else {
lod_size = 0;
break;
}
}
if (lod_size) {
auto* out_lod = out->mutable_lod();
for (size_t i = 1; i < ins.size(); ++i) {
auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod());
AppendLoD(out_lod, in_lod);
}
}
}
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && ins.size() < 10) {
size_t output_offset = 0;
......
......@@ -100,6 +100,41 @@ class TestConcatOp5(TestConcatOp):
self.axis = -3
class TestConcatOp6(TestConcatOp):
def setUp(self):
self.op_type = "concat"
self.dtype = self.get_dtype()
self.init_test_data()
self.lod = [[20, 80]]
self.out_lod = [[20, 80, 20, 80, 20, 80]]
self.inputs = {
'X': [('x0', (self.x0, self.lod)), ('x1', (self.x1, self.lod)),
('x2', (self.x2, self.lod))]
}
self.attrs = {'axis': self.axis}
if self.axis < 0:
self.actual_axis = self.axis + len(self.x0.shape)
self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0
else:
self.actual_axis = self.axis
out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis)
self.outputs = {'Out': (out, self.out_lod)}
def test_check_output(self):
self.check_output(check_dygraph=False)
def test_check_grad(self):
self.check_grad(['x0'], 'Out', check_dygraph=False)
self.check_grad(['x1'], 'Out', check_dygraph=False)
self.check_grad(['x2'], 'Out', check_dygraph=False)
def init_test_data(self):
self.x0 = np.random.random([100]).astype(self.dtype)
self.x1 = np.random.random([100]).astype(self.dtype)
self.x2 = np.random.random([100]).astype(self.dtype)
self.axis = 0
def create_test_AxisTensor(parent):
class TestConcatAxisTensor(parent):
def setUp(self):
......@@ -134,6 +169,7 @@ create_test_AxisTensor(TestConcatOp2)
create_test_AxisTensor(TestConcatOp3)
create_test_AxisTensor(TestConcatOp4)
create_test_AxisTensor(TestConcatOp5)
create_test_AxisTensor(TestConcatOp6)
#----------------Concat Fp16----------------
......@@ -155,6 +191,7 @@ create_test_fp16(TestConcatOp2)
create_test_fp16(TestConcatOp3)
create_test_fp16(TestConcatOp4)
create_test_fp16(TestConcatOp5)
create_test_fp16(TestConcatOp6)
class TestConcatOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册