diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 91276ba6e8bed620ab53305e72bfa9d52cc4c07b..8fa0416049f8fa128d7ab61f8350b41960f07263 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -48,52 +48,17 @@ struct CastOpFunctor { } }; -template -static void CastFunction(const framework::ExecutionContext& context) { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - - auto in_t = framework::EigenVector::Flatten(*in); - out->mutable_data(context.GetPlace()); - auto out_t = framework::EigenVector::Flatten(*out); - auto& place = - *context.template device_context().eigen_device(); - out_t.device(place) = in_t.template cast(); -} - template class CastOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto out_type = static_cast( - context.Attr("out_dtype")); - - if (out_type == paddle::framework::proto::VarType::FP64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::FP32) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::FP16) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::INT64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::INT32) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::UINT8) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::BOOL) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::COMPLEX64) { - CastFunction(context); - } else if (out_type == paddle::framework::proto::VarType::COMPLEX128) { - CastFunction(context); - } else { - // NOTE(chenweihang): if else branch do nothing, the output var will - // be non-initialized in dygraph, which will throw error if the - // non-initialized var is used as the next op's input - PADDLE_THROW(platform::errors::Unimplemented( - "Now does not support casting Tensor to `%s` data type.", - framework::DataTypeToString(out_type))); - } + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + framework::VisitDataType( + static_cast( + context.Attr("out_dtype")), + CastOpFunctor( + in, out, context.template device_context())); } }; diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index 44fdd8c74bf7c4fd7850608e9654ac426dab17f8..0fc3dccab4a64d118b94496d99bdfd0760f3bb6e 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -90,18 +90,6 @@ class TestCastOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype_type) -class TestCastOpErrorInDygraph(unittest.TestCase): - def test_non_support_out_dtype(self): - paddle.disable_static() - - with self.assertRaises(NotImplementedError): - tensor = paddle.randn([10, 10], 'float32') - core.ops.cast(tensor, 'in_dtype', core.VarDesc.VarType.FP32, - 'out_dtype', core.VarDesc.VarType.INT16) - - paddle.enable_static() - - if __name__ == '__main__': paddle.enable_static() unittest.main()