Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4ed6f3bc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4ed6f3bc
编写于
9月 01, 2022
作者:
X
Xiaoxu Chen
提交者:
GitHub
9月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gelu and erf primitive operators for new autograd (#45338)
* add erf_p primitive operators * add gelu orig2prim rule
上级
1a0ef45e
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
313 addition
and
21 deletion
+313
-21
paddle/fluid/operators/prim_ops/CMakeLists.txt
paddle/fluid/operators/prim_ops/CMakeLists.txt
+5
-1
paddle/fluid/operators/prim_ops/erf_p_op.cc
paddle/fluid/operators/prim_ops/erf_p_op.cc
+78
-0
paddle/fluid/operators/prim_ops/prim_op_test.cc
paddle/fluid/operators/prim_ops/prim_op_test.cc
+20
-0
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
.../fluid/tests/unittests/autograd/test_jvp_and_transpose.py
+36
-0
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
...n/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
+65
-0
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
...n/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
+20
-0
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
+38
-17
python/paddle/fluid/tests/unittests/autograd/test_primops.py
python/paddle/fluid/tests/unittests/autograd/test_primops.py
+1
-0
python/paddle/incubate/autograd/primops.py
python/paddle/incubate/autograd/primops.py
+5
-0
python/paddle/incubate/autograd/primrules.py
python/paddle/incubate/autograd/primrules.py
+45
-3
未找到文件。
paddle/fluid/operators/prim_ops/CMakeLists.txt
浏览文件 @
4ed6f3bc
...
...
@@ -22,13 +22,17 @@ set(PRIM_OP_SRCS
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
sin_p_op.cc
cos_p_op.cc
exp_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc
log_p_op.cc
select_p_op.cc
eq_p_op.cc
pow_p_op.cc
max_p_op.cc
)
max_p_op.cc
erf_p_op.cc
)
cc_test
(
prim_op_test
...
...
paddle/fluid/operators/prim_ops/erf_p_op.cc
0 → 100644
浏览文件 @
4ed6f3bc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
framework
{
class
InferShapeContext
;
class
VarDesc
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
ErfPrimOp
:
public
framework
::
OperatorBase
{
public:
ErfPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator erf_p should not be excuted directly"
));
}
};
class
ErfPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of erf_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of erf_p op."
);
AddComment
(
R"DOC(Autograd primitive erf_p operator.)DOC"
);
}
};
class
ErfPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
ErfPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
erf_p
,
paddle
::
operators
::
ErfPrimOp
,
paddle
::
operators
::
ErfPrimOpMaker
,
paddle
::
operators
::
ErfPrimOpShapeInference
,
paddle
::
operators
::
ErfPrimOpVarTypeInference
);
paddle/fluid/operators/prim_ops/prim_op_test.cc
浏览文件 @
4ed6f3bc
...
...
@@ -39,6 +39,7 @@ USE_OP_ITSELF(select_p);
USE_OP_ITSELF
(
eq_p
);
USE_OP_ITSELF
(
pow_p
);
USE_OP_ITSELF
(
max_p
);
USE_OP_ITSELF
(
erf_p
);
namespace
paddle
{
namespace
framework
{
...
...
@@ -710,5 +711,24 @@ TEST(PrimOp, max_p) {
ASSERT_EQ
(
shapes
[
2
],
4L
);
}
TEST
(
PrimOp
,
erf_p
)
{
ProgramDesc
program
;
auto
*
block
=
program
.
MutableBlock
(
0
);
std
::
vector
<
int64_t
>
shape
{
3
,
4
,
5
};
std
::
string
x0
=
"x0"
;
std
::
string
x1
=
"x1"
;
NewVar
(
block
,
x0
,
shape
);
AppendOp
(
block
,
"erf_p"
,
{{
"X"
,
{
x0
}}},
{{
"Y"
,
{
x1
}}},
{});
ASSERT_EQ
(
block
->
Var
(
"x1"
)
->
GetType
(),
proto
::
VarType
::
LOD_TENSOR
);
ASSERT_EQ
(
block
->
Var
(
"x1"
)
->
GetDataType
(),
proto
::
VarType_Type_FP32
);
auto
shapes
=
block
->
Var
(
"x1"
)
->
GetShape
();
ASSERT_EQ
(
shapes
.
size
(),
3UL
);
ASSERT_EQ
(
shapes
[
0
],
3L
);
ASSERT_EQ
(
shapes
[
1
],
4L
);
ASSERT_EQ
(
shapes
[
2
],
5L
);
}
}
// namespace framework
}
// namespace paddle
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
浏览文件 @
4ed6f3bc
...
...
@@ -364,6 +364,42 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class
TestErfPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'erf_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'erf_p'
,
# jvp op:
'exp_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'mul_p'
,
'mul_p'
,
'pow_p'
,
'sub_p'
,
# transpose op:
]
class
TestLogPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
浏览文件 @
4ed6f3bc
...
...
@@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim):
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestErfOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'erf'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'erf'
,
'erf_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestLogOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
...
...
@@ -559,5 +579,50 @@ class TestMaxOrig2Prim(TestElementWiseAddOrig2Prim):
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestGeluOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'gelu'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
8
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{
'approximate'
:
False
}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'gelu'
,
'add_p'
,
'erf_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'mul_p'
,
'mul_p'
,
'mul_p'
]
# { prim_op_output_index: orig_op_output_var }
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestGeluApproximateOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'gelu'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
8
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{
'approximate'
:
True
}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'add_p'
,
'add_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'fill_constant_p'
,
'gelu'
,
'mul_p'
,
'mul_p'
,
'mul_p'
,
'mul_p'
,
'pow_p'
,
'tanh_p'
]
# { prim_op_output_index: orig_op_output_var }
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
浏览文件 @
4ed6f3bc
...
...
@@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig):
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestErfPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'erf_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'erf_p'
,
'erf'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestLogPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
浏览文件 @
4ed6f3bc
...
...
@@ -16,11 +16,11 @@ import typing
import
unittest
import
numpy
as
np
import
autograd
import
autograd.numpy
as
np_autograd
import
paddle
import
autograd
import
autograd.numpy
as
anp
import
autograd.scipy
as
ascipy
import
config
import
utils
...
...
@@ -278,6 +278,11 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
np
.
array
([
1
,
2
,
3
]),
np
.
array
([
2
,
2
,
2
]),
),
None
,
'float32'
),
(
'erf'
,
paddle
.
erf
,
(
np
.
random
.
rand
(
300
,
288
),
),
None
,
'float32'
),
(
'gelu'
,
paddle
.
nn
.
functional
.
gelu
,
(
np
.
random
.
rand
(
200
,
189
),
),
None
,
'float32'
),
(
'gelu_approximate'
,
lambda
x
:
paddle
.
nn
.
functional
.
gelu
(
x
,
True
),
(
np
.
random
.
rand
(
200
,
189
),
),
None
,
'float32'
),
))
class
TestGrad
(
unittest
.
TestCase
):
...
...
@@ -397,17 +402,27 @@ def multiply_pd(x):
multiply_ag
=
lambda
xs
:
xs
[
0
]
*
xs
[
0
]
*
xs
[
0
]
*
xs
[
0
]
*
xs
[
0
]
sin_ag
=
lambda
xs
:
np_autograd
.
sin
(
xs
[
0
])
cos_ag
=
lambda
xs
:
np_autograd
.
cos
(
xs
[
0
])
exp_ag
=
lambda
xs
:
np_autograd
.
exp
(
xs
[
0
])
sin_ag
=
lambda
xs
:
anp
.
sin
(
xs
[
0
])
cos_ag
=
lambda
xs
:
anp
.
cos
(
xs
[
0
])
exp_ag
=
lambda
xs
:
anp
.
exp
(
xs
[
0
])
pow_ag
=
lambda
xs
:
xs
[
0
]
**
xs
[
1
]
log_ag
=
lambda
xs
:
np_autograd
.
log
(
xs
[
0
])
log_ag
=
lambda
xs
:
anp
.
log
(
xs
[
0
])
erf_ag
=
lambda
xs
:
ascipy
.
special
.
erf
(
xs
[
0
])
def
gelu_ag
(
x
,
approximate
=
False
):
if
approximate
:
sqrt_2_over_pi
=
np
.
sqrt
(
2
/
np
.
pi
).
astype
(
x
.
dtype
)
cdf
=
0.5
*
(
1.0
+
anp
.
tanh
(
sqrt_2_over_pi
*
(
x
+
0.044715
*
(
x
**
3
))))
return
x
*
cdf
else
:
return
x
*
(
ascipy
.
special
.
erf
(
x
/
np
.
sqrt
(
2
))
+
1
)
/
2
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'fun_pd'
,
'fun_ag'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'multiply'
,
multiply_pd
,
multiply_ag
,
(
utils
.
TEST_CASE_NAME
,
'fun_pd'
,
'fun_ag'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'multiply'
,
multiply_pd
,
multiply_ag
,
(
np
.
random
.
rand
(
3
,
5
),
),
None
,
'float32'
),
(
'sin'
,
paddle
.
sin
,
sin_ag
,
(
np
.
random
.
rand
(
2
,
3
),
),
None
,
'float32'
),
(
'cos'
,
paddle
.
cos
,
cos_ag
,
(
np
.
random
.
rand
(
3
,
4
),
),
None
,
'float32'
),
...
...
@@ -415,7 +430,13 @@ log_ag = lambda xs: np_autograd.log(xs[0])
(
'pow'
,
paddle
.
pow
,
pow_ag
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
),
(
'log'
,
paddle
.
log
,
log_ag
,
(
np
.
random
.
rand
(
3
,
8
),
),
None
,
'float32'
),
))
(
'erf'
,
paddle
.
erf
,
erf_ag
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
(
'gelu'
,
paddle
.
nn
.
functional
.
gelu
,
lambda
xs
:
gelu_ag
(
xs
[
0
]),
(
np
.
random
.
rand
(
10
,
20
,
30
),
),
None
,
'float32'
),
(
'gelu_approximate'
,
lambda
x
:
paddle
.
nn
.
functional
.
gelu
(
x
,
approximate
=
True
),
lambda
xs
:
gelu_ag
(
xs
[
0
],
approximate
=
True
),
(
np
.
random
.
rand
(
10
,
20
,
30
),
),
None
,
'float32'
)))
class
TestGradWithHigherOrder
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_primops.py
浏览文件 @
4ed6f3bc
...
...
@@ -41,6 +41,7 @@ paddle.enable_static()
(
'sin'
,
primops
.
sin
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'cos'
,
primops
.
cos
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'exp'
,
primops
.
exp
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'erf'
,
primops
.
erf
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'log'
,
primops
.
log
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'reshape'
,
primops
.
reshape
,
randn
(
2
,
3
),
{
'shape'
:
(
3
,
2
)
...
...
python/paddle/incubate/autograd/primops.py
浏览文件 @
4ed6f3bc
...
...
@@ -355,3 +355,8 @@ def pow(x, y, out=None):
@
REGISTER_FN
(
'max_p'
,
'X'
,
'Y'
,
'Z'
)
def
max
(
x
,
y
,
out
=
None
):
return
_simple_binop
(
LayerHelper
(
'max_p'
,
**
locals
()))
@
REGISTER_FN
(
'erf_p'
,
'X'
,
'Y'
)
def
erf
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'erf_p'
,
**
locals
()))
python/paddle/incubate/autograd/primrules.py
浏览文件 @
4ed6f3bc
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
typing
import
math
import
paddle
...
...
@@ -19,7 +20,7 @@ from . import primops
from
.primops
import
(
add
,
broadcast
,
concat
,
cos
,
div
,
exp
,
fill_const
,
gather
,
matmul
,
mul
,
neg
,
reduce
,
reshape
,
scatter_add
,
set_value
,
sin
,
slice_assign
,
slice_select
,
split
,
sqrt
,
sub
,
tanh
,
transpose
,
log
,
select
,
eq
,
max
)
transpose
,
log
,
select
,
eq
,
max
,
erf
)
from
.primreg
import
(
REGISTER_JVP
,
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_TRANSPOSE
,
lookup_fn
,
lookup_jvp
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_transpose
,
...
...
@@ -171,6 +172,11 @@ def exp_orig2prim(op, x):
return
exp
(
x
)
@
REGISTER_ORIG2PRIM
(
'erf'
)
def
erf_orig2prim
(
op
,
x
):
return
erf
(
x
)
@
REGISTER_ORIG2PRIM
(
'log'
)
def
log_orig2prim
(
op
,
x
):
return
log
(
x
)
...
...
@@ -321,13 +327,34 @@ def elementwise_pow_orig2prim(op, x, y):
def
elementwise_max_orig2prim
(
op
,
x
,
y
):
if
x
.
shape
!=
y
.
shape
:
y
=
broadcast
(
y
,
shape
=
x
.
shape
)
return
primops
.
max
(
x
,
y
)
## Register prim2orig lower rules
@
REGISTER_ORIG2PRIM
(
'gelu'
)
def
gelu_orig2prim
(
op
,
x
):
if
op
.
attr
(
'approximate'
):
cdf
=
mul
(
fill_const
(
0.5
,
x
.
shape
,
x
.
dtype
),
add
(
fill_const
(
1.0
,
x
.
shape
,
x
.
dtype
),
tanh
(
mul
(
fill_const
(
math
.
sqrt
(
2
/
math
.
pi
),
x
.
shape
,
x
.
dtype
),
add
(
x
,
mul
(
fill_const
(
0.044715
,
x
.
shape
,
x
.
dtype
),
primops
.
pow
(
x
,
fill_const
(
3.
,
x
.
shape
,
x
.
dtype
))))))))
return
mul
(
x
,
cdf
)
else
:
return
mul
(
mul
(
fill_const
(
0.5
,
x
.
shape
,
x
.
dtype
),
x
),
add
(
fill_const
(
1.0
,
x
.
shape
,
x
.
dtype
),
erf
(
mul
(
x
,
fill_const
(
1
/
math
.
sqrt
(
2.
),
x
.
shape
,
x
.
dtype
)))))
## Register prim2orig lower rules
@
REGISTER_PRIM2ORIG
(
'add_p'
)
def
add_prim2orig
(
op
,
x
,
y
):
return
paddle
.
add
(
x
,
y
)
...
...
@@ -373,6 +400,11 @@ def exp_prim2orig(op, x):
return
paddle
.
exp
(
x
)
@
REGISTER_PRIM2ORIG
(
'erf_p'
)
def
erf_prim2orig
(
op
,
x
):
return
paddle
.
erf
(
x
)
@
REGISTER_PRIM2ORIG
(
'log_p'
)
def
log_prim2orig
(
op
,
x
):
return
paddle
.
log
(
x
)
...
...
@@ -574,6 +606,16 @@ def exp_jvp(op, x_dot):
return
mul
(
x_dot
,
y
)
@
REGISTER_JVP
(
'erf_p'
)
def
erf_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
x
,
=
op_position_inputs
(
op
)
return
mul
(
fill_const
(
2.
/
math
.
sqrt
(
math
.
pi
),
x
.
shape
,
x
.
dtype
),
mul
(
x_dot
,
exp
(
neg
(
primops
.
pow
(
x
,
fill_const
(
2.
,
x
.
shape
,
x
.
dtype
))))))
@
REGISTER_JVP
(
'log_p'
)
def
log_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录