diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc index 7612df9ab915011001b57b37e4fd559d393302a7..3f88a460d140f6d7389194a29e37128f5ba5b458 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass, paddle::framework::ir::EmbeddingEltwiseLayerNormFusePass); +REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("lookup_table", 0) + .EQ("elementweise_add", 0)); diff --git a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc index 71c9dbae1a46af1ecae0aaff3fde52de8142d4bb..727e42629f9fab9183668ae0cc84ae54eb01982c 100644 --- a/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc @@ -16,12 +16,13 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { namespace ir { -TEST(SkipLayerNormFusePass, basic) { +TEST(EmbeddingElewiseLayernormFusePass, basic) { // inputs operator output // -------------------------------------------------------------------- // (x, y) elementwise_add -> elementwise_out @@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) { "The number of fusion nodes does not meet expectations after fuse")); } +TEST(EmbeddingElewiseLayernormFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("embedding_eltwise_layernorm_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 82e0af3c198750296032769f2f3b04658871adb7..f7a8e3e3f6c3c77e978c57eeb7515d8cfce86471 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "do not perform " + type() + "+bias fuse"; return; } + if (conv->Op()->HasAttr("dilations")) { + auto dilations = + BOOST_GET_CONST(std::vector, conv->Op()->GetAttr("dilations")); + for (const auto& d : dilations) { + if (d != 1) { + LOG(WARNING) + << "dilation conv not supported in MKLDNN, fuse not apply " + << "and set conv attribute use_mkldnn = false"; + conv->Op()->SetAttr("use_mkldnn", false); + return; + } + } + } auto* eltwise_bias_tensor = scope->FindVar(eltwise_bias->Name())->GetMutable(); @@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, paddle::framework::ir::Conv2DTransposeBiasFusePass); REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, paddle::framework::ir::Conv3DBiasFusePass); +REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 88aac001a93ae836d62fe3bf3fc502960eebe70f..455350d2f703c52a9ef3e5714a60573408310080 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { @@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) { ASSERT_EQ(pass.type(), std::string("conv2d_transpose")); } +TEST(ConvBiasFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("conv_bias_mkldnn_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index af2b1308e084ee937f26bf90caf2df6fb44e044b..2fb131aceaad28a365e8202dca35cfe53f8f54da 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -19,6 +19,7 @@ #include #include #include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ResidualConnectionMKLDNNFusePass); +REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 8a13596cd50087475bf12b6cfa5920b82e24de31..fd4910fc8e95cd98fe9feaba51e70d5a143ad443 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { AssertOpsCount(graph, 2, 1); } +TEST(ConvElementwiseAddMKLDNNFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("conv_elementwise_add_mkldnn_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index c5965701a53d4312d89f1e09f17840b09f1bd5f5..df5ba3314e637fefe930d4c45f431314dd7d8493 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(depthwise_conv_mkldnn_pass, paddle::framework::ir::DepthwiseConvMKLDNNPass); +REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "depthwise_conv2d", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc index a37565236cd440bd803184d038ad4deb3c0b6150..c6c72ba33d6295d90c502ab88d7d712d76a11aad 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc @@ -16,6 +16,8 @@ #include +#include "paddle/fluid/framework/op_version_registry.h" + namespace paddle { namespace framework { namespace ir { @@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() { return prog; } +TEST(DepthwiseConvMKLDNNPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("depthwise_conv_mkldnn_pass")); +} + TEST(DepthwiseConvMKLDNNPass, basic) { auto prog = BuildProgramDesc(); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 198107ea082dc86d9e65a926bf9befe2fc4abfa4..9d2b4ebaf8ccf33e175e46c08657e7eeed467055 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/errors.h" namespace paddle { @@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass, REGISTER_PASS(multihead_matmul_fuse_pass_v2, paddle::framework::ir::MultiHeadMatmulV2FusePass); +REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("mul", 0) + .EQ("elementwise_add", 0) + .EQ("reshape2", 0) + .EQ("transpose2", 0) + .EQ("scale", 0) + .EQ("matmul", 0) + .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc index d8a06b037bdefbe8776c9b95b36be80afb988393..2eda643d4e53aa061908f02c9d31b765241c318b 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT #include #include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) { num_fused_nodes_after)); } +TEST(MultiHeadMatmulFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("multihead_matmul_fuse_pass_v2")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc index 9dddc9154f8fc39144b38535824999b933a92106..2e3cd16d5ce49fdd6186f98c72d77c75c4053559 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { REGISTER_PASS(skip_layernorm_fuse_pass, paddle::framework::ir::SkipLayerNormFusePass); +REGISTER_PASS_CAPABILITY(skip_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("elementwise_add", 0) + .EQ("layer_norm", 0)); diff --git a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc index d2d7469872857a070294520a589fee4ca383f065..eff5dcddf54ee49be5b14a7bdfa609079f925036 100644 --- a/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace framework { @@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) { "The number of fusion nodes does not meet expectations after fuse")); } +TEST(SkipLayerNormFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("skip_layernorm_fuse_pass")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h index 3958d3f685470f2505abf0e8bfd269d3834970ae..338e46111fca83230aca1c7877578e557cef5a31 100644 --- a/paddle/fluid/operators/average_accumulates_op.h +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -54,9 +54,13 @@ class AverageAccumulatesKernel : public framework::OpKernel { float average_window = ctx.Attr("average_window"); int64_t max_average_window = ctx.Attr("max_average_window"); int64_t min_average_window = ctx.Attr("min_average_window"); - PADDLE_ENFORCE_LE(min_average_window, max_average_window, - "min_average_window shouldn't be larger than " - "max_average_window"); + PADDLE_ENFORCE_LE( + min_average_window, max_average_window, + platform::errors::InvalidArgument( + "The min_average_window > " + "max_average_window is not right, min_average_window is %ld, " + "max_average_window is %ld.", + min_average_window, max_average_window)); // Get inputs auto* param = ctx.Input("param"); diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index f539e2e6f6d2d6faa084d1e62ec894b4b65e96bf..3d28ca90a5a15fd53a57034a4722a21842dc4b1c 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -55,31 +55,38 @@ class EmptyOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "empty"); if (context->HasInput("ShapeTensor")) { - auto dims = context->GetInputDim("ShapeTensor"); + auto shape_dims = context->GetInputDim("ShapeTensor"); int num_ele = 1; - for (int i = 0; i < dims.size(); ++i) { - num_ele *= dims[i]; + for (int i = 0; i < shape_dims.size(); ++i) { + num_ele *= shape_dims[i]; } - - context->SetOutputDim("Out", framework::make_ddim({num_ele})); + auto vec_dims = std::vector(num_ele, -1); + context->SetOutputDim("Out", framework::make_ddim(vec_dims)); } else if (context->HasInputs("ShapeTensorList")) { std::vector out_dims; auto dims_list = context->GetInputsDim("ShapeTensorList"); for (size_t i = 0; i < dims_list.size(); ++i) { auto& dims = dims_list[i]; - PADDLE_ENFORCE_EQ( - dims, framework::make_ddim({1}), - "ShapeError: The shape of Tensor in list must be [1]. " - "But received the shape " - "is [%s]", - dims); - - out_dims.push_back(dims[0]); + PADDLE_ENFORCE_EQ(dims, framework::make_ddim({1}), + platform::errors::InvalidArgument( + "The shape of Tensor in list must be [1]. " + "But received the shape is [%s]", + dims)); + + out_dims.push_back(-1); } context->SetOutputDim("Out", framework::make_ddim(out_dims)); } else { auto& shape = context->Attrs().Get>("shape"); + for (size_t i = 0; i < shape.size(); ++i) { + PADDLE_ENFORCE_GE( + shape[i], 0, + platform::errors::InvalidArgument( + "Each value of attribute 'shape' is expected to be no less " + "than 0. But recieved: shape[%u] = %d; shape = [%s].", + i, shape[i], framework::make_ddim(shape))); + } context->SetOutputDim("Out", framework::make_ddim(shape)); } } diff --git a/paddle/fluid/operators/math/beam_search.cc b/paddle/fluid/operators/math/beam_search.cc index 0155ef188ef967fbf67505d28beeeaf956bb3a70..550de1aadde2935fae34226dba78cc06d82cd1f3 100644 --- a/paddle/fluid/operators/math/beam_search.cc +++ b/paddle/fluid/operators/math/beam_search.cc @@ -87,7 +87,10 @@ class BeamSearchFunctor { lod[0].assign(high_level.begin(), high_level.end()); lod[1].assign(low_level.begin(), low_level.end()); if (!framework::CheckLoD(lod)) { - PADDLE_THROW("lod %s is not right", framework::LoDToString(lod)); + PADDLE_THROW(platform::errors::InvalidArgument( + "lod %s is not right in" + " beam_search, please check your code.", + framework::LoDToString(lod))); } selected_ids->set_lod(lod); selected_scores->set_lod(lod); diff --git a/paddle/fluid/operators/math/beam_search.cu b/paddle/fluid/operators/math/beam_search.cu index cf6d44c1abc531da9d00738bba22f70a4c68bbab..ed3ead47d171efb4128a294c7d7a24324c7187b7 100644 --- a/paddle/fluid/operators/math/beam_search.cu +++ b/paddle/fluid/operators/math/beam_search.cu @@ -400,7 +400,10 @@ class BeamSearchFunctor { context.Wait(); if (!framework::CheckLoD(selected_lod)) { - PADDLE_THROW("lod %s is not right", framework::LoDToString(selected_lod)); + PADDLE_THROW(platform::errors::InvalidArgument( + "lod %s is not right in" + " beam_search, please check your code.", + framework::LoDToString(selected_lod))); } selected_ids->set_lod(selected_lod); diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc index 6a143b3c056455595fdedc131b0c5f4ee756e1e0..2a7ce83967f0f74f4c2178dd4277e6a1687b5ec7 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/fluid/operators/math/blas.cc @@ -20,7 +20,11 @@ namespace operators { namespace math { MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim, int num_flatten_cols, bool trans) { - PADDLE_ENFORCE_GT(tensor_dim.size(), 1); + PADDLE_ENFORCE_GT( + tensor_dim.size(), 1, + platform::errors::InvalidArgument("The tensor dim size should be greater " + "than 1, but reveived dim size is %d", + tensor_dim.size())); MatDescriptor retv; if (num_flatten_cols > 1) { auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols); diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index d0c5f74d4efb8248b41d8b2af285e8dd7ec4d479..a0464cf70e2dcc44c42fc2ca7440680ef8a53e6e 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -60,7 +60,8 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasSgemmStridedBatched(args...)); #else - PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5"); + PADDLE_THROW(platform::errors::Unimplemented( + "SgemmStridedBatched is not supported on cuda <= 7.5")); #endif } @@ -85,7 +86,8 @@ struct CUBlas { beta, C, Ctype, ldc)); }); #else - PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); + PADDLE_THROW(platform::errors::Unimplemented( + "cublasSgemmEx is not supported on cuda <= 7.5")); #endif } @@ -146,13 +148,15 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasDgemmStridedBatched(args...)); #else - PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5"); + PADDLE_THROW(platform::errors::Unimplemented( + "DgemmStridedBatched is not supported on cuda <= 7.5")); #endif } template static void GEMM_EX(ARGS... args) { - PADDLE_THROW("Currently there are not cublasDgemmEx."); + PADDLE_THROW(platform::errors::Unimplemented( + "Currently there are not cublasDgemmEx.")); } template @@ -216,7 +220,8 @@ struct CUBlas { reinterpret_cast(beta), reinterpret_cast<__half *>(C), ldc, strideC, batchCount)); #else - PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); + PADDLE_THROW(platform::errors::Unimplemented( + "HgemmStridedBatched is not supported on cuda <= 7.5")); #endif } @@ -247,7 +252,8 @@ struct CUBlas { beta, C, Ctype, ldc, computeType, algo)); }); #else - PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); + PADDLE_THROW(platform::errors::Unimplemented( + "cublasGemmEx is not supported on cuda <= 7.5")); #endif } }; @@ -302,8 +308,12 @@ inline void Blas::GEMM( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53, - "cublas fp16 gemm requires GPU compute capability >= 53"); + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), 53, + platform::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); float h_alpha = static_cast(alpha); float h_beta = static_cast(beta); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 892bf15738141bfbb7e75fa6b37c0cda53a8e098..515d6a2435e86fe07ffe1309628ef2fbeefdc6f0 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -29,7 +29,8 @@ template <> struct CBlas { template static void VCOPY(ARGS... args) { - PADDLE_THROW("Blas VCOPY don't support int8_t"); + PADDLE_THROW(platform::errors::Unimplemented( + "Blas VCOPY do not supported on CPU, please check your code")); } }; @@ -347,22 +348,47 @@ struct CBlas { template <> struct CBlas { - static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } + static void GEMM(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 GEMM not supported on CPU, please check your code")); + } + static void SMM_GEMM(...) { - PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); + PADDLE_THROW(platform::errors::Unimplemented( + "float16 SMM_GEMM not supported on CPU, please check your code")); } - static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } - static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } - static void VSQUARE(...) { - PADDLE_THROW("float16 VSQUARE not supported on CPU"); + static void VMUL(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 VMUL not supported on CPU, please check your code")); } - static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } - static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; - static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; - static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); }; + static void VEXP(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 VEXP not supported on CPU, please check your code")); + } + static void VSQUARE(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 VSQUARE not supported on CPU, please check your code")); + } + static void VPOW(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 VPOW not supported on CPU, please check your code")); + } + static void DOT(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 DOT not supported on CPU, please check your code")); + }; + static void SCAL(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 SCAL not supported on CPU, please check your code")); + }; + static void ASUM(...) { + PADDLE_THROW(platform::errors::Unimplemented( + "float16 ASUM not supported on CPU, please check your code")); + }; #ifdef PADDLE_WITH_MKLML static void GEMM_BATCH(...) { - PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); + PADDLE_THROW(platform::errors::Unimplemented( + "float16 GEMM_BATCH not supported on CPU, please check your code")); } #endif }; @@ -446,11 +472,18 @@ void Blas::MatMul(const framework::Tensor &mat_a, bool trans_a, auto dim_a = mat_a.dims(); auto dim_b = mat_b.dims(); auto dim_out = mat_out->dims(); - PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, - "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), - "The places of matrices must be same"); + PADDLE_ENFORCE_EQ( + dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, true, + platform::errors::InvalidArgument( + "The input and output of matmul should be matrix, the dim size must " + "be 2," + "but received dim size input_a:%d, input_b:%d, output:%d", + dim_a.size(), dim_b.size(), dim_out.size())); + PADDLE_ENFORCE_EQ( + mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), true, + platform::errors::InvalidArgument("The places of matrices in the matmul " + "should be same, please check your " + "code.")); int M = dim_out[0]; int N = dim_out[1]; @@ -715,7 +748,13 @@ void Blas::BatchedGEMMWithHead( } } else { - PADDLE_ENFORCE_EQ(W1, H2); + PADDLE_ENFORCE_EQ( + W1, H2, + platform::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + W1, H2)); int ldc = W2 * head_number; int sub_width = W1 / head_number; @@ -785,7 +824,14 @@ void Blas::MatMul(const framework::Tensor &mat_a, const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, framework::Tensor *mat_out, T beta) const { - PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); + PADDLE_ENFORCE_EQ( + dim_a.width_, dim_b.height_, + platform::errors::InvalidArgument( + "The fisrt matrix width should be same as second matrix height," + "but received fisrt matrix width %d" + ", second matrix height %d", + dim_a.width_, dim_b.height_)); + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { @@ -793,12 +839,14 @@ void Blas::MatMul(const framework::Tensor &mat_a, dim_a.width_, alpha, mat_a.data(), mat_b.data(), beta, mat_out->data()); } else { - PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || - dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0, - "dim_a.batch_size should be equal to dim_b.batch_size, or " - "one of dim_a.batch_size and dim_b.batch_size should be 0. " - "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", - dim_a.batch_size_, dim_b.batch_size_); + PADDLE_ENFORCE_EQ( + dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || + dim_b.batch_size_ == 0, + true, platform::errors::InvalidArgument( + "dim_a.batch_size should be equal to dim_b.batch_size, or " + "one of dim_a.batch_size and dim_b.batch_size should be 0. " + "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", + dim_a.batch_size_, dim_b.batch_size_)); this->template BatchedGEMM( transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a.data(), mat_b.data(), beta, mat_out->data(), @@ -834,15 +882,42 @@ void Blas::MatMulWithHead(const framework::Tensor &mat_a, int head_number, framework::Tensor *mat_out, T beta, bool mat_b_split_vertical) const { - PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0); - PADDLE_ENFORCE_GE(head_number, 1); - PADDLE_ENFORCE_LE(head_number, dim_a.width_); + PADDLE_ENFORCE_EQ( + dim_a.width_ % head_number, 0, + platform::errors::InvalidArgument( + "The first input width must be some times the head number" + "but received first input width %d" + ", head_number %d", + dim_a.width_, head_number)); + PADDLE_ENFORCE_GE(head_number, 1, + platform::errors::InvalidArgument( + "The head number should be greater equal 1," + "but received head number %d", + head_number)); + PADDLE_ENFORCE_LE( + head_number, dim_a.width_, + platform::errors::InvalidArgument( + "The head number should be less equal first input width," + "but received first input width %d" + ", head_number %d", + dim_a.width_, head_number)); CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (mat_b_split_vertical) { - PADDLE_ENFORCE_EQ(dim_b.height_, dim_a.width_ / head_number); - PADDLE_ENFORCE_EQ(dim_b.width_ % head_number, 0); + PADDLE_ENFORCE_EQ( + dim_b.height_, dim_a.width_ / head_number, + platform::errors::InvalidArgument( + "The second input height should be equal than first input width," + "but received second input height %d, first input width %d", + dim_b.height_, dim_a.width_ / head_number)); + PADDLE_ENFORCE_EQ( + dim_a.width_ % head_number, 0, + platform::errors::InvalidArgument( + "The second input width should be some times the head number" + "but received second input width %d" + ", head_number %d", + dim_b.width_, head_number)); } if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { @@ -888,9 +963,16 @@ void Blas::MatMulWithHead(const framework::Tensor &mat_a, mat_out->data() + sub_matC_offset, ldc); } } else { - PADDLE_ENFORCE_EQ((dim_a.batch_size_ == dim_b.batch_size_ || - dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0), - true); + PADDLE_ENFORCE_EQ( + (dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || + dim_b.batch_size_ == 0), + true, + platform::errors::InvalidArgument( + "The first input batch size should be equal than second input," + "either two input batch size is 0, but received first input batch " + "size" + " %d, second input batch size %d", + dim_a.batch_size_, dim_b.batch_size_)); this->template BatchedGEMMWithHead( transA, transB, dim_a.width_, dim_a.height_, dim_b.width_, diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index 62bffe630484e3ab30bedcf2324f6516bca3b27e..0ecf9bfb5d8c0ccadbdd8b7a0b8f6193d4dc5310 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -68,6 +68,6 @@ REGISTER_OPERATOR( shape, ops::ShapeOp, ops::ShapeOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, +REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel); diff --git a/paddle/fluid/operators/shape_op.cu b/paddle/fluid/operators/shape_op.cu index 4b9dca0d4028be36ad8ba46ebe35db101e003ee9..5d50b17818cbb8068db2fded1f3f4e76bad44430 100644 --- a/paddle/fluid/operators/shape_op.cu +++ b/paddle/fluid/operators/shape_op.cu @@ -15,8 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" REGISTER_OP_CUDA_KERNEL( - shape, paddle::operators::ShapeKernel, - paddle::operators::ShapeKernel, + shape, paddle::operators::ShapeKernel, + paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 016726633ea355ed20149e94833ca7e1657c3f7d..661471599cb080da7a65c11fecc339830f2c00ee 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -77,6 +77,7 @@ from .tensor.creation import triu #DEFINE_ALIAS from .tensor.creation import tril #DEFINE_ALIAS from .tensor.creation import meshgrid #DEFINE_ALIAS from .tensor.creation import empty #DEFINE_ALIAS +from .tensor.creation import empty_like #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 8d399c929018f08eb3d02e50981566705536bbf5..7b276293638189d304e5c33b2cd4497bb4256bab 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -67,6 +67,7 @@ class ImperativeQuantAware(object): Examples: .. code-block:: python + import paddle from paddle.fluid.contrib.slim.quantization \ import ImperativeQuantAware from paddle.vision.models \ @@ -86,13 +87,12 @@ class ImperativeQuantAware(object): # ... # Save quant model for the inference. - imperative_qat.save_quantized_model( - dirname="./resnet50_qat", - model=model, - input_shape=[(3, 224, 224)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=model, + model_path="./resnet50_qat", + input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32')]) """ super(ImperativeQuantAware, self).__init__() self._weight_bits = weight_bits @@ -148,75 +148,6 @@ class ImperativeQuantAware(object): quant_layer = self._get_quantized_counterpart(layer) setattr(obj, target, quant_layer) - def save_quantized_model(self, - dirname, - model, - input_shape, - input_dtype, - feed, - fetch, - append_batch_size=True): - """ - Save the quantized model for the inference. - - Args: - dirname (str): the directory to save the quantized model. - model(fluid.dygraph.Layer): the quantized model to be saved. - input_shape(list[tuple(int)]): The shape value for each input, - e.g. [(3, 224, 224)]. - input_dtype(list[str]): The dtype value for each input, - e.g. ['float32']. - feed(list[int]): the indices of the input variables of the - imperative functions which will be saved as input variables in - inference model. - fetch(list[int]): the indices of the returned variable of the - imperative functions which will be saved as output variables in - inference model. - append_batch_size(bool, optional): - If true, it prepends an extra axis to the input_shape, meanwhile, - the input_shape shouldn't contain the batch size dimension. - Otherwise, it just uses the input_shape. Default True. - Returns: - None - """ - assert isinstance( - input_shape, list), "The parameter `input_shape` shoubld be a list." - assert isinstance( - input_dtype, list), "The parameter `input_dtype` shoubld be a list." - assert isinstance(feed, list), "The parameter `feed` shoubld be a list." - assert isinstance(fetch, - list), "The parameter `fetch` shoubld be a list." - assert len(input_shape) == len( - input_dtype - ), "The length of input_shape should be equal to input_dtype's." - assert len(input_dtype) == len( - feed), "The length of input_shape should be equal to feed's." - - with dygraph.guard(): - model.eval() - input_vars = [] - for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)): - if append_batch_size: - shape = [None] + list(shape) - # Note(Aurelius84): need a elegant way to name this. - in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i) - input_vars.append(in_spec) - # use `declarative` to convert dygraph into static program - model.forward = dygraph.jit.declarative( - model.forward, input_spec=input_vars) - outputs = model.forward.concrete_program.outputs - input_spec = [input_vars[i] for i in feed] - configs = dygraph.jit.SaveLoadConfig() - configs.separate_params = True - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] - configs.output_spec = [outputs[i] for i in fetch] - dygraph.jit.save( - layer=model, - model_path=dirname, - input_spec=input_spec, - configs=configs) - def _get_quantized_counterpart(self, layer): quant_layers = tuple(self._quant_layers_map.values()) quantized_counterpart = tuple('Quantized' + k diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index 79b0bbd6a4dd3850f49aa0b5124e9be86d4e6ee3..f076d274b643367a2703910dfa6899c5bfd1317c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -221,7 +221,7 @@ class TestImperativeQat(unittest.TestCase): model_dict = lenet.state_dict() fluid.save_dygraph(model_dict, "save_temp") - # test the correctness of `save_quantized_model` + # test the correctness of `paddle.jit.save` data = next(test_reader()) test_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype('float32') @@ -231,13 +231,14 @@ class TestImperativeQat(unittest.TestCase): # save inference quantized model path = "./mnist_infer_model" - imperative_qat.save_quantized_model( - dirname=path, - model=lenet, - input_shape=[(1, 28, 28)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=lenet, + model_path=path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) + if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: @@ -245,7 +246,10 @@ class TestImperativeQat(unittest.TestCase): exe = fluid.Executor(place) [inference_program, feed_target_names, fetch_targets] = ( fluid.io.load_inference_model( - dirname=path, executor=exe)) + dirname=path, + executor=exe, + model_filename="__model__", + params_filename="__variables__")) after_save, = exe.run(inference_program, feed={feed_target_names[0]: test_data}, fetch_list=fetch_targets) @@ -332,13 +336,13 @@ class TestImperativeQat(unittest.TestCase): if batch_id % 100 == 0: _logger.info('{}: {}'.format('loss', avg_loss.numpy())) - imperative_qat.save_quantized_model( - dirname="./dynamic_mnist", - model=lenet, - input_shape=[(1, 28, 28)], - input_dtype=['float32'], - feed=[0], - fetch=[0]) + paddle.jit.save( + layer=lenet, + model_path="./dynamic_mnist", + input_spec=[ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ]) # static graph train _logger.info( diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 5152799ca72f1461d6fbfc3a619a6aa9b9477934..5050067e48a1b147d43abd955c64c7fbb8cf6068 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -60,7 +60,7 @@ class DygraphToStaticAst(gast.NodeTransformer): def transfer_from_node_type(self, node_wrapper): translator_logger = logging_utils.TranslatorLogger() translator_logger.log( - 1, " Source code: \n{}".format(ast_to_source_code(self.root))) + 1, "Source code: \n{}".format(ast_to_source_code(self.root))) # Generic transformation self.visit(node_wrapper.node) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 37ce8b0a152ff8e258e8aee2a54ed7215f77c146..3d1ed836ff1aca38d796f3a247e9a5d6f6cf3add 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import six import inspect import numpy as np import collections + import paddle from paddle.fluid import core from paddle.fluid.dygraph import layers from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs from paddle.fluid.dygraph.dygraph_to_static.utils import type_name from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code @@ -291,7 +292,7 @@ def convert_to_input_spec(inputs, input_spec): if len(inputs) > len(input_spec): for rest_input in inputs[len(input_spec):]: if isinstance(rest_input, (core.VarBase, np.ndarray)): - logging.warning( + logging_utils.warn( "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. " "Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.". format(type_name(rest_input))) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py index c52872b15016169504359b54ad5a40360e244ce0..4d9ed5916adfd79013be1d8d1bb90f3c44428b49 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/logging_utils.py @@ -26,6 +26,8 @@ CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL' DEFAULT_VERBOSITY = -1 DEFAULT_CODE_LEVEL = -1 +LOG_AllTransformer = 100 + def synchronized(func): def wrapper(*args, **kwargs): @@ -53,10 +55,15 @@ class TranslatorLogger(object): return self._initialized = True + self.logger_name = "Dynamic-to-Static" self._logger = log_helper.get_logger( - __name__, 1, fmt='%(asctime)s-%(levelname)s: %(message)s') + self.logger_name, + 1, + fmt='%(asctime)s %(name)s %(levelname)s: %(message)s') self._verbosity_level = None self._transformed_code_level = None + self._need_to_echo_log_to_stdout = None + self._need_to_echo_code_to_stdout = None @property def logger(self): @@ -86,6 +93,28 @@ class TranslatorLogger(object): self.check_level(level) self._transformed_code_level = level + @property + def need_to_echo_log_to_stdout(self): + if self._need_to_echo_log_to_stdout is not None: + return self._need_to_echo_log_to_stdout + return False + + @need_to_echo_log_to_stdout.setter + def need_to_echo_log_to_stdout(self, log_to_stdout): + assert isinstance(log_to_stdout, (bool, type(None))) + self._need_to_echo_log_to_stdout = log_to_stdout + + @property + def need_to_echo_code_to_stdout(self): + if self._need_to_echo_code_to_stdout is not None: + return self._need_to_echo_code_to_stdout + return False + + @need_to_echo_code_to_stdout.setter + def need_to_echo_code_to_stdout(self, code_to_stdout): + assert isinstance(code_to_stdout, (bool, type(None))) + self._need_to_echo_code_to_stdout = code_to_stdout + def check_level(self, level): if isinstance(level, (six.integer_types, type(None))): rv = level @@ -110,34 +139,56 @@ class TranslatorLogger(object): def error(self, msg, *args, **kwargs): self.logger.error(msg, *args, **kwargs) + if self.need_to_echo_log_to_stdout: + self._output_to_stdout('ERROR: ' + msg, *args) def warn(self, msg, *args, **kwargs): - self.logger.warn(msg, *args, **kwargs) + self.logger.warning(msg, *args, **kwargs) + if self.need_to_echo_log_to_stdout: + self._output_to_stdout('WARNING: ' + msg, *args) def log(self, level, msg, *args, **kwargs): if self.has_verbosity(level): - self.logger.log(level, msg, *args, **kwargs) + msg_with_level = '(Level {}) {}'.format(level, msg) + self.logger.info(msg_with_level, *args, **kwargs) + if self.need_to_echo_log_to_stdout: + self._output_to_stdout('INFO: ' + msg_with_level, *args) def log_transformed_code(self, level, ast_node, transformer_name, *args, **kwargs): if self.has_code_level(level): source_code = ast_to_source_code(ast_node) - header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\ - .format(level, transformer_name) + if level == LOG_AllTransformer: + header_msg = "After the last level ast transformer: '{}', the transformed code:\n" \ + .format(transformer_name) + else: + header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\ + .format(level, transformer_name) msg = header_msg + source_code self.logger.info(msg, *args, **kwargs) + if self.need_to_echo_code_to_stdout: + self._output_to_stdout('INFO: ' + msg, *args) + + def _output_to_stdout(self, msg, *args): + msg = self.logger_name + ' ' + msg + print(msg % args) + _TRANSLATOR_LOGGER = TranslatorLogger() -def set_verbosity(level=0): +def set_verbosity(level=0, also_to_stdout=False): """ - Sets the verbosity level of log for dygraph to static graph. + Sets the verbosity level of log for dygraph to static graph. Logs can be output to stdout by setting `also_to_stdout`. + There are two means to set the logging verbosity: - 1. Call function `set_verbosity` - 2. Set environment variable `TRANSLATOR_VERBOSITY` + + 1. Call function `set_verbosity` + + 2. Set environment variable `TRANSLATOR_VERBOSITY` + **Note**: `set_verbosity` has a higher priority than the environment variable. @@ -145,6 +196,7 @@ def set_verbosity(level=0): Args: level(int): The verbosity level. The larger value idicates more verbosity. The default value is 0, which means no logging. + also_to_stdout(bool): Whether to also output log messages to `sys.stdout`. Examples: .. code-block:: python @@ -159,27 +211,30 @@ def set_verbosity(level=0): # The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity` """ _TRANSLATOR_LOGGER.verbosity_level = level + _TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = also_to_stdout def get_verbosity(): return _TRANSLATOR_LOGGER.verbosity_level -LOG_AllTransformer = 100 - - -def set_code_level(level=LOG_AllTransformer): +def set_code_level(level=LOG_AllTransformer, also_to_stdout=False): """ - Sets the level to print code from specific level of Ast Transformer. + Sets the level to print code from specific level Ast Transformer. Code can be output to stdout by setting `also_to_stdout`. + There are two means to set the code level: - 1. Call function `set_code_level` - 2. Set environment variable `TRANSLATOR_CODE_LEVEL` + + 1. Call function `set_code_level` + + 2. Set environment variable `TRANSLATOR_CODE_LEVEL` + **Note**: `set_code_level` has a higher priority than the environment variable. Args: level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers. + also_to_stdout(bool): Whether to also output code to `sys.stdout`. Examples: .. code-block:: python @@ -195,6 +250,7 @@ def set_code_level(level=LOG_AllTransformer): """ _TRANSLATOR_LOGGER.transformed_code_level = level + _TRANSLATOR_LOGGER.need_to_echo_code_to_stdout = also_to_stdout def get_code_level(): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 59cb5fb144eb50f4616c94ed78348d56a4029834..1004665ca15fbc2458c1626735f161c7f4904596 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -14,21 +14,17 @@ from __future__ import print_function import numpy as np -import logging import six -from paddle.fluid import log_helper from paddle.fluid import framework, backward, core from paddle.fluid.dygraph import layers from paddle.fluid.dygraph.base import switch_to_static_graph +from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import pack_sequence_as import paddle.compat as cpt -_logger = log_helper.get_logger( - __name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s') - class NestSequence(object): """ @@ -72,7 +68,7 @@ class NestSequence(object): if not isinstance(var, (framework.Variable, core.VarBase)): warning_types.add(type(var)) if warning_types: - _logger.warning( + logging_utils.warn( "Output of traced function contains non-tensor type values: {}. " "Currently, We don't support to update them while training and will return " "what we first saw. Please try to return them as tensor.". diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py index d555c8ed28f358a43e53966dd30d76d85a03dde5..efde2481721f4f16f0ecadb1200b4abcde73e2b8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -15,14 +15,8 @@ from __future__ import print_function import gast -import logging -from paddle.fluid import log_helper -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code - -_logger = log_helper.get_logger( - __name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s') +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor class PrintTransformer(gast.NodeTransformer): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index dbf030ccda16fb102af4e00e2a2a4d1fd7983a06..5218c0aac957422a665513b5eb2a0391c5c7a01f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -13,17 +13,15 @@ # limitations under the License. from __future__ import print_function -import gast + import collections -import logging +import gast import inspect import six import textwrap import threading -import warnings import weakref -import gast from paddle.fluid import framework from paddle.fluid import in_dygraph_mode from paddle.fluid.dygraph import layers @@ -451,7 +449,7 @@ class StaticLayer(object): format(self._function_spec)) # If more than one programs have been cached, return the recent converted program by default. elif cached_program_len > 1: - logging.warning( + logging_utils.warn( "Current {} has more than one cached programs: {}, the last traced progam will be return by default.". format(self._function_spec, cached_program_len)) @@ -632,7 +630,7 @@ class ProgramCache(object): # Note: raise warnings if number of traced program is more than `max_tracing_count` current_tracing_count = len(self._caches) if current_tracing_count > MAX_TRACED_PROGRAM_COUNT: - logging.warning( + logging_utils.warn( "Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. " "The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.". format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT)) @@ -804,8 +802,9 @@ class ProgramTranslator(object): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_output" + if not self.enable_to_static: - warnings.warn( + logging_utils.warn( "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " "We will just return dygraph output. " "Please call ProgramTranslator.enable(True) if you would like to get static output." @@ -879,8 +878,9 @@ class ProgramTranslator(object): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_func" + if not self.enable_to_static: - warnings.warn( + logging_utils.warn( "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will " "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output." ) @@ -933,8 +933,9 @@ class ProgramTranslator(object): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_program" + if not self.enable_to_static: - warnings.warn( + logging_utils.warn( "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False." "We will just return dygraph output. " "Please call ProgramTranslator.enable(True) if you would like to get static output." diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 834c1a737d73bdb8ec3dd89eb5ccd6c0780a211d..10819e4b320dd0630c7ac43fdf89b84252823a94 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -26,6 +26,7 @@ from paddle.fluid import core from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph +from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer @@ -120,7 +121,7 @@ def _dygraph_to_static_func_(dygraph_func): def __impl__(*args, **kwargs): program_translator = ProgramTranslator() if in_dygraph_mode() or not program_translator.enable_to_static: - warnings.warn( + logging_utils.warn( "The decorator 'dygraph_to_static_func' doesn't work in " "dygraph mode or set ProgramTranslator.enable to False. " "We will just return dygraph output.") @@ -215,7 +216,7 @@ def declarative(function=None, input_spec=None): if isinstance(function, Layer): if isinstance(function.forward, StaticLayer): class_name = function.__class__.__name__ - warnings.warn( + logging_utils.warn( "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.". format(class_name)) function.forward = decorated(function.forward) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bc9f182d95e3b728fbc0866e1c79f5508d3a04aa..4a750f301a02c1a8f90e4103103c174baf32ead9 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11229,7 +11229,7 @@ def shape(input): input.shape = [3, 2] Args: - input (Variable): The input can be N-D Tensor or SelectedRows with data type float16, float32, float64, int32, int64. + input (Variable): The input can be N-D Tensor or SelectedRows with data type bool, float16, float32, float64, int32, int64. If input variable is type of SelectedRows, returns the shape of it's inner tensor. Returns: @@ -11253,8 +11253,8 @@ def shape(input): print(res) # [array([ 3, 100, 100], dtype=int32)] """ check_variable_and_dtype( - input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'shape') + input, 'input', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape') helper = LayerHelper('shape', **locals()) out = helper.create_variable_for_type_inference(dtype='int32') helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py index 510b615654751500c33dc3311353ba7e2f8baf40..b8a18179742df108d44dbf527adba819eefe91cd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py @@ -56,8 +56,30 @@ class TestLoggingUtils(unittest.TestCase): with self.assertRaises(TypeError): paddle.jit.set_verbosity(3.3) - def test_code_level(self): + def test_also_to_stdout(self): + logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = None + self.assertEqual( + logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False) + paddle.jit.set_verbosity(also_to_stdout=False) + self.assertEqual( + logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False) + + logging_utils._TRANSLATOR_LOGGER.need_to_echo_node_to_stdout = None + self.assertEqual( + logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, False) + + paddle.jit.set_code_level(also_to_stdout=True) + self.assertEqual( + logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, True) + + with self.assertRaises(AssertionError): + paddle.jit.set_verbosity(also_to_stdout=1) + + with self.assertRaises(AssertionError): + paddle.jit.set_code_level(also_to_stdout=1) + + def test_set_code_level(self): paddle.jit.set_code_level(None) os.environ[logging_utils.CODE_LEVEL_ENV_NAME] = '2' self.assertEqual(logging_utils.get_code_level(), 2) @@ -71,7 +93,25 @@ class TestLoggingUtils(unittest.TestCase): with self.assertRaises(TypeError): paddle.jit.set_code_level(3.3) - def test_log(self): + def test_log_api(self): + # test api for CI Converage + logging_utils.set_verbosity(1, True) + + logging_utils.warn("warn") + logging_utils.error("error") + + logging_utils.log(1, "log level 1") + logging_utils.log(2, "log level 2") + + source_code = "x = 3" + ast_code = gast.parse(source_code) + logging_utils.set_code_level(1, True) + logging_utils.log_transformed_code(1, ast_code, "TestTransformer") + logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True) + logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer, + ast_code, "TestTransformer") + + def test_log_message(self): stream = io.BytesIO() if six.PY2 else io.StringIO() log = self.translator_logger.logger stdout_handler = logging.StreamHandler(stream) @@ -84,13 +124,14 @@ class TestLoggingUtils(unittest.TestCase): if six.PY3: with mock.patch.object(sys, 'stdout', stream): + logging_utils.set_verbosity(1, False) logging_utils.warn(warn_msg) logging_utils.error(error_msg) - self.translator_logger.verbosity_level = 1 logging_utils.log(1, log_msg_1) logging_utils.log(2, log_msg_2) - result_msg = '\n'.join([warn_msg, error_msg, log_msg_1, ""]) + result_msg = '\n'.join( + [warn_msg, error_msg, "(Level 1) " + log_msg_1, ""]) self.assertEqual(result_msg, stream.getvalue()) def test_log_transformed_code(self): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb397b5a95b240dcaff9dee3758646b35ab5022 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -0,0 +1,171 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig +"""Test for fusion of conv and bias.""" + + +#padding SAME +class ConvBiasMkldnnFusePassTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + param_attr = fluid.ParamAttr( + initializer=fluid.initializer.Xavier(uniform=False), + learning_rate=0.001) + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + padding="SAME", + bias_attr=param_attr) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +#padding VALID +class ConvBiasMkldnnFusePassTest1(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + param_attr = fluid.ParamAttr( + initializer=fluid.initializer.Xavier(uniform=False), + learning_rate=0.001) + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + padding="VALID", + bias_attr=param_attr) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +#padding number +class ConvBiasMkldnnFusePassTest2(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + param_attr = fluid.ParamAttr( + initializer=fluid.initializer.Xavier(uniform=False), + learning_rate=0.001) + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + padding=[2, 4, 6, 8], + bias_attr=param_attr) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +#dilation not supported yet, just print warning log and does not fuse +class ConvBiasMkldnnFusePassTest3(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + param_attr = fluid.ParamAttr( + initializer=fluid.initializer.Xavier(uniform=False), + learning_rate=0.001) + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + padding="VALID", + dilation=2, + groups=3, + bias_attr=param_attr, + use_cudnn=False, + act="softmax", + data_format="NCHW") + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +#all conv params except for dilation +class ConvBiasMkldnnFusePassTest4(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + param_attr = fluid.ParamAttr( + initializer=fluid.initializer.Xavier(uniform=False), + learning_rate=0.001) + conv_out = fluid.layers.conv2d( + input=data, + num_filters=3, + filter_size=3, + padding="VALID", + groups=3, + bias_attr=param_attr, + use_cudnn=False, + act="softmax", + data_format="NCHW") + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [conv_out] + self.enable_mkldnn = True + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index 4982cd195820811b9a8ec3fe6d01955234032120..c6190590108876ba97feb4dad0c31884727ec978 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -26,7 +26,7 @@ def stable_softmax(x): return exps / np.sum(exps) -def log_softmax(x, axis=-1): +def log_softmax(x, axis=1): softmax_out = np.apply_along_axis(stable_softmax, axis, x) return np.log(softmax_out) diff --git a/python/paddle/fluid/tests/unittests/test_empty_like_op.py b/python/paddle/fluid/tests/unittests/test_empty_like_op.py new file mode 100644 index 0000000000000000000000000000000000000000..32d732d9a809950ade5484431b833056336acd54 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_empty_like_op.py @@ -0,0 +1,192 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.data_feeder import convert_dtype +import paddle.fluid.core as core +from paddle.static import program_guard, Program + + +class TestEmptyLikeAPICommon(unittest.TestCase): + def __check_out__(self, out): + data_type = convert_dtype(out.dtype) + self.assertEqual(data_type, self.dst_dtype, + 'dtype should be %s, but get %s' % + (self.dst_dtype, data_type)) + + shape = out.shape + self.assertTupleEqual(shape, self.dst_shape, + 'shape should be %s, but get %s' % + (self.dst_shape, shape)) + + if data_type in ['float32', 'float64', 'int32', 'int64']: + max_value = np.nanmax(out) + min_value = np.nanmin(out) + always_non_full_zero = max_value > min_value + always_full_zero = max_value == 0.0 and min_value == 0.0 + self.assertTrue(always_full_zero or always_non_full_zero, + 'always_full_zero or always_non_full_zero.') + elif data_type in ['bool']: + total_num = out.size + true_num = np.sum(out == True) + false_num = np.sum(out == False) + self.assertTrue(total_num == true_num + false_num, + 'The value should always be True or False.') + else: + self.assertTrue(False, 'invalid data type') + + +class TestEmptyLikeAPI(TestEmptyLikeAPICommon): + def setUp(self): + self.init_config() + + def test_dygraph_api_out(self): + paddle.disable_static() + out = paddle.empty_like(self.x, self.dtype) + self.__check_out__(out.numpy()) + paddle.enable_static() + + def init_config(self): + self.x = np.random.random((200, 3)).astype("float32") + self.dtype = self.x.dtype + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI2(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("float64") + self.dtype = self.x.dtype + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI3(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("int") + self.dtype = self.x.dtype + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI4(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("int64") + self.dtype = self.x.dtype + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI5(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("bool") + self.dtype = self.x.dtype + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI6(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("float64") + self.dtype = "float32" + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI7(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("int") + self.dtype = "float32" + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI8(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("int64") + self.dtype = "float32" + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI9(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("bool") + self.dtype = "float32" + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI10(TestEmptyLikeAPI): + def init_config(self): + self.x = np.random.random((200, 3)).astype("float32") + self.dtype = "bool" + self.dst_shape = self.x.shape + self.dst_dtype = self.dtype + + +class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon): + def setUp(self): + self.init_config() + + def test_static_graph(self): + dtype = 'float32' + + train_program = Program() + startup_program = Program() + + with program_guard(train_program, startup_program): + x = np.random.random(self.x_shape).astype(dtype) + data_x = paddle.static.data( + 'x', shape=self.data_x_shape, dtype=dtype) + + out = paddle.empty_like(data_x) + + place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + res = exe.run(train_program, feed={'x': x}, fetch_list=[out]) + + self.dst_dtype = dtype + self.dst_shape = x.shape + self.__check_out__(res[0]) + + def init_config(self): + self.x_shape = (200, 3) + self.data_x_shape = [200, 3] + + +class TestEmptyLikeAPI_Static2(TestEmptyLikeAPI_Static): + def init_config(self): + self.x_shape = (3, 200, 3) + self.data_x_shape = [-1, 200, 3] + + +class TestEmptyError(unittest.TestCase): + def test_attr(self): + def test_dtype(): + x = np.random.random((200, 3)).astype("float64") + dtype = 'uint8' + result = paddle.empty_like(x, dtype=dtype) + + self.assertRaises(TypeError, test_dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index da086c0955e849619ccbce17a297ca4615a3f3d0..4395520eec70e8483cb61097a166576f4040cb4d 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1093,7 +1093,7 @@ def cross_entropy(input, " 'none', but received %s, which is not allowed." % reduction) #step 1. log_softmax - log_softmax_out = paddle.nn.functional.log_softmax(input) + log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1) if weight is not None and not isinstance(weight, Variable): raise ValueError( "The weight' is not a Variable, please convert to Variable.") diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 8bb584be2362e7b02bc5b7c5603b148d37499c2d..a713663e1822d4af2d09efb2986aeb513930bbc0 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -41,6 +41,7 @@ from .creation import triu #DEFINE_ALIAS from .creation import tril #DEFINE_ALIAS from .creation import meshgrid #DEFINE_ALIAS from .creation import empty #DEFINE_ALIAS +from .creation import empty_like #DEFINE_ALIAS from .io import save #DEFINE_ALIAS from .io import load #DEFINE_ALIAS from .linalg import matmul #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 8011b92964b7e21fd930f19cec954b27f470e0c6..9aee911e568d1b2cd7aac0cf45e44f2886612a5a 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -49,6 +49,7 @@ __all__ = [ 'full', 'full_like', 'empty', + 'empty_like', 'triu', 'tril', 'meshgrid' @@ -1068,3 +1069,70 @@ def empty(shape, dtype=None, name=None): stop_gradient=True) out.stop_gradient = True return out + + +def empty_like(x, dtype=None, name=None): + """ + This Op returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``. + If the ``dtype`` is None, the data type of Tensor is same with ``x``. + + Args: + x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64. + dtype(np.dtype|str, optional): The data type of output. The data type can be one + of bool, float16, float32, float64, int32, int64. The default value is None, which means the output + data type is the same as input. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() # Now we are in imperative mode + paddle.set_device("cpu") # and use cpu device + + x = paddle.randn([2, 3], 'float32') + output = paddle.empty_like(x) + #[[1.8491974e+20 1.8037303e+28 1.7443726e+28] # uninitialized + # [4.9640171e+28 3.0186127e+32 5.6715899e-11]] # uninitialized + """ + + if dtype is None: + dtype = x.dtype + dtype = convert_dtype(dtype) + + if in_dygraph_mode(): + out = core.ops.empty('shape', x.shape, 'dtype', + convert_np_dtype_to_dtype_(dtype)) + out.stop_gradient = True + return out + + helper = LayerHelper("empty_like", **locals()) + check_variable_and_dtype( + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'empty_like') + check_dtype(dtype, 'dtype', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'empty_like') + out = helper.create_variable_for_type_inference(dtype=dtype) + + inputs = {} + attrs = {} + attrs['dtype'] = convert_np_dtype_to_dtype_(dtype) + shape = paddle.shape(x) + utils.get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='empty_like') + + helper.append_op( + type='empty', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs, + stop_gradient=True) + out.stop_gradient = True + return out