Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
22342d51
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
22342d51
编写于
7月 26, 2022
作者:
X
Xiaoxu Chen
提交者:
GitHub
7月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sin,cos,exp primitive operators (#44345)
上级
0d51fcf1
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
765 addition
and
205 deletion
+765
-205
paddle/fluid/operators/prim_ops/cos_p_op.cc
paddle/fluid/operators/prim_ops/cos_p_op.cc
+71
-0
paddle/fluid/operators/prim_ops/exp_p_op.cc
paddle/fluid/operators/prim_ops/exp_p_op.cc
+71
-0
paddle/fluid/operators/prim_ops/sin_p_op.cc
paddle/fluid/operators/prim_ops/sin_p_op.cc
+71
-0
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
.../fluid/tests/unittests/autograd/test_jvp_and_transpose.py
+91
-0
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
...n/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
+60
-0
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
...n/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
+60
-0
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
+146
-55
python/paddle/fluid/tests/unittests/autograd/test_primops.py
python/paddle/fluid/tests/unittests/autograd/test_primops.py
+117
-143
python/paddle/incubate/autograd/primops.py
python/paddle/incubate/autograd/primops.py
+15
-0
python/paddle/incubate/autograd/primrules.py
python/paddle/incubate/autograd/primrules.py
+63
-7
未找到文件。
paddle/fluid/operators/prim_ops/cos_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
CosPrimOp
:
public
framework
::
OperatorBase
{
public:
CosPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator cos_p should not be excuted directly"
));
}
};
class
CosPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of cos_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of cos_p op."
);
AddComment
(
R"DOC(Autograd primitive cos_p operator.)DOC"
);
}
};
class
CosPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
CosPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
cos_p
,
paddle
::
operators
::
CosPrimOp
,
paddle
::
operators
::
CosPrimOpMaker
,
paddle
::
operators
::
CosPrimOpShapeInference
,
paddle
::
operators
::
CosPrimOpVarTypeInference
);
paddle/fluid/operators/prim_ops/exp_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
ExpPrimOp
:
public
framework
::
OperatorBase
{
public:
ExpPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator exp_p should not be excuted directly"
));
}
};
class
ExpPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of exp_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of exp_p op."
);
AddComment
(
R"DOC(Autograd primitive exp_p operator.)DOC"
);
}
};
class
ExpPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
ExpPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
exp_p
,
paddle
::
operators
::
ExpPrimOp
,
paddle
::
operators
::
ExpPrimOpMaker
,
paddle
::
operators
::
ExpPrimOpShapeInference
,
paddle
::
operators
::
ExpPrimOpVarTypeInference
);
paddle/fluid/operators/prim_ops/sin_p_op.cc
0 → 100644
浏览文件 @
22342d51
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
SinPrimOp
:
public
framework
::
OperatorBase
{
public:
SinPrimOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Prim operator sin_p should not be excuted directly"
));
}
};
class
SinPrimOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The input tensor of sin_p op."
);
AddOutput
(
"Y"
,
"(Tensor), The output tensor of sin_p op."
);
AddComment
(
R"DOC(Autograd primitive sin_p operator.)DOC"
);
}
};
class
SinPrimOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
framework
::
InferShapeVarPtr
x_var_ptr
=
ctx
->
GetInputVarPtrs
(
"X"
)[
0
];
framework
::
InferShapeVarPtr
y_var_ptr
=
ctx
->
GetOutputVarPtrs
(
"Y"
)[
0
];
framework
::
VarDesc
*
x_var
=
PADDLE_GET
(
framework
::
VarDesc
*
,
x_var_ptr
);
PADDLE_GET
(
framework
::
VarDesc
*
,
y_var_ptr
)
->
SetShape
(
x_var
->
GetShape
());
}
};
class
SinPrimOpVarTypeInference
:
public
framework
::
StaticGraphVarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
x_name
=
Input
(
ctx
,
"X"
)[
0
];
auto
y_name
=
Output
(
ctx
,
"Y"
)[
0
];
SetType
(
ctx
,
y_name
,
GetType
(
ctx
,
x_name
));
SetDataType
(
ctx
,
y_name
,
GetDataType
(
ctx
,
x_name
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OPERATOR
(
sin_p
,
paddle
::
operators
::
SinPrimOp
,
paddle
::
operators
::
SinPrimOpMaker
,
paddle
::
operators
::
SinPrimOpShapeInference
,
paddle
::
operators
::
SinPrimOpVarTypeInference
);
python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py
浏览文件 @
22342d51
...
...
@@ -273,6 +273,97 @@ class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class
TestSinPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'sin_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'sin_p'
,
# jvp op:
'mul_p'
,
'cos_p'
,
# transpose op:
]
class
TestCosPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'cos_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'cos_p'
,
# jvp op:
'mul_p'
,
'sin_p'
,
'fill_constant_p'
,
'sub_p'
# transpose op:
]
class
TestExpPJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
# Set prim op
self
.
op_type
=
'exp_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
prim_input
=
{
'X'
:
X
,
}
self
.
prim_output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
prim_attrs
=
{}
# Set JVP
X_DOT
=
paddle
.
static
.
data
(
name
=
'X_DOT'
,
shape
=
[
5
,
6
],
dtype
=
'int64'
)
self
.
jvp_args
=
(
X_DOT
,
)
self
.
jvp_out_shape_map
=
{
0
:
self
.
prim_output
[
'Y'
]}
self
.
all_ops
=
[
# prim op:
'exp_p'
,
# jvp op:
'mul_p'
,
# transpose op:
]
class
TestReshapePJVPAndTranspose
(
TestAddPJVPAndTranspose
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py
浏览文件 @
22342d51
...
...
@@ -148,6 +148,66 @@ class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim):
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestSinOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'sin'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'sin'
,
'sin_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestCosOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'cos'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'cos'
,
'cos_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestExpOrig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
self
.
op_type
=
'exp'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
3
,
4
],
dtype
=
'float'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Out'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
orig2prim_args
=
(
X
,
)
self
.
all_ops
=
[
'exp'
,
'exp_p'
]
self
.
out_map
=
{
0
:
self
.
output
[
'Out'
]}
class
TestReshape2Orig2Prim
(
TestElementWiseAddOrig2Prim
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py
浏览文件 @
22342d51
...
...
@@ -164,6 +164,66 @@ class TestTanhPPrim2Orig(TestAddPPrim2Orig):
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestSinPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'sin_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'sin_p'
,
'sin'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestCosPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'cos_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'cos_p'
,
'cos'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestExpPPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
self
.
op_type
=
'exp_p'
X
=
paddle
.
static
.
data
(
name
=
'X'
,
shape
=
[
7
,
8
],
dtype
=
'float64'
)
self
.
input
=
{
'X'
:
X
,
}
self
.
output
=
{
'Y'
:
self
.
layer_help
.
create_variable_for_type_inference
(
dtype
=
X
.
dtype
)
}
self
.
attrs
=
{}
self
.
prim2orig_args
=
(
X
,
)
self
.
all_ops
=
[
'exp_p'
,
'exp'
]
self
.
out_map
=
{
self
.
output
[
'Y'
]:
0
}
class
TestReshapePPrim2Orig
(
TestAddPPrim2Orig
):
def
init_data
(
self
):
...
...
python/paddle/fluid/tests/unittests/autograd/test_primapi.py
浏览文件 @
22342d51
...
...
@@ -17,7 +17,6 @@ import unittest
import
numpy
as
np
import
paddle
from
paddle.incubate.autograd
import
primapi
import
config
import
utils
...
...
@@ -135,19 +134,19 @@ class TestWithoutProgramGuard(unittest.TestCase):
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
)
,
((
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
)
,
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
)
,
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
)
,
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
)
,
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)
),
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
'float64'
)
))
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
)
,
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
)
,
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
)
,
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
)
,
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
))
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
'float64'
),
))
class
TestForwardGrad
(
unittest
.
TestCase
):
@
classmethod
...
...
@@ -219,7 +218,8 @@ class TestForwardGrad(unittest.TestCase):
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
primapi
.
forward_grad
(
ys
,
static_xs
,
static_v
)
ys_grad
=
paddle
.
incubate
.
autograd
.
forward_grad
(
ys
,
static_xs
,
static_v
)
paddle
.
incubate
.
autograd
.
prim2orig
(
mp
.
block
(
0
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
...
...
@@ -229,15 +229,144 @@ class TestForwardGrad(unittest.TestCase):
def
test_illegal_param
(
self
):
paddle
.
incubate
.
autograd
.
enable_prim
()
with
self
.
assertRaises
(
TypeError
):
primapi
.
forward_grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
paddle
.
incubate
.
autograd
.
forward_grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
with
self
.
assertRaises
(
TypeError
):
primapi
.
forward_grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
forward_grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
disable_prim
()
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
((
utils
.
TEST_CASE_NAME
,
'fun'
,
'xs'
,
'v'
,
'dtype'
),
(
(
'matmul'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
3
,
2
)),
None
,
'float32'
),
(
'multiply'
,
paddle
.
multiply
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float64'
),
(
'add'
,
paddle
.
add
,
(
np
.
random
.
rand
(
2
,
3
),
np
.
random
.
rand
(
2
,
3
)),
None
,
'float32'
),
(
'input_not_sequence'
,
paddle
.
tanh
,
(
np
.
random
.
rand
(
5
,
5
),
),
None
,
'float64'
),
(
'input_gradients_not_none'
,
paddle
.
matmul
,
(
np
.
random
.
rand
(
3
,
3
),
np
.
random
.
rand
(
3
,
3
)),
(
np
.
random
.
rand
(
3
,
3
),
),
'float64'
),
(
'sin'
,
paddle
.
sin
,
(
np
.
random
.
rand
(
100
,
200
),
),
None
,
'float32'
),
(
'cos'
,
paddle
.
cos
,
(
np
.
random
.
rand
(
200
,
90
),
),
None
,
'float32'
),
(
'exp'
,
paddle
.
exp
,
(
np
.
random
.
rand
(
299
,
320
),
),
None
,
'float32'
),
))
class
TestGrad
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
incubate
.
autograd
.
enable_prim
()
def
tearDown
(
self
):
paddle
.
incubate
.
autograd
.
disable_prim
()
paddle
.
disable_static
()
@
classmethod
def
setUpClass
(
cls
):
cls
.
xs
=
tuple
(
x
.
astype
(
cls
.
dtype
)
for
x
in
cls
.
xs
)
cls
.
_rtol
=
config
.
TOLERANCE
.
get
(
str
(
cls
.
dtype
)).
get
(
"first_order_grad"
).
get
(
"rtol"
)
cls
.
_atol
=
config
.
TOLERANCE
.
get
(
str
(
cls
.
dtype
)).
get
(
"first_order_grad"
).
get
(
"atol"
)
def
test_grad
(
self
):
def
expected
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
_
,
ys_grad
=
paddle
.
incubate
.
autograd
.
vjp
(
self
.
fun
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
def
actual
():
paddle
.
incubate
.
autograd
.
enable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
incubate
.
autograd
.
grad
(
ys
,
static_xs
,
static_v
)
paddle
.
incubate
.
autograd
.
prim2orig
(
mp
.
block
(
0
))
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
disable_prim
()
return
out
actual
=
actual
()
expected
=
expected
()
self
.
assertEqual
(
type
(
actual
),
type
(
expected
))
for
i
,
j
in
zip
(
actual
,
expected
):
np
.
testing
.
assert_allclose
(
i
,
j
,
rtol
=
self
.
_rtol
,
atol
=
self
.
_atol
)
def
test_illegal_param
(
self
):
paddle
.
incubate
.
autograd
.
enable_prim
()
with
self
.
assertRaises
(
TypeError
):
paddle
.
incubate
.
autograd
.
grad
(
1
,
paddle
.
static
.
data
(
'inputs'
,
shape
=
[
1
]))
with
self
.
assertRaises
(
TypeError
):
paddle
.
incubate
.
autograd
.
grad
(
paddle
.
static
.
data
(
'targets'
,
shape
=
[
1
]),
1
)
paddle
.
incubate
.
autograd
.
disable_prim
()
def
test_disable_prim
(
self
):
def
expected
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
incubate
.
autograd
.
grad
(
ys
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
def
actual
():
paddle
.
incubate
.
autograd
.
disable_prim
()
sp
=
paddle
.
static
.
Program
()
mp
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
feed
,
static_xs
,
static_v
=
utils
.
gen_static_data_and_feed
(
self
.
xs
,
self
.
v
,
stop_gradient
=
False
)
ys
=
self
.
fun
(
*
static_xs
)
if
isinstance
(
static_xs
,
typing
.
Sequence
)
else
self
.
fun
(
static_xs
)
ys_grad
=
paddle
.
static
.
gradients
(
ys
,
static_xs
,
static_v
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
mp
,
feed
=
feed
,
fetch_list
=
ys_grad
)
paddle
.
incubate
.
autograd
.
enable_prim
()
return
out
actual
=
actual
()
expected
=
expected
()
self
.
assertEqual
(
type
(
actual
),
type
(
expected
))
for
i
,
j
in
zip
(
actual
,
expected
):
np
.
testing
.
assert_allclose
(
i
,
j
,
rtol
=
self
.
_rtol
,
atol
=
self
.
_atol
)
class
TestGradWithHigherOrder
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
enable_static
()
paddle
.
incubate
.
autograd
.
enable_prim
()
...
...
@@ -346,44 +475,6 @@ class TestGrad(unittest.TestCase):
np
.
testing
.
assert_allclose
(
outs
,
result
,
rtol
=
1e-5
,
atol
=
1e-5
)
paddle
.
incubate
.
autograd
.
disable_prim
()
def
test_disable_prim
(
self
):
def
actual
(
x
:
np
.
array
):
paddle
.
incubate
.
autograd
.
disable_prim
()
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
var_x
=
paddle
.
static
.
data
(
'x'
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
var_x
.
stop_gradient
=
False
y
=
paddle
.
tanh
(
var_x
)
y_grad
=
paddle
.
incubate
.
autograd
.
grad
(
y
,
var_x
)
y_second_grad
=
paddle
.
incubate
.
autograd
.
grad
(
y_grad
,
var_x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup
)
return
exe
.
run
(
main
,
feed
=
{
'x'
:
x
},
fetch_list
=
[
y_grad
,
y_second_grad
])
def
expect
(
x
:
np
.
array
):
paddle
.
incubate
.
autograd
.
disable_prim
()
main
=
paddle
.
static
.
Program
()
startup
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
,
startup
):
var_x
=
paddle
.
static
.
data
(
'x'
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
var_x
.
stop_gradient
=
False
y
=
paddle
.
tanh
(
var_x
)
y_grad
=
paddle
.
static
.
gradients
(
y
,
var_x
)
y_second_grad
=
paddle
.
static
.
gradients
(
y_grad
,
var_x
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
startup
)
return
exe
.
run
(
main
,
feed
=
{
'x'
:
x
},
fetch_list
=
[
y_grad
,
y_second_grad
])
x
=
np
.
random
.
randn
(
100
,
200
)
for
i
,
j
in
zip
(
actual
(
x
),
expect
(
x
)):
np
.
testing
.
assert_allclose
(
i
,
j
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/autograd/test_primops.py
浏览文件 @
22342d51
...
...
@@ -11,154 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
uuid
import
numpy
as
np
import
paddle
from
paddle.incubate.autograd.primops
import
(
neg
,
set_value
,
add
,
sub
,
mul
,
div
,
sqrt
,
tanh
,
reshape
,
broadcast
,
transpose
,
split
,
concat
,
reduce
,
matmul
,
slice_select
,
slice_assign
,
gather
,
scatter_add
,
fill_const
)
from
paddle.incubate.autograd.primx
import
Transform
,
topo_path
,
orig2prim
,
prim2orig
from
paddle.incubate.autograd.utils
import
enable_prim
,
disable_prim
,
prim_enabled
class
TestPyPrimOps
(
unittest
.
TestCase
):
""" Test Python wrappers of primitive ops. """
def
setUp
(
self
):
from
numpy.random
import
randint
,
randn
from
paddle.incubate.autograd
import
primops
,
primx
from
paddle.incubate.autograd
import
utils
as
prim_utils
import
config
import
utils
paddle
.
enable_static
()
@
utils
.
place
(
config
.
DEVICES
)
@
utils
.
parameterize
(
(
utils
.
TEST_CASE_NAME
,
'op'
,
'args'
,
'kwargs'
,
'expected_shape'
,
'expected_dtype'
),
(
(
'add'
,
primops
.
add
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sub'
,
primops
.
sub
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'mul'
,
primops
.
mul
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'div'
,
primops
.
div
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sub'
,
primops
.
sub
,
(
randn
(
2
,
3
),
randn
(
2
,
3
)),
{},
(
2
,
3
),
'float64'
),
(
'sqrt'
,
primops
.
sqrt
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'tanh'
,
primops
.
tanh
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'sin'
,
primops
.
sin
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'cos'
,
primops
.
cos
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'exp'
,
primops
.
exp
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
(
'reshape'
,
primops
.
reshape
,
randn
(
2
,
3
),
{
'shape'
:
(
3
,
2
)
},
(
3
,
2
),
'float64'
),
(
'broadcast'
,
primops
.
broadcast
,
randn
(
2
),
{
'shape'
:
(
3
,
2
)
},
(
3
,
2
),
'float64'
),
(
'transpose'
,
primops
.
transpose
,
randn
(
2
,
3
),
{
'axis'
:
(
1
,
0
)
},
(
3
,
2
),
'float64'
),
(
'concat_axis0'
,
primops
.
concat
,
((
randn
(
2
,
3
),
randn
(
2
,
3
)),
),
{
'axis'
:
0
},
(
4
,
3
),
'float64'
),
(
'concat_axis1'
,
primops
.
concat
,
((
randn
(
2
,
3
),
randn
(
2
,
3
)),
),
{
'axis'
:
1
},
(
2
,
6
),
'float64'
),
(
'reduce_axis1'
,
primops
.
reduce
,
randn
(
2
,
3
),
{
'axis'
:
(
1
,
)
},
(
2
,
),
'float64'
),
(
'reduce_axis01'
,
primops
.
reduce
,
randn
(
2
,
3
),
{
'axis'
:
(
0
,
1
)
},
(
1
,
),
'float64'
),
(
'split'
,
primops
.
split
,
randn
(
2
,
3
),
{
'num_or_sections'
:
[
1
,
2
],
'axis'
:
1
},
((
2
,
1
),
(
2
,
2
)),
(
'float64'
,
'float64'
)),
(
'matmul'
,
primops
.
matmul
,
(
randn
(
2
,
3
),
randn
(
3
,
2
)),
{},
(
2
,
2
),
'float64'
),
(
'slice_select'
,
primops
.
slice_select
,
randn
(
3
,
2
),
{
'axis'
:
[
0
],
'starts'
:
[
0
],
'ends'
:
[
2
],
'strides'
:
[
1
]
},
(
2
,
2
),
'float64'
),
(
'slice_assign'
,
primops
.
slice_assign
,
(
randn
(
2
,
3
),
randn
(
2
,
2
)),
{
'axis'
:
[
1
],
'starts'
:
[
1
],
'ends'
:
[
3
],
'strides'
:
[
1
]
},
(
2
,
3
),
'float64'
),
(
'gather'
,
primops
.
gather
,
(
randn
(
3
,
2
),
randint
(
0
,
2
,
(
5
,
),
np
.
int32
)),
{
'axis'
:
0
},
(
5
,
2
),
'float64'
),
(
'scatter_add'
,
primops
.
scatter_add
,
(
randn
(
3
,
2
),
randn
(
5
,
2
),
randint
(
0
,
2
,
(
5
,
),
np
.
int32
)),
{
'axis'
:
0
},
(
3
,
2
),
'float64'
),
(
'fill_const'
,
primops
.
fill_const
,
(),
{
'value'
:
10
,
'shape'
:
(
3
,
2
),
'dtype'
:
paddle
.
float32
},
(
3
,
2
),
'float32'
),
(
'neg'
,
primops
.
neg
,
randn
(
2
,
3
),
{},
(
2
,
3
),
'float64'
),
))
class
TestPrimops
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
paddle
.
enable_static
()
def
test_ops
(
self
):
A
=
np
.
random
.
rand
(
1
)
B
=
np
.
random
.
rand
(
2
)
C
=
np
.
random
.
rand
(
2
,
3
)
D
=
np
.
random
.
rand
(
2
,
3
)
E
=
np
.
random
.
rand
(
3
,
2
)
a
=
paddle
.
static
.
data
(
name
=
'A'
,
shape
=
A
.
shape
,
dtype
=
'float32'
)
b
=
paddle
.
static
.
data
(
name
=
'B'
,
shape
=
B
.
shape
,
dtype
=
'float32'
)
c
=
paddle
.
static
.
data
(
name
=
'C'
,
shape
=
C
.
shape
,
dtype
=
'float32'
)
d
=
paddle
.
static
.
data
(
name
=
'D'
,
shape
=
D
.
shape
,
dtype
=
'float32'
)
e
=
paddle
.
static
.
data
(
name
=
'E'
,
shape
=
E
.
shape
,
dtype
=
'float32'
)
add_1
=
add
(
a
,
a
)
self
.
assertEqual
(
add_1
.
dtype
,
a
.
dtype
)
self
.
assertEqual
(
add_1
.
shape
,
a
.
shape
)
add_2
=
add
(
c
,
d
)
self
.
assertEqual
(
add_2
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
add_2
.
shape
,
c
.
shape
)
sub_1
=
sub
(
c
,
d
)
self
.
assertEqual
(
sub_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
sub_1
.
shape
,
c
.
shape
)
mul_1
=
mul
(
c
,
d
)
self
.
assertEqual
(
mul_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
mul_1
.
shape
,
c
.
shape
)
div_1
=
div
(
c
,
d
)
self
.
assertEqual
(
div_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
div_1
.
shape
,
c
.
shape
)
sqrt_1
=
sqrt
(
b
)
self
.
assertEqual
(
sqrt_1
.
dtype
,
b
.
dtype
)
self
.
assertEqual
(
sqrt_1
.
shape
,
b
.
shape
)
tanh_1
=
tanh
(
d
)
self
.
assertEqual
(
tanh_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
tanh_1
.
shape
,
d
.
shape
)
reshape_1
=
reshape
(
c
,
d
.
shape
)
self
.
assertEqual
(
reshape_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
reshape_1
.
shape
,
d
.
shape
)
broadcast_1
=
broadcast
(
b
,
e
.
shape
)
self
.
assertEqual
(
broadcast_1
.
dtype
,
b
.
dtype
)
self
.
assertEqual
(
broadcast_1
.
shape
,
e
.
shape
)
transpose_1
=
transpose
(
c
,
axis
=
[
1
,
0
])
self
.
assertEqual
(
transpose_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
transpose_1
.
shape
,
e
.
shape
)
split_1_0
,
split_1_1
=
split
(
c
,
num_or_sections
=
[
1
,
2
],
axis
=
1
)
self
.
assertEqual
(
split_1_0
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
split_1_0
.
shape
,
(
2
,
1
))
self
.
assertEqual
(
split_1_1
.
shape
,
(
2
,
2
))
concat_1
=
concat
([
c
,
d
],
axis
=
0
)
self
.
assertEqual
(
concat_1
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
concat_1
.
shape
,
(
4
,
3
))
reduce_1
=
reduce
(
d
,
axis
=
[
1
])
self
.
assertEqual
(
reduce_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
reduce_1
.
shape
,
(
2
,
))
reduce_2
=
reduce
(
c
,
axis
=
[
0
,
1
])
self
.
assertEqual
(
reduce_2
.
dtype
,
c
.
dtype
)
self
.
assertEqual
(
reduce_2
.
shape
,
(
1
,
))
# TODO: reduce + keepdim
matmul_1
=
matmul
(
d
,
e
)
self
.
assertEqual
(
matmul_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
matmul_1
.
shape
,
(
2
,
2
))
slice_select_1
=
slice_select
(
e
,
axis
=
[
0
],
starts
=
[
0
],
ends
=
[
2
],
strides
=
[
1
])
self
.
assertEqual
(
slice_select_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
slice_select_1
.
shape
,
(
2
,
2
))
slice_select_2
=
slice_select
(
d
,
axis
=
[
0
,
1
],
starts
=
[
0
,
1
],
ends
=
[
2
,
3
],
strides
=
[
1
,
2
])
self
.
assertEqual
(
slice_select_2
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
slice_select_2
.
shape
,
(
2
,
1
))
y
=
broadcast
(
b
,
[
2
,
2
])
slice_assign_1
=
slice_assign
(
d
,
y
,
axis
=
[
1
],
starts
=
[
1
],
ends
=
[
3
],
strides
=
[
1
])
self
.
assertEqual
(
slice_assign_1
.
dtype
,
d
.
dtype
)
self
.
assertEqual
(
slice_assign_1
.
shape
,
d
.
shape
)
index
=
paddle
.
static
.
data
(
'index'
,
shape
=
[
5
],
dtype
=
'int32'
)
gather_1
=
gather
(
e
,
index
,
axis
=
0
)
self
.
assertEqual
(
gather_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
gather_1
.
shape
,
(
5
,
2
))
y
=
paddle
.
rand
([
5
,
2
],
dtype
=
'float32'
)
scatter_add_1
=
scatter_add
(
e
,
y
,
index
,
axis
=
0
)
self
.
assertEqual
(
scatter_add_1
.
dtype
,
e
.
dtype
)
self
.
assertEqual
(
scatter_add_1
.
shape
,
e
.
shape
)
fill_const_1
=
fill_const
(
value
=
10
,
shape
=
a
.
shape
,
dtype
=
a
.
dtype
)
self
.
assertEqual
(
fill_const_1
.
shape
,
a
.
shape
)
self
.
assertEqual
(
fill_const_1
.
dtype
,
a
.
dtype
)
neg_1
=
neg
(
x
=
b
)
self
.
assertEqual
(
neg_1
.
shape
,
b
.
shape
)
self
.
assertEqual
(
neg_1
.
dtype
,
b
.
dtype
)
set_value_1
=
set_value
(
d
,
a
,
axis
=
[
1
],
starts
=
[
1
],
ends
=
[
3
],
strides
=
[
1
],
out
=
d
)
self
.
assertEqual
(
set_value_1
.
shape
,
d
.
shape
)
self
.
assertEqual
(
set_value_1
.
dtype
,
d
.
dtype
)
@
classmethod
def
tearDownClass
(
cls
):
paddle
.
disable_static
()
def
test_prim_ops
(
self
):
program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
program
):
args
=
self
.
_as_tuple
(
self
.
args
)
args
=
self
.
arr2var
(
args
)
results
=
self
.
op
(
*
args
,
**
self
.
kwargs
)
results
=
self
.
_as_tuple
(
results
)
expected_shape
=
self
.
_as_tuple
(
self
.
expected_shape
)
expected_dtype
=
self
.
_as_tuple
(
self
.
expected_dtype
)
for
r
,
shape
,
dtype
in
zip
(
results
,
expected_shape
,
expected_dtype
):
self
.
assertEqual
(
r
.
shape
,
shape
)
self
.
assertEqual
(
str
(
r
.
dtype
).
split
(
'.'
)[
1
],
dtype
)
def
arr2var
(
self
,
arr
):
"""convert numpy ndarray to paddle Variable recursively."""
return
[
paddle
.
static
.
data
(
f
'x
{
uuid
.
uuid4
()
}
'
,
v
.
shape
,
v
.
dtype
)
if
isinstance
(
v
,
np
.
ndarray
)
else
self
.
arr2var
(
v
)
for
v
in
arr
]
def
_as_tuple
(
self
,
input
):
if
isinstance
(
input
,
(
tuple
,
list
))
and
len
(
input
)
==
0
:
return
input
if
not
isinstance
(
input
,
(
tuple
,
list
))
or
all
(
isinstance
(
i
,
int
)
for
i
in
input
):
return
(
input
,
)
return
input
if
__name__
==
'__main__'
:
...
...
python/paddle/incubate/autograd/primops.py
浏览文件 @
22342d51
...
...
@@ -122,6 +122,21 @@ def tanh(x, out=None):
return
_simple_unop
(
LayerHelper
(
'tanh_p'
,
**
locals
()))
@
REGISTER_FN
(
'sin_p'
,
'X'
,
'Y'
)
def
sin
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'sin_p'
,
**
locals
()))
@
REGISTER_FN
(
'cos_p'
,
'X'
,
'Y'
)
def
cos
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'cos_p'
,
**
locals
()))
@
REGISTER_FN
(
'exp_p'
,
'X'
,
'Y'
)
def
exp
(
x
,
out
=
None
):
return
_simple_unop
(
LayerHelper
(
'exp_p'
,
**
locals
()))
@
REGISTER_FN
(
'reshape_p'
,
'X'
,
'Y'
)
def
reshape
(
x
,
shape
,
out
=
None
):
return
_manipulation_unop
(
LayerHelper
(
'reshape_p'
,
**
locals
()))
...
...
python/paddle/incubate/autograd/primrules.py
浏览文件 @
22342d51
...
...
@@ -15,13 +15,15 @@ import typing
import
paddle
from
.primreg
import
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_JVP
,
REGISTER_TRANSPOSE
from
.primreg
import
(
lookup_fn
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_jvp
,
lookup_transpose
,
op_position_inputs
,
op_position_output
)
from
.primops
import
(
neg
,
add
,
sub
,
mul
,
div
,
sqrt
,
tanh
,
reshape
,
broadcast
,
transpose
,
split
,
concat
,
reduce
,
matmul
,
slice_select
,
slice_assign
,
gather
,
scatter_add
,
fill_const
,
set_value
)
from
.utils
import
get_input_var_list
,
get_output_var_list
,
INT_DTYPE_2_STRING
from
.primops
import
(
add
,
broadcast
,
concat
,
cos
,
div
,
exp
,
fill_const
,
gather
,
matmul
,
mul
,
neg
,
reduce
,
reshape
,
scatter_add
,
set_value
,
sin
,
slice_assign
,
slice_select
,
split
,
sqrt
,
sub
,
tanh
,
transpose
)
from
.primreg
import
(
REGISTER_JVP
,
REGISTER_ORIG2PRIM
,
REGISTER_PRIM2ORIG
,
REGISTER_TRANSPOSE
,
lookup_fn
,
lookup_jvp
,
lookup_orig2prim
,
lookup_prim2orig
,
lookup_transpose
,
op_position_inputs
,
op_position_output
)
from
.utils
import
INT_DTYPE_2_STRING
,
get_input_var_list
,
get_output_var_list
def
_orig2prim
(
op
,
*
args
):
...
...
@@ -149,6 +151,21 @@ def tanh_orig2prim(op, x):
return
tanh
(
x
)
@
REGISTER_ORIG2PRIM
(
'sin'
)
def
sin_orig2prim
(
op
,
x
):
return
sin
(
x
)
@
REGISTER_ORIG2PRIM
(
'cos'
)
def
cos_orig2prim
(
op
,
x
):
return
cos
(
x
)
@
REGISTER_ORIG2PRIM
(
'exp'
)
def
exp_orig2prim
(
op
,
x
):
return
exp
(
x
)
@
REGISTER_ORIG2PRIM
(
'fill_zeros_like'
)
def
fill_zeros_like_orig2prim
(
op
,
x
):
return
fill_const
(
value
=
0.0
,
shape
=
x
.
shape
,
dtype
=
x
.
dtype
)
...
...
@@ -301,6 +318,21 @@ def tanh_prim2orig(op, x):
return
paddle
.
tanh
(
x
)
@
REGISTER_PRIM2ORIG
(
'sin_p'
)
def
sin_prim2orig
(
op
,
x
):
return
paddle
.
sin
(
x
)
@
REGISTER_PRIM2ORIG
(
'cos_p'
)
def
cos_prim2orig
(
op
,
x
):
return
paddle
.
cos
(
x
)
@
REGISTER_PRIM2ORIG
(
'exp_p'
)
def
exp_prim2orig
(
op
,
x
):
return
paddle
.
exp
(
x
)
@
REGISTER_PRIM2ORIG
(
'reshape_p'
)
def
reshape_prim2orig
(
op
,
x
):
return
paddle
.
reshape
(
x
,
shape
=
op
.
attr
(
'shape'
))
...
...
@@ -453,6 +485,30 @@ def tanh_jvp(op, x_dot):
return
y_dot
@
REGISTER_JVP
(
'sin_p'
)
def
sin_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
x
,
=
op_position_inputs
(
op
)
return
mul
(
x_dot
,
cos
(
x
))
@
REGISTER_JVP
(
'cos_p'
)
def
cos_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
x
,
=
op_position_inputs
(
op
)
return
mul
(
x_dot
,
neg
(
sin
(
x
)))
@
REGISTER_JVP
(
'exp_p'
)
def
exp_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
return
None
y
=
op_position_output
(
op
)
return
mul
(
x_dot
,
y
)
@
REGISTER_JVP
(
'reshape_p'
)
def
reshape_jvp
(
op
,
x_dot
):
if
x_dot
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录