未验证 提交 8a097399 编写于 作者: H HongyuJia 提交者: GitHub

[Tensor Operants & Prim] Tensor pow API uses elementwise_pow (#50886)

* [Tensor Operants & Prim] Tensor pow API uses elementwise_pow

* unittest change to fill_constant+elementwise_pow
上级 659cede0
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
- multiply - multiply
- divide - divide
- unsqueeze - unsqueeze
- pow
- exp - exp
- scale - scale
- matmul - matmul
...@@ -25,5 +24,4 @@ ...@@ -25,5 +24,4 @@
- scatter_nd_add - scatter_nd_add
- tile - tile
- transpose - transpose
- subtract
- pad - pad
...@@ -64,6 +64,10 @@ class EagerTensorOperants : public TensorOperantsBase { ...@@ -64,6 +64,10 @@ class EagerTensorOperants : public TensorOperantsBase {
Tensor divide(const Scalar& x, const Tensor& y); Tensor divide(const Scalar& x, const Tensor& y);
Tensor pow(const Tensor& x, const Tensor& y);
Tensor pow(const Tensor& x, const Scalar& y);
""" """
...@@ -121,6 +125,14 @@ Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) { ...@@ -121,6 +125,14 @@ Tensor EagerTensorOperants::divide(const Scalar& x, const Tensor& y) {
return ::divide_ad_func(::full_like_ad_func(y, x), y); return ::divide_ad_func(::full_like_ad_func(y, x), y);
} }
Tensor EagerTensorOperants::pow(const Tensor& x, const Tensor& y) {
return ::elementwise_pow_ad_func(x, y);
}
Tensor EagerTensorOperants::pow(const Tensor& x, const Scalar& y) {
return ::elementwise_pow_ad_func(x, ::full_like_ad_func(x, y));
}
""" """
...@@ -176,6 +188,10 @@ class StaticTensorOperants : public TensorOperantsBase { ...@@ -176,6 +188,10 @@ class StaticTensorOperants : public TensorOperantsBase {
Tensor divide(const Scalar& x, const Tensor& y); Tensor divide(const Scalar& x, const Tensor& y);
Tensor pow(const Tensor& x, const Tensor& y);
Tensor pow(const Tensor& x, const Scalar& y);
""" """
...@@ -236,6 +252,14 @@ Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) { ...@@ -236,6 +252,14 @@ Tensor StaticTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y); return paddle::prim::divide<DescTensor>(paddle::prim::full<DescTensor>(y.shape(), x, y.dtype(), y.place()), y);
} }
Tensor StaticTensorOperants::pow(const Tensor& x, const Tensor& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, y);
}
Tensor StaticTensorOperants::pow(const Tensor& x, const Scalar& y) {
return paddle::prim::elementwise_pow<DescTensor>(x, paddle::prim::full<DescTensor>(x.shape(), y, x.dtype(), x.place()));
}
""" """
......
...@@ -30,6 +30,7 @@ cc_test_old( ...@@ -30,6 +30,7 @@ cc_test_old(
operator operator
elementwise_mul_op elementwise_mul_op
elementwise_sub_op elementwise_sub_op
elementwise_pow_op
fill_constant_op fill_constant_op
activation_op activation_op
phi_api phi_api
......
...@@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) { ...@@ -194,7 +194,7 @@ TEST(StaticPrim, TanhBackwardComposite) {
target_block, target_block,
grad_sub_block)); grad_sub_block));
ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1)); ASSERT_EQ(target_block->AllOps().size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(4)); ASSERT_EQ(grad_ops.size(), static_cast<std::size_t>(5));
ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh"); ASSERT_EQ(target_block->AllOps()[0]->Type(), "tanh");
ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(), ASSERT_EQ(target_block->AllOps()[0]->Inputs().at("X").size(),
static_cast<std::size_t>(1)); static_cast<std::size_t>(1));
...@@ -204,36 +204,41 @@ TEST(StaticPrim, TanhBackwardComposite) { ...@@ -204,36 +204,41 @@ TEST(StaticPrim, TanhBackwardComposite) {
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b"); ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b"); ASSERT_EQ(target_block->AllOps()[0]->Outputs().at("Out")[0], "b");
ASSERT_EQ(grad_ops[0]->Type(), "pow"); ASSERT_EQ(grad_ops[0]->Type(), "fill_constant");
ASSERT_EQ(grad_ops[0]->Inputs().at("X").size(), static_cast<std::size_t>(1)); ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[0]->GetAttr("dtype")),
ASSERT_EQ(grad_ops[0]->Inputs().at("X")[0], "b"); static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(PADDLE_GET_CONST(float, grad_ops[0]->GetAttr("factor")),
static_cast<float>(2.0));
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(), ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1)); static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Type(), "fill_constant"); ASSERT_EQ(grad_ops[1]->Type(), "elementwise_pow");
ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[1]->GetAttr("dtype")), ASSERT_EQ(grad_ops[1]->Inputs().at("X").size(), static_cast<std::size_t>(1));
static_cast<int>(5)); // ProtoDataType::FP32 ASSERT_EQ(grad_ops[1]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[1]->Outputs().at("Out").size(), ASSERT_EQ(grad_ops[1]->Inputs().at("X")[0], "b");
ASSERT_EQ(grad_ops[0]->Outputs().at("Out").size(),
static_cast<std::size_t>(1)); static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[2]->Type(), "elementwise_sub"); ASSERT_EQ(grad_ops[2]->Type(), "fill_constant");
ASSERT_EQ(grad_ops[2]->Inputs().at("X").size(), static_cast<std::size_t>(1)); ASSERT_EQ(PADDLE_GET_CONST(int, grad_ops[2]->GetAttr("dtype")),
ASSERT_EQ(grad_ops[2]->Inputs().at("Y").size(), static_cast<std::size_t>(1)); static_cast<int>(5)); // ProtoDataType::FP32
ASSERT_EQ(grad_ops[2]->Inputs().at("X")[0],
grad_ops[1]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(), ASSERT_EQ(grad_ops[2]->Outputs().at("Out").size(),
static_cast<std::size_t>(1)); static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Type(), "elementwise_mul"); ASSERT_EQ(grad_ops[3]->Type(), "elementwise_sub");
ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1)); ASSERT_EQ(grad_ops[3]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1)); ASSERT_EQ(grad_ops[3]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[3]->Inputs().at("Y")[0], ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0],
grad_ops[2]->Outputs().at("Out")[0]); grad_ops[2]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[3]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(), ASSERT_EQ(grad_ops[3]->Outputs().at("Out").size(),
static_cast<std::size_t>(1)); static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Type(), "elementwise_mul");
ASSERT_EQ(grad_ops[4]->Inputs().at("X").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y").size(), static_cast<std::size_t>(1));
ASSERT_EQ(grad_ops[4]->Inputs().at("Y")[0],
grad_ops[3]->Outputs().at("Out")[0]);
ASSERT_EQ(grad_ops[4]->Inputs().at("X")[0], "b@GRAD");
ASSERT_EQ(grad_ops[4]->Outputs().at("Out").size(),
static_cast<std::size_t>(1));
} }
TEST(StaticCompositeGradMaker, TestMutiInputMethod) { TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
...@@ -376,7 +381,7 @@ TEST(StaticPrim, TestFlags) { ...@@ -376,7 +381,7 @@ TEST(StaticPrim, TestFlags) {
USE_OP_ITSELF(fill_constant); USE_OP_ITSELF(fill_constant);
USE_OP_ITSELF(tanh); USE_OP_ITSELF(tanh);
USE_OP_ITSELF(tanh_grad); USE_OP_ITSELF(tanh_grad);
USE_OP_ITSELF(pow);
USE_OP_ITSELF(elementwise_mul); USE_OP_ITSELF(elementwise_mul);
USE_OP_ITSELF(elementwise_sub); USE_OP_ITSELF(elementwise_sub);
USE_OP_ITSELF(elementwise_pow);
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
...@@ -677,12 +677,13 @@ class PADDLE_API Tensor final { ...@@ -677,12 +677,13 @@ class PADDLE_API Tensor final {
Tensor divide(const Scalar& y) const; Tensor divide(const Scalar& y) const;
Tensor multiply(const Scalar& y) const; Tensor multiply(const Scalar& y) const;
Tensor subtract(const Scalar& y) const; Tensor subtract(const Scalar& y) const;
Tensor pow(const Tensor& y) const;
Tensor pow(const Scalar& y) const;
Tensor exp() const; Tensor exp() const;
Tensor floor() const; Tensor floor() const;
Tensor gather_nd(const Tensor& index) const; Tensor gather_nd(const Tensor& index) const;
Tensor log() const; Tensor log() const;
Tensor pow(const Scalar& y) const;
Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const; Tensor roll(const IntArray& shifts, const std::vector<int64_t>& axis) const;
Tensor scatter(const Tensor& index, Tensor scatter(const Tensor& index,
const Tensor& updates, const Tensor& updates,
......
...@@ -29,6 +29,8 @@ inplace_optional_out_type_map = { ...@@ -29,6 +29,8 @@ inplace_optional_out_type_map = {
indent = " " indent = " "
specific_ops_map = {"elementwise_pow": "pow"}
operants_base_include = """// Generated by paddle/phi/api/yaml/generator/tensor_operants_gen.py operants_base_include = """// Generated by paddle/phi/api/yaml/generator/tensor_operants_gen.py
...@@ -68,6 +70,10 @@ class TensorOperantsBase { ...@@ -68,6 +70,10 @@ class TensorOperantsBase {
virtual Tensor multiply(const Scalar& x, const Tensor& y) = 0; virtual Tensor multiply(const Scalar& x, const Tensor& y) = 0;
virtual Tensor subtract(const Scalar& x, const Tensor& y) = 0; virtual Tensor subtract(const Scalar& x, const Tensor& y) = 0;
virtual Tensor pow(const Tensor& x, const Tensor& y) = 0;
virtual Tensor pow(const Tensor& x, const Scalar& y) = 0;
""" """
...@@ -143,6 +149,14 @@ Tensor Tensor::operator-() const { ...@@ -143,6 +149,14 @@ Tensor Tensor::operator-() const {
return scale(-1.0, 0.0, true); return scale(-1.0, 0.0, true);
} }
Tensor Tensor::pow(const Tensor& y) const {
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
}
Tensor Tensor::pow(const Scalar& y) const {
return paddle::OperantsManager::Instance().pow(static_cast<const Tensor &>(*this), y);
}
PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y) { PADDLE_API Tensor operator+(const Scalar& x, const Tensor& y) {
return paddle::OperantsManager::Instance().add(x, y); return paddle::OperantsManager::Instance().add(x, y);
} }
...@@ -211,6 +225,10 @@ class PhiTensorOperants : public TensorOperantsBase { ...@@ -211,6 +225,10 @@ class PhiTensorOperants : public TensorOperantsBase {
Tensor divide(const Scalar& x, const Tensor& y); Tensor divide(const Scalar& x, const Tensor& y);
Tensor pow(const Tensor& x, const Tensor& y);
Tensor pow(const Tensor& x, const Scalar& y);
""" """
...@@ -267,6 +285,14 @@ Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) { ...@@ -267,6 +285,14 @@ Tensor PhiTensorOperants::multiply(const Scalar& x, const Tensor& y) {
Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) { Tensor PhiTensorOperants::divide(const Scalar& x, const Tensor& y) {
return paddle::experimental::divide(paddle::experimental::full_like(y, x), y); return paddle::experimental::divide(paddle::experimental::full_like(y, x), y);
} }
Tensor PhiTensorOperants::pow(const Tensor& x, const Tensor& y) {
return paddle::experimental::elementwise_pow(x, y);
}
Tensor PhiTensorOperants::pow(const Tensor& x, const Scalar& y) {
return paddle::experimental::elementwise_pow(x, paddle::experimental::full_like(x, y));
}
""" """
...@@ -359,6 +385,10 @@ class OperantsManager { ...@@ -359,6 +385,10 @@ class OperantsManager {
Tensor divide(const Scalar& x, const Tensor& y); Tensor divide(const Scalar& x, const Tensor& y);
Tensor pow(const Tensor& x, const Tensor& y);
Tensor pow(const Tensor& x, const Scalar& y);
""" """
...@@ -512,8 +542,10 @@ class OperantsAPI(ForwardAPI): ...@@ -512,8 +542,10 @@ class OperantsAPI(ForwardAPI):
""" """
def gene_operants_manager_code(self): def gene_operants_manager_code(self, is_specific_op=False):
func_name = self.get_api_func_name() func_name = self.get_api_func_name()
if is_specific_op:
func_name = specific_ops_map[func_name]
func_args = self.inputs['names'] + self.attrs['names'] func_args = self.inputs['names'] + self.attrs['names']
func_args_code = ", ".join(func_args) func_args_code = ", ".join(func_args)
return f""" return f"""
...@@ -552,11 +584,19 @@ class OperantsAPI(ForwardAPI): ...@@ -552,11 +584,19 @@ class OperantsAPI(ForwardAPI):
def gene_operants_manager_implementation(self): def gene_operants_manager_implementation(self):
func_name = self.get_api_func_name() func_name = self.get_api_func_name()
final_code = "" final_code = ""
# Codes for arthemetic operants
if func_name in ["add", "subtract", "multiply", "divide"]: if func_name in ["add", "subtract", "multiply", "divide"]:
final_code += f""" final_code += f"""
{self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}} {self.get_return_type()} OperantsManager::{func_name}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code()}}}
{self.get_return_type()} OperantsManager::{func_name}(const Scalar& x, const Tensor& y) {{{self.gene_operants_manager_code()}}} {self.get_return_type()} OperantsManager::{func_name}(const Scalar& x, const Tensor& y) {{{self.gene_operants_manager_code()}}}
"""
# Codes for specific operants
if func_name in specific_ops_map.keys():
final_code += f"""
{self.get_return_type()} OperantsManager::{specific_ops_map[func_name]}(const Tensor& x, const Tensor& y) {{{self.gene_operants_manager_code(is_specific_op=True)}}}
{self.get_return_type()} OperantsManager::{specific_ops_map[func_name]}(const Tensor& x, const Scalar& y) {{{self.gene_operants_manager_code(is_specific_op=True)}}}
""" """
# func decalaration # func decalaration
if func_name[-1] != '_': if func_name[-1] != '_':
......
# Attach operants to Tensor, this file should be consistent with the declaration in `tensor.h` # Attach operants to Tensor, this file should be consistent with the declaration in `tensor.h`
- add
- subtract
- multiply
- divide
- unsqueeze - unsqueeze
- pow
- exp - exp
- scale - scale
- multiply
- matmul - matmul
- expand - expand
- divide
- sum - sum
- add
- abs - abs
- assign - assign
- elementwise_pow - elementwise_pow
...@@ -22,4 +22,3 @@ ...@@ -22,4 +22,3 @@
- scatter - scatter
- scatter_nd_add - scatter_nd_add
- tile - tile
- subtract
...@@ -42,7 +42,8 @@ from paddle.fluid import core, framework ...@@ -42,7 +42,8 @@ from paddle.fluid import core, framework
set(), set(),
tuple(), tuple(),
( (
'pow', 'fill_constant',
'elementwise_pow',
'fill_constant', 'fill_constant',
'elementwise_sub', 'elementwise_sub',
'elementwise_mul', 'elementwise_mul',
......
...@@ -28,7 +28,8 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase): ...@@ -28,7 +28,8 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
self.grad_sub_block = tuple() self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad' self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = ( self.desired_ops_no_skip = (
'pow', 'fill_constant',
'elementwise_pow',
'fill_constant', 'fill_constant',
'elementwise_sub', 'elementwise_sub',
'elementwise_mul', 'elementwise_mul',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册