Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6ef1fbb6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ef1fbb6
编写于
9月 18, 2020
作者:
Y
yaoxuefeng6
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into mod_dataset_v2
上级
2e2074ef
fef94eac
变更
39
隐藏空白更改
内联
并排
Showing
39 changed file
with
875 addition
and
208 deletion
+875
-208
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
...uid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
...mework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
+8
-1
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
...e/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
+19
-0
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
.../framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
...mework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
...e/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
+5
-0
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
.../framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
+8
-0
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
+11
-0
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
...e/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
+7
-0
paddle/fluid/operators/average_accumulates_op.h
paddle/fluid/operators/average_accumulates_op.h
+7
-3
paddle/fluid/operators/empty_op.cc
paddle/fluid/operators/empty_op.cc
+20
-13
paddle/fluid/operators/math/beam_search.cc
paddle/fluid/operators/math/beam_search.cc
+4
-1
paddle/fluid/operators/math/beam_search.cu
paddle/fluid/operators/math/beam_search.cu
+4
-1
paddle/fluid/operators/math/blas.cc
paddle/fluid/operators/math/blas.cc
+5
-1
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+18
-8
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+115
-33
paddle/fluid/operators/shape_op.cc
paddle/fluid/operators/shape_op.cc
+1
-1
paddle/fluid/operators/shape_op.cu
paddle/fluid/operators/shape_op.cu
+2
-2
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+7
-76
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
+20
-16
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
...n/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
+3
-2
python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py
...n/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py
+72
-16
python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
...paddle/fluid/dygraph/dygraph_to_static/partial_program.py
+2
-6
python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py
...ddle/fluid/dygraph/dygraph_to_static/print_transformer.py
+1
-7
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+10
-9
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+3
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+3
-3
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
...d/tests/unittests/dygraph_to_static/test_logging_utils.py
+45
-4
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
...unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
+171
-0
python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py
...n/paddle/fluid/tests/unittests/test_cross_entropy_loss.py
+1
-1
python/paddle/fluid/tests/unittests/test_empty_like_op.py
python/paddle/fluid/tests/unittests/test_empty_like_op.py
+192
-0
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+1
-1
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+1
-0
python/paddle/tensor/creation.py
python/paddle/tensor/creation.py
+68
-0
未找到文件。
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
...
@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
REGISTER_PASS
(
embedding_eltwise_layernorm_fuse_pass
,
REGISTER_PASS
(
embedding_eltwise_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
EmbeddingEltwiseLayerNormFusePass
);
paddle
::
framework
::
ir
::
EmbeddingEltwiseLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
embedding_eltwise_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"lookup_table"
,
0
)
.
EQ
(
"elementweise_add"
,
0
));
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -16,12 +16,13 @@ limitations under the License. */
...
@@ -16,12 +16,13 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
TEST
(
SkipLayerN
ormFusePass
,
basic
)
{
TEST
(
EmbeddingElewiseLayern
ormFusePass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
// (x, y) elementwise_add -> elementwise_out
// (x, y) elementwise_add -> elementwise_out
...
@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
...
@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"
));
"The number of fusion nodes does not meet expectations after fuse"
));
}
}
TEST
(
EmbeddingElewiseLayernormFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"embedding_eltwise_layernorm_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"do not perform "
+
type
()
+
"+bias fuse"
;
VLOG
(
3
)
<<
"do not perform "
+
type
()
+
"+bias fuse"
;
return
;
return
;
}
}
if
(
conv
->
Op
()
->
HasAttr
(
"dilations"
))
{
auto
dilations
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
conv
->
Op
()
->
GetAttr
(
"dilations"
));
for
(
const
auto
&
d
:
dilations
)
{
if
(
d
!=
1
)
{
LOG
(
WARNING
)
<<
"dilation conv not supported in MKLDNN, fuse not apply "
<<
"and set conv attribute use_mkldnn = false"
;
conv
->
Op
()
->
SetAttr
(
"use_mkldnn"
,
false
);
return
;
}
}
}
auto
*
eltwise_bias_tensor
=
auto
*
eltwise_bias_tensor
=
scope
->
FindVar
(
eltwise_bias
->
Name
())
->
GetMutable
<
LoDTensor
>
();
scope
->
FindVar
(
eltwise_bias
->
Name
())
->
GetMutable
<
LoDTensor
>
();
...
@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
...
@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle
::
framework
::
ir
::
Conv2DTransposeBiasFusePass
);
paddle
::
framework
::
ir
::
Conv2DTransposeBiasFusePass
);
REGISTER_PASS
(
conv3d_bias_mkldnn_fuse_pass
,
REGISTER_PASS
(
conv3d_bias_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
Conv3DBiasFusePass
);
paddle
::
framework
::
ir
::
Conv3DBiasFusePass
);
REGISTER_PASS_CAPABILITY
(
conv_bias_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"conv2d"
,
0
)
.
EQ
(
"elementwise_add"
,
0
));
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
...
@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
ASSERT_EQ
(
pass
.
type
(),
std
::
string
(
"conv2d_transpose"
));
ASSERT_EQ
(
pass
.
type
(),
std
::
string
(
"conv2d_transpose"
));
}
}
TEST
(
ConvBiasFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"conv_bias_mkldnn_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <memory>
#include <memory>
#include <tuple>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
...
@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
ResidualConnectionMKLDNNFusePass
);
paddle
::
framework
::
ir
::
ResidualConnectionMKLDNNFusePass
);
REGISTER_PASS_CAPABILITY
(
conv_elementwise_add_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"conv2d"
,
0
)
.
EQ
(
"elementwise_add"
,
0
));
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
...
@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
AssertOpsCount
(
graph
,
2
,
1
);
AssertOpsCount
(
graph
,
2
,
1
);
}
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"conv_elementwise_add_mkldnn_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
REGISTER_PASS_CAPABILITY
(
depthwise_conv_mkldnn_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"depthwise_conv2d"
,
0
));
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
...
@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
...
@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
return
prog
;
return
prog
;
}
}
TEST
(
DepthwiseConvMKLDNNPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"depthwise_conv_mkldnn_pass"
));
}
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
auto
prog
=
BuildProgramDesc
();
...
...
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
...
@@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
REGISTER_PASS
(
multihead_matmul_fuse_pass_v2
,
REGISTER_PASS
(
multihead_matmul_fuse_pass_v2
,
paddle
::
framework
::
ir
::
MultiHeadMatmulV2FusePass
);
paddle
::
framework
::
ir
::
MultiHeadMatmulV2FusePass
);
REGISTER_PASS_CAPABILITY
(
multihead_matmul_fuse_pass_v2
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"mul"
,
0
)
.
EQ
(
"elementwise_add"
,
0
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
EQ
(
"matmul"
,
0
)
.
EQ
(
"softmax"
,
0
));
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -12,6 +12,7 @@ limitations under the License. */
...
@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
...
@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
num_fused_nodes_after
));
num_fused_nodes_after
));
}
}
TEST
(
MultiHeadMatmulFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multihead_matmul_fuse_pass_v2"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
浏览文件 @
6ef1fbb6
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS
(
skip_layernorm_fuse_pass
,
REGISTER_PASS
(
skip_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
SkipLayerNormFusePass
);
paddle
::
framework
::
ir
::
SkipLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
skip_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"elementwise_add"
,
0
)
.
EQ
(
"layer_norm"
,
0
));
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
浏览文件 @
6ef1fbb6
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
...
@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"
));
"The number of fusion nodes does not meet expectations after fuse"
));
}
}
TEST
(
SkipLayerNormFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"skip_layernorm_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/average_accumulates_op.h
浏览文件 @
6ef1fbb6
...
@@ -54,9 +54,13 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
...
@@ -54,9 +54,13 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float
average_window
=
ctx
.
Attr
<
float
>
(
"average_window"
);
float
average_window
=
ctx
.
Attr
<
float
>
(
"average_window"
);
int64_t
max_average_window
=
ctx
.
Attr
<
int64_t
>
(
"max_average_window"
);
int64_t
max_average_window
=
ctx
.
Attr
<
int64_t
>
(
"max_average_window"
);
int64_t
min_average_window
=
ctx
.
Attr
<
int64_t
>
(
"min_average_window"
);
int64_t
min_average_window
=
ctx
.
Attr
<
int64_t
>
(
"min_average_window"
);
PADDLE_ENFORCE_LE
(
min_average_window
,
max_average_window
,
PADDLE_ENFORCE_LE
(
"min_average_window shouldn't be larger than "
min_average_window
,
max_average_window
,
"max_average_window"
);
platform
::
errors
::
InvalidArgument
(
"The min_average_window > "
"max_average_window is not right, min_average_window is %ld, "
"max_average_window is %ld."
,
min_average_window
,
max_average_window
));
// Get inputs
// Get inputs
auto
*
param
=
ctx
.
Input
<
Tensor
>
(
"param"
);
auto
*
param
=
ctx
.
Input
<
Tensor
>
(
"param"
);
...
...
paddle/fluid/operators/empty_op.cc
浏览文件 @
6ef1fbb6
...
@@ -55,31 +55,38 @@ class EmptyOp : public framework::OperatorWithKernel {
...
@@ -55,31 +55,38 @@ class EmptyOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"empty"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"empty"
);
if
(
context
->
HasInput
(
"ShapeTensor"
))
{
if
(
context
->
HasInput
(
"ShapeTensor"
))
{
auto
dims
=
context
->
GetInputDim
(
"ShapeTensor"
);
auto
shape_
dims
=
context
->
GetInputDim
(
"ShapeTensor"
);
int
num_ele
=
1
;
int
num_ele
=
1
;
for
(
int
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
shape_
dims
.
size
();
++
i
)
{
num_ele
*=
dims
[
i
];
num_ele
*=
shape_
dims
[
i
];
}
}
auto
vec_dims
=
std
::
vector
<
int
>
(
num_ele
,
-
1
);
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
{
num_ele
}
));
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
vec_dims
));
}
else
if
(
context
->
HasInputs
(
"ShapeTensorList"
))
{
}
else
if
(
context
->
HasInputs
(
"ShapeTensorList"
))
{
std
::
vector
<
int
>
out_dims
;
std
::
vector
<
int
>
out_dims
;
auto
dims_list
=
context
->
GetInputsDim
(
"ShapeTensorList"
);
auto
dims_list
=
context
->
GetInputsDim
(
"ShapeTensorList"
);
for
(
size_t
i
=
0
;
i
<
dims_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dims_list
.
size
();
++
i
)
{
auto
&
dims
=
dims_list
[
i
];
auto
&
dims
=
dims_list
[
i
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dims
,
framework
::
make_ddim
({
1
}),
dims
,
framework
::
make_ddim
({
1
}),
platform
::
errors
::
InvalidArgument
(
"ShapeError: The shape of Tensor in list must be [1]. "
"The shape of Tensor in list must be [1]. "
"But received the shape "
"But received the shape is [%s]"
,
"is [%s]"
,
dims
));
dims
);
out_dims
.
push_back
(
-
1
);
out_dims
.
push_back
(
dims
[
0
]);
}
}
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims
));
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims
));
}
else
{
}
else
{
auto
&
shape
=
context
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"shape"
);
auto
&
shape
=
context
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"shape"
);
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
PADDLE_ENFORCE_GE
(
shape
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s]."
,
i
,
shape
[
i
],
framework
::
make_ddim
(
shape
)));
}
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
shape
));
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
shape
));
}
}
}
}
...
...
paddle/fluid/operators/math/beam_search.cc
浏览文件 @
6ef1fbb6
...
@@ -87,7 +87,10 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
...
@@ -87,7 +87,10 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
lod
[
0
].
assign
(
high_level
.
begin
(),
high_level
.
end
());
lod
[
0
].
assign
(
high_level
.
begin
(),
high_level
.
end
());
lod
[
1
].
assign
(
low_level
.
begin
(),
low_level
.
end
());
lod
[
1
].
assign
(
low_level
.
begin
(),
low_level
.
end
());
if
(
!
framework
::
CheckLoD
(
lod
))
{
if
(
!
framework
::
CheckLoD
(
lod
))
{
PADDLE_THROW
(
"lod %s is not right"
,
framework
::
LoDToString
(
lod
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"lod %s is not right in"
" beam_search, please check your code."
,
framework
::
LoDToString
(
lod
)));
}
}
selected_ids
->
set_lod
(
lod
);
selected_ids
->
set_lod
(
lod
);
selected_scores
->
set_lod
(
lod
);
selected_scores
->
set_lod
(
lod
);
...
...
paddle/fluid/operators/math/beam_search.cu
浏览文件 @
6ef1fbb6
...
@@ -400,7 +400,10 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
...
@@ -400,7 +400,10 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
context
.
Wait
();
context
.
Wait
();
if
(
!
framework
::
CheckLoD
(
selected_lod
))
{
if
(
!
framework
::
CheckLoD
(
selected_lod
))
{
PADDLE_THROW
(
"lod %s is not right"
,
framework
::
LoDToString
(
selected_lod
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"lod %s is not right in"
" beam_search, please check your code."
,
framework
::
LoDToString
(
selected_lod
)));
}
}
selected_ids
->
set_lod
(
selected_lod
);
selected_ids
->
set_lod
(
selected_lod
);
...
...
paddle/fluid/operators/math/blas.cc
浏览文件 @
6ef1fbb6
...
@@ -20,7 +20,11 @@ namespace operators {
...
@@ -20,7 +20,11 @@ namespace operators {
namespace
math
{
namespace
math
{
MatDescriptor
CreateMatrixDescriptor
(
const
framework
::
DDim
&
tensor_dim
,
MatDescriptor
CreateMatrixDescriptor
(
const
framework
::
DDim
&
tensor_dim
,
int
num_flatten_cols
,
bool
trans
)
{
int
num_flatten_cols
,
bool
trans
)
{
PADDLE_ENFORCE_GT
(
tensor_dim
.
size
(),
1
);
PADDLE_ENFORCE_GT
(
tensor_dim
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The tensor dim size should be greater "
"than 1, but reveived dim size is %d"
,
tensor_dim
.
size
()));
MatDescriptor
retv
;
MatDescriptor
retv
;
if
(
num_flatten_cols
>
1
)
{
if
(
num_flatten_cols
>
1
)
{
auto
flatten_dim
=
framework
::
flatten_to_2d
(
tensor_dim
,
num_flatten_cols
);
auto
flatten_dim
=
framework
::
flatten_to_2d
(
tensor_dim
,
num_flatten_cols
);
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
6ef1fbb6
...
@@ -60,7 +60,8 @@ struct CUBlas<float> {
...
@@ -60,7 +60,8 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasSgemmStridedBatched
(
args
...));
platform
::
dynload
::
cublasSgemmStridedBatched
(
args
...));
#else
#else
PADDLE_THROW
(
"SgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"SgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -85,7 +86,8 @@ struct CUBlas<float> {
...
@@ -85,7 +86,8 @@ struct CUBlas<float> {
beta
,
C
,
Ctype
,
ldc
));
beta
,
C
,
Ctype
,
ldc
));
});
});
#else
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasSgemmEx is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -146,13 +148,15 @@ struct CUBlas<double> {
...
@@ -146,13 +148,15 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasDgemmStridedBatched
(
args
...));
platform
::
dynload
::
cublasDgemmStridedBatched
(
args
...));
#else
#else
PADDLE_THROW
(
"DgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
GEMM_EX
(
ARGS
...
args
)
{
static
void
GEMM_EX
(
ARGS
...
args
)
{
PADDLE_THROW
(
"Currently there are not cublasDgemmEx."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently there are not cublasDgemmEx."
));
}
}
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
...
@@ -216,7 +220,8 @@ struct CUBlas<platform::float16> {
...
@@ -216,7 +220,8 @@ struct CUBlas<platform::float16> {
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
reinterpret_cast
<
const
__half
*>
(
beta
),
reinterpret_cast
<
__half
*>
(
C
),
ldc
,
strideC
,
batchCount
));
ldc
,
strideC
,
batchCount
));
#else
#else
PADDLE_THROW
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"HgemmStridedBatched is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
...
@@ -247,7 +252,8 @@ struct CUBlas<platform::float16> {
...
@@ -247,7 +252,8 @@ struct CUBlas<platform::float16> {
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
});
#else
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasGemmEx is not supported on cuda <= 7.5"
));
#endif
#endif
}
}
};
};
...
@@ -302,8 +308,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -302,8 +308,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
// TODO(kexinzhao): add processing code for compute capability < 53 case
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
53
,
PADDLE_ENFORCE_GE
(
"cublas fp16 gemm requires GPU compute capability >= 53"
);
context_
.
GetComputeCapability
(),
53
,
platform
::
errors
::
InvalidArgument
(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_alpha
=
static_cast
<
float
>
(
alpha
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
float
h_beta
=
static_cast
<
float
>
(
beta
);
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
6ef1fbb6
...
@@ -29,7 +29,8 @@ template <>
...
@@ -29,7 +29,8 @@ template <>
struct
CBlas
<
int8_t
>
{
struct
CBlas
<
int8_t
>
{
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
VCOPY
(
ARGS
...
args
)
{
static
void
VCOPY
(
ARGS
...
args
)
{
PADDLE_THROW
(
"Blas VCOPY don't support int8_t"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Blas VCOPY do not supported on CPU, please check your code"
));
}
}
};
};
...
@@ -347,22 +348,47 @@ struct CBlas<double> {
...
@@ -347,22 +348,47 @@ struct CBlas<double> {
template
<
>
template
<
>
struct
CBlas
<
platform
::
float16
>
{
struct
CBlas
<
platform
::
float16
>
{
static
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
static
void
GEMM
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 GEMM not supported on CPU, please check your code"
));
}
static
void
SMM_GEMM
(...)
{
static
void
SMM_GEMM
(...)
{
PADDLE_THROW
(
"float16 SMM_GEMM not supported on CPU"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 SMM_GEMM not supported on CPU, please check your code"
));
}
}
static
void
VMUL
(...)
{
PADDLE_THROW
(
"float16 VMUL not supported on CPU"
);
}
static
void
VMUL
(...)
{
static
void
VEXP
(...)
{
PADDLE_THROW
(
"float16 VEXP not supported on CPU"
);
}
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
static
void
VSQUARE
(...)
{
"float16 VMUL not supported on CPU, please check your code"
));
PADDLE_THROW
(
"float16 VSQUARE not supported on CPU"
);
}
}
static
void
VPOW
(...)
{
PADDLE_THROW
(
"float16 VPOW not supported on CPU"
);
}
static
void
VEXP
(...)
{
static
void
DOT
(...)
{
PADDLE_THROW
(
"float16 DOT not supported on CPU"
);
};
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
static
void
SCAL
(...)
{
PADDLE_THROW
(
"float16 SCAL not supported on CPU"
);
};
"float16 VEXP not supported on CPU, please check your code"
));
static
void
ASUM
(...)
{
PADDLE_THROW
(
"float16 ASUM not supported on CPU"
);
};
}
static
void
VSQUARE
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 VSQUARE not supported on CPU, please check your code"
));
}
static
void
VPOW
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 VPOW not supported on CPU, please check your code"
));
}
static
void
DOT
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 DOT not supported on CPU, please check your code"
));
};
static
void
SCAL
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 SCAL not supported on CPU, please check your code"
));
};
static
void
ASUM
(...)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 ASUM not supported on CPU, please check your code"
));
};
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
static
void
GEMM_BATCH
(...)
{
static
void
GEMM_BATCH
(...)
{
PADDLE_THROW
(
"float16 GEMM_BATCH not supported on CPU"
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"float16 GEMM_BATCH not supported on CPU, please check your code"
));
}
}
#endif
#endif
};
};
...
@@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
...
@@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
auto
dim_a
=
mat_a
.
dims
();
auto
dim_a
=
mat_a
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_out
=
mat_out
->
dims
();
auto
dim_out
=
mat_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
PADDLE_ENFORCE_EQ
(
"The input and output of matmul be matrix"
);
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
true
,
PADDLE_ENFORCE
(
platform
::
errors
::
InvalidArgument
(
mat_a
.
place
()
==
mat_b
.
place
()
&&
mat_a
.
place
()
==
mat_out
->
place
(),
"The input and output of matmul should be matrix, the dim size must "
"The places of matrices must be same"
);
"be 2,"
"but received dim size input_a:%d, input_b:%d, output:%d"
,
dim_a
.
size
(),
dim_b
.
size
(),
dim_out
.
size
()));
PADDLE_ENFORCE_EQ
(
mat_a
.
place
()
==
mat_b
.
place
()
&&
mat_a
.
place
()
==
mat_out
->
place
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The places of matrices in the matmul "
"should be same, please check your "
"code."
));
int
M
=
dim_out
[
0
];
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
N
=
dim_out
[
1
];
...
@@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
...
@@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
W1
,
H2
);
PADDLE_ENFORCE_EQ
(
W1
,
H2
,
platform
::
errors
::
InvalidArgument
(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d"
,
W1
,
H2
));
int
ldc
=
W2
*
head_number
;
int
ldc
=
W2
*
head_number
;
int
sub_width
=
W1
/
head_number
;
int
sub_width
=
W1
/
head_number
;
...
@@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
...
@@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const
framework
::
Tensor
&
mat_b
,
const
framework
::
Tensor
&
mat_b
,
const
MatDescriptor
&
dim_b
,
T
alpha
,
const
MatDescriptor
&
dim_b
,
T
alpha
,
framework
::
Tensor
*
mat_out
,
T
beta
)
const
{
framework
::
Tensor
*
mat_out
,
T
beta
)
const
{
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
,
dim_b
.
height_
);
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
,
dim_b
.
height_
,
platform
::
errors
::
InvalidArgument
(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d"
,
dim_a
.
width_
,
dim_b
.
height_
));
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
...
@@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
...
@@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
());
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
());
}
else
{
}
else
{
PADDLE_ENFORCE
(
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
PADDLE_ENFORCE_EQ
(
dim_a
.
batch_size_
==
0
||
dim_b
.
batch_size_
==
0
,
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
dim_a
.
batch_size_
==
0
||
"dim_a.batch_size should be equal to dim_b.batch_size, or "
dim_b
.
batch_size_
==
0
,
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
true
,
platform
::
errors
::
InvalidArgument
(
"But got dim_a.batch_size = %d, dim_b.batch_size = %d."
,
"dim_a.batch_size should be equal to dim_b.batch_size, or "
dim_a
.
batch_size_
,
dim_b
.
batch_size_
);
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
"But got dim_a.batch_size = %d, dim_b.batch_size = %d."
,
dim_a
.
batch_size_
,
dim_b
.
batch_size_
));
this
->
template
BatchedGEMM
<
T
>(
this
->
template
BatchedGEMM
<
T
>(
transA
,
transB
,
dim_a
.
height_
,
dim_b
.
width_
,
dim_a
.
width_
,
alpha
,
transA
,
transB
,
dim_a
.
height_
,
dim_b
.
width_
,
dim_a
.
width_
,
alpha
,
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
(),
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
(),
...
@@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
...
@@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
int
head_number
,
int
head_number
,
framework
::
Tensor
*
mat_out
,
T
beta
,
framework
::
Tensor
*
mat_out
,
T
beta
,
bool
mat_b_split_vertical
)
const
{
bool
mat_b_split_vertical
)
const
{
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
%
head_number
,
0
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_GE
(
head_number
,
1
);
dim_a
.
width_
%
head_number
,
0
,
PADDLE_ENFORCE_LE
(
head_number
,
dim_a
.
width_
);
platform
::
errors
::
InvalidArgument
(
"The first input width must be some times the head number"
"but received first input width %d"
", head_number %d"
,
dim_a
.
width_
,
head_number
));
PADDLE_ENFORCE_GE
(
head_number
,
1
,
platform
::
errors
::
InvalidArgument
(
"The head number should be greater equal 1,"
"but received head number %d"
,
head_number
));
PADDLE_ENFORCE_LE
(
head_number
,
dim_a
.
width_
,
platform
::
errors
::
InvalidArgument
(
"The head number should be less equal first input width,"
"but received first input width %d"
", head_number %d"
,
dim_a
.
width_
,
head_number
));
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transA
=
!
dim_a
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
dim_b
.
trans_
?
CblasNoTrans
:
CblasTrans
;
if
(
mat_b_split_vertical
)
{
if
(
mat_b_split_vertical
)
{
PADDLE_ENFORCE_EQ
(
dim_b
.
height_
,
dim_a
.
width_
/
head_number
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dim_b
.
width_
%
head_number
,
0
);
dim_b
.
height_
,
dim_a
.
width_
/
head_number
,
platform
::
errors
::
InvalidArgument
(
"The second input height should be equal than first input width,"
"but received second input height %d, first input width %d"
,
dim_b
.
height_
,
dim_a
.
width_
/
head_number
));
PADDLE_ENFORCE_EQ
(
dim_a
.
width_
%
head_number
,
0
,
platform
::
errors
::
InvalidArgument
(
"The second input width should be some times the head number"
"but received second input width %d"
", head_number %d"
,
dim_b
.
width_
,
head_number
));
}
}
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
if
(
dim_a
.
batch_size_
==
0
&&
dim_b
.
batch_size_
==
0
)
{
...
@@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
...
@@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
mat_out
->
data
<
T
>
()
+
sub_matC_offset
,
ldc
);
mat_out
->
data
<
T
>
()
+
sub_matC_offset
,
ldc
);
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
((
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
PADDLE_ENFORCE_EQ
(
dim_a
.
batch_size_
==
0
||
dim_b
.
batch_size_
==
0
),
(
dim_a
.
batch_size_
==
dim_b
.
batch_size_
||
dim_a
.
batch_size_
==
0
||
true
);
dim_b
.
batch_size_
==
0
),
true
,
platform
::
errors
::
InvalidArgument
(
"The first input batch size should be equal than second input,"
"either two input batch size is 0, but received first input batch "
"size"
" %d, second input batch size %d"
,
dim_a
.
batch_size_
,
dim_b
.
batch_size_
));
this
->
template
BatchedGEMMWithHead
<
T
>(
this
->
template
BatchedGEMMWithHead
<
T
>(
transA
,
transB
,
dim_a
.
width_
,
dim_a
.
height_
,
dim_b
.
width_
,
transA
,
transB
,
dim_a
.
width_
,
dim_a
.
height_
,
dim_b
.
width_
,
...
...
paddle/fluid/operators/shape_op.cc
浏览文件 @
6ef1fbb6
...
@@ -68,6 +68,6 @@ REGISTER_OPERATOR(
...
@@ -68,6 +68,6 @@ REGISTER_OPERATOR(
shape
,
ops
::
ShapeOp
,
ops
::
ShapeOpMaker
,
shape
,
ops
::
ShapeOp
,
ops
::
ShapeOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
shape
,
ops
::
ShapeKernel
<
int
>
,
ops
::
ShapeKernel
<
int32_
t
>
,
REGISTER_OP_CPU_KERNEL
(
shape
,
ops
::
ShapeKernel
<
bool
>
,
ops
::
ShapeKernel
<
in
t
>
,
ops
::
ShapeKernel
<
int64_t
>
,
ops
::
ShapeKernel
<
float
>
,
ops
::
ShapeKernel
<
int64_t
>
,
ops
::
ShapeKernel
<
float
>
,
ops
::
ShapeKernel
<
double
>
);
ops
::
ShapeKernel
<
double
>
);
paddle/fluid/operators/shape_op.cu
浏览文件 @
6ef1fbb6
...
@@ -15,8 +15,8 @@ limitations under the License. */
...
@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
shape
,
paddle
::
operators
::
ShapeKernel
<
int
>
,
shape
,
paddle
::
operators
::
ShapeKernel
<
bool
>
,
paddle
::
operators
::
ShapeKernel
<
int
32_t
>
,
paddle
::
operators
::
ShapeKernel
<
int
>
,
paddle
::
operators
::
ShapeKernel
<
int64_t
>
,
paddle
::
operators
::
ShapeKernel
<
int64_t
>
,
paddle
::
operators
::
ShapeKernel
<
float
>
,
paddle
::
operators
::
ShapeKernel
<
float
>
,
paddle
::
operators
::
ShapeKernel
<
double
>
,
paddle
::
operators
::
ShapeKernel
<
double
>
,
...
...
python/paddle/__init__.py
浏览文件 @
6ef1fbb6
...
@@ -77,6 +77,7 @@ from .tensor.creation import triu #DEFINE_ALIAS
...
@@ -77,6 +77,7 @@ from .tensor.creation import triu #DEFINE_ALIAS
from
.tensor.creation
import
tril
#DEFINE_ALIAS
from
.tensor.creation
import
tril
#DEFINE_ALIAS
from
.tensor.creation
import
meshgrid
#DEFINE_ALIAS
from
.tensor.creation
import
meshgrid
#DEFINE_ALIAS
from
.tensor.creation
import
empty
#DEFINE_ALIAS
from
.tensor.creation
import
empty
#DEFINE_ALIAS
from
.tensor.creation
import
empty_like
#DEFINE_ALIAS
from
.tensor.linalg
import
matmul
#DEFINE_ALIAS
from
.tensor.linalg
import
matmul
#DEFINE_ALIAS
from
.tensor.linalg
import
dot
#DEFINE_ALIAS
from
.tensor.linalg
import
dot
#DEFINE_ALIAS
# from .tensor.linalg import einsum #DEFINE_ALIAS
# from .tensor.linalg import einsum #DEFINE_ALIAS
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
6ef1fbb6
...
@@ -67,6 +67,7 @@ class ImperativeQuantAware(object):
...
@@ -67,6 +67,7 @@ class ImperativeQuantAware(object):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization
\
from paddle.fluid.contrib.slim.quantization
\
import ImperativeQuantAware
import ImperativeQuantAware
from paddle.vision.models
\
from paddle.vision.models
\
...
@@ -86,13 +87,12 @@ class ImperativeQuantAware(object):
...
@@ -86,13 +87,12 @@ class ImperativeQuantAware(object):
# ...
# ...
# Save quant model for the inference.
# Save quant model for the inference.
imperative_qat.save_quantized_model(
paddle.jit.save(
dirname="./resnet50_qat",
layer=model,
model=model,
model_path="./resnet50_qat",
input_shape=[(3, 224, 224)],
input_spec=[
input_dtype=['float32'],
paddle.static.InputSpec(
feed=[0],
shape=[None, 3, 224, 224], dtype='float32')])
fetch=[0])
"""
"""
super
(
ImperativeQuantAware
,
self
).
__init__
()
super
(
ImperativeQuantAware
,
self
).
__init__
()
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
...
@@ -148,75 +148,6 @@ class ImperativeQuantAware(object):
...
@@ -148,75 +148,6 @@ class ImperativeQuantAware(object):
quant_layer
=
self
.
_get_quantized_counterpart
(
layer
)
quant_layer
=
self
.
_get_quantized_counterpart
(
layer
)
setattr
(
obj
,
target
,
quant_layer
)
setattr
(
obj
,
target
,
quant_layer
)
def
save_quantized_model
(
self
,
dirname
,
model
,
input_shape
,
input_dtype
,
feed
,
fetch
,
append_batch_size
=
True
):
"""
Save the quantized model for the inference.
Args:
dirname (str): the directory to save the quantized model.
model(fluid.dygraph.Layer): the quantized model to be saved.
input_shape(list[tuple(int)]): The shape value for each input,
e.g. [(3, 224, 224)].
input_dtype(list[str]): The dtype value for each input,
e.g. ['float32'].
feed(list[int]): the indices of the input variables of the
imperative functions which will be saved as input variables in
inference model.
fetch(list[int]): the indices of the returned variable of the
imperative functions which will be saved as output variables in
inference model.
append_batch_size(bool, optional):
If true, it prepends an extra axis to the input_shape, meanwhile,
the input_shape shouldn't contain the batch size dimension.
Otherwise, it just uses the input_shape. Default True.
Returns:
None
"""
assert
isinstance
(
input_shape
,
list
),
"The parameter `input_shape` shoubld be a list."
assert
isinstance
(
input_dtype
,
list
),
"The parameter `input_dtype` shoubld be a list."
assert
isinstance
(
feed
,
list
),
"The parameter `feed` shoubld be a list."
assert
isinstance
(
fetch
,
list
),
"The parameter `fetch` shoubld be a list."
assert
len
(
input_shape
)
==
len
(
input_dtype
),
"The length of input_shape should be equal to input_dtype's."
assert
len
(
input_dtype
)
==
len
(
feed
),
"The length of input_shape should be equal to feed's."
with
dygraph
.
guard
():
model
.
eval
()
input_vars
=
[]
for
i
,
(
shape
,
dtype
)
in
enumerate
(
zip
(
input_shape
,
input_dtype
)):
if
append_batch_size
:
shape
=
[
None
]
+
list
(
shape
)
# Note(Aurelius84): need a elegant way to name this.
in_spec
=
paddle
.
static
.
InputSpec
(
shape
,
dtype
,
'feed_%d'
%
i
)
input_vars
.
append
(
in_spec
)
# use `declarative` to convert dygraph into static program
model
.
forward
=
dygraph
.
jit
.
declarative
(
model
.
forward
,
input_spec
=
input_vars
)
outputs
=
model
.
forward
.
concrete_program
.
outputs
input_spec
=
[
input_vars
[
i
]
for
i
in
feed
]
configs
=
dygraph
.
jit
.
SaveLoadConfig
()
configs
.
separate_params
=
True
if
not
isinstance
(
outputs
,
(
tuple
,
list
)):
outputs
=
[
outputs
]
configs
.
output_spec
=
[
outputs
[
i
]
for
i
in
fetch
]
dygraph
.
jit
.
save
(
layer
=
model
,
model_path
=
dirname
,
input_spec
=
input_spec
,
configs
=
configs
)
def
_get_quantized_counterpart
(
self
,
layer
):
def
_get_quantized_counterpart
(
self
,
layer
):
quant_layers
=
tuple
(
self
.
_quant_layers_map
.
values
())
quant_layers
=
tuple
(
self
.
_quant_layers_map
.
values
())
quantized_counterpart
=
tuple
(
'Quantized'
+
k
quantized_counterpart
=
tuple
(
'Quantized'
+
k
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
浏览文件 @
6ef1fbb6
...
@@ -221,7 +221,7 @@ class TestImperativeQat(unittest.TestCase):
...
@@ -221,7 +221,7 @@ class TestImperativeQat(unittest.TestCase):
model_dict
=
lenet
.
state_dict
()
model_dict
=
lenet
.
state_dict
()
fluid
.
save_dygraph
(
model_dict
,
"save_temp"
)
fluid
.
save_dygraph
(
model_dict
,
"save_temp"
)
# test the correctness of `
save_quantized_model
`
# test the correctness of `
paddle.jit.save
`
data
=
next
(
test_reader
())
data
=
next
(
test_reader
())
test_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
test_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
for
x
in
data
]).
astype
(
'float32'
)
...
@@ -231,13 +231,14 @@ class TestImperativeQat(unittest.TestCase):
...
@@ -231,13 +231,14 @@ class TestImperativeQat(unittest.TestCase):
# save inference quantized model
# save inference quantized model
path
=
"./mnist_infer_model"
path
=
"./mnist_infer_model"
imperative_qat
.
save_quantized_model
(
paddle
.
jit
.
save
(
dirname
=
path
,
layer
=
lenet
,
model
=
lenet
,
model_path
=
path
,
input_shape
=
[(
1
,
28
,
28
)],
input_spec
=
[
input_dtype
=
[
'float32'
],
paddle
.
static
.
InputSpec
(
feed
=
[
0
],
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
fetch
=
[
0
])
])
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
else
:
else
:
...
@@ -245,7 +246,10 @@ class TestImperativeQat(unittest.TestCase):
...
@@ -245,7 +246,10 @@ class TestImperativeQat(unittest.TestCase):
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
fluid
.
io
.
load_inference_model
(
fluid
.
io
.
load_inference_model
(
dirname
=
path
,
executor
=
exe
))
dirname
=
path
,
executor
=
exe
,
model_filename
=
"__model__"
,
params_filename
=
"__variables__"
))
after_save
,
=
exe
.
run
(
inference_program
,
after_save
,
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
test_data
},
feed
=
{
feed_target_names
[
0
]:
test_data
},
fetch_list
=
fetch_targets
)
fetch_list
=
fetch_targets
)
...
@@ -332,13 +336,13 @@ class TestImperativeQat(unittest.TestCase):
...
@@ -332,13 +336,13 @@ class TestImperativeQat(unittest.TestCase):
if
batch_id
%
100
==
0
:
if
batch_id
%
100
==
0
:
_logger
.
info
(
'{}: {}'
.
format
(
'loss'
,
avg_loss
.
numpy
()))
_logger
.
info
(
'{}: {}'
.
format
(
'loss'
,
avg_loss
.
numpy
()))
imperative_qat
.
save_quantized_model
(
paddle
.
jit
.
save
(
dirname
=
"./dynamic_mnist"
,
layer
=
lenet
,
model
=
lenet
,
model
_path
=
"./dynamic_mnist"
,
input_s
hape
=
[(
1
,
28
,
28
)],
input_s
pec
=
[
input_dtype
=
[
'float32'
],
paddle
.
static
.
InputSpec
(
feed
=
[
0
],
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
fetch
=
[
0
])
])
# static graph train
# static graph train
_logger
.
info
(
_logger
.
info
(
...
...
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
6ef1fbb6
...
@@ -60,7 +60,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
...
@@ -60,7 +60,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
def
transfer_from_node_type
(
self
,
node_wrapper
):
def
transfer_from_node_type
(
self
,
node_wrapper
):
translator_logger
=
logging_utils
.
TranslatorLogger
()
translator_logger
=
logging_utils
.
TranslatorLogger
()
translator_logger
.
log
(
translator_logger
.
log
(
1
,
"
Source code:
\n
{}"
.
format
(
ast_to_source_code
(
self
.
root
)))
1
,
"Source code:
\n
{}"
.
format
(
ast_to_source_code
(
self
.
root
)))
# Generic transformation
# Generic transformation
self
.
visit
(
node_wrapper
.
node
)
self
.
visit
(
node_wrapper
.
node
)
...
...
python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py
浏览文件 @
6ef1fbb6
...
@@ -12,17 +12,18 @@
...
@@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
logging
import
six
import
six
import
inspect
import
inspect
import
numpy
as
np
import
numpy
as
np
import
collections
import
collections
import
paddle
import
paddle
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
pack_sequence_as
from
paddle.fluid.layers.utils
import
pack_sequence_as
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
parse_arg_and_kwargs
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
parse_arg_and_kwargs
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
type_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
type_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
func_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
func_to_source_code
...
@@ -291,7 +292,7 @@ def convert_to_input_spec(inputs, input_spec):
...
@@ -291,7 +292,7 @@ def convert_to_input_spec(inputs, input_spec):
if
len
(
inputs
)
>
len
(
input_spec
):
if
len
(
inputs
)
>
len
(
input_spec
):
for
rest_input
in
inputs
[
len
(
input_spec
):]:
for
rest_input
in
inputs
[
len
(
input_spec
):]:
if
isinstance
(
rest_input
,
(
core
.
VarBase
,
np
.
ndarray
)):
if
isinstance
(
rest_input
,
(
core
.
VarBase
,
np
.
ndarray
)):
logging
.
warning
(
logging
_utils
.
warn
(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs."
.
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs."
.
format
(
type_name
(
rest_input
)))
format
(
type_name
(
rest_input
)))
...
...
python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py
浏览文件 @
6ef1fbb6
...
@@ -26,6 +26,8 @@ CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
...
@@ -26,6 +26,8 @@ CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
DEFAULT_VERBOSITY
=
-
1
DEFAULT_VERBOSITY
=
-
1
DEFAULT_CODE_LEVEL
=
-
1
DEFAULT_CODE_LEVEL
=
-
1
LOG_AllTransformer
=
100
def
synchronized
(
func
):
def
synchronized
(
func
):
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
...
@@ -53,10 +55,15 @@ class TranslatorLogger(object):
...
@@ -53,10 +55,15 @@ class TranslatorLogger(object):
return
return
self
.
_initialized
=
True
self
.
_initialized
=
True
self
.
logger_name
=
"Dynamic-to-Static"
self
.
_logger
=
log_helper
.
get_logger
(
self
.
_logger
=
log_helper
.
get_logger
(
__name__
,
1
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
self
.
logger_name
,
1
,
fmt
=
'%(asctime)s %(name)s %(levelname)s: %(message)s'
)
self
.
_verbosity_level
=
None
self
.
_verbosity_level
=
None
self
.
_transformed_code_level
=
None
self
.
_transformed_code_level
=
None
self
.
_need_to_echo_log_to_stdout
=
None
self
.
_need_to_echo_code_to_stdout
=
None
@
property
@
property
def
logger
(
self
):
def
logger
(
self
):
...
@@ -86,6 +93,28 @@ class TranslatorLogger(object):
...
@@ -86,6 +93,28 @@ class TranslatorLogger(object):
self
.
check_level
(
level
)
self
.
check_level
(
level
)
self
.
_transformed_code_level
=
level
self
.
_transformed_code_level
=
level
@
property
def
need_to_echo_log_to_stdout
(
self
):
if
self
.
_need_to_echo_log_to_stdout
is
not
None
:
return
self
.
_need_to_echo_log_to_stdout
return
False
@
need_to_echo_log_to_stdout
.
setter
def
need_to_echo_log_to_stdout
(
self
,
log_to_stdout
):
assert
isinstance
(
log_to_stdout
,
(
bool
,
type
(
None
)))
self
.
_need_to_echo_log_to_stdout
=
log_to_stdout
@
property
def
need_to_echo_code_to_stdout
(
self
):
if
self
.
_need_to_echo_code_to_stdout
is
not
None
:
return
self
.
_need_to_echo_code_to_stdout
return
False
@
need_to_echo_code_to_stdout
.
setter
def
need_to_echo_code_to_stdout
(
self
,
code_to_stdout
):
assert
isinstance
(
code_to_stdout
,
(
bool
,
type
(
None
)))
self
.
_need_to_echo_code_to_stdout
=
code_to_stdout
def
check_level
(
self
,
level
):
def
check_level
(
self
,
level
):
if
isinstance
(
level
,
(
six
.
integer_types
,
type
(
None
))):
if
isinstance
(
level
,
(
six
.
integer_types
,
type
(
None
))):
rv
=
level
rv
=
level
...
@@ -110,34 +139,56 @@ class TranslatorLogger(object):
...
@@ -110,34 +139,56 @@ class TranslatorLogger(object):
def
error
(
self
,
msg
,
*
args
,
**
kwargs
):
def
error
(
self
,
msg
,
*
args
,
**
kwargs
):
self
.
logger
.
error
(
msg
,
*
args
,
**
kwargs
)
self
.
logger
.
error
(
msg
,
*
args
,
**
kwargs
)
if
self
.
need_to_echo_log_to_stdout
:
self
.
_output_to_stdout
(
'ERROR: '
+
msg
,
*
args
)
def
warn
(
self
,
msg
,
*
args
,
**
kwargs
):
def
warn
(
self
,
msg
,
*
args
,
**
kwargs
):
self
.
logger
.
warn
(
msg
,
*
args
,
**
kwargs
)
self
.
logger
.
warning
(
msg
,
*
args
,
**
kwargs
)
if
self
.
need_to_echo_log_to_stdout
:
self
.
_output_to_stdout
(
'WARNING: '
+
msg
,
*
args
)
def
log
(
self
,
level
,
msg
,
*
args
,
**
kwargs
):
def
log
(
self
,
level
,
msg
,
*
args
,
**
kwargs
):
if
self
.
has_verbosity
(
level
):
if
self
.
has_verbosity
(
level
):
self
.
logger
.
log
(
level
,
msg
,
*
args
,
**
kwargs
)
msg_with_level
=
'(Level {}) {}'
.
format
(
level
,
msg
)
self
.
logger
.
info
(
msg_with_level
,
*
args
,
**
kwargs
)
if
self
.
need_to_echo_log_to_stdout
:
self
.
_output_to_stdout
(
'INFO: '
+
msg_with_level
,
*
args
)
def
log_transformed_code
(
self
,
level
,
ast_node
,
transformer_name
,
*
args
,
def
log_transformed_code
(
self
,
level
,
ast_node
,
transformer_name
,
*
args
,
**
kwargs
):
**
kwargs
):
if
self
.
has_code_level
(
level
):
if
self
.
has_code_level
(
level
):
source_code
=
ast_to_source_code
(
ast_node
)
source_code
=
ast_to_source_code
(
ast_node
)
header_msg
=
"After the level {} ast transformer: '{}', the transformed code:
\n
"
\
if
level
==
LOG_AllTransformer
:
.
format
(
level
,
transformer_name
)
header_msg
=
"After the last level ast transformer: '{}', the transformed code:
\n
"
\
.
format
(
transformer_name
)
else
:
header_msg
=
"After the level {} ast transformer: '{}', the transformed code:
\n
"
\
.
format
(
level
,
transformer_name
)
msg
=
header_msg
+
source_code
msg
=
header_msg
+
source_code
self
.
logger
.
info
(
msg
,
*
args
,
**
kwargs
)
self
.
logger
.
info
(
msg
,
*
args
,
**
kwargs
)
if
self
.
need_to_echo_code_to_stdout
:
self
.
_output_to_stdout
(
'INFO: '
+
msg
,
*
args
)
def
_output_to_stdout
(
self
,
msg
,
*
args
):
msg
=
self
.
logger_name
+
' '
+
msg
print
(
msg
%
args
)
_TRANSLATOR_LOGGER
=
TranslatorLogger
()
_TRANSLATOR_LOGGER
=
TranslatorLogger
()
def
set_verbosity
(
level
=
0
):
def
set_verbosity
(
level
=
0
,
also_to_stdout
=
False
):
"""
"""
Sets the verbosity level of log for dygraph to static graph.
Sets the verbosity level of log for dygraph to static graph. Logs can be output to stdout by setting `also_to_stdout`.
There are two means to set the logging verbosity:
There are two means to set the logging verbosity:
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
**Note**:
**Note**:
`set_verbosity` has a higher priority than the environment variable.
`set_verbosity` has a higher priority than the environment variable.
...
@@ -145,6 +196,7 @@ def set_verbosity(level=0):
...
@@ -145,6 +196,7 @@ def set_verbosity(level=0):
Args:
Args:
level(int): The verbosity level. The larger value idicates more verbosity.
level(int): The verbosity level. The larger value idicates more verbosity.
The default value is 0, which means no logging.
The default value is 0, which means no logging.
also_to_stdout(bool): Whether to also output log messages to `sys.stdout`.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -159,27 +211,30 @@ def set_verbosity(level=0):
...
@@ -159,27 +211,30 @@ def set_verbosity(level=0):
# The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity`
# The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity`
"""
"""
_TRANSLATOR_LOGGER
.
verbosity_level
=
level
_TRANSLATOR_LOGGER
.
verbosity_level
=
level
_TRANSLATOR_LOGGER
.
need_to_echo_log_to_stdout
=
also_to_stdout
def
get_verbosity
():
def
get_verbosity
():
return
_TRANSLATOR_LOGGER
.
verbosity_level
return
_TRANSLATOR_LOGGER
.
verbosity_level
LOG_AllTransformer
=
100
def
set_code_level
(
level
=
LOG_AllTransformer
,
also_to_stdout
=
False
):
def
set_code_level
(
level
=
LOG_AllTransformer
):
"""
"""
Sets the level to print code from specific level of Ast Transformer.
Sets the level to print code from specific level Ast Transformer. Code can be output to stdout by setting `also_to_stdout`.
There are two means to set the code level:
There are two means to set the code level:
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
**Note**:
**Note**:
`set_code_level` has a higher priority than the environment variable.
`set_code_level` has a higher priority than the environment variable.
Args:
Args:
level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers.
level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers.
also_to_stdout(bool): Whether to also output code to `sys.stdout`.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -195,6 +250,7 @@ def set_code_level(level=LOG_AllTransformer):
...
@@ -195,6 +250,7 @@ def set_code_level(level=LOG_AllTransformer):
"""
"""
_TRANSLATOR_LOGGER
.
transformed_code_level
=
level
_TRANSLATOR_LOGGER
.
transformed_code_level
=
level
_TRANSLATOR_LOGGER
.
need_to_echo_code_to_stdout
=
also_to_stdout
def
get_code_level
():
def
get_code_level
():
...
...
python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
浏览文件 @
6ef1fbb6
...
@@ -14,21 +14,17 @@
...
@@ -14,21 +14,17 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
as
np
import
numpy
as
np
import
logging
import
six
import
six
from
paddle.fluid
import
log_helper
from
paddle.fluid
import
framework
,
backward
,
core
from
paddle.fluid
import
framework
,
backward
,
core
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static.return_transformer
import
RETURN_NO_VALUE_MAGIC_NUM
from
paddle.fluid.dygraph.dygraph_to_static.return_transformer
import
RETURN_NO_VALUE_MAGIC_NUM
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
pack_sequence_as
from
paddle.fluid.layers.utils
import
pack_sequence_as
import
paddle.compat
as
cpt
import
paddle.compat
as
cpt
_logger
=
log_helper
.
get_logger
(
__name__
,
logging
.
WARNING
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
NestSequence
(
object
):
class
NestSequence
(
object
):
"""
"""
...
@@ -72,7 +68,7 @@ class NestSequence(object):
...
@@ -72,7 +68,7 @@ class NestSequence(object):
if
not
isinstance
(
var
,
(
framework
.
Variable
,
core
.
VarBase
)):
if
not
isinstance
(
var
,
(
framework
.
Variable
,
core
.
VarBase
)):
warning_types
.
add
(
type
(
var
))
warning_types
.
add
(
type
(
var
))
if
warning_types
:
if
warning_types
:
_logger
.
warning
(
logging_utils
.
warn
(
"Output of traced function contains non-tensor type values: {}. "
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor."
.
"what we first saw. Please try to return them as tensor."
.
...
...
python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py
浏览文件 @
6ef1fbb6
...
@@ -15,14 +15,8 @@
...
@@ -15,14 +15,8 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
gast
import
gast
import
logging
from
paddle.fluid
import
log_helper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
NodeVarType
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
_logger
=
log_helper
.
get_logger
(
__name__
,
logging
.
WARNING
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
PrintTransformer
(
gast
.
NodeTransformer
):
class
PrintTransformer
(
gast
.
NodeTransformer
):
...
...
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
6ef1fbb6
...
@@ -13,17 +13,15 @@
...
@@ -13,17 +13,15 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
gast
import
collections
import
collections
import
logging
import
gast
import
inspect
import
inspect
import
six
import
six
import
textwrap
import
textwrap
import
threading
import
threading
import
warnings
import
weakref
import
weakref
import
gast
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
from
paddle.fluid
import
in_dygraph_mode
from
paddle.fluid
import
in_dygraph_mode
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph
import
layers
...
@@ -451,7 +449,7 @@ class StaticLayer(object):
...
@@ -451,7 +449,7 @@ class StaticLayer(object):
format
(
self
.
_function_spec
))
format
(
self
.
_function_spec
))
# If more than one programs have been cached, return the recent converted program by default.
# If more than one programs have been cached, return the recent converted program by default.
elif
cached_program_len
>
1
:
elif
cached_program_len
>
1
:
logging
.
warning
(
logging
_utils
.
warn
(
"Current {} has more than one cached programs: {}, the last traced progam will be return by default."
.
"Current {} has more than one cached programs: {}, the last traced progam will be return by default."
.
format
(
self
.
_function_spec
,
cached_program_len
))
format
(
self
.
_function_spec
,
cached_program_len
))
...
@@ -632,7 +630,7 @@ class ProgramCache(object):
...
@@ -632,7 +630,7 @@ class ProgramCache(object):
# Note: raise warnings if number of traced program is more than `max_tracing_count`
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count
=
len
(
self
.
_caches
)
current_tracing_count
=
len
(
self
.
_caches
)
if
current_tracing_count
>
MAX_TRACED_PROGRAM_COUNT
:
if
current_tracing_count
>
MAX_TRACED_PROGRAM_COUNT
:
logging
.
warning
(
logging
_utils
.
warn
(
"Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
"Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors."
.
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors."
.
format
(
current_tracing_count
,
MAX_TRACED_PROGRAM_COUNT
))
format
(
current_tracing_count
,
MAX_TRACED_PROGRAM_COUNT
))
...
@@ -804,8 +802,9 @@ class ProgramTranslator(object):
...
@@ -804,8 +802,9 @@ class ProgramTranslator(object):
assert
callable
(
assert
callable
(
dygraph_func
dygraph_func
),
"Input dygraph_func is not a callable in ProgramTranslator.get_output"
),
"Input dygraph_func is not a callable in ProgramTranslator.get_output"
if
not
self
.
enable_to_static
:
if
not
self
.
enable_to_static
:
warning
s
.
warn
(
logging_util
s
.
warn
(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
"Please call ProgramTranslator.enable(True) if you would like to get static output."
...
@@ -879,8 +878,9 @@ class ProgramTranslator(object):
...
@@ -879,8 +878,9 @@ class ProgramTranslator(object):
assert
callable
(
assert
callable
(
dygraph_func
dygraph_func
),
"Input dygraph_func is not a callable in ProgramTranslator.get_func"
),
"Input dygraph_func is not a callable in ProgramTranslator.get_func"
if
not
self
.
enable_to_static
:
if
not
self
.
enable_to_static
:
warning
s
.
warn
(
logging_util
s
.
warn
(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
)
...
@@ -933,8 +933,9 @@ class ProgramTranslator(object):
...
@@ -933,8 +933,9 @@ class ProgramTranslator(object):
assert
callable
(
assert
callable
(
dygraph_func
dygraph_func
),
"Input dygraph_func is not a callable in ProgramTranslator.get_program"
),
"Input dygraph_func is not a callable in ProgramTranslator.get_program"
if
not
self
.
enable_to_static
:
if
not
self
.
enable_to_static
:
warning
s
.
warn
(
logging_util
s
.
warn
(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
"Please call ProgramTranslator.enable(True) if you would like to get static output."
...
...
python/paddle/fluid/dygraph/jit.py
浏览文件 @
6ef1fbb6
...
@@ -26,6 +26,7 @@ from paddle.fluid import core
...
@@ -26,6 +26,7 @@ from paddle.fluid import core
from
paddle.fluid.compiler
import
BuildStrategy
,
CompiledProgram
,
ExecutionStrategy
from
paddle.fluid.compiler
import
BuildStrategy
,
CompiledProgram
,
ExecutionStrategy
from
paddle.fluid.data_feeder
import
check_type
from
paddle.fluid.data_feeder
import
check_type
from
paddle.fluid.dygraph.base
import
program_desc_tracing_guard
,
switch_to_static_graph
from
paddle.fluid.dygraph.base
import
program_desc_tracing_guard
,
switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
set_code_level
,
set_verbosity
from
paddle.fluid.dygraph.dygraph_to_static.logging_utils
import
set_code_level
,
set_verbosity
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
ProgramTranslator
,
StaticLayer
,
unwrap_decorators
from
paddle.fluid.dygraph.dygraph_to_static.program_translator
import
ProgramTranslator
,
StaticLayer
,
unwrap_decorators
from
paddle.fluid.dygraph.io
import
EXTRA_VAR_INFO_FILENAME
,
VARIABLE_FILENAME
,
TranslatedLayer
from
paddle.fluid.dygraph.io
import
EXTRA_VAR_INFO_FILENAME
,
VARIABLE_FILENAME
,
TranslatedLayer
...
@@ -120,7 +121,7 @@ def _dygraph_to_static_func_(dygraph_func):
...
@@ -120,7 +121,7 @@ def _dygraph_to_static_func_(dygraph_func):
def
__impl__
(
*
args
,
**
kwargs
):
def
__impl__
(
*
args
,
**
kwargs
):
program_translator
=
ProgramTranslator
()
program_translator
=
ProgramTranslator
()
if
in_dygraph_mode
()
or
not
program_translator
.
enable_to_static
:
if
in_dygraph_mode
()
or
not
program_translator
.
enable_to_static
:
warning
s
.
warn
(
logging_util
s
.
warn
(
"The decorator 'dygraph_to_static_func' doesn't work in "
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. "
"dygraph mode or set ProgramTranslator.enable to False. "
"We will just return dygraph output."
)
"We will just return dygraph output."
)
...
@@ -215,7 +216,7 @@ def declarative(function=None, input_spec=None):
...
@@ -215,7 +216,7 @@ def declarative(function=None, input_spec=None):
if
isinstance
(
function
,
Layer
):
if
isinstance
(
function
,
Layer
):
if
isinstance
(
function
.
forward
,
StaticLayer
):
if
isinstance
(
function
.
forward
,
StaticLayer
):
class_name
=
function
.
__class__
.
__name__
class_name
=
function
.
__class__
.
__name__
warning
s
.
warn
(
logging_util
s
.
warn
(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one."
.
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one."
.
format
(
class_name
))
format
(
class_name
))
function
.
forward
=
decorated
(
function
.
forward
)
function
.
forward
=
decorated
(
function
.
forward
)
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
6ef1fbb6
...
@@ -11229,7 +11229,7 @@ def shape(input):
...
@@ -11229,7 +11229,7 @@ def shape(input):
input.shape = [3, 2]
input.shape = [3, 2]
Args:
Args:
input (Variable): The input can be N-D Tensor or SelectedRows with data type float16, float32, float64, int32, int64.
input (Variable): The input can be N-D Tensor or SelectedRows with data type
bool,
float16, float32, float64, int32, int64.
If input variable is type of SelectedRows, returns the shape of it's inner tensor.
If input variable is type of SelectedRows, returns the shape of it's inner tensor.
Returns:
Returns:
...
@@ -11253,8 +11253,8 @@ def shape(input):
...
@@ -11253,8 +11253,8 @@ def shape(input):
print(res) # [array([ 3, 100, 100], dtype=int32)]
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
"""
check_variable_and_dtype(
check_variable_and_dtype(
input, 'input',
['float16', 'float32', 'float64', 'int32', 'int64'],
input, 'input',
'shape')
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'shape')
helper = LayerHelper('shape', **locals())
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
helper.append_op(
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
浏览文件 @
6ef1fbb6
...
@@ -56,8 +56,30 @@ class TestLoggingUtils(unittest.TestCase):
...
@@ -56,8 +56,30 @@ class TestLoggingUtils(unittest.TestCase):
with
self
.
assertRaises
(
TypeError
):
with
self
.
assertRaises
(
TypeError
):
paddle
.
jit
.
set_verbosity
(
3.3
)
paddle
.
jit
.
set_verbosity
(
3.3
)
def
test_code_level
(
self
):
def
test_also_to_stdout
(
self
):
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_log_to_stdout
=
None
self
.
assertEqual
(
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_log_to_stdout
,
False
)
paddle
.
jit
.
set_verbosity
(
also_to_stdout
=
False
)
self
.
assertEqual
(
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_log_to_stdout
,
False
)
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_node_to_stdout
=
None
self
.
assertEqual
(
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_code_to_stdout
,
False
)
paddle
.
jit
.
set_code_level
(
also_to_stdout
=
True
)
self
.
assertEqual
(
logging_utils
.
_TRANSLATOR_LOGGER
.
need_to_echo_code_to_stdout
,
True
)
with
self
.
assertRaises
(
AssertionError
):
paddle
.
jit
.
set_verbosity
(
also_to_stdout
=
1
)
with
self
.
assertRaises
(
AssertionError
):
paddle
.
jit
.
set_code_level
(
also_to_stdout
=
1
)
def
test_set_code_level
(
self
):
paddle
.
jit
.
set_code_level
(
None
)
paddle
.
jit
.
set_code_level
(
None
)
os
.
environ
[
logging_utils
.
CODE_LEVEL_ENV_NAME
]
=
'2'
os
.
environ
[
logging_utils
.
CODE_LEVEL_ENV_NAME
]
=
'2'
self
.
assertEqual
(
logging_utils
.
get_code_level
(),
2
)
self
.
assertEqual
(
logging_utils
.
get_code_level
(),
2
)
...
@@ -71,7 +93,25 @@ class TestLoggingUtils(unittest.TestCase):
...
@@ -71,7 +93,25 @@ class TestLoggingUtils(unittest.TestCase):
with
self
.
assertRaises
(
TypeError
):
with
self
.
assertRaises
(
TypeError
):
paddle
.
jit
.
set_code_level
(
3.3
)
paddle
.
jit
.
set_code_level
(
3.3
)
def
test_log
(
self
):
def
test_log_api
(
self
):
# test api for CI Converage
logging_utils
.
set_verbosity
(
1
,
True
)
logging_utils
.
warn
(
"warn"
)
logging_utils
.
error
(
"error"
)
logging_utils
.
log
(
1
,
"log level 1"
)
logging_utils
.
log
(
2
,
"log level 2"
)
source_code
=
"x = 3"
ast_code
=
gast
.
parse
(
source_code
)
logging_utils
.
set_code_level
(
1
,
True
)
logging_utils
.
log_transformed_code
(
1
,
ast_code
,
"TestTransformer"
)
logging_utils
.
set_code_level
(
logging_utils
.
LOG_AllTransformer
,
True
)
logging_utils
.
log_transformed_code
(
logging_utils
.
LOG_AllTransformer
,
ast_code
,
"TestTransformer"
)
def
test_log_message
(
self
):
stream
=
io
.
BytesIO
()
if
six
.
PY2
else
io
.
StringIO
()
stream
=
io
.
BytesIO
()
if
six
.
PY2
else
io
.
StringIO
()
log
=
self
.
translator_logger
.
logger
log
=
self
.
translator_logger
.
logger
stdout_handler
=
logging
.
StreamHandler
(
stream
)
stdout_handler
=
logging
.
StreamHandler
(
stream
)
...
@@ -84,13 +124,14 @@ class TestLoggingUtils(unittest.TestCase):
...
@@ -84,13 +124,14 @@ class TestLoggingUtils(unittest.TestCase):
if
six
.
PY3
:
if
six
.
PY3
:
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
with
mock
.
patch
.
object
(
sys
,
'stdout'
,
stream
):
logging_utils
.
set_verbosity
(
1
,
False
)
logging_utils
.
warn
(
warn_msg
)
logging_utils
.
warn
(
warn_msg
)
logging_utils
.
error
(
error_msg
)
logging_utils
.
error
(
error_msg
)
self
.
translator_logger
.
verbosity_level
=
1
logging_utils
.
log
(
1
,
log_msg_1
)
logging_utils
.
log
(
1
,
log_msg_1
)
logging_utils
.
log
(
2
,
log_msg_2
)
logging_utils
.
log
(
2
,
log_msg_2
)
result_msg
=
'
\n
'
.
join
([
warn_msg
,
error_msg
,
log_msg_1
,
""
])
result_msg
=
'
\n
'
.
join
(
[
warn_msg
,
error_msg
,
"(Level 1) "
+
log_msg_1
,
""
])
self
.
assertEqual
(
result_msg
,
stream
.
getvalue
())
self
.
assertEqual
(
result_msg
,
stream
.
getvalue
())
def
test_log_transformed_code
(
self
):
def
test_log_transformed_code
(
self
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
0 → 100644
浏览文件 @
6ef1fbb6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
AnalysisConfig
"""Test for fusion of conv and bias."""
#padding SAME
class
ConvBiasMkldnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"SAME"
,
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#padding VALID
class
ConvBiasMkldnnFusePassTest1
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#padding number
class
ConvBiasMkldnnFusePassTest2
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
[
2
,
4
,
6
,
8
],
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#dilation not supported yet, just print warning log and does not fuse
class
ConvBiasMkldnnFusePassTest3
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
dilation
=
2
,
groups
=
3
,
bias_attr
=
param_attr
,
use_cudnn
=
False
,
act
=
"softmax"
,
data_format
=
"NCHW"
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#all conv params except for dilation
class
ConvBiasMkldnnFusePassTest4
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
groups
=
3
,
bias_attr
=
param_attr
,
use_cudnn
=
False
,
act
=
"softmax"
,
data_format
=
"NCHW"
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py
浏览文件 @
6ef1fbb6
...
@@ -26,7 +26,7 @@ def stable_softmax(x):
...
@@ -26,7 +26,7 @@ def stable_softmax(x):
return
exps
/
np
.
sum
(
exps
)
return
exps
/
np
.
sum
(
exps
)
def
log_softmax
(
x
,
axis
=
-
1
):
def
log_softmax
(
x
,
axis
=
1
):
softmax_out
=
np
.
apply_along_axis
(
stable_softmax
,
axis
,
x
)
softmax_out
=
np
.
apply_along_axis
(
stable_softmax
,
axis
,
x
)
return
np
.
log
(
softmax_out
)
return
np
.
log
(
softmax_out
)
...
...
python/paddle/fluid/tests/unittests/test_empty_like_op.py
0 → 100644
浏览文件 @
6ef1fbb6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.data_feeder
import
convert_dtype
import
paddle.fluid.core
as
core
from
paddle.static
import
program_guard
,
Program
class
TestEmptyLikeAPICommon
(
unittest
.
TestCase
):
def
__check_out__
(
self
,
out
):
data_type
=
convert_dtype
(
out
.
dtype
)
self
.
assertEqual
(
data_type
,
self
.
dst_dtype
,
'dtype should be %s, but get %s'
%
(
self
.
dst_dtype
,
data_type
))
shape
=
out
.
shape
self
.
assertTupleEqual
(
shape
,
self
.
dst_shape
,
'shape should be %s, but get %s'
%
(
self
.
dst_shape
,
shape
))
if
data_type
in
[
'float32'
,
'float64'
,
'int32'
,
'int64'
]:
max_value
=
np
.
nanmax
(
out
)
min_value
=
np
.
nanmin
(
out
)
always_non_full_zero
=
max_value
>
min_value
always_full_zero
=
max_value
==
0.0
and
min_value
==
0.0
self
.
assertTrue
(
always_full_zero
or
always_non_full_zero
,
'always_full_zero or always_non_full_zero.'
)
elif
data_type
in
[
'bool'
]:
total_num
=
out
.
size
true_num
=
np
.
sum
(
out
==
True
)
false_num
=
np
.
sum
(
out
==
False
)
self
.
assertTrue
(
total_num
==
true_num
+
false_num
,
'The value should always be True or False.'
)
else
:
self
.
assertTrue
(
False
,
'invalid data type'
)
class
TestEmptyLikeAPI
(
TestEmptyLikeAPICommon
):
def
setUp
(
self
):
self
.
init_config
()
def
test_dygraph_api_out
(
self
):
paddle
.
disable_static
()
out
=
paddle
.
empty_like
(
self
.
x
,
self
.
dtype
)
self
.
__check_out__
(
out
.
numpy
())
paddle
.
enable_static
()
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"float32"
)
self
.
dtype
=
self
.
x
.
dtype
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI2
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"float64"
)
self
.
dtype
=
self
.
x
.
dtype
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI3
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"int"
)
self
.
dtype
=
self
.
x
.
dtype
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI4
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"int64"
)
self
.
dtype
=
self
.
x
.
dtype
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI5
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"bool"
)
self
.
dtype
=
self
.
x
.
dtype
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI6
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"float64"
)
self
.
dtype
=
"float32"
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI7
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"int"
)
self
.
dtype
=
"float32"
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI8
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"int64"
)
self
.
dtype
=
"float32"
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI9
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"bool"
)
self
.
dtype
=
"float32"
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI10
(
TestEmptyLikeAPI
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"float32"
)
self
.
dtype
=
"bool"
self
.
dst_shape
=
self
.
x
.
shape
self
.
dst_dtype
=
self
.
dtype
class
TestEmptyLikeAPI_Static
(
TestEmptyLikeAPICommon
):
def
setUp
(
self
):
self
.
init_config
()
def
test_static_graph
(
self
):
dtype
=
'float32'
train_program
=
Program
()
startup_program
=
Program
()
with
program_guard
(
train_program
,
startup_program
):
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
dtype
)
data_x
=
paddle
.
static
.
data
(
'x'
,
shape
=
self
.
data_x_shape
,
dtype
=
dtype
)
out
=
paddle
.
empty_like
(
data_x
)
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
res
=
exe
.
run
(
train_program
,
feed
=
{
'x'
:
x
},
fetch_list
=
[
out
])
self
.
dst_dtype
=
dtype
self
.
dst_shape
=
x
.
shape
self
.
__check_out__
(
res
[
0
])
def
init_config
(
self
):
self
.
x_shape
=
(
200
,
3
)
self
.
data_x_shape
=
[
200
,
3
]
class
TestEmptyLikeAPI_Static2
(
TestEmptyLikeAPI_Static
):
def
init_config
(
self
):
self
.
x_shape
=
(
3
,
200
,
3
)
self
.
data_x_shape
=
[
-
1
,
200
,
3
]
class
TestEmptyError
(
unittest
.
TestCase
):
def
test_attr
(
self
):
def
test_dtype
():
x
=
np
.
random
.
random
((
200
,
3
)).
astype
(
"float64"
)
dtype
=
'uint8'
result
=
paddle
.
empty_like
(
x
,
dtype
=
dtype
)
self
.
assertRaises
(
TypeError
,
test_dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/nn/functional/loss.py
浏览文件 @
6ef1fbb6
...
@@ -1093,7 +1093,7 @@ def cross_entropy(input,
...
@@ -1093,7 +1093,7 @@ def cross_entropy(input,
" 'none', but received %s, which is not allowed."
%
reduction
)
" 'none', but received %s, which is not allowed."
%
reduction
)
#step 1. log_softmax
#step 1. log_softmax
log_softmax_out
=
paddle
.
nn
.
functional
.
log_softmax
(
input
)
log_softmax_out
=
paddle
.
nn
.
functional
.
log_softmax
(
input
,
axis
=
1
)
if
weight
is
not
None
and
not
isinstance
(
weight
,
Variable
):
if
weight
is
not
None
and
not
isinstance
(
weight
,
Variable
):
raise
ValueError
(
raise
ValueError
(
"The weight' is not a Variable, please convert to Variable."
)
"The weight' is not a Variable, please convert to Variable."
)
...
...
python/paddle/tensor/__init__.py
浏览文件 @
6ef1fbb6
...
@@ -41,6 +41,7 @@ from .creation import triu #DEFINE_ALIAS
...
@@ -41,6 +41,7 @@ from .creation import triu #DEFINE_ALIAS
from
.creation
import
tril
#DEFINE_ALIAS
from
.creation
import
tril
#DEFINE_ALIAS
from
.creation
import
meshgrid
#DEFINE_ALIAS
from
.creation
import
meshgrid
#DEFINE_ALIAS
from
.creation
import
empty
#DEFINE_ALIAS
from
.creation
import
empty
#DEFINE_ALIAS
from
.creation
import
empty_like
#DEFINE_ALIAS
from
.io
import
save
#DEFINE_ALIAS
from
.io
import
save
#DEFINE_ALIAS
from
.io
import
load
#DEFINE_ALIAS
from
.io
import
load
#DEFINE_ALIAS
from
.linalg
import
matmul
#DEFINE_ALIAS
from
.linalg
import
matmul
#DEFINE_ALIAS
...
...
python/paddle/tensor/creation.py
浏览文件 @
6ef1fbb6
...
@@ -49,6 +49,7 @@ __all__ = [
...
@@ -49,6 +49,7 @@ __all__ = [
'full'
,
'full'
,
'full_like'
,
'full_like'
,
'empty'
,
'empty'
,
'empty_like'
,
'triu'
,
'triu'
,
'tril'
,
'tril'
,
'meshgrid'
'meshgrid'
...
@@ -1068,3 +1069,70 @@ def empty(shape, dtype=None, name=None):
...
@@ -1068,3 +1069,70 @@ def empty(shape, dtype=None, name=None):
stop_gradient
=
True
)
stop_gradient
=
True
)
out
.
stop_gradient
=
True
out
.
stop_gradient
=
True
return
out
return
out
def
empty_like
(
x
,
dtype
=
None
,
name
=
None
):
"""
This Op returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``.
If the ``dtype`` is None, the data type of Tensor is same with ``x``.
Args:
x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64.
dtype(np.dtype|str, optional): The data type of output. The data type can be one
of bool, float16, float32, float64, int32, int64. The default value is None, which means the output
data type is the same as input.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static() # Now we are in imperative mode
paddle.set_device("cpu") # and use cpu device
x = paddle.randn([2, 3], 'float32')
output = paddle.empty_like(x)
#[[1.8491974e+20 1.8037303e+28 1.7443726e+28] # uninitialized
# [4.9640171e+28 3.0186127e+32 5.6715899e-11]] # uninitialized
"""
if
dtype
is
None
:
dtype
=
x
.
dtype
dtype
=
convert_dtype
(
dtype
)
if
in_dygraph_mode
():
out
=
core
.
ops
.
empty
(
'shape'
,
x
.
shape
,
'dtype'
,
convert_np_dtype_to_dtype_
(
dtype
))
out
.
stop_gradient
=
True
return
out
helper
=
LayerHelper
(
"empty_like"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'empty_like'
)
check_dtype
(
dtype
,
'dtype'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'empty_like'
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
inputs
=
{}
attrs
=
{}
attrs
[
'dtype'
]
=
convert_np_dtype_to_dtype_
(
dtype
)
shape
=
paddle
.
shape
(
x
)
utils
.
get_shape_tensor_inputs
(
inputs
=
inputs
,
attrs
=
attrs
,
shape
=
shape
,
op_type
=
'empty_like'
)
helper
.
append_op
(
type
=
'empty'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
[
out
]},
attrs
=
attrs
,
stop_gradient
=
True
)
out
.
stop_gradient
=
True
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录