diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 7278d80ce9ba195a1c83a3ba67dcb449d7f81e59..4853e5324c30f56e90cec9a7c75a48686b58a4b8 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -105,6 +105,11 @@ class CastOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(tensor->type(), tensor_place); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"}); + } }; } // namespace operators diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index bf0e81a23bf90a805909e9d651848a6fc834a849..466adfa5f3672853dd3b9e8aa6f1cfae3b0b41aa 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -59,8 +59,6 @@ class CastOpKernel : public framework::OpKernel { auto* out = context.Output("Out"); auto out_dtype = context.Attr("out_dtype"); - // todo: not used in_dtype - auto in_dtype = context.Attr("in_dtype"); auto& dev_ctx = context.device_context(); out->mutable_data(dev_ctx.GetPlace(), @@ -71,12 +69,9 @@ class CastOpKernel : public framework::OpKernel { auto pt_out_dtype = pten::TransToPtenDataType( static_cast(out_dtype)); - auto pt_in_dtype = pten::TransToPtenDataType( - static_cast(in_dtype)); // call new kernel - pten::Cast(dev_ctx, *pt_x.get(), pt_out_dtype, pt_in_dtype, - pt_out.get()); + pten::Cast(dev_ctx, *pt_x.get(), pt_out_dtype, pt_out.get()); } }; diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 1c120edb9ab7bc188a7721f14fc16ccbf3b6d1f0..0b17415a6a98de623be97316b84ac50f6eddea03 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -33,8 +33,10 @@ using add_kernel = void (*)(const DeviceContext&, int, DenseTensor*); -using cast_kernel = void (*)( - const DeviceContext&, const DenseTensor&, DataType, DataType, DenseTensor*); +using cast_kernel = void (*)(const DeviceContext&, + const DenseTensor&, + DataType, + DenseTensor*); using divide_kernel = void (*)(const DeviceContext&, const DenseTensor&, diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index 7015dce95d02448ec16eee126071a476cc6b1d76..e317964dd1e236023ac5bc41c981ea64b84fcbcc 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -40,14 +40,13 @@ DenseTensor Flatten(const ContextT& dev_ctx, template DenseTensor Cast(const ContextT& dev_ctx, const DenseTensor& x, - DataType out_dtype, - DataType in_dtype) { + DataType out_dtype) { auto out_meta = CastInferMeta(x.meta(), out_dtype); pten::DenseTensor dense_out( pten::make_intrusive( dev_ctx.GetPlace()), std::move(out_meta)); - Cast(dev_ctx, x, out_dtype, in_dtype, &dense_out); + Cast(dev_ctx, x, out_dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/kernels/cast_kernel.h b/paddle/pten/kernels/cast_kernel.h index 968139f6f17636abe303502035dd8e15c0b48dff..5243fa05fac154b16f85c80b9305424c3a43f2cd 100644 --- a/paddle/pten/kernels/cast_kernel.h +++ b/paddle/pten/kernels/cast_kernel.h @@ -22,7 +22,6 @@ template void Cast(const ContextT& dev_ctx, const DenseTensor& x, DataType out_dtype, - DataType in_dtype, DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/cpu/cast_kernel.cc b/paddle/pten/kernels/cpu/cast_kernel.cc index db57da1d41d339d97559648b3077f912e9b441a7..a9964d99eef07e637d2b857b32c13bfb2f6d5875 100644 --- a/paddle/pten/kernels/cpu/cast_kernel.cc +++ b/paddle/pten/kernels/cpu/cast_kernel.cc @@ -50,7 +50,6 @@ template void Cast(const ContextT& dev_ctx, const DenseTensor& x, DataType out_dtype, - DataType in_dtype, DenseTensor* out) { PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] { CastKernelImpl(dev_ctx, x, out); diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 011fc9077dca9041c7828a69fc996757042124ba..84816664164fab910b016a8315c68684360a49ac 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - #include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/api/ext/dispatch.h" @@ -84,7 +82,6 @@ template void Cast(const ContextT& dev_ctx, const DenseTensor& x, DataType out_dtype, - DataType in_dtype, DenseTensor* out) { PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] { CastCUDAKernelImpl(dev_ctx, x, out); diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h index 2680f2d3fddca53d56ee1f8931aeac2e13dcc66e..d4fdb477633a3bccaaf62e701c40573dd9a181e4 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h +++ b/paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h @@ -1112,7 +1112,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x, AsyncCopy(x, y); y->Resize(out_dims); } else { - pten::Cast(*dev_ctx, x, y->dtype(), x.dtype(), y); + pten::Cast(*dev_ctx, x, y->dtype(), y); } return; } diff --git a/paddle/pten/kernels/hybird/general/reduce_impl.h b/paddle/pten/kernels/hybird/general/reduce_impl.h index a8e0bc5de2196ea8890d299a5065af9d7cb93aef..daa23456178d71bc0173821f1833d90f5c28e0ba 100644 --- a/paddle/pten/kernels/hybird/general/reduce_impl.h +++ b/paddle/pten/kernels/hybird/general/reduce_impl.h @@ -59,7 +59,7 @@ void Reduce(const DeviceContext& dev_ctx, pten::DenseTensorMeta(out_dtype, x.dims(), x.layout())); // cast x tensor to out_dtype - pten::Cast(dev_ctx, x, out_dtype, x.dtype(), &tmp_tensor); + pten::Cast(dev_ctx, x, out_dtype, &tmp_tensor); // do reduce sum PD_VISIT_ALL_TYPES( diff --git a/paddle/pten/tests/kernels/test_cast_dev_api.cc b/paddle/pten/tests/kernels/test_cast_dev_api.cc index 5bbaf2a2c373d2462cafa58eb7cef6a281a31615..dc3cff150b47b3b85e0d9f8cffe4da508dc3469a 100644 --- a/paddle/pten/tests/kernels/test_cast_dev_api.cc +++ b/paddle/pten/tests/kernels/test_cast_dev_api.cc @@ -49,13 +49,11 @@ TEST(DEV_API, cast) { auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); pten::DataType out_dtype = pten::DataType::FLOAT64; - pten::DataType in_dtype = pten::DataType::FLOAT32; // 2. test API auto out = pten::Cast( *(static_cast(dev_ctx)), dense_x, - out_dtype, - in_dtype); + out_dtype); // 3. check result ASSERT_EQ(out.dims().size(), 2); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 8aacaec69c8f3bc53ed89895885ab8c66443624f..562a726aa29f27bb6b7017c72ba86dc9f33372c3 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -15,9 +15,17 @@ func : CastInferMeta kernel : func : cast - param : [x, out_dtype, x.dtype()] + param : [x, out_dtype] data_type : x +- api : conj + args : (const Tensor& x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : conj + - api : divide args : (const Tensor& x, const Tensor& y) output : Tensor @@ -171,11 +179,3 @@ args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) output : Tensor invoke : full_like(x, 0, dtype, place, layout) - -- api : conj - args : (const Tensor& x) - output : Tensor - infer_meta : - func : UnchangedInferMeta - kernel : - func : conj