未验证 提交 3f5705c3 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #9148 from kexinzhao/cast_op_fp16

Add float16 support for cast op
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -88,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, ...@@ -88,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>, ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>); ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, paddle::platform::float16>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
template <typename T> template <typename T>
using CastOpKernel = using CastOpKernel =
...@@ -20,4 +21,5 @@ using CastOpKernel = ...@@ -20,4 +21,5 @@ using CastOpKernel =
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>, REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>, CastOpKernel<int>, CastOpKernel<int64_t>,
CastOpKernel<bool>); CastOpKernel<bool>,
CastOpKernel<paddle::platform::float16>);
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
class TestCastOp(op_test.OpTest): class TestCastOp1(op_test.OpTest):
def setUp(self): def setUp(self):
ipt = np.random.random(size=[10, 10]) ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')} self.inputs = {'X': ipt.astype('float32')}
...@@ -36,5 +36,36 @@ class TestCastOp(op_test.OpTest): ...@@ -36,5 +36,36 @@ class TestCastOp(op_test.OpTest):
self.check_grad(['X'], ['Out']) self.check_grad(['X'], ['Out'])
class TestCastOp2(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
# numpy float16 is binded to fluid float16 via uint16
self.inputs = {'X': ipt.astype('float16').view(np.uint16)}
self.outputs = {'Out': ipt.astype('float32')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP16),
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output(atol=1e-3)
class TestCastOp3(op_test.OpTest):
def setUp(self):
ipt = np.random.random(size=[10, 10])
self.inputs = {'X': ipt.astype('float32')}
self.outputs = {'Out': ipt.astype('float16')}
self.attrs = {
'in_dtype': int(core.VarDesc.VarType.FP32),
'out_dtype': int(core.VarDesc.VarType.FP16)
}
self.op_type = 'cast'
def test_check_output(self):
self.check_output(atol=1e-3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册