未验证 提交 7bf2aa38 编写于 作者: T TTerror 提交者: GitHub

add fill_any_like/flatten ops to train ssd on kunlun (#36550)

* add some ops to train ssd on kunlun

* update test_fill_any_like_op_xpu.py
上级 b6e7f8e9
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/fill_any_like_op.h"
namespace paddle {
namespace operators {
template <typename T>
class FillAnyLikeXPUKernel : public framework::OpKernel<T> {
public:
using CommonType = typename std::common_type<
float,
typename std::conditional<std::is_same<T, platform::float16>::value,
float, T>::type>::type;
using XPUInTDType = typename XPUTypeTrait<T>::Type;
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
float value = context.Attr<float>("value");
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
platform::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
typeid(T).name(),
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()), value));
PADDLE_ENFORCE_EQ(
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
int ret = xpu::constant(dev_ctx.x_context(), out_data, out->numel(),
static_cast<XPUInTDType>(value));
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU CONSTANT API return wrong value[%d %s].", ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(fill_any_like, ops::FillAnyLikeXPUKernel<int>,
ops::FillAnyLikeXPUKernel<int64_t>,
ops::FillAnyLikeXPUKernel<float>,
ops::FillAnyLikeXPUKernel<paddle::platform::float16>);
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/flatten_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::XPUDeviceContext, float>,
ops::FlattenKernel<paddle::platform::XPUDeviceContext, int>,
ops::FlattenKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::FlattenKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_grad,
ops::FlattenGradKernel<paddle::platform::XPUDeviceContext, float>,
ops::FlattenGradKernel<paddle::platform::XPUDeviceContext, int>,
ops::FlattenGradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::FlattenGradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten2, ops::Flatten2Kernel<paddle::platform::XPUDeviceContext, float>,
ops::Flatten2Kernel<paddle::platform::XPUDeviceContext, int>,
ops::Flatten2Kernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Flatten2Kernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten2_grad,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, float>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::XPUDeviceContext,
int64_t>);
REGISTER_OP_XPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::XPUDeviceContext,
int64_t>);
#endif
......@@ -119,6 +119,42 @@ XPUOpMap& get_kl2_ops() {
{"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"fill_any_like",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten2", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten2_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_contiguous_range",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"flatten_contiguous_range_grad",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore
};
......
......@@ -91,11 +91,31 @@ class XPUOpTest(OpTest):
# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if not hasattr(cls, "no_need_check_grad") \
and not is_empty_grad_op(cls.op_type):
if cls.dtype is not None and \
cls.dtype != np.float32:
if cls.dtype is None or \
(cls.dtype == np.float16 \
and cls.op_type not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST \
and not hasattr(cls, "exist_check_grad")):
raise AssertionError("This test of %s op needs check_grad." %
cls.op_type)
# check for op test with fp64 precision, but not check mkldnn op test for now
if cls.dtype in [np.float32, np.float64] \
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
and not hasattr(cls, 'exist_fp64_check_grad') \
and not is_xpu_op_test() \
and not is_mkldnn_op_test() \
and not is_rocm_op_test() \
and not is_npu_op_test():
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." %
cls.op_type)
if not cls.input_shape_is_large \
and cls.op_type not in check_shape_white_list.NEED_TO_FIX_OP_LIST:
raise AssertionError(
"Input's shape should be large than or equal to 100 for " +
cls.op_type + " Op.")
def try_call_once(self, data_type):
if not self.call_once:
self.call_once = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
import paddle.compat as cpt
import unittest
import numpy as np
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
class TestFillAnyLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_any_like"
self.dtype = np.float32
self.use_xpu = True
self.use_mkldnn = False
self.value = 0.0
self.init()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.attrs = {'value': self.value, 'use_xpu': True}
self.outputs = {'Out': self.value * np.ones_like(self.inputs["X"])}
def init(self):
pass
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float32
self.value = 0.0
class TestFillAnyLikeOpValue1(TestFillAnyLikeOp):
def init(self):
self.value = 1.0
class TestFillAnyLikeOpValue2(TestFillAnyLikeOp):
def init(self):
self.value = 1e-9
class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp):
def init(self):
self.dtype = np.float16
self.value = 0.05
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
class TestFlatten2Op(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "flatten2"
self.place = paddle.XPUPlace(0)
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 4, 5)
self.axis = 1
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axis": self.axis}
class TestFlatten2OpWithCornerAxis(TestFlatten2Op):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.axis = 0
self.new_shape = (1, 120)
class TestFlatten2OpWithDefaultAxis(TestFlatten2Op):
def init_test_case(self):
self.in_shape = (10, 2, 2, 3)
self.new_shape = (10, 12)
def init_attrs(self):
self.attrs = {}
class TestFlatten2OpSixDims(TestFlatten2Op):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.axis = 4
self.new_shape = (36, 16)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
class TestFlattenOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "flatten_contiguous_range"
self.place = paddle.XPUPlace(0)
self.use_xpu = True
self.use_mkldnn = False
self.start_axis = 0
self.stop_axis = -1
self.dtype = np.float32
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype(self.dtype)}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.in_shape).astype("float32")
}
def set_xpu(self):
self.__class__.use_xpu = True
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = -1
self.new_shape = (120)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
'use_xpu': True,
}
class TestFlattenOp_1(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 1
self.stop_axis = 2
self.new_shape = (3, 10, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_2(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_3(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 2
self.new_shape = (30, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_4(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = -2
self.stop_axis = -1
self.new_shape = (3, 2, 20)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_5(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 2
self.stop_axis = 2
self.new_shape = (3, 2, 5, 4)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.start_axis = 3
self.stop_axis = 5
self.new_shape = (3, 2, 3, 32)
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_Float32(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.float32
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
class TestFlattenOp_int32(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int32
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis,
'use_xpu': True
}
def test_check_grad(self):
pass
class TestFlattenOp_int8(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int8
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
def test_check_grad(self):
pass
class TestFlattenOp_int64(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
self.start_axis = 0
self.stop_axis = 1
self.new_shape = (6, 5, 4)
self.dtype = np.int64
def init_attrs(self):
self.attrs = {
"start_axis": self.start_axis,
"stop_axis": self.stop_axis
}
def test_check_grad(self):
pass
class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
image_shape = (2, 3, 4, 4)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x = x.astype('float32')
def test_ValueError1():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
out = paddle.flatten(x_var, start_axis=2, stop_axis=1)
self.assertRaises(ValueError, test_ValueError1)
def test_ValueError2():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=10, stop_axis=1)
self.assertRaises(ValueError, test_ValueError2)
def test_ValueError3():
x_var = paddle.static.data(
name="x", shape=image_shape, dtype='float32')
paddle.flatten(x_var, start_axis=2, stop_axis=10)
self.assertRaises(ValueError, test_ValueError3)
def test_type():
# dtype must be float32, float64, int8, int32, int64
x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x2 = x2.astype('float16')
x2_var = paddle.fluid.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16')
paddle.flatten(x2_var)
self.assertRaises(TypeError, test_type)
def test_InputError():
out = paddle.flatten(x)
self.assertRaises(ValueError, test_InputError)
class TestStaticFlattenPythonAPI(unittest.TestCase):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return paddle.flatten(x, start_axis, stop_axis)
def test_static_api(self):
paddle.enable_static()
np_x = np.random.rand(2, 3, 4, 4).astype('float32')
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[2, 3, 4, 4], dtype='float32')
out = self.execute_api(x, start_axis=-2, stop_axis=-1)
exe = paddle.static.Executor(place=paddle.XPUPlace(0))
fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out])
self.assertTrue((2, 3, 16) == fetch_out[0].shape)
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return x.flatten_(start_axis, stop_axis)
class TestFlattenPython(unittest.TestCase):
def test_python_api(self):
image_shape = (2, 3, 4, 4)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x = x.astype('float32')
def test_InputError():
out = paddle.flatten(x)
self.assertRaises(ValueError, test_InputError)
def test_Negative():
paddle.disable_static(paddle.XPUPlace(0))
img = paddle.to_tensor(x)
out = paddle.flatten(img, start_axis=-2, stop_axis=-1)
return out.numpy().shape
res_shape = test_Negative()
self.assertTrue((2, 3, 16) == res_shape)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
class TestFlattenOp(XPUOpTest):
def setUp(self):
self.op_type = "flatten"
self.use_xpu = True
self.place = paddle.XPUPlace(0)
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 1
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axis": self.axis}
class TestFlattenOp1(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 0
self.new_shape = (1, 120)
class TestFlattenOpWithDefaultAxis(TestFlattenOp):
def init_test_case(self):
self.in_shape = (10, 2, 2, 3)
self.new_shape = (10, 12)
def init_attrs(self):
self.attrs = {}
class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.axis = 4
self.new_shape = (36, 16)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册