未验证 提交 6da637e8 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add op: cumsum, fill_any_like, unsqueeze (#41791)

上级 8f77f8bc
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CumSumMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
bool exclusive = ctx.Attr<bool>("exclusive");
bool reverse = ctx.Attr<bool>("reverse");
bool flatten = ctx.Attr<bool>("flatten");
out->mutable_data<T>(ctx.GetPlace());
Tensor* input_ptr = const_cast<Tensor*>(x);
Tensor flat_x(x->type());
if (flatten) {
PADDLE_ENFORCE_EQ(
axis, -1,
platform::errors::InvalidArgument(
"when flatten is true, attr axis must be default %d, but got %d",
-1, axis));
flat_x.ShareDataWith(*x);
flat_x.Resize(phi::make_ddim({x->numel()}));
input_ptr = &flat_x;
}
const int true_axis = (axis < 0) ? input_ptr->dims().size() + axis : axis;
MLUCnnlTensorDesc input_desc(*input_ptr);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Cumsum(ctx, true_axis, exclusive, reverse, input_desc.get(),
GetBasePtr(input_ptr), out_desc.get(), GetBasePtr(out));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(cumsum, ops::CumSumMLUKernel<int>,
ops::CumSumMLUKernel<float>,
ops::CumSumMLUKernel<plat::float16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class FillAnyLikeMLUKernel : public framework::OpKernel<T> {
public:
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, platform::float16>::value,
float, T>::type>::type;
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
float value = ctx.Attr<float>("value");
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
platform::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
typeid(T).name(),
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()), value));
PADDLE_ENFORCE_EQ(
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
auto value_t = static_cast<T>(value);
MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType<T>());
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value_t, out_desc.get(),
GetBasePtr(out));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(fill_any_like, ops::FillAnyLikeMLUKernel<int>,
ops::FillAnyLikeMLUKernel<int64_t>,
ops::FillAnyLikeMLUKernel<float>,
ops::FillAnyLikeMLUKernel<plat::float16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_MLU
#include <memory>
#include <string>
#include "paddle/fluid/operators/unsqueeze_op.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<plat::MLUDeviceContext, float>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, double>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, plat::float16>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, bool>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int8_t>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int64_t>);
REGISTER_OP_MLU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<plat::MLUDeviceContext, float>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, double>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, plat::float16>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, bool>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int8_t>,
ops::UnsqueezeKernel<plat::MLUDeviceContext, int64_t>);
REGISTER_OP_MLU_KERNEL(
unsqueeze_grad, ops::UnsqueezeGradKernel<plat::MLUDeviceContext, float>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, double>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, plat::float16>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, bool>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, int>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<plat::MLUDeviceContext, int64_t>);
REGISTER_OP_MLU_KERNEL(
unsqueeze2_grad, ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, float>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, double>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, plat::float16>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, int>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<plat::MLUDeviceContext, int64_t>);
#endif
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
class TestMLUCumSumOp(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.set_mlu()
self.init_dtype()
self.init_testcase()
def test_check_output(self):
self.check_output_with_place(self.place)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def init_testcase(self):
self.attrs = {'axis': 2}
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=2)}
class TestMLUCumSumOp2(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': -1, 'reverse': True}
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.outputs = {
'Out': np.flip(
np.flip(
self.inputs['X'], axis=2).cumsum(axis=2), axis=2)
}
class TestMLUCumSumOp3(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 1}
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
class TestMLUCumSumOp4(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 0}
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
class TestMLUCumSumOp5(TestMLUCumSumOp):
def init_testcase(self):
self.inputs = {'X': np.random.random((5, 20)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
class TestMLUCumSumOp7(TestMLUCumSumOp):
def init_testcase(self):
self.inputs = {'X': np.random.random((100)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
class TestNPUCumSumExclusive1(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 65)).astype(self.dtype)
self.inputs = {'X': a}
self.outputs = {
'Out': np.concatenate(
(np.zeros(
(4, 5, 1), dtype=self.dtype), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
class TestNPUCumSumExclusive2(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 888)).astype(self.dtype)
self.inputs = {'X': a}
self.outputs = {
'Out': np.concatenate(
(np.zeros(
(1, 1, 1), dtype=self.dtype), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
class TestNPUCumSumExclusive3(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 888)).astype(self.dtype)
self.inputs = {'X': a}
self.outputs = {
'Out': np.concatenate(
(np.zeros(
(4, 5, 1), dtype=self.dtype), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
class TestNPUCumSumExclusive4(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 3049)).astype(self.dtype)
self.inputs = {'X': a}
self.outputs = {
'Out': np.concatenate(
(np.zeros(
(1, 1, 1), dtype=self.dtype), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
class TestNPUCumSumExclusive5(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 3096)).astype(self.dtype)
self.inputs = {'X': a}
self.outputs = {
'Out': np.concatenate(
(np.zeros(
(4, 5, 1), dtype=self.dtype), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
class TestNPUCumSumReverseExclusive(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'axis': 2, 'reverse': True, "exclusive": True}
a = np.random.random((4, 5, 6)).astype(self.dtype)
self.inputs = {'X': a}
a = np.flip(a, axis=2)
self.outputs = {
'Out': np.concatenate(
(np.flip(
a[:, :, :-1].cumsum(axis=2), axis=2), np.zeros(
(4, 5, 1), dtype=self.dtype)),
axis=2)
}
class TestNPUCumSumWithFlatten1(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'flatten': True}
self.inputs = {'X': np.random.random((5, 6)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum()}
class TestNPUCumSumWithFlatten2(TestMLUCumSumOp):
def init_testcase(self):
self.attrs = {'flatten': True}
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.outputs = {'Out': self.inputs['X'].cumsum()}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import paddle
import unittest
import numpy as np
from op_test import OpTest
paddle.enable_static()
class TestFillAnyLikeOp(OpTest):
def setUp(self):
self.init_dtype()
self.set_mlu()
self.op_type = "fill_any_like"
self.set_value()
self.set_input()
self.attrs = {'value': self.value}
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
def init_dtype(self):
self.dtype = np.float32
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.__class__.no_need_check_grad = True
def set_input(self):
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
def set_value(self):
self.value = 0.0
def test_check_output(self):
self.check_output_with_place(self.place)
class TestFillAnyLikeOp2(TestFillAnyLikeOp):
def set_value(self):
self.value = -0.0
class TestFillAnyLikeOp3(TestFillAnyLikeOp):
def set_value(self):
self.value = 1.0
class TestFillAnyLikeOp4(TestFillAnyLikeOp):
def set_value(self):
self.value = 1e-9
class TestFillAnyLikeOp5(TestFillAnyLikeOp):
def set_value(self):
if self.dtype == "float16":
self.value = 0.05
else:
self.value = 5.0
class TestFillAnyLikeOpInt32(TestFillAnyLikeOp):
def init_dtype(self):
self.dtype = np.int32
def set_value(self):
self.value = -1
class TestFillAnyLikeOpInt64(TestFillAnyLikeOp):
def init_dtype(self):
self.dtype = np.int64
def set_value(self):
self.value = -1
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
def init_dtype(self):
self.dtype = np.float32
def set_value(self):
self.value = 0.09
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
from op_test import OpTest
paddle.enable_static()
# Correct: General.
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.set_mlu()
self.op_type = "unsqueeze2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(OpTest):
def setUp(self):
self.init_test_case()
self.set_mlu()
self.op_type = "unsqueeze2"
axes_tensor_list = []
for index, ele in enumerate(self.axes):
axes_tensor_list.append(("axes" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
"X": np.random.random(self.ori_shape).astype("float32"),
"AxesTensorList": axes_tensor_list
}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (1, 2)
self.new_shape = (20, 1, 1, 5)
def init_attrs(self):
self.attrs = {}
class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a Tensor
class TestUnsqueezeOp_AxesTensor(OpTest):
def setUp(self):
self.init_test_case()
self.set_mlu()
self.op_type = "unsqueeze2"
self.inputs = {
"X": np.random.random(self.ori_shape).astype("float32"),
"AxesTensor": np.array(self.axes).astype("int32")
}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (1, 2)
self.new_shape = (20, 1, 1, 5)
def init_attrs(self):
self.attrs = {}
class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
from op_test import OpTest
paddle.enable_static()
# Correct: General.
class TestUnsqueezeOp(OpTest):
def setUp(self):
self.init_test_case()
self.set_mlu()
self.op_type = "unsqueeze"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册