未验证 提交 2d0e8c3b 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim-Relevant] Multiply operants replace by scale (#51469)

上级 300f36c0
......@@ -102,7 +102,7 @@ Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) {
}
Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return ::multiply_ad_func(x, ::full_like_ad_func(x, y));
return ::scale_ad_func(x, y, 0.0f, true);
}
Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
......@@ -118,7 +118,7 @@ Tensor EagerTensorOperants::subtract(const Scalar& x, const Tensor& y) {
}
Tensor EagerTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return ::multiply_ad_func(::full_like_ad_func(y, x), y);
return ::scale_ad_func(y, x, 0.0f, true);
}
Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) {
......@@ -229,7 +229,7 @@ Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
}
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::prim::multiply<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
}
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
......@@ -245,7 +245,7 @@ Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
}
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::prim::multiply<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
}
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
......
......@@ -277,7 +277,7 @@ Tensor PhiTensorOperants::subtract(const Tensor& x, const Scalar& y) {
}
Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::experimental::multiply(x, paddle::experimental::full_like(x, y));
return paddle::experimental::scale(x, y, 0.0f, true);
}
Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) {
......@@ -293,7 +293,7 @@ Tensor PhiTensorOperants::subtract(const Scalar& x, const Tensor& y) {
}
Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::experimental::multiply(paddle::experimental::full_like(y, x), y);
return paddle::experimental::scale(y, x, 0.0f, true);
}
Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册