From 42dd0f1b7a46d13a59e4d901dcd8499e7643bbce Mon Sep 17 00:00:00 2001 From: qipengh Date: Tue, 7 Jun 2022 14:29:49 +0800 Subject: [PATCH] [MLU]support cast double type (#43058) * [MLU]support cast double type * [MLU]fix cast test --- paddle/fluid/operators/cast_op_mlu.cc | 32 +------------------ paddle/fluid/operators/mlu/mlu_baseop.h | 3 ++ .../tests/unittests/mlu/test_cast_op_mlu.py | 19 +++++++++++ 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/cast_op_mlu.cc b/paddle/fluid/operators/cast_op_mlu.cc index f28889e7acf..f0df271a8d0 100644 --- a/paddle/fluid/operators/cast_op_mlu.cc +++ b/paddle/fluid/operators/cast_op_mlu.cc @@ -44,37 +44,7 @@ class CastMLUKernel : public framework::OpKernel { framework::DataTypeToString(src_type), framework::DataTypeToString(dst_type))); - switch (dst_type) { - case VT::FP32: - output->mutable_data(place); - break; - case VT::FP16: - output->mutable_data(place); - break; - case VT::INT32: - output->mutable_data(place); - break; - case VT::INT16: - output->mutable_data(place); - break; - case VT::INT8: - output->mutable_data(place); - break; - case VT::UINT8: - output->mutable_data(place); - break; - case VT::BOOL: - output->mutable_data(place); - break; - case VT::INT64: - output->mutable_data(place); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "Not supported cast %d -> %d", - framework::DataTypeToString(src_type), - framework::DataTypeToString(dst_type))); - } + output->mutable_data(place, framework::TransToPhiDataType(dst_type)); MLUCnnlTensorDesc input_desc(*input); MLUCnnlTensorDesc output_desc(*output); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index f048ac7c5c3..c97ee3efd3f 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -75,6 +75,9 @@ inline cnnlDataType_t ToCnnlDataType( case DataType::FLOAT32: type = CNNL_DTYPE_FLOAT; break; + case DataType::FLOAT64: + type = CNNL_DTYPE_DOUBLE; + break; case DataType::INT8: type = CNNL_DTYPE_INT8; break; diff --git a/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py index 6ba62b11499..88b46af8df2 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py @@ -61,6 +61,25 @@ class TestCastOpFp16ToFp32(OpTest): self.op_type = 'cast' self.place = paddle.device.MLUPlace(0) self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpFp32ToFp64(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float64')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP64) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True def test_check_output(self): self.check_output_with_place(self.place, atol=1e-3) -- GitLab