未验证 提交 04f042a5 编写于 作者: Z zyfncg 提交者: GitHub

remove MakePtenDenseTensor in op compute (#38910)

上级 c48a9ad5
...@@ -63,14 +63,11 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -63,14 +63,11 @@ class CastOpKernel : public framework::OpKernel<InT> {
out->mutable_data(dev_ctx.GetPlace(), out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype)); static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
auto pt_out_dtype = pten::TransToPtenDataType( auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype)); static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel // call new kernel
pten::CastKernel<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_out.get()); pten::CastKernel<InT>(dev_ctx, *in, pt_out_dtype, out);
} }
}; };
......
...@@ -202,10 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> { ...@@ -202,10 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel<T> {
commonterm_for_range(commonterm_functor); commonterm_for_range(commonterm_functor);
commonterm_conj = helper.Transpose(commonterm_conj); commonterm_conj = helper.Transpose(commonterm_conj);
auto pt_x = paddle::experimental::MakePtenDenseTensor(commonterm); pten::AddKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm);
auto pt_y = paddle::experimental::MakePtenDenseTensor(commonterm_conj);
auto pt_z = paddle::experimental::MakePtenDenseTensor(commonterm);
pten::AddKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), -1, pt_z.get());
auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false);
auto mat_dim_c = auto mat_dim_c =
......
...@@ -34,11 +34,9 @@ class ConjKernel : public framework::OpKernel<T> { ...@@ -34,11 +34,9 @@ class ConjKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace(), size_t(x->numel() * sizeof(T))); out->mutable_data<T>(context.GetPlace(), size_t(x->numel() * sizeof(T)));
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get()); pten::ConjKernel<T>(dev_ctx, *x, out);
} }
}; };
......
...@@ -40,13 +40,8 @@ class DotKernel : public framework::OpKernel<T> { ...@@ -40,13 +40,8 @@ class DotKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
out->mutable_data<T>(x->place()); out->mutable_data<T>(x->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::DotKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(), pten::DotKernel<T, DeviceContext>(dev_ctx, *x, *y, out);
pt_out.get());
} }
}; };
...@@ -63,17 +58,11 @@ class DotGradKernel : public framework::OpKernel<T> { ...@@ -63,17 +58,11 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace()); if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy);
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::DotGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(), pten::DotGradKernel<T>(dev_ctx, *tensor_x, *tensor_y, *tensor_dout,
pt_dy.get()); tensor_dx, tensor_dy);
} }
}; };
......
...@@ -60,11 +60,9 @@ class FillAnyLikeKernel : public framework::OpKernel<T> { ...@@ -60,11 +60,9 @@ class FillAnyLikeKernel : public framework::OpKernel<T> {
std::isnan(value), false, std::isnan(value), false,
platform::errors::InvalidArgument("The filled value is NaN.")); platform::errors::InvalidArgument("The filled value is NaN."));
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
const auto& dev_ctx = context.template device_context<DeviceContext>(); const auto& dev_ctx = context.template device_context<DeviceContext>();
// call new kernel // call new kernel
pten::FullLikeKernel<T>(dev_ctx, value, pt_out.get()); pten::FullLikeKernel<T>(dev_ctx, value, out);
} }
}; };
......
...@@ -131,12 +131,9 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> { ...@@ -131,12 +131,9 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto &stop_axis = context.Attr<int>("stop_axis"); auto &stop_axis = context.Attr<int>("stop_axis");
auto &dev_ctx = context.device_context<DeviceContext>(); auto &dev_ctx = context.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis, pten::FlattenKernel<T, DeviceContext>(dev_ctx, *in, start_axis, stop_axis,
stop_axis, pt_out.get()); out);
} }
}; };
...@@ -152,20 +149,8 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> { ...@@ -152,20 +149,8 @@ class FlattenContiguousRangeGradKernel : public framework::OpKernel<T> {
d_x->mutable_data(ctx.GetPlace(), d_out->type()); d_x->mutable_data(ctx.GetPlace(), d_out->type());
auto &dev_ctx = ctx.device_context<DeviceContext>(); auto &dev_ctx = ctx.device_context<DeviceContext>();
auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x);
auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out);
// Because the holder of xshape may be nullptr, we can't use
// MakePtenDenseTensor.
// So, we create a new DenseTensor to save the dims of xshape.
pten::DenseTensorMeta xshape_meta{pten::TransToPtenDataType(d_x->type()),
xshape->dims(), d_x->layout()};
auto pt_xshape =
pten::Empty<T, DeviceContext>(dev_ctx, std::move(xshape_meta));
// call new kernel // call new kernel
pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *pt_d_out.get(), pten::FlattenGradKernel<T, DeviceContext>(dev_ctx, *d_out, *xshape, d_x);
pt_xshape, pt_d_x.get());
} }
}; };
......
...@@ -220,11 +220,8 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, ...@@ -220,11 +220,8 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1,
const framework::Tensor& src2, framework::Tensor* out) { const framework::Tensor& src2, framework::Tensor* out) {
out->Resize(src1.dims()); out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(src1);
auto pt_y = paddle::experimental::MakePtenDenseTensor(src2); pten::AddKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*out);
pten::AddKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(), -1,
pt_z.get());
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -232,11 +229,8 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, ...@@ -232,11 +229,8 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1,
const framework::Tensor& src2, framework::Tensor* out) { const framework::Tensor& src2, framework::Tensor* out) {
out->Resize(src1.dims()); out->Resize(src1.dims());
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(src1);
auto pt_y = paddle::experimental::MakePtenDenseTensor(src2); pten::SubtractKernel<T, DeviceContext>(dev_ctx, src1, src2, -1, out);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*out);
pten::SubtractKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), *pt_y.get(), -1,
pt_z.get());
} }
template <typename DeviceContext, typename T, size_t D> template <typename DeviceContext, typename T, size_t D>
......
...@@ -52,13 +52,8 @@ class MatMulV2Kernel : public framework::OpKernel<T> { ...@@ -52,13 +52,8 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
Out->mutable_data<T>(X->place()); Out->mutable_data<T>(X->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*X);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out);
// call new kernel // call new kernel
pten::MatmulKernel<T>(dev_ctx, *pt_x, *pt_y, trans_x, trans_y, pten::MatmulKernel<T>(dev_ctx, *X, *Y, trans_x, trans_y, Out);
pt_out.get());
} }
}; };
...@@ -151,19 +146,11 @@ class MatMulV2GradKernel : public framework::OpKernel<T> { ...@@ -151,19 +146,11 @@ class MatMulV2GradKernel : public framework::OpKernel<T> {
if (dx) dx->mutable_data<T>(ctx.GetPlace()); if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace()); if (dy) dy->mutable_data<T>(ctx.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx)
: std::unique_ptr<pten::DenseTensor>(nullptr);
auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy)
: std::unique_ptr<pten::DenseTensor>(nullptr);
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x, pten::MatmulGradKernel<T>(dev_ctx, *x, *y, *dout, transpose_x, transpose_y,
transpose_y, pt_dx.get(), pt_dy.get()); dx, dy);
} }
}; };
...@@ -188,21 +175,11 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> { ...@@ -188,21 +175,11 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
if (dy) dy->mutable_data<T>(context.GetPlace()); if (dy) dy->mutable_data<T>(context.GetPlace());
if (ddout) ddout->mutable_data<T>(context.GetPlace()); if (ddout) ddout->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy);
auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout);
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulDoubleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, pten::MatmulDoubleGradKernel<T>(dev_ctx, *x, *y, *dout, *ddx, *ddy,
*pt_ddy, transpose_x, transpose_y, transpose_x, transpose_y, dx, dy, ddout);
pt_dx.get(), pt_dy.get(), pt_ddout.get());
} }
}; };
...@@ -238,28 +215,11 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> { ...@@ -238,28 +215,11 @@ class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
if (out_d_ddx) out_d_ddx->mutable_data<T>(context.GetPlace()); if (out_d_ddx) out_d_ddx->mutable_data<T>(context.GetPlace());
if (out_d_ddy) out_d_ddy->mutable_data<T>(context.GetPlace()); if (out_d_ddy) out_d_ddy->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx);
auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy);
auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout);
auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x);
auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y);
auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout);
auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx);
auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy);
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel // call new kernel
pten::MatmulTripleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, pten::MatmulTripleGradKernel<T>(
*pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout, dev_ctx, *x, *y, *dout, *ddx, *ddy, *d_dx, *d_dy, *d_ddout, transpose_x,
transpose_x, transpose_y, pt_out_d_x.get(), transpose_y, out_d_x, out_d_y, out_d_dout, out_d_ddx, out_d_ddy);
pt_out_d_y.get(), pt_out_d_dout.get(),
pt_out_d_ddx.get(), pt_out_d_ddy.get());
} }
}; };
......
...@@ -65,12 +65,8 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -65,12 +65,8 @@ class ScaleKernel : public framework::OpKernel<T> {
out->mutable_data<T>(in->place()); out->mutable_data<T>(in->place());
auto& dev_ctx = ctx.device_context<DeviceContext>(); auto& dev_ctx = ctx.device_context<DeviceContext>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::ScaleKernel<T>(dev_ctx, *pt_x.get(), scale, bias, bias_after_scale, pten::ScaleKernel<T>(dev_ctx, *in, scale, bias, bias_after_scale, out);
pt_out.get());
} }
}; };
......
...@@ -34,11 +34,8 @@ class SignKernel : public framework::OpKernel<T> { ...@@ -34,11 +34,8 @@ class SignKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.device_context<DeviceContext>(); auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data<T>(x->place()); out->mutable_data<T>(x->place());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::SignKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), pt_out.get()); pten::SignKernel<T, DeviceContext>(dev_ctx, *x, out);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册