未验证 提交 04f042a5 编写于 作者: Z zyfncg 提交者: GitHub

remove MakePtenDenseTensor in op compute (#38910)

上级 c48a9ad5
......@@ -63,14 +63,11 @@ class CastOpKernel : public framework::OpKernel<InT> {
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel
pten::CastKernel<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_out.get());
pten::CastKernel<InT>(dev_ctx, *in, pt_out_dtype, out);
}
};
......
......@@ -202,10 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj);
auto pt_x = paddle::experimental::MakePtenDenseTensor(commonterm);
auto pt_y = paddle::experimental::MakePtenDenseTensor(commonterm_conj);
auto pt_z = paddle::experimental::MakePtenDenseTensor(commonterm);
pten::AddKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), -1, pt_z.get());
pten::AddKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm);
auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false);
auto mat_dim_c =
......
......@@ -34,11 +34,9 @@ class ConjKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace(), size_t(x->numel() * sizeof(T)));
auto& dev_ctx = context.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
pten::ConjKernel<T>(dev_ctx, *x, out);
}
};
......
......@@ -40,13 +40,8 @@ class DotKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
out->mutable_data<T>(x->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::DotKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(),
pt_out.get());
pten::DotKernel<T, DeviceContext>(dev_ctx, *x, *y, out);
}
};
......@@ -63,17 +58,11 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy);
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::DotGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(),
pt_dy.get());
pten::DotGradKernel<T>(dev_ctx, *tensor_x, *tensor_y, *tensor_dout,
tensor_dx, tensor_dy);
}
};
......
......@@ -60,11 +60,9 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN."));
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel
pten::FullLikeKernel<T>(dev_ctx, value, pt_out.get());
pten::FullLikeKernel<T>(dev_ctx, value, out);
}
};
......
......@@ -131,12 +131,9 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto &stop_axis = context.Attr<int>("stop_axis");
auto &dev_ctx = context.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis,
stop_axis, pt_out.get());
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *in, start_axis, stop_axis,
out);
}
};
......@@ -152,20 +149,8 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
d_x->mutable_data(ctx.GetPlace(), d_out->type());
auto &dev_ctx = ctx.device_context<DeviceContext>();
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out);
// Because the holder of xshape may be nullptr, we can't use
// MakePtenDenseTensor.
// So, we create a new DenseTensor to save the dims of xshape.
pten::DenseTensorMeta xshape_meta{pten::TransToPtenDataType(d_x->type()),
xshape->dims(), d_x->layout()};
auto pt_xshape =
pten::Empty<T, DeviceContext>(dev_ctx, std::move(xshape_meta));
// call new kernel
pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *pt_d_out.get(),
pt_xshape, pt_d_x.get());
pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *d_out, *xshape, d_x);
}
};
......
......@@ -220,11 +220,8 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
const framework::Tensor& src2, framework::Tensor* out) {
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(src1);
auto pt_y = paddle::experimental::MakePtenDenseTensor(src2);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*out);
pten::AddKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(), -1,
pt_z.get());
pten::AddKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
}
template <typename DeviceContext, typename T>
......@@ -232,11 +229,8 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
const framework::Tensor& src2, framework::Tensor* out) {
out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(src1);
auto pt_y = paddle::experimental::MakePtenDenseTensor(src2);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*out);
pten::SubtractKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(), -1,
pt_z.get());
pten::SubtractKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
}
template <typename DeviceContext, typename T, size_t D>
......
......@@ -52,13 +52,8 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>();
Out->mutable_data<T>(X->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*X);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out);
// call new kernel
pten::MatmulKernel<T>(dev_ctx, *pt_x, *pt_y, trans_x, trans_y,
pt_out.get());
pten::MatmulKernel<T>(dev_ctx, *X, *Y, trans_x, trans_y, Out);
}
};
......@@ -151,19 +146,11 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx)
: std::unique_ptr<pten::DenseTensor>(nullptr);
auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy)
: std::unique_ptr<pten::DenseTensor>(nullptr);
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::MatmulGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x,
transpose_y, pt_dx.get(), pt_dy.get());
pten::MatmulGradKernel<T>(dev_ctx, *x, *y, *dout, transpose_x, transpose_y,
dx, dy);
}
};
......@@ -188,21 +175,11 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
if (dy) dy->mutable_data<T>(context.GetPlace());
if (ddout) ddout->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy);
auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout);
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::MatmulDoubleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx,
*pt_ddy, transpose_x, transpose_y,
pt_dx.get(), pt_dy.get(), pt_ddout.get());
pten::MatmulDoubleGradKernel<T>(dev_ctx, *x, *y, *dout, *ddx, *ddy,
transpose_x, transpose_y, dx, dy, ddout);
}
};
......@@ -238,28 +215,11 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
if (out_d_ddx) out_d_ddx->mutable_data<T>(context.GetPlace());
if (out_d_ddy) out_d_ddy->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx);
auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy);
auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout);
auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x);
auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y);
auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout);
auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx);
auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy);
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::MatmulTripleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx,
*pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout,
transpose_x, transpose_y, pt_out_d_x.get(),
pt_out_d_y.get(), pt_out_d_dout.get(),
pt_out_d_ddx.get(), pt_out_d_ddy.get());
pten::MatmulTripleGradKernel<T>(
dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x,
transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy);
}
};
......
......@@ -65,12 +65,8 @@ class ScaleKernel : public framework::OpKernel<T> {
out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::ScaleKernel<T>(dev_ctx, *pt_x.get(), scale, bias, bias_after_scale,
pt_out.get());
pten::ScaleKernel<T>(dev_ctx, *in, scale, bias, bias_after_scale, out);
}
};
......
......@@ -34,11 +34,8 @@ class SignKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data<T>(x->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::SignKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), pt_out.get());
pten::SignKernel<T, DeviceContext>(dev_ctx, *x, out);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册