未验证 提交 0d956e17 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim] Tensor arithmetic operants support left scalar type (#50840)

上级 44a32fbd
......@@ -56,6 +56,14 @@ class EagerTensorOperants : public TensorOperantsBase {
Tensor divide(const Tensor& x, const Scalar& y);
Tensor add(const Scalar& x, const Tensor& y);
Tensor subtract(const Scalar& x, const Tensor& y);
Tensor multiply(const Scalar& x, const Tensor& y);
Tensor divide(const Scalar& x, const Tensor& y);
"""
......@@ -97,6 +105,22 @@ Tensor EagerTensorOperants::divide(const Tensor& x, const Scalar& y) {
return ::divide_ad_func(x, ::full_like_ad_func(x, y));
}
Tensor EagerTensorOperants::add(const Scalar& x, const Tensor& y) {
return ::add_ad_func(::full_like_ad_func(y, x), y);
}
Tensor EagerTensorOperants::subtract(const Scalar& x, const Tensor& y) {
return ::subtract_ad_func(::full_like_ad_func(y, x), y);
}
Tensor EagerTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return ::multiply_ad_func(::full_like_ad_func(y, x), y);
}
Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) {
return ::divide_ad_func(::full_like_ad_func(y, x), y);
}
"""
......@@ -144,6 +168,14 @@ class StaticTensorOperants : public TensorOperantsBase {
Tensor divide(const Tensor& x, const Scalar& y);
Tensor add(const Scalar& x, const Tensor& y);
Tensor subtract(const Scalar& x, const Tensor& y);
Tensor multiply(const Scalar& x, const Tensor& y);
Tensor divide(const Scalar& x, const Tensor& y);
"""
......@@ -188,6 +220,22 @@ Tensor StaticTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::prim::divide<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
Tensor StaticTensorOperants::add(const Scalar& x, const Tensor& y) {
return paddle::prim::add<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
Tensor StaticTensorOperants::subtract(const Scalar& x, const Tensor& y) {
return paddle::prim::subtract<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
Tensor StaticTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::prim::multiply<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
}
"""
......
......@@ -61,7 +61,7 @@ void gather_grad(const Tensor& x,
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
if (!grad_x) return;
auto grad_x_tmp = grad_out * (out.pow(2.0) * -1.0 + 1.0);
auto grad_x_tmp = grad_out * (1.0 - out.pow(2.0));
set_output<T>(grad_x_tmp, grad_x);
}
......@@ -201,7 +201,7 @@ void divide_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto dy_res = x / y.pow(2.0) * -1.0 * out_grad;
auto dy_res = -(x / y.pow(2.0)) * out_grad;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
......@@ -242,7 +242,7 @@ void divide_grad(const Tensor& x,
template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto x_grad_tmp = out_grad / 2.0 / out;
auto x_grad_tmp = out_grad * 0.5 / out;
set_output<T>(x_grad_tmp, x_grad);
}
}
......
......@@ -29,7 +29,7 @@ cc_test_old(
prim_utils
operator
elementwise_mul_op
elementwise_add_op
elementwise_sub_op
fill_constant_op
activation_op
phi_api
......
......@@ -35,7 +35,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -44,7 +44,7 @@ PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT);
#endif
......@@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block,
grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(6));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(4));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1));
......@@ -218,36 +218,22 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Inputs().at("Y")[0],
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Type(), "fill_constant");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[3]->GetAttr("dtype")),
static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0],
grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Type(), "elementwise_add");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[5]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[5]->Inputs().at("Y")[0],
grad_ops[4]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[5]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[5]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
}
TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
......@@ -392,5 +378,5 @@ USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(scale);
......@@ -556,6 +556,8 @@ class PADDLE_API Tensor final {
Tensor operator/(const Scalar& other) const;
Tensor operator-() const;
/* Part 8: Autograd methods */
/**
......@@ -699,5 +701,13 @@ class PADDLE_API Tensor final {
Tensor tile(const IntArray& repeat_times) const;
};
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y);
PADDLE_API Tensor operator-(const Scalar& x, const Tensor& y);
PADDLE_API Tensor operator*(const Scalar& x, const Tensor& y);
PADDLE_API Tensor operator/(const Scalar& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -60,6 +60,14 @@ class TensorOperantsBase {
virtual Tensor multiply(const Tensor& x, const Scalar& y) = 0;
virtual Tensor subtract(const Tensor& x, const Scalar& y) = 0;
virtual Tensor add(const Scalar& x, const Tensor& y) = 0;
virtual Tensor divide(const Scalar& x, const Tensor& y) = 0;
virtual Tensor multiply(const Scalar& x, const Tensor& y) = 0;
virtual Tensor subtract(const Scalar& x, const Tensor& y) = 0;
"""
......@@ -130,6 +138,26 @@ Tensor Tensor::multiply(const Scalar& y) const {
Tensor Tensor::subtract(const Scalar& y) const {
return paddle::OperantsManager::Instance().subtract(static_cast<const Tensor &>(*this), y);
}
Tensor Tensor::operator-() const {
return scale(-1.0, 0.0, true);
}
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().add(x, y);
}
PADDLE_API Tensor operator-(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().subtract(x, y);
}
PADDLE_API Tensor operator*(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().multiply(x, y);
}
PADDLE_API Tensor operator/(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().divide(x, y);
}
"""
......@@ -175,6 +203,14 @@ class PhiTensorOperants : public TensorOperantsBase {
Tensor divide(const Tensor& x, const Scalar& y);
Tensor add(const Scalar& x, const Tensor& y);
Tensor subtract(const Scalar& x, const Tensor& y);
Tensor multiply(const Scalar& x, const Tensor& y);
Tensor divide(const Scalar& x, const Tensor& y);
"""
......@@ -215,6 +251,22 @@ Tensor PhiTensorOperants::multiply(const Tensor& x, const Scalar& y) {
Tensor PhiTensorOperants::divide(const Tensor& x, const Scalar& y) {
return paddle::experimental::divide(x, paddle::experimental::full_like(x, y));
}
Tensor PhiTensorOperants::add(const Scalar& x, const Tensor& y) {
return paddle::experimental::add(paddle::experimental::full_like(y, x), y);
}
Tensor PhiTensorOperants::subtract(const Scalar& x, const Tensor& y) {
return paddle::experimental::subtract(paddle::experimental::full_like(y, x), y);
}
Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) {
return paddle::experimental::multiply(paddle::experimental::full_like(y, x), y);
}
Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::experimental::divide(paddle::experimental::full_like(y, x), y);
}
"""
......@@ -299,6 +351,14 @@ class OperantsManager {
Tensor divide(const Tensor& x, const Scalar& y);
Tensor add(const Scalar& x, const Tensor& y);
Tensor subtract(const Scalar& x, const Tensor& y);
Tensor multiply(const Scalar& x, const Tensor& y);
Tensor divide(const Scalar& x, const Tensor& y);
"""
......@@ -357,6 +417,28 @@ class OperantsAPI(ForwardAPI):
{indent}virtual {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_declare_args(inplace_flag=True)}) = 0;
"""
def get_declare_args_without_first_tensor(self, inplace_flag=False):
func_name = self.get_api_func_name()
declare_args = self.get_input_tensor_args(inplace_flag)
assert len(declare_args) >= 1, (
"Error! Api %s has no Tensor inputs" % func_name
)
first_input_type = " ".join(declare_args[0].split(" ")[:-1])
# NOTE(HongyuJia): Do not consider "const paddle::optional<Tensor>&"
assert first_input_type == "const Tensor&", (
"Error! The first argument of Tensor Api %s must be Tensor, but received %s"
% (func_name, first_input_type)
)
for name in self.attrs['names']:
default_value = ''
if self.attrs['attr_info'][name][1] is not None:
default_value = ' = ' + self.attrs['attr_info'][name][1]
declare_args.append(
self.attrs['attr_info'][name][0] + ' ' + name + default_value
)
# remove first Tensor argument
return ", ".join(declare_args[1:])
def get_define_args_without_first_tensor(self, inplace_flag=False):
func_name = self.get_api_func_name()
define_args = self.get_input_tensor_args(inplace_flag)
......@@ -473,6 +555,8 @@ class OperantsAPI(ForwardAPI):
if func_name in ["add", "subtract", "multiply", "divide"]:
final_code += f"""
{self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}}
{self.get_return_type()} OperantsManager::{func_name}(const Scalar& x, const Tensor& y) {{{self.gene_operants_manager_code()}}}
"""
# func decalaration
if func_name[-1] != '_':
......
......@@ -44,9 +44,7 @@ from paddle.fluid import core, framework
(
'pow',
'fill_constant',
'elementwise_mul',
'fill_constant',
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
),
),
......
......@@ -30,9 +30,7 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
self.desired_ops_no_skip = (
'pow',
'fill_constant',
'elementwise_mul',
'fill_constant',
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
)
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册