Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
22342d51
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
22342d51
编写于
7月 26, 2022
作者:
X
Xiaoxu Chen
提交者:
GitHub
7月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sin,cos,exp primitive operators (#44345)
上级
0d51fcf1
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
765 addition
and
205 deletion
+765
-205
paddle/fluid/operators/prim_ops/cos_p_op.cc
paddle/fluid/operators/prim_ops/cos_p_op.cc
+71
-0
paddle/fluid/operators/prim_ops/exp_p_op.cc
paddle/fluid/operators/prim_ops/exp_p_op.cc
+71
-0
paddle/fluid/operators/prim_ops/sin_p_op.cc
paddle/fluid/operators/prim_ops/sin_p_op.cc
+71
-0
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
.../fluid/tests/unittests/autograd/test_jvp_and_transpose.py
+91
-0
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
...n/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
+60
-0
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
...n/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
+60
-0
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
+146
-55
python/paddle/fluid/tests/unittests/autograd/test_primops.py
python/paddle/fluid/tests/unittests/autograd/test_primops.py
+117
-143
python/paddle/incubate/autograd/primops.py
python/paddle/incubate/autograd/primops.py
+15
-0
python/paddle/incubate/autograd/primrules.py
python/paddle/incubate/autograd/primrules.py
+63
-7
未找到文件。
paddle/fluid/operators/prim_ops/cos_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
CosPrimOp
:
public
framework
::
OperatorBase
{
public:
CosPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator cos_p should not be excuted directly"
));
}
};
class
CosPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of cos_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of cos_p op."
);
AddComment
(
R"DOC(Autograd primitive cos_p operator.)DOC"
);
}
};
class
CosPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
CosPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
cos_p
,
paddle
::
operators
::
CosPrimOp
,
paddle
::
operators
::
CosPrimOpMaker
,
paddle
::
operators
::
CosPrimOpShapeInference
,
paddle
::
operators
::
CosPrimOpVarTypeInference
);
paddle/fluid/operators/prim_ops/exp_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
ExpPrimOp
:
public
framework
::
OperatorBase
{
public:
ExpPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator exp_p should not be excuted directly"
));
}
};
class
ExpPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of exp_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of exp_p op."
);
AddComment
(
R"DOC(Autograd primitive exp_p operator.)DOC"
);
}
};
class
ExpPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
ExpPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
exp_p
,
paddle
::
operators
::
ExpPrimOp
,
paddle
::
operators
::
ExpPrimOpMaker
,
paddle
::
operators
::
ExpPrimOpShapeInference
,
paddle
::
operators
::
ExpPrimOpVarTypeInference
);
paddle/fluid/operators/prim_ops/sin_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
SinPrimOp
:
public
framework
::
OperatorBase
{
public:
SinPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator sin_p should not be excuted directly"
));
}
};
class
SinPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of sin_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of sin_p op."
);
AddComment
(
R"DOC(Autograd primitive sin_p operator.)DOC"
);
}
};
class
SinPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
SinPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
sin_p
,
paddle
::
operators
::
SinPrimOp
,
paddle
::
operators
::
SinPrimOpMaker
,
paddle
::
operators
::
SinPrimOpShapeInference
,
paddle
::
operators
::
SinPrimOpVarTypeInference
);
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
浏览文件 @
22342d51
...
...
@@ -273,6 +273,97 @@ class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class
TestSinPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'sin_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'sin_p'
,
# jvp op:
'mul_p'
,
'cos_p'
,
# transpose op:
]
class
TestCosPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'cos_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'cos_p'
,
# jvp op:
'mul_p'
,
'sin_p'
,
'fill_constant_p'
,
'sub_p'
# transpose op:
]
class
TestExpPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'exp_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'exp_p'
,
# jvp op:
'mul_p'
,
# transpose op:
]
class
TestReshapePJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
浏览文件 @
22342d51
...
...
@@ -148,6 +148,66 @@ class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim):
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestSinOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'sin'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'sin'
,
'sin_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestCosOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'cos'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'cos'
,
'cos_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestExpOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'exp'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'exp'
,
'exp_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestReshape2Orig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
浏览文件 @
22342d51
...
...
@@ -164,6 +164,66 @@ class TestTanhPPrim2Orig(TestAddPPrim2Orig):
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestSinPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'sin_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'sin_p'
,
'sin'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestCosPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'cos_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'cos_p'
,
'cos'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestExpPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'exp_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'exp_p'
,
'exp'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestReshapePPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
浏览文件 @
22342d51
...
...
@@ -17,7 +17,6 @@ import unittest
import
numpy
as
np
import
paddle
from
paddle.incubate.autograd
import
primapi
import
config
import
utils
...
...
@@ -135,19 +134,19 @@ class TestWithoutProgramGuard(unittest.TestCase):
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
)
,
((
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
)
,
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
)
,
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
)
,
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
)
,
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)
),
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
'float64'
)
))
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
)
,
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
)
,
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
)
,
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
)
,
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
))
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
'float64'
),
))
class
TestForwardGrad
(
unittest
.
TestCase
):
@
classmethod
...
...
@@ -219,7 +218,8 @@ class TestForwardGrad(unittest.TestCase):
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
primapi
.
forward_grad
(
ys
,
static_xs
,
static_v
)
ys_grad
=
paddle
.
incubate
.
autograd
.
forward_grad
(
ys
,
static_xs
,
static_v
)
paddle
.
incubate
.
autograd
.
prim2orig
(
mp
.
block
(
0
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
...
...
@@ -229,15 +229,144 @@ class TestForwardGrad(unittest.TestCase):
def
test_illegal_param
(
self
):
paddle
.
incubate
.
autograd
.
enable_prim
()
with
self
.
assertRaises
(
TypeError
):
primapi
.
forward_grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
paddle
.
incubate
.
autograd
.
forward_grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
with
self
.
assertRaises
(
TypeError
):
primapi
.
forward_grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
forward_grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
disable_prim
()
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
((
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
),
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
),
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
),
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
),
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
(
np
.
random
.
rand
(
3
,
3
),
),
'float64'
),
(
'sin'
,
paddle
.
sin
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
(
'cos'
,
paddle
.
cos
,
(
np
.
random
.
rand
(
200
,
90
),
),
None
,
'float32'
),
(
'exp'
,
paddle
.
exp
,
(
np
.
random
.
rand
(
299
,
320
),
),
None
,
'float32'
),
))
class
TestGrad
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
incubate
.
autograd
.
enable_prim
()
def
tearDown
(
self
):
paddle
.
incubate
.
autograd
.
disable_prim
()
paddle
.
disable_static
()
@
classmethod
def
setUpClass
(
cls
):
cls
.
xs
=
tuple
(
x
.
astype
(
cls
.
dtype
)
for
x
in
cls
.
xs
)
cls
.
_rtol
=
config
.
TOLERANCE
.
get
(
str
(
cls
.
dtype
)).
get
(
"first_order_grad"
).
get
(
"rtol"
)
cls
.
_atol
=
config
.
TOLERANCE
.
get
(
str
(
cls
.
dtype
)).
get
(
"first_order_grad"
).
get
(
"atol"
)
def
test_grad
(
self
):
def
expected
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
_
,
ys_grad
=
paddle
.
incubate
.
autograd
.
vjp
(
self
.
fun
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
def
actual
():
paddle
.
incubate
.
autograd
.
enable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
incubate
.
autograd
.
grad
(
ys
,
static_xs
,
static_v
)
paddle
.
incubate
.
autograd
.
prim2orig
(
mp
.
block
(
0
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
disable_prim
()
return
out
actual
=
actual
()
expected
=
expected
()
self
.
assertEqual
(
type
(
actual
),
type
(
expected
))
for
i
,
j
in
zip
(
actual
,
expected
):
np
.
testing
.
assert_allclose
(
i
,
j
,
rtol
=
self
.
_rtol
,
atol
=
self
.
_atol
)
def
test_illegal_param
(
self
):
paddle
.
incubate
.
autograd
.
enable_prim
()
with
self
.
assertRaises
(
TypeError
):
paddle
.
incubate
.
autograd
.
grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
with
self
.
assertRaises
(
TypeError
):
paddle
.
incubate
.
autograd
.
grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
disable_prim
()
def
test_disable_prim
(
self
):
def
expected
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
incubate
.
autograd
.
grad
(
ys
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
def
actual
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
static
.
gradients
(
ys
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
actual
=
actual
()
expected
=
expected
()
self
.
assertEqual
(
type
(
actual
),
type
(
expected
))
for
i
,
j
in
zip
(
actual
,
expected
):
np
.
testing
.
assert_allclose
(
i
,
j
,
rtol
=
self
.
_rtol
,
atol
=
self
.
_atol
)
class
TestGradWithHigherOrder
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
incubate
.
autograd
.
enable_prim
()
...
...
@@ -346,44 +475,6 @@ class TestGrad(unittest.TestCase):
np
.
testing
.
assert_allclose
(
outs
,
result
,
rtol
=
1e-5
,
atol
=
1e-5
)
paddle
.
incubate
.
autograd
.
disable_prim
()
def
test_disable_prim
(
self
):
def
actual
(
x
:
np
.
array
):
paddle
.
incubate
.
autograd
.
disable_prim
()
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
var_x
=
paddle
.
static
.
data
(
'x'
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
var_x
.
stop_gradient
=
False
y
=
paddle
.
tanh
(
var_x
)
y_grad
=
paddle
.
incubate
.
autograd
.
grad
(
y
,
var_x
)
y_second_grad
=
paddle
.
incubate
.
autograd
.
grad
(
y_grad
,
var_x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup
)
return
exe
.
run
(
main
,
feed
=
{
'x'
:
x
},
fetch_list
=
[
y_grad
,
y_second_grad
])
def
expect
(
x
:
np
.
array
):
paddle
.
incubate
.
autograd
.
disable_prim
()
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
var_x
=
paddle
.
static
.
data
(
'x'
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
var_x
.
stop_gradient
=
False
y
=
paddle
.
tanh
(
var_x
)
y_grad
=
paddle
.
static
.
gradients
(
y
,
var_x
)
y_second_grad
=
paddle
.
static
.
gradients
(
y_grad
,
var_x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup
)
return
exe
.
run
(
main
,
feed
=
{
'x'
:
x
},
fetch_list
=
[
y_grad
,
y_second_grad
])
x
=
np
.
random
.
randn
(
100
,
200
)
for
i
,
j
in
zip
(
actual
(
x
),
expect
(
x
)):
np
.
testing
.
assert_allclose
(
i
,
j
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/autograd/test_primops.py
浏览文件 @
22342d51
...
...
@@ -11,154 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
uuid
import
numpy
as
np
import
paddle
from
paddle.incubate.autograd.primops
import
(
neg
,
set_value
,
add
,
sub
,
mul
,
div
,
sqrt
,
tanh
,
reshape
,
broadcast
,
transpose
,
split
,
concat
,
reduce
,
matmul
,
slice_select
,
slice_assign
,
gather
,
scatter_add
,
fill_const
)
from
paddle.incubate.autograd.primx
import
Transform
,
topo_path
,
orig2prim
,
prim2orig
from
paddle.incubate.autograd.utils
import
enable_prim
,
disable_prim
,
prim_enabled
class
TestPyPrimOps
(
unittest
.
TestCase
):
""" Test Python wrappers of primitive ops. """
def
setUp
(
self
):
from
numpy.random
import
randint
,
randn
from
paddle.incubate.autograd
import
primops
,
primx
from
paddle.incubate.autograd
import
utils
as
prim_utils
import
config
import
utils
paddle
.
enable_static
()
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'op'
,
'args'
,
'kwargs'
,
'expected_shape'
,
'expected_dtype'
),
(
(
'add'
,
primops
.
add
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sub'
,
primops
.
sub
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'mul'
,
primops
.
mul
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'div'
,
primops
.
div
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sub'
,
primops
.
sub
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sqrt'
,
primops
.
sqrt
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'tanh'
,
primops
.
tanh
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'sin'
,
primops
.
sin
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'cos'
,
primops
.
cos
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'exp'
,
primops
.
exp
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'reshape'
,
primops
.
reshape
,
randn
(
2
,
3
),
{
'shape'
:
(
3
,
2
)
},
(
3
,
2
),
'float64'
),
(
'broadcast'
,
primops
.
broadcast
,
randn
(
2
),
{
'shape'
:
(
3
,
2
)
},
(
3
,
2
),
'float64'
),
(
'transpose'
,
primops
.
transpose
,
randn
(
2
,
3
),
{
'axis'
:
(
1
,
0
)
},
(
3
,
2
),
'float64'
),
(
'concat_axis0'
,
primops
.
concat
,
((
randn
(
2
,
3
),
randn
(
2
,
3
)),
),
{
'axis'
:
0
},
(
4
,
3
),
'float64'
),
(
'concat_axis1'
,
primops
.
concat
,
((
randn
(
2
,
3
),
randn
(
2
,
3
)),
),
{
'axis'
:
1
},
(
2
,
6
),
'float64'
),
(
'reduce_axis1'
,
primops
.
reduce
,
randn
(
2
,
3
),
{
'axis'
:
(
1
,
)
},
(
2
,
),
'float64'
),
(
'reduce_axis01'
,
primops
.
reduce
,
randn
(
2
,
3
),
{
'axis'
:
(
0
,
1
)
},
(
1
,
),
'float64'
),
(
'split'
,
primops
.
split
,
randn
(
2
,
3
),
{
'num_or_sections'
:
[
1
,
2
],
'axis'
:
1
},
((
2
,
1
),
(
2
,
2
)),
(
'float64'
,
'float64'
)),
(
'matmul'
,
primops
.
matmul
,
(
randn
(
2
,
3
),
randn
(
3
,
2
)),
{},
(
2
,
2
),
'float64'
),
(
'slice_select'
,
primops
.
slice_select
,
randn
(
3
,
2
),
{
'axis'
:
[
0
],
'starts'
:
[
0
],
'ends'
:
[
2
],
'strides'
:
[
1
]
},
(
2
,
2
),
'float64'
),
(
'slice_assign'
,
primops
.
slice_assign
,
(
randn
(
2
,
3
),
randn
(
2
,
2
)),
{
'axis'
:
[
1
],
'starts'
:
[
1
],
'ends'
:
[
3
],
'strides'
:
[
1
]
},
(
2
,
3
),
'float64'
),
(
'gather'
,
primops
.
gather
,
(
randn
(
3
,
2
),
randint
(
0
,
2
,
(
5
,
),
np
.
int32
)),
{
'axis'
:
0
},
(
5
,
2
),
'float64'
),
(
'scatter_add'
,
primops
.
scatter_add
,
(
randn
(
3
,
2
),
randn
(
5
,
2
),
randint
(
0
,
2
,
(
5
,
),
np
.
int32
)),
{
'axis'
:
0
},
(
3
,
2
),
'float64'
),
(
'fill_const'
,
primops
.
fill_const
,
(),
{
'value'
:
10
,
'shape'
:
(
3
,
2
),
'dtype'
:
paddle
.
float32
},
(
3
,
2
),
'float32'
),
(
'neg'
,
primops
.
neg
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
))
class
TestPrimops
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
paddle
.
enable_static
()
def
test_ops
(
self
):
A
=
np
.
random
.
rand
(
1
)
B
=
np
.
random
.
rand
(
2
)
C
=
np
.
random
.
rand
(
2
,
3
)
D
=
np
.
random
.
rand
(
2
,
3
)
E
=
np
.
random
.
rand
(
3
,
2
)
a
=
paddle
.
static
.
data
(
name
=
'A'
,
shape
=
A
.
shape
,
dtype
=
'float32'
)
b
=
paddle
.
static
.
data
(
name
=
'B'
,
shape
=
B
.
shape
,
dtype
=
'float32'
)
c
=
paddle
.
static
.
data
(
name
=
'C'
,
shape
=
C
.
shape
,
dtype
=
'float32'
)
d
=
paddle
.
static
.
data
(
name
=
'D'
,
shape
=
D
.
shape
,
dtype
=
'float32'
)
e
=
paddle
.
static
.
data
(
name
=
'E'
,
shape
=
E
.
shape
,
dtype
=
'float32'
)
add_1
=
add
(
a
,
a
)
self
.
assertEqual
(
add_1
.
dtype
,
a
.
dtype
)
self
.
assertEqual
(
add_1
.
shape
,
a
.
shape
)
add_2
=
add
(
c
,
d
)
self
.
assertEqual
(
add_2
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
add_2
.
shape
,
c
.
shape
)
sub_1
=
sub
(
c
,
d
)
self
.
assertEqual
(
sub_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
sub_1
.
shape
,
c
.
shape
)
mul_1
=
mul
(
c
,
d
)
self
.
assertEqual
(
mul_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
mul_1
.
shape
,
c
.
shape
)
div_1
=
div
(
c
,
d
)
self
.
assertEqual
(
div_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
div_1
.
shape
,
c
.
shape
)
sqrt_1
=
sqrt
(
b
)
self
.
assertEqual
(
sqrt_1
.
dtype
,
b
.
dtype
)
self
.
assertEqual
(
sqrt_1
.
shape
,
b
.
shape
)
tanh_1
=
tanh
(
d
)
self
.
assertEqual
(
tanh_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
tanh_1
.
shape
,
d
.
shape
)
reshape_1
=
reshape
(
c
,
d
.
shape
)
self
.
assertEqual
(
reshape_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
reshape_1
.
shape
,
d
.
shape
)
broadcast_1
=
broadcast
(
b
,
e
.
shape
)
self
.
assertEqual
(
broadcast_1
.
dtype
,
b
.
dtype
)
self
.
assertEqual
(
broadcast_1
.
shape
,
e
.
shape
)
transpose_1
=
transpose
(
c
,
axis
=
[
1
,
0
])
self
.
assertEqual
(
transpose_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
transpose_1
.
shape
,
e
.
shape
)
split_1_0
,
split_1_1
=
split
(
c
,
num_or_sections
=
[
1
,
2
],
axis
=
1
)
self
.
assertEqual
(
split_1_0
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
split_1_0
.
shape
,
(
2
,
1
))
self
.
assertEqual
(
split_1_1
.
shape
,
(
2
,
2
))
concat_1
=
concat
([
c
,
d
],
axis
=
0
)
self
.
assertEqual
(
concat_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
concat_1
.
shape
,
(
4
,
3
))
reduce_1
=
reduce
(
d
,
axis
=
[
1
])
self
.
assertEqual
(
reduce_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
reduce_1
.
shape
,
(
2
,
))
reduce_2
=
reduce
(
c
,
axis
=
[
0
,
1
])
self
.
assertEqual
(
reduce_2
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
reduce_2
.
shape
,
(
1
,
))
# TODO: reduce + keepdim
matmul_1
=
matmul
(
d
,
e
)
self
.
assertEqual
(
matmul_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
matmul_1
.
shape
,
(
2
,
2
))
slice_select_1
=
slice_select
(
e
,
axis
=
[
0
],
starts
=
[
0
],
ends
=
[
2
],
strides
=
[
1
])
self
.
assertEqual
(
slice_select_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
slice_select_1
.
shape
,
(
2
,
2
))
slice_select_2
=
slice_select
(
d
,
axis
=
[
0
,
1
],
starts
=
[
0
,
1
],
ends
=
[
2
,
3
],
strides
=
[
1
,
2
])
self
.
assertEqual
(
slice_select_2
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
slice_select_2
.
shape
,
(
2
,
1
))
y
=
broadcast
(
b
,
[
2
,
2
])
slice_assign_1
=
slice_assign
(
d
,
y
,
axis
=
[
1
],
starts
=
[
1
],
ends
=
[
3
],
strides
=
[
1
])
self
.
assertEqual
(
slice_assign_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
slice_assign_1
.
shape
,
d
.
shape
)
index
=
paddle
.
static
.
data
(
'index'
,
shape
=
[
5
],
dtype
=
'int32'
)
gather_1
=
gather
(
e
,
index
,
axis
=
0
)
self
.
assertEqual
(
gather_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
gather_1
.
shape
,
(
5
,
2
))
y
=
paddle
.
rand
([
5
,
2
],
dtype
=
'float32'
)
scatter_add_1
=
scatter_add
(
e
,
y
,
index
,
axis
=
0
)
self
.
assertEqual
(
scatter_add_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
scatter_add_1
.
shape
,
e
.
shape
)
fill_const_1
=
fill_const
(
value
=
10
,
shape
=
a
.
shape
,
dtype
=
a
.
dtype
)
self
.
assertEqual
(
fill_const_1
.
shape
,
a
.
shape
)
self
.
assertEqual
(
fill_const_1
.
dtype
,
a
.
dtype
)
neg_1
=
neg
(
x
=
b
)
self
.
assertEqual
(
neg_1
.
shape
,
b
.
shape
)
self
.
assertEqual
(
neg_1
.
dtype
,
b
.
dtype
)
set_value_1
=
set_value
(
d
,
a
,
axis
=
[
1
],
starts
=
[
1
],
ends
=
[
3
],
strides
=
[
1
],
out
=
d
)
self
.
assertEqual
(
set_value_1
.
shape
,
d
.
shape
)
self
.
assertEqual
(
set_value_1
.
dtype
,
d
.
dtype
)
@
classmethod
def
tearDownClass
(
cls
):
paddle
.
disable_static
()
def
test_prim_ops
(
self
):
program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
program
):
args
=
self
.
_as_tuple
(
self
.
args
)
args
=
self
.
arr2var
(
args
)
results
=
self
.
op
(
*
args
,
**
self
.
kwargs
)
results
=
self
.
_as_tuple
(
results
)
expected_shape
=
self
.
_as_tuple
(
self
.
expected_shape
)
expected_dtype
=
self
.
_as_tuple
(
self
.
expected_dtype
)
for
r
,
shape
,
dtype
in
zip
(
results
,
expected_shape
,
expected_dtype
):
self
.
assertEqual
(
r
.
shape
,
shape
)
self
.
assertEqual
(
str
(
r
.
dtype
).
split
(
'.'
)[
1
],
dtype
)
def
arr2var
(
self
,
arr
):
"""convert numpy ndarray to paddle Variable recursively."""
return
[
paddle
.
static
.
data
(
f
'x
{
uuid
.
uuid4
()
}
'
,
v
.
shape
,
v
.
dtype
)
if
isinstance
(
v
,
np
.
ndarray
)
else
self
.
arr2var
(
v
)
for
v
in
arr
]
def
_as_tuple
(
self
,
input
):
if
isinstance
(
input
,
(
tuple
,
list
))
and
len
(
input
)
==
0
:
return
input
if
not
isinstance
(
input
,
(
tuple
,
list
))
or
all
(
isinstance
(
i
,
int
)
for
i
in
input
):
return
(
input
,
)
return
input
if
__name__
==
'__main__'
:
...
...
python/paddle/incubate/autograd/primops.py
浏览文件 @
22342d51
...
...
@@ -122,6 +122,21 @@ def tanh(x, out=None):
return
_simple_unop
(
LayerHelper
(
'tanh_p'
,
**
locals
()))
@
REGISTER_FN
(
'sin_p'
,
'X'
,
'Y'
)
def
sin
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'sin_p'
,
**
locals
()))
@
REGISTER_FN
(
'cos_p'
,
'X'
,
'Y'
)
def
cos
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'cos_p'
,
**
locals
()))
@
REGISTER_FN
(
'exp_p'
,
'X'
,
'Y'
)
def
exp
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'exp_p'
,
**
locals
()))
@
REGISTER_FN
(
'reshape_p'
,
'X'
,
'Y'
)
def
reshape
(
x
,
shape
,
out
=
None
):
return
_manipulation_unop
(
LayerHelper
(
'reshape_p'
,
**
locals
()))
...
...
python/paddle/incubate/autograd/primrules.py
浏览文件 @
22342d51
...
...
@@ -15,13 +15,15 @@ import typing
import
paddle
from
.primreg
import
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_JVP
,
REGISTER_TRANSPOSE
from
.primreg
import
(
lookup_fn
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_jvp
,
lookup_transpose
,
op_position_inputs
,
op_position_output
)
from
.primops
import
(
neg
,
add
,
sub
,
mul
,
div
,
sqrt
,
tanh
,
reshape
,
broadcast
,
transpose
,
split
,
concat
,
reduce
,
matmul
,
slice_select
,
slice_assign
,
gather
,
scatter_add
,
fill_const
,
set_value
)
from
.utils
import
get_input_var_list
,
get_output_var_list
,
INT_DTYPE_2_STRING
from
.primops
import
(
add
,
broadcast
,
concat
,
cos
,
div
,
exp
,
fill_const
,
gather
,
matmul
,
mul
,
neg
,
reduce
,
reshape
,
scatter_add
,
set_value
,
sin
,
slice_assign
,
slice_select
,
split
,
sqrt
,
sub
,
tanh
,
transpose
)
from
.primreg
import
(
REGISTER_JVP
,
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_TRANSPOSE
,
lookup_fn
,
lookup_jvp
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_transpose
,
op_position_inputs
,
op_position_output
)
from
.utils
import
INT_DTYPE_2_STRING
,
get_input_var_list
,
get_output_var_list
def
_orig2prim
(
op
,
*
args
):
...
...
@@ -149,6 +151,21 @@ def tanh_orig2prim(op, x):
return
tanh
(
x
)
@
REGISTER_ORIG2PRIM
(
'sin'
)
def
sin_orig2prim
(
op
,
x
):
return
sin
(
x
)
@
REGISTER_ORIG2PRIM
(
'cos'
)
def
cos_orig2prim
(
op
,
x
):
return
cos
(
x
)
@
REGISTER_ORIG2PRIM
(
'exp'
)
def
exp_orig2prim
(
op
,
x
):
return
exp
(
x
)
@
REGISTER_ORIG2PRIM
(
'fill_zeros_like'
)
def
fill_zeros_like_orig2prim
(
op
,
x
):
return
fill_const
(
value
=
0.0
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
...
...
@@ -301,6 +318,21 @@ def tanh_prim2orig(op, x):
return
paddle
.
tanh
(
x
)
@
REGISTER_PRIM2ORIG
(
'sin_p'
)
def
sin_prim2orig
(
op
,
x
):
return
paddle
.
sin
(
x
)
@
REGISTER_PRIM2ORIG
(
'cos_p'
)
def
cos_prim2orig
(
op
,
x
):
return
paddle
.
cos
(
x
)
@
REGISTER_PRIM2ORIG
(
'exp_p'
)
def
exp_prim2orig
(
op
,
x
):
return
paddle
.
exp
(
x
)
@
REGISTER_PRIM2ORIG
(
'reshape_p'
)
def
reshape_prim2orig
(
op
,
x
):
return
paddle
.
reshape
(
x
,
shape
=
op
.
attr
(
'shape'
))
...
...
@@ -453,6 +485,30 @@ def tanh_jvp(op, x_dot):
return
y_dot
@
REGISTER_JVP
(
'sin_p'
)
def
sin_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
x
,
=
op_position_inputs
(
op
)
return
mul
(
x_dot
,
cos
(
x
))
@
REGISTER_JVP
(
'cos_p'
)
def
cos_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
x
,
=
op_position_inputs
(
op
)
return
mul
(
x_dot
,
neg
(
sin
(
x
)))
@
REGISTER_JVP
(
'exp_p'
)
def
exp_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
y
=
op_position_output
(
op
)
return
mul
(
x_dot
,
y
)
@
REGISTER_JVP
(
'reshape_p'
)
def
reshape_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录