diff --git a/paddle/fluid/operators/cast_op_mlu.cc b/paddle/fluid/operators/cast_op_mlu.cc index f28889e7acf8773e2c55044037eb6bbde71ce12f..f0df271a8d07e7dd02738641167730afb4ccb3cb 100644 --- a/paddle/fluid/operators/cast_op_mlu.cc +++ b/paddle/fluid/operators/cast_op_mlu.cc @@ -44,37 +44,7 @@ class CastMLUKernel : public framework::OpKernel { framework::DataTypeToString(src_type), framework::DataTypeToString(dst_type))); - switch (dst_type) { - case VT::FP32: - output->mutable_data(place); - break; - case VT::FP16: - output->mutable_data(place); - break; - case VT::INT32: - output->mutable_data(place); - break; - case VT::INT16: - output->mutable_data(place); - break; - case VT::INT8: - output->mutable_data(place); - break; - case VT::UINT8: - output->mutable_data(place); - break; - case VT::BOOL: - output->mutable_data(place); - break; - case VT::INT64: - output->mutable_data(place); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "Not supported cast %d -> %d", - framework::DataTypeToString(src_type), - framework::DataTypeToString(dst_type))); - } + output->mutable_data(place, framework::TransToPhiDataType(dst_type)); MLUCnnlTensorDesc input_desc(*input); MLUCnnlTensorDesc output_desc(*output); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index f048ac7c5c3be08e034c7b2a3b163888f9e9e982..c97ee3efd3f566a0f4fbb89eb6fe483494c8855c 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -75,6 +75,9 @@ inline cnnlDataType_t ToCnnlDataType( case DataType::FLOAT32: type = CNNL_DTYPE_FLOAT; break; + case DataType::FLOAT64: + type = CNNL_DTYPE_DOUBLE; + break; case DataType::INT8: type = CNNL_DTYPE_INT8; break; diff --git a/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py index 6ba62b11499f460bbe620cc26cf6065a786dea28..88b46af8df2a36f57fb690f71d938a5323654137 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_cast_op_mlu.py @@ -61,6 +61,25 @@ class TestCastOpFp16ToFp32(OpTest): self.op_type = 'cast' self.place = paddle.device.MLUPlace(0) self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestCastOpFp32ToFp64(OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float64')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP64) + } + self.op_type = 'cast' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True def test_check_output(self): self.check_output_with_place(self.place, atol=1e-3)