提交 6ef1fbb6 编写于 作者: Y yaoxuefeng6

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mod_dataset_v2

......@@ -19,6 +19,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
REGISTER_PASS(embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::EmbeddingEltwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("lookup_table", 0)
.EQ("elementweise_add", 0));
......@@ -16,12 +16,13 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(SkipLayerNormFusePass, basic) {
TEST(EmbeddingElewiseLayernormFusePass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (x, y) elementwise_add -> elementwise_out
......@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"));
}
TEST(EmbeddingElewiseLayernormFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("embedding_eltwise_layernorm_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -17,6 +17,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "do not perform " + type() + "+bias fuse";
return;
}
if (conv->Op()->HasAttr("dilations")) {
auto dilations =
BOOST_GET_CONST(std::vector<int>, conv->Op()->GetAttr("dilations"));
for (const auto& d : dilations) {
if (d != 1) {
LOG(WARNING)
<< "dilation conv not supported in MKLDNN, fuse not apply "
<< "and set conv attribute use_mkldnn = false";
conv->Op()->SetAttr("use_mkldnn", false);
return;
}
}
}
auto* eltwise_bias_tensor =
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
......@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DTransposeBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0));
......@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
......@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
ASSERT_EQ(pass.type(), std::string("conv2d_transpose"));
}
TEST(ConvBiasFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("conv_bias_mkldnn_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -19,6 +19,7 @@
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ResidualConnectionMKLDNNFusePass);
REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("elementwise_add", 0));
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
AssertOpsCount(graph, 2, 1);
}
TEST(ConvElementwiseAddMKLDNNFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("conv_elementwise_add_mkldnn_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass);
REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"depthwise_conv2d", 0));
......@@ -16,6 +16,8 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
return prog;
}
TEST(DepthwiseConvMKLDNNPass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("depthwise_conv_mkldnn_pass"));
}
TEST(DepthwiseConvMKLDNNPass, basic) {
auto prog = BuildProgramDesc();
......
......@@ -19,6 +19,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
......@@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
REGISTER_PASS(multihead_matmul_fuse_pass_v2,
paddle::framework::ir::MultiHeadMatmulV2FusePass);
REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("mul", 0)
.EQ("elementwise_add", 0)
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("scale", 0)
.EQ("matmul", 0)
.EQ("softmax", 0));
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
num_fused_nodes_after));
}
TEST(MultiHeadMatmulFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("multihead_matmul_fuse_pass_v2"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(skip_layernorm_fuse_pass,
paddle::framework::ir::SkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("elementwise_add", 0)
.EQ("layer_norm", 0));
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"));
}
TEST(SkipLayerNormFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("skip_layernorm_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -54,9 +54,13 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float average_window = ctx.Attr<float>("average_window");
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
"min_average_window shouldn't be larger than "
"max_average_window");
PADDLE_ENFORCE_LE(
min_average_window, max_average_window,
platform::errors::InvalidArgument(
"The min_average_window > "
"max_average_window is not right, min_average_window is %ld, "
"max_average_window is %ld.",
min_average_window, max_average_window));
// Get inputs
auto* param = ctx.Input<Tensor>("param");
......
......@@ -55,31 +55,38 @@ class EmptyOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "empty");
if (context->HasInput("ShapeTensor")) {
auto dims = context->GetInputDim("ShapeTensor");
auto shape_dims = context->GetInputDim("ShapeTensor");
int num_ele = 1;
for (int i = 0; i < dims.size(); ++i) {
num_ele *= dims[i];
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
context->SetOutputDim("Out", framework::make_ddim({num_ele}));
auto vec_dims = std::vector<int>(num_ele, -1);
context->SetOutputDim("Out", framework::make_ddim(vec_dims));
} else if (context->HasInputs("ShapeTensorList")) {
std::vector<int> out_dims;
auto dims_list = context->GetInputsDim("ShapeTensorList");
for (size_t i = 0; i < dims_list.size(); ++i) {
auto& dims = dims_list[i];
PADDLE_ENFORCE_EQ(
dims, framework::make_ddim({1}),
"ShapeError: The shape of Tensor in list must be [1]. "
"But received the shape "
"is [%s]",
dims);
out_dims.push_back(dims[0]);
PADDLE_ENFORCE_EQ(dims, framework::make_ddim({1}),
platform::errors::InvalidArgument(
"The shape of Tensor in list must be [1]. "
"But received the shape is [%s]",
dims));
out_dims.push_back(-1);
}
context->SetOutputDim("Out", framework::make_ddim(out_dims));
} else {
auto& shape = context->Attrs().Get<std::vector<int64_t>>("shape");
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE_GE(
shape[i], 0,
platform::errors::InvalidArgument(
"Each value of attribute 'shape' is expected to be no less "
"than 0. But recieved: shape[%u] = %d; shape = [%s].",
i, shape[i], framework::make_ddim(shape)));
}
context->SetOutputDim("Out", framework::make_ddim(shape));
}
}
......
......@@ -87,7 +87,10 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
lod[0].assign(high_level.begin(), high_level.end());
lod[1].assign(low_level.begin(), low_level.end());
if (!framework::CheckLoD(lod)) {
PADDLE_THROW("lod %s is not right", framework::LoDToString(lod));
PADDLE_THROW(platform::errors::InvalidArgument(
"lod %s is not right in"
" beam_search, please check your code.",
framework::LoDToString(lod)));
}
selected_ids->set_lod(lod);
selected_scores->set_lod(lod);
......
......@@ -400,7 +400,10 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
context.Wait();
if (!framework::CheckLoD(selected_lod)) {
PADDLE_THROW("lod %s is not right", framework::LoDToString(selected_lod));
PADDLE_THROW(platform::errors::InvalidArgument(
"lod %s is not right in"
" beam_search, please check your code.",
framework::LoDToString(selected_lod)));
}
selected_ids->set_lod(selected_lod);
......
......@@ -20,7 +20,11 @@ namespace operators {
namespace math {
MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim,
int num_flatten_cols, bool trans) {
PADDLE_ENFORCE_GT(tensor_dim.size(), 1);
PADDLE_ENFORCE_GT(
tensor_dim.size(), 1,
platform::errors::InvalidArgument("The tensor dim size should be greater "
"than 1, but reveived dim size is %d",
tensor_dim.size()));
MatDescriptor retv;
if (num_flatten_cols > 1) {
auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols);
......
......@@ -60,7 +60,8 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgemmStridedBatched(args...));
#else
PADDLE_THROW("SgemmStridedBatched is not supported on cuda <= 7.5");
PADDLE_THROW(platform::errors::Unimplemented(
"SgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
}
......@@ -85,7 +86,8 @@ struct CUBlas<float> {
beta, C, Ctype, ldc));
});
#else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSgemmEx is not supported on cuda <= 7.5"));
#endif
}
......@@ -146,13 +148,15 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgemmStridedBatched(args...));
#else
PADDLE_THROW("DgemmStridedBatched is not supported on cuda <= 7.5");
PADDLE_THROW(platform::errors::Unimplemented(
"DgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
}
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
PADDLE_THROW("Currently there are not cublasDgemmEx.");
PADDLE_THROW(platform::errors::Unimplemented(
"Currently there are not cublasDgemmEx."));
}
template <typename... ARGS>
......@@ -216,7 +220,8 @@ struct CUBlas<platform::float16> {
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
ldc, strideC, batchCount));
#else
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
PADDLE_THROW(platform::errors::Unimplemented(
"HgemmStridedBatched is not supported on cuda <= 7.5"));
#endif
}
......@@ -247,7 +252,8 @@ struct CUBlas<platform::float16> {
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}
};
......@@ -302,8 +308,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
......
......@@ -29,7 +29,8 @@ template <>
struct CBlas<int8_t> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_THROW("Blas VCOPY don't support int8_t");
PADDLE_THROW(platform::errors::Unimplemented(
"Blas VCOPY do not supported on CPU, please check your code"));
}
};
......@@ -347,22 +348,47 @@ struct CBlas<double> {
template <>
struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
static void GEMM(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 GEMM not supported on CPU, please check your code"));
}
static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
PADDLE_THROW(platform::errors::Unimplemented(
"float16 SMM_GEMM not supported on CPU, please check your code"));
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void VSQUARE(...) {
PADDLE_THROW("float16 VSQUARE not supported on CPU");
static void VMUL(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 VMUL not supported on CPU, please check your code"));
}
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
static void ASUM(...) { PADDLE_THROW("float16 ASUM not supported on CPU"); };
static void VEXP(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 VEXP not supported on CPU, please check your code"));
}
static void VSQUARE(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 VSQUARE not supported on CPU, please check your code"));
}
static void VPOW(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 VPOW not supported on CPU, please check your code"));
}
static void DOT(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 DOT not supported on CPU, please check your code"));
};
static void SCAL(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 SCAL not supported on CPU, please check your code"));
};
static void ASUM(...) {
PADDLE_THROW(platform::errors::Unimplemented(
"float16 ASUM not supported on CPU, please check your code"));
};
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
PADDLE_THROW(platform::errors::Unimplemented(
"float16 GEMM_BATCH not supported on CPU, please check your code"));
}
#endif
};
......@@ -446,11 +472,18 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a, bool trans_a,
auto dim_a = mat_a.dims();
auto dim_b = mat_b.dims();
auto dim_out = mat_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(),
"The places of matrices must be same");
PADDLE_ENFORCE_EQ(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, true,
platform::errors::InvalidArgument(
"The input and output of matmul should be matrix, the dim size must "
"be 2,"
"but received dim size input_a:%d, input_b:%d, output:%d",
dim_a.size(), dim_b.size(), dim_out.size()));
PADDLE_ENFORCE_EQ(
mat_a.place() == mat_b.place() && mat_a.place() == mat_out->place(), true,
platform::errors::InvalidArgument("The places of matrices in the matmul "
"should be same, please check your "
"code."));
int M = dim_out[0];
int N = dim_out[1];
......@@ -715,7 +748,13 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
} else {
PADDLE_ENFORCE_EQ(W1, H2);
PADDLE_ENFORCE_EQ(
W1, H2,
platform::errors::InvalidArgument(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d",
W1, H2));
int ldc = W2 * head_number;
int sub_width = W1 / head_number;
......@@ -785,7 +824,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
PADDLE_ENFORCE_EQ(
dim_a.width_, dim_b.height_,
platform::errors::InvalidArgument(
"The fisrt matrix width should be same as second matrix height,"
"but received fisrt matrix width %d"
", second matrix height %d",
dim_a.width_, dim_b.height_));
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
......@@ -793,12 +839,14 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0,
"dim_a.batch_size should be equal to dim_b.batch_size, or "
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_);
PADDLE_ENFORCE_EQ(
dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
dim_b.batch_size_ == 0,
true, platform::errors::InvalidArgument(
"dim_a.batch_size should be equal to dim_b.batch_size, or "
"one of dim_a.batch_size and dim_b.batch_size should be 0. "
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_));
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
......@@ -834,15 +882,42 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
int head_number,
framework::Tensor *mat_out, T beta,
bool mat_b_split_vertical) const {
PADDLE_ENFORCE_EQ(dim_a.width_ % head_number, 0);
PADDLE_ENFORCE_GE(head_number, 1);
PADDLE_ENFORCE_LE(head_number, dim_a.width_);
PADDLE_ENFORCE_EQ(
dim_a.width_ % head_number, 0,
platform::errors::InvalidArgument(
"The first input width must be some times the head number"
"but received first input width %d"
", head_number %d",
dim_a.width_, head_number));
PADDLE_ENFORCE_GE(head_number, 1,
platform::errors::InvalidArgument(
"The head number should be greater equal 1,"
"but received head number %d",
head_number));
PADDLE_ENFORCE_LE(
head_number, dim_a.width_,
platform::errors::InvalidArgument(
"The head number should be less equal first input width,"
"but received first input width %d"
", head_number %d",
dim_a.width_, head_number));
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (mat_b_split_vertical) {
PADDLE_ENFORCE_EQ(dim_b.height_, dim_a.width_ / head_number);
PADDLE_ENFORCE_EQ(dim_b.width_ % head_number, 0);
PADDLE_ENFORCE_EQ(
dim_b.height_, dim_a.width_ / head_number,
platform::errors::InvalidArgument(
"The second input height should be equal than first input width,"
"but received second input height %d, first input width %d",
dim_b.height_, dim_a.width_ / head_number));
PADDLE_ENFORCE_EQ(
dim_a.width_ % head_number, 0,
platform::errors::InvalidArgument(
"The second input width should be some times the head number"
"but received second input width %d"
", head_number %d",
dim_b.width_, head_number));
}
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
......@@ -888,9 +963,16 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
mat_out->data<T>() + sub_matC_offset, ldc);
}
} else {
PADDLE_ENFORCE_EQ((dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0),
true);
PADDLE_ENFORCE_EQ(
(dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
dim_b.batch_size_ == 0),
true,
platform::errors::InvalidArgument(
"The first input batch size should be equal than second input,"
"either two input batch size is 0, but received first input batch "
"size"
" %d, second input batch size %d",
dim_a.batch_size_, dim_b.batch_size_));
this->template BatchedGEMMWithHead<T>(
transA, transB, dim_a.width_, dim_a.height_, dim_b.width_,
......
......@@ -68,6 +68,6 @@ REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int32_t>,
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<bool>, ops::ShapeKernel<int>,
ops::ShapeKernel<int64_t>, ops::ShapeKernel<float>,
ops::ShapeKernel<double>);
......@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL(
shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int32_t>,
shape, paddle::operators::ShapeKernel<bool>,
paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int64_t>,
paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>,
......
......@@ -77,6 +77,7 @@ from .tensor.creation import triu #DEFINE_ALIAS
from .tensor.creation import tril #DEFINE_ALIAS
from .tensor.creation import meshgrid #DEFINE_ALIAS
from .tensor.creation import empty #DEFINE_ALIAS
from .tensor.creation import empty_like #DEFINE_ALIAS
from .tensor.linalg import matmul #DEFINE_ALIAS
from .tensor.linalg import dot #DEFINE_ALIAS
# from .tensor.linalg import einsum #DEFINE_ALIAS
......
......@@ -67,6 +67,7 @@ class ImperativeQuantAware(object):
Examples:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
from paddle.vision.models \
......@@ -86,13 +87,12 @@ class ImperativeQuantAware(object):
# ...
# Save quant model for the inference.
imperative_qat.save_quantized_model(
dirname="./resnet50_qat",
model=model,
input_shape=[(3, 224, 224)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
paddle.jit.save(
layer=model,
model_path="./resnet50_qat",
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')])
"""
super(ImperativeQuantAware, self).__init__()
self._weight_bits = weight_bits
......@@ -148,75 +148,6 @@ class ImperativeQuantAware(object):
quant_layer = self._get_quantized_counterpart(layer)
setattr(obj, target, quant_layer)
def save_quantized_model(self,
dirname,
model,
input_shape,
input_dtype,
feed,
fetch,
append_batch_size=True):
"""
Save the quantized model for the inference.
Args:
dirname (str): the directory to save the quantized model.
model(fluid.dygraph.Layer): the quantized model to be saved.
input_shape(list[tuple(int)]): The shape value for each input,
e.g. [(3, 224, 224)].
input_dtype(list[str]): The dtype value for each input,
e.g. ['float32'].
feed(list[int]): the indices of the input variables of the
imperative functions which will be saved as input variables in
inference model.
fetch(list[int]): the indices of the returned variable of the
imperative functions which will be saved as output variables in
inference model.
append_batch_size(bool, optional):
If true, it prepends an extra axis to the input_shape, meanwhile,
the input_shape shouldn't contain the batch size dimension.
Otherwise, it just uses the input_shape. Default True.
Returns:
None
"""
assert isinstance(
input_shape, list), "The parameter `input_shape` shoubld be a list."
assert isinstance(
input_dtype, list), "The parameter `input_dtype` shoubld be a list."
assert isinstance(feed, list), "The parameter `feed` shoubld be a list."
assert isinstance(fetch,
list), "The parameter `fetch` shoubld be a list."
assert len(input_shape) == len(
input_dtype
), "The length of input_shape should be equal to input_dtype's."
assert len(input_dtype) == len(
feed), "The length of input_shape should be equal to feed's."
with dygraph.guard():
model.eval()
input_vars = []
for i, (shape, dtype) in enumerate(zip(input_shape, input_dtype)):
if append_batch_size:
shape = [None] + list(shape)
# Note(Aurelius84): need a elegant way to name this.
in_spec = paddle.static.InputSpec(shape, dtype, 'feed_%d' % i)
input_vars.append(in_spec)
# use `declarative` to convert dygraph into static program
model.forward = dygraph.jit.declarative(
model.forward, input_spec=input_vars)
outputs = model.forward.concrete_program.outputs
input_spec = [input_vars[i] for i in feed]
configs = dygraph.jit.SaveLoadConfig()
configs.separate_params = True
if not isinstance(outputs, (tuple, list)):
outputs = [outputs]
configs.output_spec = [outputs[i] for i in fetch]
dygraph.jit.save(
layer=model,
model_path=dirname,
input_spec=input_spec,
configs=configs)
def _get_quantized_counterpart(self, layer):
quant_layers = tuple(self._quant_layers_map.values())
quantized_counterpart = tuple('Quantized' + k
......
......@@ -221,7 +221,7 @@ class TestImperativeQat(unittest.TestCase):
model_dict = lenet.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# test the correctness of `save_quantized_model`
# test the correctness of `paddle.jit.save`
data = next(test_reader())
test_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
......@@ -231,13 +231,14 @@ class TestImperativeQat(unittest.TestCase):
# save inference quantized model
path = "./mnist_infer_model"
imperative_qat.save_quantized_model(
dirname=path,
model=lenet,
input_shape=[(1, 28, 28)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
paddle.jit.save(
layer=lenet,
model_path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
......@@ -245,7 +246,10 @@ class TestImperativeQat(unittest.TestCase):
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=path, executor=exe))
dirname=path,
executor=exe,
model_filename="__model__",
params_filename="__variables__"))
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
......@@ -332,13 +336,13 @@ class TestImperativeQat(unittest.TestCase):
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
imperative_qat.save_quantized_model(
dirname="./dynamic_mnist",
model=lenet,
input_shape=[(1, 28, 28)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
paddle.jit.save(
layer=lenet,
model_path="./dynamic_mnist",
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
# static graph train
_logger.info(
......
......@@ -60,7 +60,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
def transfer_from_node_type(self, node_wrapper):
translator_logger = logging_utils.TranslatorLogger()
translator_logger.log(
1, " Source code: \n{}".format(ast_to_source_code(self.root)))
1, "Source code: \n{}".format(ast_to_source_code(self.root)))
# Generic transformation
self.visit(node_wrapper.node)
......
......@@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import inspect
import numpy as np
import collections
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
......@@ -291,7 +292,7 @@ def convert_to_input_spec(inputs, input_spec):
if len(inputs) > len(input_spec):
for rest_input in inputs[len(input_spec):]:
if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging.warning(
logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.".
format(type_name(rest_input)))
......
......@@ -26,6 +26,8 @@ CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
DEFAULT_VERBOSITY = -1
DEFAULT_CODE_LEVEL = -1
LOG_AllTransformer = 100
def synchronized(func):
def wrapper(*args, **kwargs):
......@@ -53,10 +55,15 @@ class TranslatorLogger(object):
return
self._initialized = True
self.logger_name = "Dynamic-to-Static"
self._logger = log_helper.get_logger(
__name__, 1, fmt='%(asctime)s-%(levelname)s: %(message)s')
self.logger_name,
1,
fmt='%(asctime)s %(name)s %(levelname)s: %(message)s')
self._verbosity_level = None
self._transformed_code_level = None
self._need_to_echo_log_to_stdout = None
self._need_to_echo_code_to_stdout = None
@property
def logger(self):
......@@ -86,6 +93,28 @@ class TranslatorLogger(object):
self.check_level(level)
self._transformed_code_level = level
@property
def need_to_echo_log_to_stdout(self):
if self._need_to_echo_log_to_stdout is not None:
return self._need_to_echo_log_to_stdout
return False
@need_to_echo_log_to_stdout.setter
def need_to_echo_log_to_stdout(self, log_to_stdout):
assert isinstance(log_to_stdout, (bool, type(None)))
self._need_to_echo_log_to_stdout = log_to_stdout
@property
def need_to_echo_code_to_stdout(self):
if self._need_to_echo_code_to_stdout is not None:
return self._need_to_echo_code_to_stdout
return False
@need_to_echo_code_to_stdout.setter
def need_to_echo_code_to_stdout(self, code_to_stdout):
assert isinstance(code_to_stdout, (bool, type(None)))
self._need_to_echo_code_to_stdout = code_to_stdout
def check_level(self, level):
if isinstance(level, (six.integer_types, type(None))):
rv = level
......@@ -110,34 +139,56 @@ class TranslatorLogger(object):
def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('ERROR: ' + msg, *args)
def warn(self, msg, *args, **kwargs):
self.logger.warn(msg, *args, **kwargs)
self.logger.warning(msg, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('WARNING: ' + msg, *args)
def log(self, level, msg, *args, **kwargs):
if self.has_verbosity(level):
self.logger.log(level, msg, *args, **kwargs)
msg_with_level = '(Level {}) {}'.format(level, msg)
self.logger.info(msg_with_level, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('INFO: ' + msg_with_level, *args)
def log_transformed_code(self, level, ast_node, transformer_name, *args,
**kwargs):
if self.has_code_level(level):
source_code = ast_to_source_code(ast_node)
header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\
.format(level, transformer_name)
if level == LOG_AllTransformer:
header_msg = "After the last level ast transformer: '{}', the transformed code:\n" \
.format(transformer_name)
else:
header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\
.format(level, transformer_name)
msg = header_msg + source_code
self.logger.info(msg, *args, **kwargs)
if self.need_to_echo_code_to_stdout:
self._output_to_stdout('INFO: ' + msg, *args)
def _output_to_stdout(self, msg, *args):
msg = self.logger_name + ' ' + msg
print(msg % args)
_TRANSLATOR_LOGGER = TranslatorLogger()
def set_verbosity(level=0):
def set_verbosity(level=0, also_to_stdout=False):
"""
Sets the verbosity level of log for dygraph to static graph.
Sets the verbosity level of log for dygraph to static graph. Logs can be output to stdout by setting `also_to_stdout`.
There are two means to set the logging verbosity:
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
**Note**:
`set_verbosity` has a higher priority than the environment variable.
......@@ -145,6 +196,7 @@ def set_verbosity(level=0):
Args:
level(int): The verbosity level. The larger value idicates more verbosity.
The default value is 0, which means no logging.
also_to_stdout(bool): Whether to also output log messages to `sys.stdout`.
Examples:
.. code-block:: python
......@@ -159,27 +211,30 @@ def set_verbosity(level=0):
# The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity`
"""
_TRANSLATOR_LOGGER.verbosity_level = level
_TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = also_to_stdout
def get_verbosity():
return _TRANSLATOR_LOGGER.verbosity_level
LOG_AllTransformer = 100
def set_code_level(level=LOG_AllTransformer):
def set_code_level(level=LOG_AllTransformer, also_to_stdout=False):
"""
Sets the level to print code from specific level of Ast Transformer.
Sets the level to print code from specific level Ast Transformer. Code can be output to stdout by setting `also_to_stdout`.
There are two means to set the code level:
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
**Note**:
`set_code_level` has a higher priority than the environment variable.
Args:
level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers.
also_to_stdout(bool): Whether to also output code to `sys.stdout`.
Examples:
.. code-block:: python
......@@ -195,6 +250,7 @@ def set_code_level(level=LOG_AllTransformer):
"""
_TRANSLATOR_LOGGER.transformed_code_level = level
_TRANSLATOR_LOGGER.need_to_echo_code_to_stdout = also_to_stdout
def get_code_level():
......
......@@ -14,21 +14,17 @@
from __future__ import print_function
import numpy as np
import logging
import six
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
import paddle.compat as cpt
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class NestSequence(object):
"""
......@@ -72,7 +68,7 @@ class NestSequence(object):
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
_logger.warning(
logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".
......
......@@ -15,14 +15,8 @@
from __future__ import print_function
import gast
import logging
from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
class PrintTransformer(gast.NodeTransformer):
......
......@@ -13,17 +13,15 @@
# limitations under the License.
from __future__ import print_function
import gast
import collections
import logging
import gast
import inspect
import six
import textwrap
import threading
import warnings
import weakref
import gast
from paddle.fluid import framework
from paddle.fluid import in_dygraph_mode
from paddle.fluid.dygraph import layers
......@@ -451,7 +449,7 @@ class StaticLayer(object):
format(self._function_spec))
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging.warning(
logging_utils.warn(
"Current {} has more than one cached programs: {}, the last traced progam will be return by default.".
format(self._function_spec, cached_program_len))
......@@ -632,7 +630,7 @@ class ProgramCache(object):
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
logging.warning(
logging_utils.warn(
"Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.".
format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT))
......@@ -804,8 +802,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
......@@ -879,8 +878,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
......@@ -933,8 +933,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
......
......@@ -26,6 +26,7 @@ from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
......@@ -120,7 +121,7 @@ def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn(
logging_utils.warn(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. "
"We will just return dygraph output.")
......@@ -215,7 +216,7 @@ def declarative(function=None, input_spec=None):
if isinstance(function, Layer):
if isinstance(function.forward, StaticLayer):
class_name = function.__class__.__name__
warnings.warn(
logging_utils.warn(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
format(class_name))
function.forward = decorated(function.forward)
......
......@@ -11229,7 +11229,7 @@ def shape(input):
input.shape = [3, 2]
Args:
input (Variable): The input can be N-D Tensor or SelectedRows with data type float16, float32, float64, int32, int64.
input (Variable): The input can be N-D Tensor or SelectedRows with data type bool, float16, float32, float64, int32, int64.
If input variable is type of SelectedRows, returns the shape of it's inner tensor.
Returns:
......@@ -11253,8 +11253,8 @@ def shape(input):
print(res) # [array([ 3, 100, 100], dtype=int32)]
"""
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'shape')
input, 'input',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'shape')
helper = LayerHelper('shape', **locals())
out = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
......
......@@ -56,8 +56,30 @@ class TestLoggingUtils(unittest.TestCase):
with self.assertRaises(TypeError):
paddle.jit.set_verbosity(3.3)
def test_code_level(self):
def test_also_to_stdout(self):
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = None
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False)
paddle.jit.set_verbosity(also_to_stdout=False)
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False)
logging_utils._TRANSLATOR_LOGGER.need_to_echo_node_to_stdout = None
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, False)
paddle.jit.set_code_level(also_to_stdout=True)
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, True)
with self.assertRaises(AssertionError):
paddle.jit.set_verbosity(also_to_stdout=1)
with self.assertRaises(AssertionError):
paddle.jit.set_code_level(also_to_stdout=1)
def test_set_code_level(self):
paddle.jit.set_code_level(None)
os.environ[logging_utils.CODE_LEVEL_ENV_NAME] = '2'
self.assertEqual(logging_utils.get_code_level(), 2)
......@@ -71,7 +93,25 @@ class TestLoggingUtils(unittest.TestCase):
with self.assertRaises(TypeError):
paddle.jit.set_code_level(3.3)
def test_log(self):
def test_log_api(self):
# test api for CI Converage
logging_utils.set_verbosity(1, True)
logging_utils.warn("warn")
logging_utils.error("error")
logging_utils.log(1, "log level 1")
logging_utils.log(2, "log level 2")
source_code = "x = 3"
ast_code = gast.parse(source_code)
logging_utils.set_code_level(1, True)
logging_utils.log_transformed_code(1, ast_code, "TestTransformer")
logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True)
logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
ast_code, "TestTransformer")
def test_log_message(self):
stream = io.BytesIO() if six.PY2 else io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
......@@ -84,13 +124,14 @@ class TestLoggingUtils(unittest.TestCase):
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
logging_utils.set_verbosity(1, False)
logging_utils.warn(warn_msg)
logging_utils.error(error_msg)
self.translator_logger.verbosity_level = 1
logging_utils.log(1, log_msg_1)
logging_utils.log(2, log_msg_2)
result_msg = '\n'.join([warn_msg, error_msg, log_msg_1, ""])
result_msg = '\n'.join(
[warn_msg, error_msg, "(Level 1) " + log_msg_1, ""])
self.assertEqual(result_msg, stream.getvalue())
def test_log_transformed_code(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
"""Test for fusion of conv and bias."""
#padding SAME
class ConvBiasMkldnnFusePassTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
padding="SAME",
bias_attr=param_attr)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
#padding VALID
class ConvBiasMkldnnFusePassTest1(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
padding="VALID",
bias_attr=param_attr)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
#padding number
class ConvBiasMkldnnFusePassTest2(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
padding=[2, 4, 6, 8],
bias_attr=param_attr)
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
#dilation not supported yet, just print warning log and does not fuse
class ConvBiasMkldnnFusePassTest3(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
padding="VALID",
dilation=2,
groups=3,
bias_attr=param_attr,
use_cudnn=False,
act="softmax",
data_format="NCHW")
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
#all conv params except for dilation
class ConvBiasMkldnnFusePassTest4(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 100, 100], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv_out = fluid.layers.conv2d(
input=data,
num_filters=3,
filter_size=3,
padding="VALID",
groups=3,
bias_attr=param_attr,
use_cudnn=False,
act="softmax",
data_format="NCHW")
self.feeds = {
"data": np.random.random((1, 3, 100, 100)).astype("float32")
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
if __name__ == "__main__":
unittest.main()
......@@ -26,7 +26,7 @@ def stable_softmax(x):
return exps / np.sum(exps)
def log_softmax(x, axis=-1):
def log_softmax(x, axis=1):
softmax_out = np.apply_along_axis(stable_softmax, axis, x)
return np.log(softmax_out)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.data_feeder import convert_dtype
import paddle.fluid.core as core
from paddle.static import program_guard, Program
class TestEmptyLikeAPICommon(unittest.TestCase):
def __check_out__(self, out):
data_type = convert_dtype(out.dtype)
self.assertEqual(data_type, self.dst_dtype,
'dtype should be %s, but get %s' %
(self.dst_dtype, data_type))
shape = out.shape
self.assertTupleEqual(shape, self.dst_shape,
'shape should be %s, but get %s' %
(self.dst_shape, shape))
if data_type in ['float32', 'float64', 'int32', 'int64']:
max_value = np.nanmax(out)
min_value = np.nanmin(out)
always_non_full_zero = max_value > min_value
always_full_zero = max_value == 0.0 and min_value == 0.0
self.assertTrue(always_full_zero or always_non_full_zero,
'always_full_zero or always_non_full_zero.')
elif data_type in ['bool']:
total_num = out.size
true_num = np.sum(out == True)
false_num = np.sum(out == False)
self.assertTrue(total_num == true_num + false_num,
'The value should always be True or False.')
else:
self.assertTrue(False, 'invalid data type')
class TestEmptyLikeAPI(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()
def test_dygraph_api_out(self):
paddle.disable_static()
out = paddle.empty_like(self.x, self.dtype)
self.__check_out__(out.numpy())
paddle.enable_static()
def init_config(self):
self.x = np.random.random((200, 3)).astype("float32")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI2(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float64")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI3(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI4(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int64")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI5(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("bool")
self.dtype = self.x.dtype
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI6(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float64")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI7(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI8(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("int64")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI9(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("bool")
self.dtype = "float32"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI10(TestEmptyLikeAPI):
def init_config(self):
self.x = np.random.random((200, 3)).astype("float32")
self.dtype = "bool"
self.dst_shape = self.x.shape
self.dst_dtype = self.dtype
class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()
def test_static_graph(self):
dtype = 'float32'
train_program = Program()
startup_program = Program()
with program_guard(train_program, startup_program):
x = np.random.random(self.x_shape).astype(dtype)
data_x = paddle.static.data(
'x', shape=self.data_x_shape, dtype=dtype)
out = paddle.empty_like(data_x)
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])
self.dst_dtype = dtype
self.dst_shape = x.shape
self.__check_out__(res[0])
def init_config(self):
self.x_shape = (200, 3)
self.data_x_shape = [200, 3]
class TestEmptyLikeAPI_Static2(TestEmptyLikeAPI_Static):
def init_config(self):
self.x_shape = (3, 200, 3)
self.data_x_shape = [-1, 200, 3]
class TestEmptyError(unittest.TestCase):
def test_attr(self):
def test_dtype():
x = np.random.random((200, 3)).astype("float64")
dtype = 'uint8'
result = paddle.empty_like(x, dtype=dtype)
self.assertRaises(TypeError, test_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -1093,7 +1093,7 @@ def cross_entropy(input,
" 'none', but received %s, which is not allowed." % reduction)
#step 1. log_softmax
log_softmax_out = paddle.nn.functional.log_softmax(input)
log_softmax_out = paddle.nn.functional.log_softmax(input, axis=1)
if weight is not None and not isinstance(weight, Variable):
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")
......
......@@ -41,6 +41,7 @@ from .creation import triu #DEFINE_ALIAS
from .creation import tril #DEFINE_ALIAS
from .creation import meshgrid #DEFINE_ALIAS
from .creation import empty #DEFINE_ALIAS
from .creation import empty_like #DEFINE_ALIAS
from .io import save #DEFINE_ALIAS
from .io import load #DEFINE_ALIAS
from .linalg import matmul #DEFINE_ALIAS
......
......@@ -49,6 +49,7 @@ __all__ = [
'full',
'full_like',
'empty',
'empty_like',
'triu',
'tril',
'meshgrid'
......@@ -1068,3 +1069,70 @@ def empty(shape, dtype=None, name=None):
stop_gradient=True)
out.stop_gradient = True
return out
def empty_like(x, dtype=None, name=None):
"""
This Op returns a Tensor with uninitialized data which has identical shape of ``x`` and ``dtype``.
If the ``dtype`` is None, the data type of Tensor is same with ``x``.
Args:
x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64.
dtype(np.dtype|str, optional): The data type of output. The data type can be one
of bool, float16, float32, float64, int32, int64. The default value is None, which means the output
data type is the same as input.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: Tensor which is created according to ``x`` and ``dtype``, and is uninitialized.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static() # Now we are in imperative mode
paddle.set_device("cpu") # and use cpu device
x = paddle.randn([2, 3], 'float32')
output = paddle.empty_like(x)
#[[1.8491974e+20 1.8037303e+28 1.7443726e+28] # uninitialized
# [4.9640171e+28 3.0186127e+32 5.6715899e-11]] # uninitialized
"""
if dtype is None:
dtype = x.dtype
dtype = convert_dtype(dtype)
if in_dygraph_mode():
out = core.ops.empty('shape', x.shape, 'dtype',
convert_np_dtype_to_dtype_(dtype))
out.stop_gradient = True
return out
helper = LayerHelper("empty_like", **locals())
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'empty_like')
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'empty_like')
out = helper.create_variable_for_type_inference(dtype=dtype)
inputs = {}
attrs = {}
attrs['dtype'] = convert_np_dtype_to_dtype_(dtype)
shape = paddle.shape(x)
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='empty_like')
helper.append_op(
type='empty',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs,
stop_gradient=True)
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册