Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4c438d30
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4c438d30
编写于
9月 26, 2022
作者:
J
Jiabin Yang
提交者:
GitHub
9月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support rsqrt_p (#46369)
* support rsqrt_p * refine code and ut * add_prim_rsqrt * fix ut
上级
9a291685
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
204 addition
and
7 deletion
+204
-7
paddle/fluid/operators/prim_ops/CMakeLists.txt
paddle/fluid/operators/prim_ops/CMakeLists.txt
+2
-1
paddle/fluid/operators/prim_ops/rsqrt_p_op.cc
paddle/fluid/operators/prim_ops/rsqrt_p_op.cc
+82
-0
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
.../fluid/tests/unittests/autograd/test_jvp_and_transpose.py
+33
-0
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
...n/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
+21
-0
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
...n/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
+20
-0
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
+2
-0
python/paddle/fluid/tests/unittests/autograd/test_transform.py
...n/paddle/fluid/tests/unittests/autograd/test_transform.py
+17
-5
python/paddle/incubate/autograd/primops.py
python/paddle/incubate/autograd/primops.py
+5
-0
python/paddle/incubate/autograd/primrules.py
python/paddle/incubate/autograd/primrules.py
+22
-1
未找到文件。
paddle/fluid/operators/prim_ops/CMakeLists.txt
浏览文件 @
4c438d30
...
...
@@ -37,7 +37,8 @@ set(PRIM_OP_SRCS
max_p_op.cc
erf_p_op.cc
abs_p_op.cc
cast_p_op.cc
)
cast_p_op.cc
rsqrt_p_op.cc
)
cc_test
(
prim_op_test
...
...
paddle/fluid/operators/prim_ops/rsqrt_p_op.cc
0 → 100644
浏览文件 @
4c438d30
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
framework
{
class
InferShapeContext
;
class
VarDesc
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
RsqrtPrimOp
:
public
framework
::
OperatorBase
{
public:
RsqrtPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator rsqrt_p should not be excuted directly"
));
}
};
class
RsqrtPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of rsqrt_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of rsqrt_p op."
);
AddComment
(
R"DOC(
Autograd primitive rsqrt_p operator.
)DOC"
);
}
};
class
RsqrtPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
RsqrtPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
rsqrt_p
,
paddle
::
operators
::
RsqrtPrimOp
,
paddle
::
operators
::
RsqrtPrimOpMaker
,
paddle
::
operators
::
RsqrtPrimOpShapeInference
,
paddle
::
operators
::
RsqrtPrimOpVarTypeInference
);
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
浏览文件 @
4c438d30
...
...
@@ -241,6 +241,39 @@ class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class
TestRSqrtPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'rsqrt_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'rsqrt_p'
,
# jvp op:
'div_p'
,
'div_p'
,
'mul_p'
,
'fill_constant_p'
,
# 'sqrt_p',
# transpose op:
]
class
TestTanhPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
浏览文件 @
4c438d30
...
...
@@ -879,5 +879,26 @@ class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim):
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestRSqrtOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'rsqrt'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'rsqrt'
,
'rsqrt_p'
]
# { prim_op_output_index: orig_op_output_var }
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
浏览文件 @
4c438d30
...
...
@@ -690,5 +690,25 @@ class TestCastPPrim2Orig(TestAddPPrim2Orig):
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestRsqrtPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'rsqrt_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'rsqrt_p'
,
'rsqrt'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
浏览文件 @
4c438d30
...
...
@@ -152,6 +152,7 @@ class TestWithoutProgramGuard(unittest.TestCase):
(
'log'
,
paddle
.
log
,
(
np
.
random
.
rand
(
3
,
4
),
),
None
,
'float32'
),
(
'abs'
,
paddle
.
abs
,
(
np
.
random
.
uniform
(
-
10
,
10
,
(
10
,
10
)),
),
None
,
'float32'
),
(
'rsqrt'
,
paddle
.
rsqrt
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
))
# paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick
...
...
@@ -267,6 +268,7 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
(
np
.
random
.
rand
(
3
,
3
),
),
'float64'
),
(
'sin'
,
paddle
.
sin
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
(
'rsqrt'
,
paddle
.
rsqrt
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
(
'cos'
,
paddle
.
cos
,
(
np
.
random
.
rand
(
200
,
90
),
),
None
,
'float32'
),
(
'exp'
,
paddle
.
exp
,
(
np
.
random
.
rand
(
299
,
320
),
),
None
,
'float32'
),
# In where op, grad of condition computed by paddle.static.gradients is None,
...
...
python/paddle/fluid/tests/unittests/autograd/test_transform.py
浏览文件 @
4c438d30
...
...
@@ -48,15 +48,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
A
=
paddle
.
tanh
(
X0
)
B
=
paddle
.
tanh
(
X1
)
Y
=
paddle
.
add
(
A
,
B
)
C
=
paddle
.
rsqrt
(
B
)
Y
=
paddle
.
add
(
A
,
C
)
self
.
orig_xs
=
[
X0
,
X1
]
self
.
orig_ys
=
[
Y
,
]
self
.
orig_ops
=
[
'tanh'
,
'tanh'
,
'elementwise_add'
]
self
.
orig2prim_ops
=
[
'tanh_p'
,
'tanh_p'
,
'add_p'
]
self
.
orig_ops
=
[
'tanh'
,
'tanh'
,
'elementwise_add'
,
'rsqrt'
]
self
.
orig2prim_ops
=
[
'tanh_p'
,
'tanh_p'
,
'add_p'
,
'rsqrt_p'
]
self
.
linearize_ops
=
self
.
orig2prim_ops
+
[
# call fill_const() in linearize() function
'fill_constant_p'
,
...
...
@@ -71,6 +72,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'fill_constant_p'
,
'mul_p'
,
'add_p'
,
'fill_constant_p'
,
'div_p'
,
'div_p'
,
'mul_p'
,
]
self
.
transpose_ops
=
self
.
orig2prim_ops
+
[
# call fill_const() in transpose() function
...
...
@@ -84,6 +89,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p'
,
'sub_p'
,
'fill_constant_p'
,
'mul_p'
,
'div_p'
,
'div_p'
,
'fill_constant_p'
,
# transposed op
'mul_p'
,
'mul_p'
...
...
@@ -92,13 +101,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'tanh'
,
'tanh'
,
'add_p'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_mul'
'elementwise_mul'
,
'rsqrt'
,
'fill_constant'
,
'elementwise_div'
,
'elementwise_div'
,
'elementwise_mul'
]
self
.
prim2orig_ops
=
[
'tanh'
,
'tanh'
,
'elementwise_add'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_sub'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_sub'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_mul'
'fill_constant'
,
'elementwise_mul'
,
'elementwise_mul'
,
'rsqrt'
,
'fill_constant'
,
'elementwise_div'
,
'elementwise_div'
,
'elementwise_mul'
]
def
test_run
(
self
):
...
...
python/paddle/incubate/autograd/primops.py
浏览文件 @
4c438d30
...
...
@@ -394,3 +394,8 @@ def cast(x, dtype, out=None):
outputs
=
{
'Y'
:
out
},
attrs
=
{
'dtype'
:
dtype
})
return
out
@
REGISTER_FN
(
'rsqrt_p'
,
'X'
,
'Y'
)
def
rsqrt
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'rsqrt_p'
,
**
locals
()))
python/paddle/incubate/autograd/primrules.py
浏览文件 @
4c438d30
...
...
@@ -23,7 +23,7 @@ from .primops import (add, broadcast, concat, cos, div, eq, erf, exp,
fill_const
,
gather
,
ge
,
gt
,
log
,
matmul
,
max
,
mul
,
ne
,
neg
,
reduce_sum
,
reshape
,
scatter_add
,
select
,
set_value
,
sin
,
slice_assign
,
slice_select
,
split
,
sqrt
,
sub
,
tanh
,
transpose
)
transpose
,
rsqrt
)
from
.primreg
import
(
REGISTER_JVP
,
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_TRANSPOSE
,
lookup_fn
,
lookup_jvp
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_transpose
,
...
...
@@ -252,6 +252,11 @@ def sqrt_orig2prim(op, x):
return
sqrt
(
x
)
@
REGISTER_ORIG2PRIM
(
'rsqrt'
)
def
rsqrt_orig2prim
(
op
,
x
):
return
rsqrt
(
x
)
@
REGISTER_ORIG2PRIM
(
'matmul_v2'
)
def
matmul_v2_orig2prim
(
op
,
x
,
y
):
...
...
@@ -456,6 +461,11 @@ def sub_prim2orig(op, x, y):
return
paddle
.
subtract
(
x
,
y
)
@
REGISTER_PRIM2ORIG
(
'rsqrt_p'
)
def
rsqrt_prim2orig
(
op
,
x
):
return
paddle
.
rsqrt
(
x
)
@
REGISTER_PRIM2ORIG
(
'mul_p'
)
def
mul_prim2orig
(
op
,
x
,
y
):
return
paddle
.
multiply
(
x
,
y
)
...
...
@@ -969,6 +979,17 @@ def cast_jvp(op, x_dot):
return
primops
.
cast
(
x_dot
,
y
.
dtype
)
@
REGISTER_JVP
(
'rsqrt_p'
)
def
rsqrt_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
y
=
op_position_output
(
op
)
x
=
op_position_inputs
(
op
)
c2
=
fill_const
(
value
=-
2.0
,
shape
=
y
.
shape
,
dtype
=
y
.
dtype
)
y_dot
=
mul
(
x_dot
,
div
(
div
(
y
,
x
),
c2
))
return
y_dot
## Register transpose rules
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录