未验证 提交 2d0e8c3b 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim-Relevant] Multiply operants replace by scale (#51469)

上级 300f36c0
...@@ -102,7 +102,7 @@ Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) { ...@@ -102,7 +102,7 @@ Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) {
} }
Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) { Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return ::multiply_ad_func(x, ::full_like_ad_func(x, y)); return ::scale_ad_func(x, y, 0.0f, true);
} }
Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) { Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
...@@ -118,7 +118,7 @@ Tensor EagerTensorOperants::subtract(const Scalar& x, const Tensor& y) { ...@@ -118,7 +118,7 @@ Tensor EagerTensorOperants::subtract(const Scalar& x, const Tensor& y) {
} }
Tensor EagerTensorOperants::multiply(const Scalar& x, const Tensor& y) { Tensor EagerTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return ::multiply_ad_func(::full_like_ad_func(y, x), y); return ::scale_ad_func(y, x, 0.0f, true);
} }
Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) { Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) {
...@@ -229,7 +229,7 @@ Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) { ...@@ -229,7 +229,7 @@ Tensor StaticTensorOperants::subtract(const Tensor& x, const Scalar& y) {
} }
Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::prim::multiply<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place())); return paddle::prim::scale<DescTensor>(x, y, 0.0f, true);
} }
Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) { Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
...@@ -245,7 +245,7 @@ Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) { ...@@ -245,7 +245,7 @@ Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
} }
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::prim::multiply<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y); return paddle::prim::scale<DescTensor>(y, x, 0.0f, true);
} }
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) { Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
......
...@@ -277,7 +277,7 @@ Tensor PhiTensorOperants::subtract(const Tensor& x, const Scalar& y) { ...@@ -277,7 +277,7 @@ Tensor PhiTensorOperants::subtract(const Tensor& x, const Scalar& y) {
} }
Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) { Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) {
return paddle::experimental::multiply(x, paddle::experimental::full_like(x, y)); return paddle::experimental::scale(x, y, 0.0f, true);
} }
Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) { Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) {
...@@ -293,7 +293,7 @@ Tensor PhiTensorOperants::subtract(const Scalar& x, const Tensor& y) { ...@@ -293,7 +293,7 @@ Tensor PhiTensorOperants::subtract(const Scalar& x, const Tensor& y) {
} }
Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) { Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::experimental::multiply(paddle::experimental::full_like(y, x), y); return paddle::experimental::scale(y, x, 0.0f, true);
} }
Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) { Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册