未验证 提交 fb2a9cdf 编写于 作者: C chengduo 提交者: GitHub

Add fp16 support for pad and split (#19881)

* make pad and split support fp16
test=develop
上级 c9ea317b
......@@ -58,7 +58,7 @@ void PadGradFunction(const framework::ExecutionContext& context,
auto src_tensor = EigenTensor<T, D>::From(src);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
d_out_tensor.device(place) = src_tensor.pad(paddings, 0);
d_out_tensor.device(place) = src_tensor.pad(paddings, static_cast<T>(0));
}
template <typename DeviceContext, typename T>
......
......@@ -14,7 +14,12 @@ limitations under the License. */
#include "paddle/fluid/operators/pad_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, float>);
pad, ops::PadKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>);
pad_grad, ops::PadGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PadGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
......@@ -30,14 +30,14 @@ class PadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
T pad_value = context.Attr<T>("pad_value");
float pad_value = context.Attr<float>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int rank = x->dims().size();
math::PaddingFunctor<DeviceContext, T>(rank, context, pads, pad_value, *x,
out);
math::PaddingFunctor<DeviceContext, T>(rank, context, pads,
static_cast<T>(pad_value), *x, out);
}
};
......
......@@ -119,10 +119,11 @@ Example:
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(
split, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SplitOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int>);
split, ops::SplitOpKernel<plat::CPUDeviceContext, double>,
ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
......@@ -14,8 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/split_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
split, ops::SplitOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<paddle::platform::CUDADeviceContext, int>);
split, ops::SplitOpKernel<plat::CUDADeviceContext, double>,
ops::SplitOpKernel<plat::CUDADeviceContext, float>,
ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<plat::CUDADeviceContext, int>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>);
......@@ -22,6 +22,7 @@ from op_test import OpTest
class TestConcatOp(OpTest):
def setUp(self):
self.op_type = "concat"
self.dtype = self.get_dtype()
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
......@@ -36,6 +37,9 @@ class TestConcatOp(OpTest):
(self.x0, self.x1, self.x2), axis=self.actual_axis)
}
def get_dtype(self):
return "float32"
def test_check_output(self):
self.check_output()
......@@ -45,25 +49,25 @@ class TestConcatOp(OpTest):
self.check_grad(['x2'], 'Out')
def init_test_data(self):
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x0 = np.random.random((2, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.axis = 1
class TestConcatOp2(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.axis = 1
class TestConcatOp3(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((1, 256, 170, 256)).astype('float32')
self.x1 = np.random.random((1, 128, 170, 256)).astype('float32')
self.x2 = np.random.random((1, 128, 170, 256)).astype('float32')
self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype)
self.x1 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.x2 = np.random.random((1, 128, 170, 256)).astype(self.dtype)
self.axis = 1
def test_check_grad(self):
......@@ -72,9 +76,9 @@ class TestConcatOp3(TestConcatOp):
class TestConcatOp4(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x2 = np.random.random((0, 3, 4, 5)).astype('float32')
self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((0, 3, 4, 5)).astype(self.dtype)
self.axis = 0
def test_check_grad(self):
......@@ -83,11 +87,30 @@ class TestConcatOp4(TestConcatOp):
class TestConcatOp5(TestConcatOp):
def init_test_data(self):
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x0 = np.random.random((2, 1, 4, 5)).astype(self.dtype)
self.x1 = np.random.random((2, 2, 4, 5)).astype(self.dtype)
self.x2 = np.random.random((2, 3, 4, 5)).astype(self.dtype)
self.axis = -3
#----------------Concat Fp16----------------
def create_test_fp16(parent):
class TestConcatFp16(parent):
def get_dtype(self):
return np.float16
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestConcatFp16.__name__ = cls_name
globals()[cls_name] = TestConcatFp16
create_test_fp16(TestConcatOp)
create_test_fp16(TestConcatOp2)
create_test_fp16(TestConcatOp3)
create_test_fp16(TestConcatOp4)
create_test_fp16(TestConcatOp5)
if __name__ == '__main__':
unittest.main()
......@@ -22,8 +22,9 @@ from op_test import OpTest
class TestPadOp(OpTest):
def setUp(self):
self.initTestCase()
self.dtype = self.get_dtype()
self.op_type = "pad"
self.inputs = {'X': np.random.random(self.shape).astype("float32"), }
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype), }
self.attrs = {}
self.attrs['paddings'] = np.array(self.paddings).flatten()
self.attrs['pad_value'] = self.pad_value
......@@ -34,6 +35,9 @@ class TestPadOp(OpTest):
constant_values=self.pad_value)
}
def get_dtype(self):
return np.float32
def test_check_output(self):
self.check_output()
......@@ -67,5 +71,26 @@ class TestCase3(TestPadOp):
self.pad_value = 0.9
#----------------Pad Fp16----------------
def create_test_fp16(parent):
class TestPadFp16(parent):
def get_dtype(self):
return np.float16
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.3)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestPadFp16.__name__ = cls_name
globals()[cls_name] = TestPadFp16
create_test_fp16(TestPadOp)
create_test_fp16(TestCase1)
create_test_fp16(TestCase2)
create_test_fp16(TestCase3)
if __name__ == '__main__':
unittest.main()
......@@ -22,14 +22,18 @@ from op_test import OpTest
class TestSplitOp(OpTest):
def setUp(self):
self._set_op_type()
self.dtype = self.get_dtype()
axis = 1
x = np.random.random((4, 5, 6)).astype('float32')
x = np.random.random((4, 5, 6)).astype(self.dtype)
out = np.split(x, [2, 3], axis)
self.inputs = {'X': x}
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in range(len(out))]}
def get_dtype(self):
return "float32"
def _set_op_type(self):
self.op_type = "split"
......@@ -45,5 +49,23 @@ class TestSplitByrefOp(OpTest):
self.op_type = "split_byref"
#----------------Split Fp16----------------
def create_test_fp16(parent):
class TestSplitFp16(parent):
def get_dtype(self):
return np.float16
def test_check_grad(self):
pass
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestSplitFp16.__name__ = cls_name
globals()[cls_name] = TestSplitFp16
create_test_fp16(TestSplitOp)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册