未验证 提交 0637b9a6 编写于 作者: C chentianyu03 提交者: GitHub

[pten] remove in_type arg in cast kernel (#38486)

* remove intype arg in cast kernel

* modify conj config in api.yaml by dictionary order

* rm unused code in cast_kernel.cu
上级 78836bb7
......@@ -105,6 +105,11 @@ class CastOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}
};
} // namespace operators
......
......@@ -59,8 +59,6 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* out = context.Output<framework::Tensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
// todo: not used in_dtype
auto in_dtype = context.Attr<int>("in_dtype");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data(dev_ctx.GetPlace(),
......@@ -71,12 +69,9 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_in_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(in_dtype));
// call new kernel
pten::Cast<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_in_dtype,
pt_out.get());
pten::Cast<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_out.get());
}
};
......
......@@ -33,8 +33,10 @@ using add_kernel = void (*)(const DeviceContext&,
int,
DenseTensor*);
using cast_kernel = void (*)(
const DeviceContext&, const DenseTensor&, DataType, DataType, DenseTensor*);
using cast_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
DataType,
DenseTensor*);
using divide_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
......
......@@ -40,14 +40,13 @@ DenseTensor Flatten(const ContextT& dev_ctx,
template <typename T, typename ContextT>
DenseTensor Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype) {
DataType out_dtype) {
auto out_meta = CastInferMeta(x.meta(), out_dtype);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Cast<T, ContextT>(dev_ctx, x, out_dtype, in_dtype, &dense_out);
Cast<T, ContextT>(dev_ctx, x, out_dtype, &dense_out);
return dense_out;
}
......
......@@ -22,7 +22,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out);
} // namespace pten
......@@ -50,7 +50,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out);
......
......@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/api/ext/dispatch.h"
......@@ -84,7 +82,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
......
......@@ -1112,7 +1112,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
AsyncCopy(x, y);
y->Resize(out_dims);
} else {
pten::Cast<Tx>(*dev_ctx, x, y->dtype(), x.dtype(), y);
pten::Cast<Tx>(*dev_ctx, x, y->dtype(), y);
}
return;
}
......
......@@ -59,7 +59,7 @@ void Reduce(const DeviceContext& dev_ctx,
pten::DenseTensorMeta(out_dtype, x.dims(), x.layout()));
// cast x tensor to out_dtype
pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype, x.dtype(), &tmp_tensor);
pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype, &tmp_tensor);
// do reduce sum
PD_VISIT_ALL_TYPES(
......
......@@ -49,13 +49,11 @@ TEST(DEV_API, cast) {
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
pten::DataType out_dtype = pten::DataType::FLOAT64;
pten::DataType in_dtype = pten::DataType::FLOAT32;
// 2. test API
auto out = pten::Cast<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
out_dtype,
in_dtype);
out_dtype);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
......@@ -15,9 +15,17 @@
func : CastInferMeta
kernel :
func : cast
param : [x, out_dtype, x.dtype()]
param : [x, out_dtype]
data_type : x
- api : conj
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj
- api : divide
args : (const Tensor& x, const Tensor& y)
output : Tensor
......@@ -171,11 +179,3 @@
args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED)
output : Tensor
invoke : full_like(x, 0, dtype, place, layout)
- api : conj
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册