未验证 提交 59c7aea5 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: squeeze & unsqueeze & stack (#39402)

* add squeeze unsqueeze stack

* add unittest

* add cpu kernel
上级 e8ac7fc3
...@@ -393,7 +393,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -393,7 +393,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze_grad, squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -406,7 +408,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -406,7 +408,9 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, float>,
...@@ -419,7 +423,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -419,7 +423,9 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
squeeze2_grad, squeeze2_grad,
...@@ -433,4 +439,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -433,4 +439,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>, squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, bool>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
...@@ -35,6 +36,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -35,6 +36,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
...@@ -48,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -48,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>, squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
...@@ -62,6 +65,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -62,6 +65,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
......
...@@ -173,13 +173,16 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, ...@@ -173,13 +173,16 @@ REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker,
ops::StackGradOpMaker<paddle::imperative::OpBase>); ops::StackGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); REGISTER_OPERATOR(stack_grad, ops::StackOpGrad);
REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel<plat::CPUDeviceContext, float>, REGISTER_OP_CPU_KERNEL(
ops::StackKernel<plat::CPUDeviceContext, double>, stack, ops::StackKernel<plat::CPUDeviceContext, float>,
ops::StackKernel<plat::CPUDeviceContext, int>, ops::StackKernel<plat::CPUDeviceContext, double>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>); ops::StackKernel<plat::CPUDeviceContext, int>,
ops::StackKernel<plat::CPUDeviceContext, int64_t>,
REGISTER_OP_CPU_KERNEL(stack_grad, ops::StackKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);
ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, double>, REGISTER_OP_CPU_KERNEL(
ops::StackGradKernel<plat::CPUDeviceContext, int>, stack_grad, ops::StackGradKernel<plat::CPUDeviceContext, float>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>); ops::StackGradKernel<plat::CPUDeviceContext, double>,
ops::StackGradKernel<plat::CPUDeviceContext, int>,
ops::StackGradKernel<plat::CPUDeviceContext, int64_t>,
ops::StackGradKernel<plat::CPUDeviceContext, paddle::platform::bfloat16>);
...@@ -196,10 +196,12 @@ class StackGradGPUKernel : public framework::OpKernel<T> { ...@@ -196,10 +196,12 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL(stack, ops::StackGPUKernel<float>, REGISTER_OP_CUDA_KERNEL(stack, ops::StackGPUKernel<float>,
ops::StackGPUKernel<double>, ops::StackGPUKernel<int>, ops::StackGPUKernel<double>, ops::StackGPUKernel<int>,
ops::StackGPUKernel<int64_t>, ops::StackGPUKernel<int64_t>,
ops::StackGPUKernel<plat::float16>); ops::StackGPUKernel<plat::float16>,
ops::StackGPUKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(stack_grad, ops::StackGradGPUKernel<float>, REGISTER_OP_CUDA_KERNEL(stack_grad, ops::StackGradGPUKernel<float>,
ops::StackGradGPUKernel<double>, ops::StackGradGPUKernel<double>,
ops::StackGradGPUKernel<int>, ops::StackGradGPUKernel<int>,
ops::StackGradGPUKernel<int64_t>, ops::StackGradGPUKernel<int64_t>,
ops::StackGradGPUKernel<plat::float16>); ops::StackGradGPUKernel<plat::float16>,
ops::StackGradGPUKernel<plat::bfloat16>);
...@@ -366,7 +366,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -366,7 +366,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -379,7 +381,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -379,7 +381,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
...@@ -391,7 +395,9 @@ REGISTER_OP_CPU_KERNEL( ...@@ -391,7 +395,9 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
...@@ -404,4 +410,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -404,4 +410,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>, paddle::platform::complex<float>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>); paddle::platform::complex<double>>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
...@@ -36,6 +37,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -36,6 +37,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
...@@ -50,6 +53,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -50,6 +53,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
...@@ -65,6 +69,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -65,6 +69,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
...@@ -49,6 +50,32 @@ class TestSqueezeOp(OpTest): ...@@ -49,6 +50,32 @@ class TestSqueezeOp(OpTest):
self.attrs = {"axes": self.axes} self.attrs = {"axes": self.axes}
class TestSqueezeBF16Op(OpTest):
def setUp(self):
self.op_type = "squeeze"
self.dtype = np.uint16
self.init_test_case()
x = np.random.random(self.ori_shape).astype("float32")
out = x.reshape(self.new_shape)
self.inputs = {"X": convert_float_to_uint16(x)}
self.init_attrs()
self.outputs = {"Out": convert_float_to_uint16(out)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: There is mins axis. # Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp): class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self): def init_test_case(self):
......
...@@ -16,7 +16,8 @@ import numpy as np ...@@ -16,7 +16,8 @@ import numpy as np
import unittest import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
class TestStackOpBase(OpTest): class TestStackOpBase(OpTest):
...@@ -90,6 +91,49 @@ class TestStackOp6(TestStackOpBase): ...@@ -90,6 +91,49 @@ class TestStackOp6(TestStackOpBase):
self.axis = 3 self.axis = 3
class TestStackBF16Op(OpTest):
def initDefaultParameters(self):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = np.uint16
def initParameters(self):
pass
def get_x_names(self):
x_names = []
for i in range(self.num_inputs):
x_names.append('x{}'.format(i))
return x_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.x = []
for i in range(self.num_inputs):
self.x.append(
np.random.random(size=self.input_dim).astype(np.float32))
out = np.stack(self.x, axis=self.axis)
tmp = []
x_names = self.get_x_names()
for i in range(self.num_inputs):
tmp.append((x_names[i], convert_float_to_uint16(self.x[i])))
self.inputs = {'X': tmp}
self.outputs = {'Y': convert_float_to_uint16(out)}
self.attrs = {'axis': self.axis}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(self.get_x_names(), 'Y')
class TestStackAPIWithLoDTensorArray(unittest.TestCase): class TestStackAPIWithLoDTensorArray(unittest.TestCase):
""" """
Test stack api when the input(x) is a LoDTensorArray. Test stack api when the input(x) is a LoDTensorArray.
......
...@@ -19,7 +19,8 @@ import numpy as np ...@@ -19,7 +19,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
...@@ -48,6 +49,32 @@ class TestUnsqueezeOp(OpTest): ...@@ -48,6 +49,32 @@ class TestUnsqueezeOp(OpTest):
self.attrs = {"axes": self.axes} self.attrs = {"axes": self.axes}
class TestUnsqueezeBF16Op(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.dtype = np.uint16
x = np.random.random(self.ori_shape).astype("float32")
out = x.reshape(self.new_shape)
self.inputs = {"X": convert_float_to_uint16(x)}
self.init_attrs()
self.outputs = {"Out": convert_float_to_uint16(out)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index. # Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp): class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册