diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index b9018ecdba8303fd6b37c87edd99e192aa604228..bc0da55cdae3575414e0d9b36471056c20946489 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -522,7 +522,7 @@ ParamGradInfoMap AppendBackward( new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, {{"shape", std::vector{1}}, {"value", static_cast(1.0)}, - {"data_type", target.GetDataType()}})); + {"dtype", target.GetDataType()}})); // infer var type of fill_one_op fill_one_op->InferVarType(root_block); diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index adedd8cb0e8504fd6fc924e62a2ede3c1c7ce698..2ffb5b7dbb27b561092856eac0de23d0c3788f75 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); - VLOG(10) << op->DebugString(); + VLOG(3) << op->DebugString(); op->Run(*local_scope, *device); } if (create_local_scope) { diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index bf3066983cdcf44ae84f236ac72486e5d4fd5b92..da76052eb4d3067214841af72a35cebb26477e7f 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -26,6 +26,8 @@ namespace framework { const std::string kFeedOpType = "feed"; const std::string kFetchOpType = "fetch"; +const std::string kDropOutOpType = "dropout"; +const std::string kBatchNormOpType = "batch_norm"; bool HasDependentVar(const OpDesc& op_desc, const std::set& dependent_vars) { @@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) { prune_impl(input, output, 0); } +void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output, + int block_id) { + *output = input; + auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); + for (auto& op_desc : *op_field) { + if (op_desc.type() == kDropOutOpType || + op_desc.type() == kBatchNormOpType) { + for (auto& attr : *op_desc.mutable_attrs()) { + if (attr.name() == "is_test") { + attr.set_b(true); + break; + } + } + } + } +} + +void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) { + inference_optimize_impl(input, output, 0); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/prune.h b/paddle/framework/prune.h index 8cfb16343aa44dcc8a3349b01adecce33f1c2b5b..23db014894348094a98e043aa744c6f0d27b2640 100644 --- a/paddle/framework/prune.h +++ b/paddle/framework/prune.h @@ -22,5 +22,7 @@ namespace framework { void Prune(const ProgramDesc& input, ProgramDesc* output); +void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc index 0947e33548130a923e998f8bad68db00097af909..6058f1b8b1f78181bc4a47e9e3d4135f5139980f 100644 --- a/paddle/framework/tensor_array.cc +++ b/paddle/framework/tensor_array.cc @@ -302,7 +302,7 @@ LoDTensor TensorArray::Stack() const { const auto& first_dims = values_.front().dims(); // check all the values have the same shape - // TODO(superjom) check the same dtypes + // TODO(superjom) check the same data_type for (size_t idx = 1; idx < size(); idx++) { const auto& value_dims = values_[idx].dims(); PADDLE_ENFORCE_EQ(first_dims, value_dims); diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index aed5275dbf9be707cc6e19e729133ba8eab58195..8841c14ee083fccfd2271efd0c331805919a09d9 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(detail) -cc_library(memory SRCS memory.cc DEPS place) +cc_library(memory SRCS memory.cc DEPS place enforce) cc_library(memcpy SRCS memcpy.cc) cc_library(paddle_memory diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc index 3904a97d58166cfeeb2be7d2144700dbd8bc5721..c796a0c5d089499e7858c7a427825fdbeb05cb7f 100644 --- a/paddle/operators/beam_search_decode_op.cc +++ b/paddle/operators/beam_search_decode_op.cc @@ -17,6 +17,36 @@ limitations under the License. */ namespace paddle { namespace operators { +struct BeamSearchDecodeFunctor { + BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, LoDTensor* score_tensor) + : step_ids_(step_ids), + step_scores_(step_scores), + id_tensor_(id_tensor), + score_tensor_(score_tensor) {} + + template + void operator()() const; + + const LoDTensorArray& step_ids_; + const LoDTensorArray& step_scores_; + LoDTensor* id_tensor_; + LoDTensor* score_tensor_; +}; + +template +void BeamSearchDecodeFunctor::operator()() const { + BeamSearchDecoder beam_search_decoder; + beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_, + score_tensor_); +} + +template <> +void BeamSearchDecodeFunctor::operator()() const { + PADDLE_THROW("beam search decode op does not support bool!"); +} + class BeamSearchDecodeOp : public framework::OperatorBase { public: BeamSearchDecodeOp(const std::string& type, @@ -45,9 +75,9 @@ class BeamSearchDecodeOp : public framework::OperatorBase { LoDTensor* sentenceIds = ctx.Output("SentenceIds"); LoDTensor* sentenceScores = ctx.Output("SentenceScores"); - BeamSearchDecoder beam_search_decoder; - beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds, - sentenceScores); + framework::VisitDataType( + framework::ToDataType(scores->at(0).type()), + BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores)); } }; diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc index c65ba7eb262f3aabe2c00837b79806c0b40b60fd..c88b2c9beb4497b617078c8ac5582d2f246f43fd 100644 --- a/paddle/operators/bilinear_tensor_product_op.cc +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -77,11 +77,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output of bilinear_tensor_product operator."); AddComment(R"DOC( Bilinear Tensor Product operator. -Given input X and Y, a 3D tensor weight, and bias. Each column of the -output is computed by one slice i = 1, . . . , k of the tensor: - - M = (X W_i) \cdot Y - Out_i = \sum_i {M_i} + Bias_i +Given input X and Y, a 3D tensor Weight and a Bias. Each column of the +Output is computed by one slice $i = 1, . . . , k$ of the tensor: + +$$ +M = (X W_i) * Y \\ +Out_i = \sum_j {M_j} + Bias_i +$$ + +Where $W_i$ is the $i$-th slice of Input(Weight); + $M_j$ is the $j$-th column of $M$; + $Out_i$ is the $i$-th column of Output(Out); + $Bias_i$ is a column vector, each element of it is equal to + the $i$-th element of $Bias$; )DOC"); } diff --git a/paddle/operators/cast_op.cc b/paddle/operators/cast_op.cc index 70ee7861bab3a982eae60dd85b10c2e41f5827d0..3082a53ccfbe4f8666cfdfc2efed6b46ffdfede9 100644 --- a/paddle/operators/cast_op.cc +++ b/paddle/operators/cast_op.cc @@ -25,8 +25,8 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of cast op"); AddOutput("Out", "The output tensor of cast op"); - AddAttr("out_data_type", "output data type"); - AddAttr("in_data_type", "input data type"); + AddAttr("out_dtype", "output data type"); + AddAttr("in_dtype", "input data type"); AddComment(R"DOC( Cast Operator. @@ -58,8 +58,8 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker { grad->SetType("cast"); grad->SetInput("X", OutputGrad("Out")); grad->SetOutput("Out", InputGrad("X")); - grad->SetAttr("out_data_type", GetAttr("in_data_type")); - grad->SetAttr("in_data_type", GetAttr("out_data_type")); + grad->SetAttr("out_dtype", GetAttr("in_dtype")); + grad->SetAttr("in_dtype", GetAttr("out_dtype")); return std::unique_ptr(grad); } }; diff --git a/paddle/operators/cast_op.h b/paddle/operators/cast_op.h index ffdbff7030afedab2efc06479ac86ad70c185f48..850dc8e3498351e54d41fcd2b6596c6fe668df14 100644 --- a/paddle/operators/cast_op.h +++ b/paddle/operators/cast_op.h @@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); framework::VisitDataType( - static_cast(context.Attr("out_data_type")), + static_cast(context.Attr("out_dtype")), CastOpFunctor(in, out, context.device_context())); } }; diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 818146aca766cb13b93fd024c11c1209655d9e11..932c0bf8fbf6ffdc466516bb7c8578abf0f57209 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Out", x_dims); - if (ctx->Attrs().Get("is_training") == true) { + if (ctx->Attrs().Get("is_test") == false) { ctx->SetOutputDim("Mask", x_dims); } ctx->ShareLoD("X", /*->*/ "Out"); @@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f); - AddAttr("is_training", "True if in training phase.").SetDefault(true); + AddAttr("is_test", "True if in test phase.").SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); AddComment(R"DOC( @@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_training"), true, - "GradOp is only callable when is_training is true"); + PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_test"), false, + "GradOp is only callable when is_test is false"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null."); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index 30c769000f2b98c69eaa78a4c139630dd0956386..db3578b9bf4c081e431f202f0828ec6392c924b2 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto place = context.GetEigenDevice(); - if (context.Attr("is_training")) { + if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); int size = framework::product(mask->dims()); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index 6000b75fecdff74844605215e9364ac8f8a1525a..d9a130fdc040f745b058c39221f0bb9661473388 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - if (context.Attr("is_training")) { + if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); int seed = context.Attr("seed"); @@ -65,8 +65,8 @@ template class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - PADDLE_ENFORCE(context.Attr("is_training"), - "GradOp is only callable when is_training is true"); + PADDLE_ENFORCE(!context.Attr("is_test"), + "GradOp is only callable when is_test is false"); auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc index 985b5d1e865e513d833bff72dcd20a8f20851d8c..892922cd3aaec8bf8194320c5c3a0dd0365bb589 100644 --- a/paddle/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/operators/fill_constant_batch_size_like_op.cc @@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("data_type")), + static_cast(ctx.Attr("dtype")), ctx.device_context()); } }; @@ -63,7 +63,7 @@ class FillConstantBatchSizeLikeOpMaker FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", + AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index 818f113b90a4c239a857791fb9957e51d3287b97..3d5f84bc239615797a5cf01a74150fdb7dfc1b80 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -34,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase { using framework::OperatorBase::OperatorBase; void Run(const framework::Scope &scope, const platform::DeviceContext &dev_ctx) const override { - auto data_type = static_cast(Attr("data_type")); + auto data_type = static_cast(Attr("dtype")); auto value = Attr("value"); auto force_cpu = Attr("force_cpu"); auto &out = @@ -55,7 +55,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { FillConstantOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("data_type", + AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 53ad86c6c48d1868f4495af51661d91b39a84f0b..254c83e1378a121d99c89d9d8705935b5f06edc8 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("data_type")), + static_cast(ctx.Attr("dtype")), ctx.device_context()); } }; @@ -88,7 +88,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { "Random seed of generator." "0 means use system wide seed.") .SetDefault(0); - AddAttr("data_type", + AddAttr("dtype", "(int, default 5(FP32)) " "Output data type.") .SetDefault(framework::DataType::FP32); diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index 066bdf67aa037e9c25cfdfaff7ec8771eb59cde8..8e079a14e0a15e8ff803b6087e6b0b02083479ef 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -32,19 +32,19 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { "[(D + 2) x D]. The learnable parameter for the linear_chain_crf " "operator. See more details in the operator's comments."); AddInput("Label", - "(LoDTensor, default LoDTensor) A LoDTensor with shape " + "(LoDTensor, default LoDTensor) A LoDTensor with shape " "[N x 1], where N is the total element number in a mini-batch. " "The ground truth."); AddOutput( "Alpha", "(Tensor, default Tensor) A 2-D Tensor with shape [N x D]. " - "The forward vectors for the entire batch. Denote it as \f$\alpha\f$. " - "\f$\alpha$\f is a memo table used to calculate the normalization " - "factor in CRF. \f$\alpha[k, v]$\f stores the unnormalized " + "The forward vectors for the entire batch. Denote it as $\alpha$. " + "$\alpha$ is a memo table used to calculate the normalization " + "factor in CRF. $\alpha[k, v]$ stores the unnormalized " "probabilites of all possible unfinished sequences of tags that end at " - "position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " - "\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for " - "each tag value \f$v$\f. This vector is called a forward vecotr and " + "position $k$ with tag $v$. For each $k$, " + "$\alpha[k, v]$ is a vector of length $D$ with a component for " + "each tag value $v$. This vector is called a forward vecotr and " "will also be used in backward computations.") .AsIntermediate(); AddOutput( @@ -73,9 +73,9 @@ LinearChainCRF Operator. Conditional Random Field defines an undirected probabilistic graph with nodes denoting random variables and edges denoting dependencies between these -variables. CRF learns the conditional probability \f$P(Y|X)\f$, where -\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and -\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs. +variables. CRF learns the conditional probability $P(Y|X)$, where +$X = (x_1, x_2, ... , x_n)$ are structured inputs and +$Y = (y_1, y_2, ... , y_n)$ are labels for the inputs. Linear chain CRF is a special case of CRF that is useful for sequence labeling task. Sequence labeling tasks do not assume a lot of conditional @@ -88,21 +88,22 @@ CRF. Please refer to http://www.cs.columbia.edu/~mcollins/fb.pdf and http://cseweb.ucsd.edu/~elkan/250Bwinter2012/loglinearCRFs.pdf for details. Equation: -1. Denote Input(Emission) to this operator as \f$x\f$ here. +1. Denote Input(Emission) to this operator as $x$ here. 2. The first D values of Input(Transition) to this operator are for starting -weights, denoted as \f$a\f$ here. +weights, denoted as $a$ here. 3. The next D values of Input(Transition) of this operator are for ending -weights, denoted as \f$b\f$ here. +weights, denoted as $b$ here. 4. The remaning values of Input(Transition) are for transition weights, -denoted as \f$w\f$ here. -5. Denote Input(Label) as \f$s\f$ here. - -The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as: -\f$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L} - + \sum_{l=1}^L x_{s_l} - + \sum_{l=2}^L w_{s_{l-1},s_l})\f$ -where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over -all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight +denoted as $w$ here. +5. Denote Input(Label) as $s$ here. + +The probability of a sequence $s$ of length $L$ is defined as: +$$P(s) = (1/Z) \exp(a_{s_1} + b_{s_L} + + \sum_{l=1}^L x_{s_l} + + \sum_{l=2}^L w_{s_{l-1},s_l})$$ + +where $Z$ is a normalization value so that the sum of $P(s)$ over +all possible sequences is 1, and $x$ is the emission feature weight to the linear chain CRF. Finally, the linear chain CRF operator outputs the logarithm of the conditional diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 66fcc09bc877867e66a37adc73230d8dabf4cbed..22a37ff1bbf6b8cfb2cbc3c3dbbb20a87c5ea4e7 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -49,7 +49,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Communicator", "Create Communicator for communicating between gpus"); AddAttr>("gpus", "(vector) GPU id lists"); - AddAttr("data_type", + AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 0075ccd24271bf83f139e121efad00c2316cc11b..ea60665e398978c612321df1f340e5669038e28f 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -401,7 +401,7 @@ class RecurrentGradOp : public RecurrentBase { auto &inside_tensor = cur_scope.FindVar(inside_grad_name) ->Get(); framework::AttributeMap attrs; - attrs["data_type"] = framework::ToDataType(inside_tensor.type()); + attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; diff --git a/paddle/operators/rnn_memory_helper_op.cc b/paddle/operators/rnn_memory_helper_op.cc index b621c7f1ba3f9e9613dea5bc98ef74c7c6dae9a0..3a035f0b9acb94bab60659938e11b4996b8eaa0f 100644 --- a/paddle/operators/rnn_memory_helper_op.cc +++ b/paddle/operators/rnn_memory_helper_op.cc @@ -62,7 +62,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", ""); AddOutput("Out", ""); - AddAttr("data_type", + AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); @@ -95,7 +95,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { auto &in_var_tensor = in_var->Get(); framework::AttributeMap attrs; - attrs["data_type"] = framework::ToDataType(in_var_tensor.type()); + attrs["dtype"] = framework::ToDataType(in_var_tensor.type()); attrs["shape"] = framework::vectorize2int(in_var_tensor.dims()); attrs["value"] = 0.0f; @@ -121,7 +121,7 @@ class RNNMemoryHelperGradOpInfoMaker AddInput("X", ""); AddInput("Out", ""); AddOutput(framework::GradVarName("X"), ""); - AddAttr("data_type", + AddAttr("dtype", "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::DataType::FP32); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 93f89e33a73c5f4c6c0e5a8793a0abe7c692b656..93e0525badc26808f0dca70cc1153ac728f1fe9c 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -59,7 +59,7 @@ Then the ratio of the exponential of the given dimension and the sum of exponential values of all the other dimensions is the output of the softmax operator. -For each row `i` and each column `j` in input X, we have: +For each row $i$ and each column $j$ in Input(X), we have: $$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ )DOC"); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 3dbb62d2e571eb92025c1b3fc0a6653c7cda007a..fc027d6f95cdbc24af59ef1188b6f16f6a93e85c 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -67,15 +67,15 @@ The equation is as follows: 1) Hard label (one-hot label, so every sample has exactly one class) -$$Loss_j = \f$ -\text{Logit}_{Label_j} + +$$Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), -j = 1, ..., K $\f$$ +j = 1,..., K$$ 2) Soft label (each sample can have a distribution over all classes) -$$Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i - +$$Loss_j = -\sum_{i=0}^{K}\text{Label}_i \left(\text{Logit}_i - \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), -j = 1,...,K $\f$$ +j = 1,...,K$$ )DOC"); } diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 7975efc7cf134aaf591385a6866254a9c5f2a0bb..fff1dc7ccddf1d8cee0c8311828fd38888283cd1 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -66,7 +66,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { framework::OpKernelType GetKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - static_cast(ctx.Attr("data_type")), + static_cast(ctx.Attr("dtype")), ctx.device_context()); } }; @@ -99,7 +99,7 @@ uniform distribution. "Random seed used for generating samples. " "0 means use a seed generated by the system.") .SetDefault(0); - AddAttr("data_type", "(int, default 5(FP32)) Output tensor data type") + AddAttr("dtype", "(int, default 5(FP32)) Output tensor data type") .SetDefault(framework::DataType::FP32); } }; diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index dcc59f5ff2ae3a8ca999d72a20cfd5c759987d89..68b4f7705995e5ecb6c9b8216db7373c1777a31e 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -180,7 +180,7 @@ class WhileGradOp : public framework::OperatorBase { if (var->IsType()) { auto &inside_tensor = var->Get(); framework::AttributeMap attrs; - attrs["data_type"] = framework::ToDataType(inside_tensor.type()); + attrs["dtype"] = framework::ToDataType(inside_tensor.type()); attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index bd86a9fe268c277065cd450f91b544def6c4d32f..88df28a9668e5f354d115ff8ab32cb21e03aefb5 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,15 +1,20 @@ -cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) +if(WITH_GPU) + cc_library(enforce SRCS enforce.cc DEPS nccl) +else() + cc_library(enforce SRCS enforce.cc) +endif() +cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece enforce) + +cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog enforce) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce) -cc_library(place SRCS place.cc) +cc_library(place SRCS place.cc DEPS enforce) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) -cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) - IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index bb3fec1be9e811c26cc6851314e960e96fc366b3..f4fda65907dc26e9edb91ee46f3b8bd2de7b3f3a 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) +cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc DEPS dynamic_loader nccl) diff --git a/paddle/platform/enforce.cc b/paddle/platform/enforce.cc new file mode 100644 index 0000000000000000000000000000000000000000..e8d31bc782ec3cddd18ceaedf88fe5e7b4aed2cc --- /dev/null +++ b/paddle/platform/enforce.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace platform {} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index bfe708748a62ff9ac5d151bc652142e1f4925c83..415020ab965fa976c37870b7ad5794aab947fb4e 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -49,7 +49,6 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace { #ifdef __GNUC__ inline std::string demangle(std::string name) { int status = -4; // some arbitrary value to eliminate the compiler warning @@ -60,7 +59,6 @@ inline std::string demangle(std::string name) { #else inline std::string demangle(std::string name) { return name; } #endif -} struct EnforceNotMet : public std::exception { std::exception_ptr exp_; diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5a1ff9b7976abbe4a37f8366181d9d1ae78ea4a0..6c8f06cccb92fa9cd22fdb89a9d410e6853895cc 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -202,9 +202,9 @@ void BindVarDsec(py::module &m) { }, py::return_value_policy::reference) .def("set_shape", &VarDescBind::SetShape) - .def("set_data_type", &VarDescBind::SetDataType) + .def("set_dtype", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", &VarDescBind::GetDataType) + .def("dtype", &VarDescBind::GetDataType) .def("lod_level", &VarDescBind::GetLodLevel) .def("set_lod_level", &VarDescBind::SetLoDLevel) .def("type", &VarDescBind::GetType) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3d8d3f1d2fd3977f945928c723db5fcafffeae85..e697739cc6814b24710f7caa173d200f2e5d823d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -293,6 +293,11 @@ All parameter, weight, gradient are variables in Paddle. Prune(*prog_with_targets.Proto(), &pruned_desc); return new ProgramDescBind(pruned_desc); }); + m.def("inference_optimize", [](ProgramDescBind &origin) { + ProgramDesc pruned_desc; + InferenceOptimize(*(origin.Proto()), &pruned_desc); + return new ProgramDescBind(pruned_desc); + }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 3a8f1831cf2c44c81aee62c6ee172942db188217..f78d2f814c89aa6b5ee8387f2558a97c754e655c 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -8,7 +8,7 @@ def _clone_var_in_block_(block, var): return block.create_var( name=var.name, shape=var.shape, - dtype=var.data_type, + dtype=var.dtype, type=var.type, lod_level=var.lod_level, persistable=True) @@ -33,6 +33,9 @@ class Evaluator(object): else: self._main_program = g_main_program + def states(self): + return self._states + def _update_ops(self, *args, **kwargs): """ append update ops to the global states @@ -57,7 +60,7 @@ class Evaluator(object): attrs={ "shape": g_var.shape, "value": .0, - "data_type": 5, + "dtype": 5, }) block.append_op( type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) @@ -93,7 +96,7 @@ class Accuracy(Evaluator): def _update_ops(self, input, label, k=1, **kwargs): block = self._main_program.global_block() - topk_out = block.create_var(dtype=input.data_type) + topk_out = block.create_var(dtype=input.dtype) topk_indices = block.create_var(dtype="int64") block.append_op( type="top_k", @@ -122,16 +125,16 @@ class Accuracy(Evaluator): inputs={"X": [self._states["Total"]]}, outputs={"Out": [self._states["Total"]]}, attrs={ - "in_data_type": 5, # float32 - "out_data_type": 2, #int32 + "in_dtype": 5, # float32 + "out_dtype": 2, # int32 }) block.append_op( type="cast", inputs={"X": [self._states["Correct"]]}, outputs={"Out": [self._states["Correct"]]}, attrs={ - "in_data_type": 5, - "out_data_type": 2, + "in_dtype": 5, + "out_dtype": 2, }) block.append_op( @@ -153,7 +156,7 @@ class Accuracy(Evaluator): else: eval_program = Program() block = eval_program.global_block() - eval_out = block.create_var(dtype=self._states["Total"].data_type) + eval_out = block.create_var(dtype=self._states["Total"].dtype) e_total = _clone_var_in_block_(block, self._states["Total"]) e_correct = _clone_var_in_block_(block, self._states["Correct"]) block.append_op( @@ -161,16 +164,16 @@ class Accuracy(Evaluator): inputs={"X": [e_total]}, outputs={"Out": [e_total]}, attrs={ - "in_data_type": 2, #int32 - "out_data_type": 5, #float32 + "in_dtype": 2, # int32 + "out_dtype": 5, # float32 }) block.append_op( type="cast", inputs={"X": [e_correct]}, outputs={"Out": [e_correct]}, attrs={ - "in_data_type": 2, - "out_data_type": 5, + "in_dtype": 2, + "out_dtype": 5, }) block.append_op( type="elementwise_div", diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 7f7c310ad87f64e5d047ecfc2876d516914c75c8..872c19c2f6f4afbd25a5f7a9df38bd3dd0b61d5f 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -99,9 +99,9 @@ class Variable(object): if not isinstance(dtype, core.DataType): dtype = convert_np_dtype_to_dtype_(dtype) if is_new_var: - self.desc.set_data_type(dtype) + self.desc.set_dtype(dtype) else: - old_dtype = self.data_type + old_dtype = self.dtype if dtype != old_dtype: raise ValueError("Variable {0} has been created before. " "The previous data type is {1}; the new " @@ -162,8 +162,8 @@ class Variable(object): return tuple(self.desc.shape()) @property - def data_type(self): - return self.desc.data_type() + def dtype(self): + return self.desc.dtype() @property def lod_level(self): @@ -511,6 +511,13 @@ class Program(object): res.sync_with_cpp() return res + def inference_optimize(self): + res = Program() + res.desc = core.inference_optimize(self.desc) + res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())] + res.sync_with_cpp() + return res + @staticmethod def parse_from_string(binary_str): p = Program() diff --git a/python/paddle/v2/fluid/initializer.py b/python/paddle/v2/fluid/initializer.py index 1a9d804ee7ee8e6463d42fefb809fb45888fd064..9f23e68a7635b6e6ae927603dbcc47d63f9c7f3d 100644 --- a/python/paddle/v2/fluid/initializer.py +++ b/python/paddle/v2/fluid/initializer.py @@ -93,7 +93,7 @@ class ConstantInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "value": self._value }) var.op = op @@ -140,7 +140,7 @@ class UniformInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "min": self._low, "max": self._high, "seed": self._seed @@ -188,7 +188,7 @@ class NormalInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "mean": self._mean, "std": self._std_dev, "seed": self._seed @@ -265,7 +265,7 @@ class XavierInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "min": -limit, "max": limit, "seed": self._seed @@ -278,7 +278,7 @@ class XavierInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "mean": 0.0, "std": std, "seed": self._seed @@ -348,7 +348,7 @@ class MSRAInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "min": -limit, "max": limit, "seed": self._seed @@ -361,7 +361,7 @@ class MSRAInitializer(Initializer): outputs={"Out": var}, attrs={ "shape": var.shape, - "data_type": int(var.data_type), + "dtype": int(var.dtype), "mean": 0.0, "std": std, "seed": self._seed diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 2d070814eef0b099ba71bef223596e30388ac48a..e5b2aa3b919df4cec1091c0bbd39b7e400cc6867 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -6,7 +6,8 @@ from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \ __all__ = [ 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', - 'load_persistables', "save_inference_model", "load_inference_model" + 'load_persistables', "save_inference_model", "load_inference_model", + "get_inference_program" ] @@ -23,7 +24,7 @@ def _clone_var_in_block_(block, var): return block.create_var( name=var.name, shape=var.shape, - dtype=var.data_type, + dtype=var.dtype, type=var.type, lod_level=var.lod_level, persistable=True) @@ -151,6 +152,17 @@ def load_persistables(executor, dirname, main_program=None): predicate=is_persistable) +def get_inference_program(target_vars, main_program=None): + if main_program is None: + main_program = g_main_program + if not isinstance(target_vars, list): + target_vars = [target_vars] + + pruned_program = main_program.prune(targets=target_vars) + inference_program = pruned_program.inference_optimize() + return inference_program + + def save_inference_model(dirname, feeded_var_names, target_vars, @@ -177,13 +189,14 @@ def save_inference_model(dirname, if not os.path.isdir(dirname): os.makedirs(dirname) - pruned_program = main_program.prune(target_vars) + pruned_program = main_program.prune(targets=target_vars) + inference_program = pruned_program.inference_optimize() fetch_var_names = [v.name for v in target_vars] model_file_name = dirname + "/__model__" with open(model_file_name, "w") as f: pickle.dump({ - "program_desc_str": pruned_program.desc.serialize_to_string(), + "program_desc_str": inference_program.desc.serialize_to_string(), "feed_var_names": feeded_var_names, "fetch_var_names": fetch_var_names }, f, -1) diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index e40551ca73e991edd8e1d1df5b103c36367b7050..e0880354fbc5a09bd49de7ec9c5dffc1e3c6259e 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -108,8 +108,8 @@ class LayerHelper(object): dtype = None for each in inputs: if dtype is None: - dtype = each.data_type - elif dtype != each.data_type: + dtype = each.dtype + elif dtype != each.dtype: raise ValueError("Data Type mismatch") return dtype @@ -149,7 +149,7 @@ class LayerHelper(object): self.startup_program.global_block().create_var( name=var.name, type=var.type, - dtype=var.data_type, + dtype=var.dtype, shape=var.shape, persistable=True, initializer=initializer) @@ -180,10 +180,10 @@ class LayerHelper(object): b = self.create_parameter( attr=bias_attr, shape=size, - dtype=input_var.data_type, + dtype=input_var.dtype, suffix='b', initializer=bias_initializer) - tmp = self.create_tmp_variable(dtype=input_var.data_type) + tmp = self.create_tmp_variable(dtype=input_var.dtype) self.append_op( type='elementwise_add', inputs={'X': [input_var], @@ -198,7 +198,7 @@ class LayerHelper(object): return input_var if isinstance(act, basestring): act = {'type': act} - tmp = self.create_tmp_variable(dtype=input_var.data_type) + tmp = self.create_tmp_variable(dtype=input_var.dtype) act_type = act.pop('type') self.append_op( type=act_type, diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index fac91aac97267b1ecc867bb9b0b1f8fd40f2f299..d094035fe5cae2e77fc2364e8ccb03c350f1301a 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -114,7 +114,7 @@ def embedding(input, is_sparse=False, param_initializer=None, param_attr=None, - data_type='float32', + dtype='float32', main_program=None, startup_program=None): """ @@ -125,7 +125,7 @@ def embedding(input, size: The size of the layer is_sparse: A flag that decleares whether the input is sparse param_attr: Parameters for this layer - data_type: The type of data : float32, float_16, int etc + dtype: The type of data : float32, float_16, int etc main_program: Name of the main program that calls this startup_program: Name of the startup program @@ -145,9 +145,9 @@ def embedding(input, w = helper.create_parameter( attr=helper.param_attr, shape=size, - dtype=data_type, + dtype=dtype, initializer=param_initializer or _get_default_param_initializer()) - tmp = helper.create_tmp_variable(data_type) + tmp = helper.create_tmp_variable(dtype) helper.append_op( type='lookup_table', inputs={'Ids': input, @@ -167,23 +167,23 @@ def dynamic_lstm(input, gate_activation='sigmoid', cell_activation='tanh', candidate_activation='tanh', - data_type='float32', + dtype='float32', main_program=None, startup_program=None): helper = LayerHelper('lstm', **locals()) size = size / 4 weight = helper.create_parameter( - attr=helper.param_attr, shape=[size, 4 * size], dtype=data_type) + attr=helper.param_attr, shape=[size, 4 * size], dtype=dtype) bias_size = [1, 7 * size] if not use_peepholes: bias_size[1] = 4 * size bias = helper.create_parameter( - attr=helper.bias_attr, shape=bias_size, dtype=data_type, suffix='b') + attr=helper.bias_attr, shape=bias_size, dtype=dtype, suffix='b') - hidden = helper.create_tmp_variable(data_type) - cell = helper.create_tmp_variable(data_type) - batch_gate = helper.create_tmp_variable(data_type) - batch_cell_pre_act = helper.create_tmp_variable(data_type) + hidden = helper.create_tmp_variable(dtype) + cell = helper.create_tmp_variable(dtype) + batch_gate = helper.create_tmp_variable(dtype) + batch_cell_pre_act = helper.create_tmp_variable(dtype) helper.append_op( type='lstm', @@ -209,7 +209,7 @@ def dynamic_lstm(input, def data(name, shape, append_batch_size=True, - data_type='float32', + dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR, main_program=None, startup_program=None, @@ -221,7 +221,7 @@ def data(name, name: The name/alias of the function shape: Tuple declaring the shape. append_batch_size: Whether or not to append the data as a batch. - data_type: The type of data : float32, float_16, int etc + dtype: The type of data : float32, float_16, int etc type: The output type. By default it is LOD_TENSOR. main_program: Name of the main program that calls this startup_program: Name of the startup program @@ -251,7 +251,7 @@ def data(name, return helper.create_global_variable( name=name, shape=shape, - dtype=data_type, + dtype=dtype, type=type, stop_gradient=stop_gradient) @@ -362,9 +362,9 @@ def _create_op_func_(op_type): o_name = not_intermediate_outputs[0].name intermediate_output_names = [output.name for output in intermediate_outputs] - def infer_and_check_data_type(op_proto, **kwargs): + def infer_and_check_dtype(op_proto, **kwargs): """ - This function performs the sanity check for data_type and + This function performs the sanity check for dtype and instance type. """ dtype = None @@ -379,8 +379,8 @@ def _create_op_func_(op_type): op_type)) if dtype is None: - dtype = each.data_type - elif dtype != each.data_type: + dtype = each.dtype + elif dtype != each.dtype: raise ValueError( "operator {0} must input same dtype".format(op_type)) @@ -389,7 +389,7 @@ def _create_op_func_(op_type): def func(**kwargs): helper = LayerHelper(op_type, **kwargs) - dtype = infer_and_check_data_type(op_proto, **kwargs) + dtype = infer_and_check_dtype(op_proto, **kwargs) inputs = dict() for ipt in op_proto.inputs: @@ -426,19 +426,19 @@ _create_op_func_('reshape') _create_op_func_('transpose') -def cast(x, data_type, main_program=None): +def cast(x, dtype, main_program=None): """ - This function takes in the input with input_data_type - and casts it to the output_data_type as the output. + This function takes in the input with input_dtype + and casts it to the output_dtype as the output. """ helper = LayerHelper('cast', **locals()) - out = helper.create_tmp_variable(dtype=data_type) + out = helper.create_tmp_variable(dtype=dtype) helper.append_op( type='cast', inputs={'X': [x]}, outputs={'Out': [out]}, - attrs={'in_data_type': x.data_type, - 'out_data_type': out.data_type}) + attrs={'in_dtype': x.dtype, + 'out_dtype': out.dtype}) return out @@ -519,8 +519,8 @@ def split_lod_tensor(input, main_program=None, startup_program=None): helper = LayerHelper('split_lod_tensor', **locals()) - out_true = helper.create_tmp_variable(dtype=input.data_type) - out_false = helper.create_tmp_variable(dtype=input.data_type) + out_true = helper.create_tmp_variable(dtype=input.dtype) + out_false = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='split_lod_tensor', inputs={ @@ -541,7 +541,7 @@ def merge_lod_tensor(in_true, main_program=None, startup_program=None): helper = LayerHelper('merge_lod_tensor', **locals()) - out = helper.create_tmp_variable(dtype=in_true.data_type) + out = helper.create_tmp_variable(dtype=in_true.dtype) helper.append_op( type='merge_lod_tensor', inputs={'X': x, @@ -559,9 +559,9 @@ def cos_sim(X, Y, **kwargs): X and Y and returns that as the output. """ helper = LayerHelper('cos_sim', **kwargs) - out = helper.create_tmp_variable(dtype=X.data_type) - xnorm = helper.create_tmp_variable(dtype=X.data_type) - ynorm = helper.create_tmp_variable(dtype=X.data_type) + out = helper.create_tmp_variable(dtype=X.dtype) + xnorm = helper.create_tmp_variable(dtype=X.dtype) + ynorm = helper.create_tmp_variable(dtype=X.dtype) helper.append_op( type='cos_sim', inputs={'X': [X], @@ -577,7 +577,7 @@ def cross_entropy(input, label, **kwargs): This function computes cross_entropy using the input and label. """ helper = LayerHelper('cross_entropy', **kwargs) - out = helper.create_tmp_variable(dtype=input.data_type) + out = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='cross_entropy', inputs={'X': [input], @@ -593,14 +593,14 @@ def square_error_cost(input, label, **kwargs): The output is appending the op to do the above. """ helper = LayerHelper('square_error_cost', **kwargs) - minus_out = helper.create_tmp_variable(dtype=input.data_type) + minus_out = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='elementwise_sub', inputs={'X': [input], 'Y': [label]}, outputs={'Out': [minus_out]}) - square_out = helper.create_tmp_variable(dtype=input.data_type) + square_out = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]}) return square_out @@ -612,7 +612,7 @@ def accuracy(input, label, k=1, **kwargs): The output is the top_k inputs and their indices. """ helper = LayerHelper("accuracy", **kwargs) - topk_out = helper.create_tmp_variable(dtype=input.data_type) + topk_out = helper.create_tmp_variable(dtype=input.dtype) topk_indices = helper.create_tmp_variable(dtype="int64") helper.append_op( type="top_k", @@ -883,12 +883,12 @@ def batch_norm(input, initializer=ConstantInitializer(0.0)) mean = helper.create_global_variable( - dtype=input.data_type, shape=param_shape, persistable=True) + dtype=input.dtype, shape=param_shape, persistable=True) helper.set_variable_initializer( var=mean, initializer=ConstantInitializer(0.0)) variance = helper.create_global_variable( - dtype=input.data_type, shape=param_shape, persistable=True) + dtype=input.dtype, shape=param_shape, persistable=True) helper.set_variable_initializer( var=variance, initializer=ConstantInitializer(1.0)) @@ -927,8 +927,8 @@ def batch_norm(input, def beam_search_decode(ids, scores, main_program=None, startup_program=None): helper = LayerHelper('beam_search_decode', **locals()) - sentence_ids = helper.create_tmp_variable(dtype=ids.data_type) - sentence_scores = helper.create_tmp_variable(dtype=ids.data_type) + sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) + sentence_scores = helper.create_tmp_variable(dtype=ids.dtype) helper.append_op( type="beam_search_decode", @@ -1066,7 +1066,7 @@ class StaticRNN(object): boot_var = parent_block.create_var( name=var_name, shape=shape, - dtype=batch_ref.data_type, + dtype=batch_ref.dtype, persistable=False) parent_block.append_op( @@ -1076,7 +1076,7 @@ class StaticRNN(object): attrs={ 'value': init_value, 'shape': boot_var.shape, - 'data_type': boot_var.data_type, + 'dtype': boot_var.dtype, 'input_dim_idx': ref_batch_dim_idx, 'output_dim_idx': init_batch_dim_idx }) @@ -1085,7 +1085,7 @@ class StaticRNN(object): else: pre_mem = self.helper.create_variable( name=unique_name("@".join([self.helper.name, "mem"])), - dtype=init.data_type, + dtype=init.dtype, shape=init.shape) self.memories[pre_mem.name] = StaticRNNMemoryLink( init=init, pre_mem=pre_mem) @@ -1101,10 +1101,7 @@ class StaticRNN(object): raise ValueError("Static RNN only take fix seq_len input") ipt = self.helper.create_variable( - name=x.name, - dtype=x.data_type, - shape=list(x.shape[1:]), - type=x.type) + name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type) self.inputs.append(ipt) return ipt @@ -1113,17 +1110,17 @@ class StaticRNN(object): if not isinstance(o, Variable): raise TypeError("step output takes a Variable") - tmp_o = self.helper.create_tmp_variable(dtype=o.data_type) + tmp_o = self.helper.create_tmp_variable(dtype=o.dtype) self.helper.append_op( type='rnn_memory_helper', inputs={'X': [o]}, outputs={'Out': tmp_o}, - attrs={'data_type': o.data_type}) + attrs={'dtype': o.dtype}) out_var = self.parent_block().create_var( name=tmp_o.name, shape=[self.seq_len] + list(tmp_o.shape), - dtype=tmp_o.data_type) + dtype=tmp_o.dtype) self.outputs.append(out_var) @@ -1195,13 +1192,13 @@ class StaticRNN(object): pre_memories.append(mem.pre_mem.name) mem_var = rnn_block.var(mem.mem.name) assert isinstance(mem_var, Variable) - new_mem = self.helper.create_tmp_variable(dtype=mem_var.data_type) + new_mem = self.helper.create_tmp_variable(dtype=mem_var.dtype) rnn_block.append_op( type='rnn_memory_helper', inputs={'X': [mem_var]}, outputs={'Out': [new_mem]}, - attrs={'data_type': mem_var.data_type}) + attrs={'dtype': mem_var.dtype}) memories.append(new_mem.name) @@ -1251,7 +1248,7 @@ class While(object): if not isinstance(cond, Variable): raise TypeError("condition should be a variable") assert isinstance(cond, Variable) - if cond.data_type != core.DataType.BOOL: + if cond.dtype != core.DataType.BOOL: raise TypeError("condition should be a bool variable") if reduce(lambda a, b: a * b, cond.shape, 1) != 1: raise TypeError("condition should be a bool scalar") @@ -1323,9 +1320,9 @@ def lstm(x, main_program=main_program, startup_program=startup_program) - data_type = x.data_type - c = helper.create_tmp_variable(data_type) - h = helper.create_tmp_variable(data_type) + dtype = x.dtype + c = helper.create_tmp_variable(dtype) + h = helper.create_tmp_variable(dtype) helper.append_op( type='lstm_unit', @@ -1367,7 +1364,7 @@ def lod_tensor_to_array(x, table, main_program=None): array = helper.create_variable( name=unique_name("lod_tensor_to_array"), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, - dtype=x.data_type) + dtype=x.dtype) helper.append_op( type='lod_tensor_to_array', inputs={'X': x, @@ -1382,7 +1379,7 @@ def array_to_lod_tensor(x, table, main_program=None): LOD_Tensor. """ helper = LayerHelper("array_to_lod_tensor", **locals()) - tmp = helper.create_tmp_variable(dtype=x.data_type) + tmp = helper.create_tmp_variable(dtype=x.dtype) helper.append_op( type="array_to_lod_tensor", inputs={'X': x, @@ -1394,7 +1391,7 @@ def array_to_lod_tensor(x, table, main_program=None): def fill_constant(shape, dtype, value, main_program=None, startup_program=None): """ This function creates a tensor , with shape as mentioned in the input and - specified data_type and fills this up with a constant value that + specified dtype and fills this up with a constant value that comes in the input. It also sets the stop_gradient to be True. """ helper = LayerHelper("fill_constant", **locals()) @@ -1403,11 +1400,9 @@ def fill_constant(shape, dtype, value, main_program=None, startup_program=None): type='fill_constant', inputs={}, outputs={'Out': [out]}, - attrs={ - 'shape': shape, - 'data_type': out.data_type, - 'value': float(value) - }) + attrs={'shape': shape, + 'dtype': out.dtype, + 'value': float(value)}) out.stop_gradient = True return out @@ -1428,7 +1423,7 @@ def fill_constant_batch_size_like(input, outputs={'Out': [out]}, attrs={ 'shape': shape, - 'data_type': out.data_type, + 'dtype': out.dtype, 'value': float(value), 'input_dim_idx': input_dim_idx, 'output_dim_idx': output_dim_idx @@ -1461,7 +1456,7 @@ def increment(x, value=1.0, in_place=True, main_program=None): """ helper = LayerHelper("increment", **locals()) if not in_place: - out = helper.create_tmp_variable(dtype=x.data_type) + out = helper.create_tmp_variable(dtype=x.dtype) else: out = x helper.append_op( @@ -1482,7 +1477,7 @@ def array_write(x, i, array=None, main_program=None): array = helper.create_variable( name="{0}.out".format(helper.name), type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, - dtype=x.data_type) + dtype=x.dtype) helper.append_op( type='write_to_array', inputs={'X': [x], @@ -1521,7 +1516,7 @@ def array_read(array, i, main_program=None): array, Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: raise TypeError("array should be tensor array vairable") - out = helper.create_tmp_variable(dtype=array.data_type) + out = helper.create_tmp_variable(dtype=array.dtype) helper.append_op( type='read_from_array', inputs={'X': [array], @@ -1536,7 +1531,7 @@ def shrink_memory(x, i, table, main_program=None): as mentioned in the input parameter. """ helper = LayerHelper('shrink_memory', **locals()) - out = helper.create_tmp_variable(dtype=x.data_type) + out = helper.create_tmp_variable(dtype=x.dtype) helper.append_op( type='shrink_rnn_memory', inputs={'X': [x], @@ -1698,11 +1693,11 @@ class IfElse(object): parent_block = self.parent_block() out_true = parent_block.create_var( name=unique_name('ifelse_input' + self.helper.name), - dtype=x.data_type) + dtype=x.dtype) out_false = parent_block.create_var( name=unique_name('ifelse_input' + self.helper.name), - dtype=x.data_type) + dtype=x.dtype) parent_block.append_op( type='split_lod_tensor', inputs={ @@ -1744,7 +1739,7 @@ class IfElse(object): # create outside tensor outside_out = parent_block.create_var( name=unique_name("_".join([self.helper.name, 'output'])), - dtype=each_out.data_type) + dtype=each_out.dtype) out_table.append(outside_out) # assign local var to outside diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 87a478c2903b77d955ebde49a4a0e507c9e9ffd3..e82f0f060de6af63f63d5601ae94059192076e6f 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -92,7 +92,7 @@ class Optimizer(object): var = self.helper.create_global_variable( name=unique_name(name), persistable=True, - dtype=dtype or param.data_type, + dtype=dtype or param.dtype, type=param.type, shape=param.shape) self.helper.set_variable_initializer( @@ -202,7 +202,7 @@ class Optimizer(object): """ params_grads = append_backward_ops(loss, parameter_list, no_grad_set or set()) - # Add regularization if any + # Add regularization if any params_grads = append_regularization_ops(params_grads) optimize_ops = self.create_optimization_pass(params_grads, loss, startup_program) diff --git a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py index a7f3bfc0caf76302674a00c80c2bd9ebf834f872..a899f1088d77c4ca6462cf5306393444ea114e6c 100644 --- a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py @@ -7,11 +7,11 @@ from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.io import save_persistables, load_persistables from paddle.v2.fluid.optimizer import SGDOptimizer -x = layers.data(name='x', shape=[13], data_type='float32') +x = layers.data(name='x', shape=[13], dtype='float32') y_predict = layers.fc(input=x, size=1, act=None) -y = layers.data(name='y', shape=[1], data_type='float32') +y = layers.data(name='y', shape=[1], dtype='float32') cost = layers.square_error_cost(input=y_predict, label=y) avg_cost = layers.mean(x=cost) diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index efe63a68f0745eb728b569a03d0344877c1484f7..76cbd410f94a4be04ba71d1e3175eaed590ac80a 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -5,6 +5,7 @@ import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers import paddle.v2.fluid.nets as nets import paddle.v2.fluid.evaluator as evaluator +from paddle.v2.fluid.io import get_inference_program from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.initializer import XavierInitializer from paddle.v2.fluid.optimizer import AdamOptimizer @@ -90,8 +91,8 @@ def vgg16_bn_drop(input): classdim = 10 data_shape = [3, 32, 32] -images = layers.data(name='pixel', shape=data_shape, data_type='float32') -label = layers.data(name='label', shape=[1], data_type='int64') +images = layers.data(name='pixel', shape=data_shape, dtype='float32') +label = layers.data(name='label', shape=[1], dtype='int64') # Add neural network config # option 1. resnet @@ -116,9 +117,11 @@ PASS_NUM = 1 train_reader = paddle.batch( paddle.reader.shuffle( - paddle.dataset.cifar.train10(), buf_size=128 * 10), + paddle.dataset.cifar.train10(), buf_size=BATCH_SIZE * 10), batch_size=BATCH_SIZE) +test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) + place = core.CPUPlace() exe = Executor(place) @@ -149,10 +152,41 @@ for pass_id in range(PASS_NUM): loss = np.array(outs[0]) acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) + + batch_id = batch_id + 1 + + test_accuracy, test_acc_out = evaluator.accuracy( + input=predict, label=label) + + test_target = [avg_cost, test_acc_out] + test_accuracy.states().values() + inference_program = get_inference_program(test_target) + + test_accuracy.reset(exe) + + for data in test_reader(): + x_data = np.array(map(lambda x: x[0].reshape(data_shape), + data)).astype("float32") + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = np.expand_dims(y_data, axis=1) + + tensor_x = core.LoDTensor() + tensor_x.set(x_data, place) + + tensor_y = core.LoDTensor() + tensor_y.set(y_data, place) + + outs = exe.run(inference_program, + feed={'pixel': tensor_x, + 'label': tensor_y}, + fetch_list=[avg_cost, test_acc_out]) + out = np.array(outs[0]) + acc = np.array(outs[1]) + + test_pass_acc = test_accuracy.eval(exe) + print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) + " loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( - pass_acc)) - batch_id = batch_id + 1 + pass_acc) + " test_pass_acc:" + str(test_pass_acc)) if batch_id > 1: # this model is slow, so if we can train two mini batch, we think it works properly. diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index f66e6e748b76dec53a9e24b5b352d31395ce6bde..9c9064ba9639829ef3afd8111278b17035bee84a 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -34,26 +34,26 @@ def load_parameter(file_name, h, w): def db_lstm(): # 8 features - word = layers.data(name='word_data', shape=[1], data_type='int64') - predicate = layers.data(name='verb_data', shape=[1], data_type='int64') - ctx_n2 = layers.data(name='ctx_n2_data', shape=[1], data_type='int64') - ctx_n1 = layers.data(name='ctx_n1_data', shape=[1], data_type='int64') - ctx_0 = layers.data(name='ctx_0_data', shape=[1], data_type='int64') - ctx_p1 = layers.data(name='ctx_p1_data', shape=[1], data_type='int64') - ctx_p2 = layers.data(name='ctx_p2_data', shape=[1], data_type='int64') - mark = layers.data(name='mark_data', shape=[1], data_type='int64') + word = layers.data(name='word_data', shape=[1], dtype='int64') + predicate = layers.data(name='verb_data', shape=[1], dtype='int64') + ctx_n2 = layers.data(name='ctx_n2_data', shape=[1], dtype='int64') + ctx_n1 = layers.data(name='ctx_n1_data', shape=[1], dtype='int64') + ctx_0 = layers.data(name='ctx_0_data', shape=[1], dtype='int64') + ctx_p1 = layers.data(name='ctx_p1_data', shape=[1], dtype='int64') + ctx_p2 = layers.data(name='ctx_p2_data', shape=[1], dtype='int64') + mark = layers.data(name='mark_data', shape=[1], dtype='int64') predicate_embedding = layers.embedding( input=predicate, size=[pred_len, word_dim], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'vemb'}) mark_embedding = layers.embedding( input=mark, size=[mark_dict_len, mark_dim], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE) word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] @@ -125,7 +125,7 @@ def to_lodtensor(data, place): def main(): # define network topology feature_out = db_lstm() - target = layers.data(name='target', shape=[1], data_type='int64') + target = layers.data(name='target', shape=[1], dtype='int64') crf_cost = layers.linear_chain_crf( input=feature_out, label=target, diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index 8f737689609fec4d1819ae58b9665298547a3716..0bea5f95c895b278db86f25f54e2795d3ec0af69 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -8,8 +8,8 @@ import paddle.v2.fluid.nets as nets from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.optimizer import AdamOptimizer -images = layers.data(name='pixel', shape=[1, 28, 28], data_type='float32') -label = layers.data(name='label', shape=[1], data_type='int64') +images = layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') +label = layers.data(name='label', shape=[1], dtype='int64') conv_pool_1 = nets.simple_img_conv_pool( input=images, filter_size=5, diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index e42e4c9cc0024e193b0732df6d9ca3200df5f0b9..f57a5c8d98cd8b89e1d300b4d1fe00d6b24b0d68 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -4,13 +4,14 @@ import paddle.v2.fluid.core as core import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers import paddle.v2.fluid.evaluator as evaluator +from paddle.v2.fluid.io import get_inference_program from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.initializer import UniformInitializer from paddle.v2.fluid.optimizer import MomentumOptimizer from paddle.v2.fluid.regularizer import L2DecayRegularizer BATCH_SIZE = 128 -image = layers.data(name='x', shape=[784], data_type='float32') +image = layers.data(name='x', shape=[784], dtype='float32') param_attr = { 'name': None, @@ -27,7 +28,7 @@ predict = layers.fc(input=hidden2, act='softmax', param_attr=param_attr) -label = layers.data(name='y', shape=[1], data_type='int64') +label = layers.data(name='y', shape=[1], dtype='int64') cost = layers.cross_entropy(input=predict, label=label) avg_cost = layers.mean(x=cost) @@ -42,6 +43,8 @@ train_reader = paddle.batch( paddle.dataset.mnist.train(), buf_size=8192), batch_size=BATCH_SIZE) +test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + place = core.CPUPlace() exe = Executor(place) @@ -69,8 +72,36 @@ for pass_id in range(PASS_NUM): acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) - if pass_acc > 0.7: + test_accuracy, test_acc_out = evaluator.accuracy( + input=predict, label=label) + + test_target = [avg_cost, test_acc_out] + test_accuracy.states().values() + inference_program = get_inference_program(test_target) + + test_accuracy.reset(exe) + for data in test_reader(): + x_data = np.array(map(lambda x: x[0], data)).astype("float32") + y_data = np.array(map(lambda x: x[1], data)).astype("int64") + y_data = np.expand_dims(y_data, axis=1) + + tensor_x = core.LoDTensor() + tensor_x.set(x_data, place) + + tensor_y = core.LoDTensor() + tensor_y.set(y_data, place) + + outs = exe.run(inference_program, + feed={'x': tensor_x, + 'y': tensor_y}, + fetch_list=[avg_cost, test_acc_out]) + out = np.array(outs[0]) + acc = np.array(outs[1]) + + test_pass_acc = test_accuracy.eval(exe) + print("pass_id=" + str(pass_id) + " train_cost=" + str( + out) + " train_acc=" + str(acc) + " train_pass_acc=" + str(pass_acc) + + " test_acc=" + str(test_pass_acc)) + + if test_pass_acc > 0.7: exit(0) - # print("pass_id=" + str(pass_id) + " auc=" + - # str(acc) + " pass_acc=" + str(pass_acc)) exit(1) diff --git a/python/paddle/v2/fluid/tests/book/test_recommender_system.py b/python/paddle/v2/fluid/tests/book/test_recommender_system.py index 55ded3aed3a23c8cd7795f915dc1cbd512c6d945..f8dc1518579d5a9d7a8d0498dcc5fd8a6d1692c4 100644 --- a/python/paddle/v2/fluid/tests/book/test_recommender_system.py +++ b/python/paddle/v2/fluid/tests/book/test_recommender_system.py @@ -18,11 +18,11 @@ def get_usr_combined_features(): USR_DICT_SIZE = paddle.dataset.movielens.max_user_id() + 1 - uid = layers.data(name='user_id', shape=[1], data_type='int64') + uid = layers.data(name='user_id', shape=[1], dtype='int64') usr_emb = layers.embedding( input=uid, - data_type='float32', + dtype='float32', size=[USR_DICT_SIZE, 32], param_attr={'name': 'user_table'}, is_sparse=IS_SPARSE) @@ -31,7 +31,7 @@ def get_usr_combined_features(): USR_GENDER_DICT_SIZE = 2 - usr_gender_id = layers.data(name='gender_id', shape=[1], data_type='int64') + usr_gender_id = layers.data(name='gender_id', shape=[1], dtype='int64') usr_gender_emb = layers.embedding( input=usr_gender_id, @@ -42,7 +42,7 @@ def get_usr_combined_features(): usr_gender_fc = layers.fc(input=usr_gender_emb, size=16) USR_AGE_DICT_SIZE = len(paddle.dataset.movielens.age_table) - usr_age_id = layers.data(name='age_id', shape=[1], data_type="int64") + usr_age_id = layers.data(name='age_id', shape=[1], dtype="int64") usr_age_emb = layers.embedding( input=usr_age_id, @@ -53,7 +53,7 @@ def get_usr_combined_features(): usr_age_fc = layers.fc(input=usr_age_emb, size=16) USR_JOB_DICT_SIZE = paddle.dataset.movielens.max_job_id() + 1 - usr_job_id = layers.data(name='job_id', shape=[1], data_type="int64") + usr_job_id = layers.data(name='job_id', shape=[1], dtype="int64") usr_job_emb = layers.embedding( input=usr_job_id, @@ -75,11 +75,11 @@ def get_mov_combined_features(): MOV_DICT_SIZE = paddle.dataset.movielens.max_movie_id() + 1 - mov_id = layers.data(name='movie_id', shape=[1], data_type='int64') + mov_id = layers.data(name='movie_id', shape=[1], dtype='int64') mov_emb = layers.embedding( input=mov_id, - data_type='float32', + dtype='float32', size=[MOV_DICT_SIZE, 32], param_attr={'name': 'movie_table'}, is_sparse=IS_SPARSE) @@ -88,7 +88,7 @@ def get_mov_combined_features(): CATEGORY_DICT_SIZE = len(paddle.dataset.movielens.movie_categories()) - category_id = layers.data(name='category_id', shape=[1], data_type='int64') + category_id = layers.data(name='category_id', shape=[1], dtype='int64') mov_categories_emb = layers.embedding( input=category_id, size=[CATEGORY_DICT_SIZE, 32], is_sparse=IS_SPARSE) @@ -98,7 +98,7 @@ def get_mov_combined_features(): MOV_TITLE_DICT_SIZE = len(paddle.dataset.movielens.get_movie_title_dict()) - mov_title_id = layers.data(name='movie_title', shape=[1], data_type='int64') + mov_title_id = layers.data(name='movie_title', shape=[1], dtype='int64') mov_title_emb = layers.embedding( input=mov_title_id, size=[MOV_TITLE_DICT_SIZE, 32], is_sparse=IS_SPARSE) @@ -126,7 +126,7 @@ def model(): # need cos sim inference = layers.cos_sim(X=usr_combined_features, Y=mov_combined_features) - label = layers.data(name='score', shape=[1], data_type='float32') + label = layers.data(name='score', shape=[1], dtype='float32') square_cost = layers.square_error_cost(input=inference, label=label) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py index 4929f7cf615e61de5c4f61ef44c5340e9ac4492a..3103be83a63d64fcba87132ddc5d830b92047b27 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py @@ -10,8 +10,8 @@ from paddle.v2.fluid.optimizer import AdamOptimizer def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32): - data = layers.data(name="words", shape=[1], data_type="int64") - label = layers.data(name="label", shape=[1], data_type="int64") + data = layers.data(name="words", shape=[1], dtype="int64") + label = layers.data(name="label", shape=[1], dtype="int64") emb = layers.embedding(input=data, size=[input_dim, emb_dim]) conv_3 = nets.sequence_conv_pool( diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py index b3ee91938865afb929670a388a561b156aec1fe9..208978224f4e83a23efadae37fbe51d0d59dafe8 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py @@ -14,8 +14,8 @@ def stacked_lstm_net(input_dim, hid_dim=512, stacked_num=3): assert stacked_num % 2 == 1 - data = layers.data(name="words", shape=[1], data_type="int64") - label = layers.data(name="label", shape=[1], data_type="int64") + data = layers.data(name="words", shape=[1], dtype="int64") + label = layers.data(name="label", shape=[1], dtype="int64") emb = layers.embedding(input=data, size=[input_dim, emb_dim]) # add bias attr diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py index 9a51a2f207ebed340b8e5c60e7ebeb82a611dbc5..8aebeba653cf49438929fa51312b5af33c3b438d 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py @@ -12,19 +12,19 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50): name="words", shape=[seq_len * batch_size, 1], append_batch_size=False, - data_type="int64") + dtype="int64") label = layers.data( name="label", shape=[batch_size, 1], append_batch_size=False, - data_type="int64") + dtype="int64") emb = layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) emb = layers.transpose(x=emb, axis=[1, 0, 2]) c_pre_init = layers.fill_constant( - dtype=emb.data_type, shape=[batch_size, emb_dim], value=0.0) + dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0) layer_1_out = layers.lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) layer_1_out = layers.transpose(x=layer_1_out, axis=[1, 0, 2]) diff --git a/python/paddle/v2/fluid/tests/book/test_word2vec.py b/python/paddle/v2/fluid/tests/book/test_word2vec.py index afa7b285198e0349317e123e4bd98e8336217afa..0629e1cab7fd7e501d9cbf3ae8ee22fe9383ad2b 100644 --- a/python/paddle/v2/fluid/tests/book/test_word2vec.py +++ b/python/paddle/v2/fluid/tests/book/test_word2vec.py @@ -16,34 +16,34 @@ IS_SPARSE = True word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) -first_word = layers.data(name='firstw', shape=[1], data_type='int64') -second_word = layers.data(name='secondw', shape=[1], data_type='int64') -third_word = layers.data(name='thirdw', shape=[1], data_type='int64') -forth_word = layers.data(name='forthw', shape=[1], data_type='int64') -next_word = layers.data(name='nextw', shape=[1], data_type='int64') +first_word = layers.data(name='firstw', shape=[1], dtype='int64') +second_word = layers.data(name='secondw', shape=[1], dtype='int64') +third_word = layers.data(name='thirdw', shape=[1], dtype='int64') +forth_word = layers.data(name='forthw', shape=[1], dtype='int64') +next_word = layers.data(name='nextw', shape=[1], dtype='int64') embed_first = layers.embedding( input=first_word, size=[dict_size, EMBED_SIZE], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) embed_second = layers.embedding( input=second_word, size=[dict_size, EMBED_SIZE], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) embed_third = layers.embedding( input=third_word, size=[dict_size, EMBED_SIZE], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) embed_forth = layers.embedding( input=forth_word, size=[dict_size, EMBED_SIZE], - data_type='float32', + dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) diff --git a/python/paddle/v2/fluid/tests/op_test.py b/python/paddle/v2/fluid/tests/op_test.py index 90269e308a31d2606b23d741ce0d0fa91a0a6aeb..51023bd19a8326152335eabc9e96600427527f26 100644 --- a/python/paddle/v2/fluid/tests/op_test.py +++ b/python/paddle/v2/fluid/tests/op_test.py @@ -458,7 +458,7 @@ class OpTest(unittest.TestCase): mean_inputs = map(block.var, output_names) if len(mean_inputs) == 1: - loss = block.create_var(dtype=mean_inputs[0].data_type, shape=[1]) + loss = block.create_var(dtype=mean_inputs[0].dtype, shape=[1]) op = block.append_op( inputs={"X": mean_inputs}, outputs={"Out": loss}, type='mean') op.desc.infer_var_type(block.desc) @@ -466,8 +466,7 @@ class OpTest(unittest.TestCase): else: avg_sum = [] for cur_loss in mean_inputs: - cur_avg_loss = block.create_var( - dtype=cur_loss.data_type, shape=[1]) + cur_avg_loss = block.create_var(dtype=cur_loss.dtype, shape=[1]) op = block.append_op( inputs={"X": [cur_loss]}, outputs={"Out": [cur_avg_loss]}, @@ -476,13 +475,13 @@ class OpTest(unittest.TestCase): op.desc.infer_shape(block.desc) avg_sum.append(cur_avg_loss) - loss_sum = block.create_var(dtype=avg_sum[0].data_type, shape=[1]) + loss_sum = block.create_var(dtype=avg_sum[0].dtype, shape=[1]) op_sum = block.append_op( inputs={"X": avg_sum}, outputs={"Out": loss_sum}, type='sum') op_sum.desc.infer_var_type(block.desc) op_sum.desc.infer_shape(block.desc) - loss = block.create_var(dtype=loss_sum.data_type, shape=[1]) + loss = block.create_var(dtype=loss_sum.dtype, shape=[1]) op_loss = block.append_op( inputs={"X": loss_sum}, outputs={"Out": loss}, diff --git a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py index 8a11820d2aba2dd4d17d925f0e0fe9f324100418..5fad7d8cce5af3677aa77dc0abb64f1ecd380419 100644 --- a/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py +++ b/python/paddle/v2/fluid/tests/test_beam_search_decode_op.py @@ -35,15 +35,15 @@ class TestBeamSearchDecodeOp(unittest.TestCase): self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], np.array( - [1, 2, 3, 4, 5, 6], dtype="float32")) + [1, 2, 3, 4, 5, 6], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 1, 1, 3, 5, 5, 6]], np.array( - [0, 1, 2, 3, 4, 5], dtype="float32")) + [0, 1, 2, 3, 4, 5], dtype="float64")) self.append_lod_tensor( scores, [[0, 3, 6], [0, 0, 1, 2, 3, 4, 5]], np.array( - [0, 1, 2, 3, 4], dtype="float32")) + [0, 1, 2, 3, 4], dtype="float64")) sentence_ids = self.scope.var("sentence_ids").get_tensor() sentence_scores = self.scope.var("sentence_scores").get_tensor() diff --git a/python/paddle/v2/fluid/tests/test_cast_op.py b/python/paddle/v2/fluid/tests/test_cast_op.py index 0c4b6310652e84d3dd7f281a8b98ae0435072afb..4e431bb88da6070718d64a68467be20ca87f8fb9 100644 --- a/python/paddle/v2/fluid/tests/test_cast_op.py +++ b/python/paddle/v2/fluid/tests/test_cast_op.py @@ -10,8 +10,8 @@ class TestCastOp(op_test.OpTest): self.inputs = {'X': ipt.astype('float32')} self.outputs = {'Out': ipt.astype('float64')} self.attrs = { - 'in_data_type': int(core.DataType.FP32), - 'out_data_type': int(core.DataType.FP64) + 'in_dtype': int(core.DataType.FP32), + 'out_dtype': int(core.DataType.FP64) } self.op_type = 'cast' diff --git a/python/paddle/v2/fluid/tests/test_conditional_block.py b/python/paddle/v2/fluid/tests/test_conditional_block.py index 293803f004a1513611fba30634d5552e1da84fef..2a30fd107968ce0fa188bda44e731ad760dce1f5 100644 --- a/python/paddle/v2/fluid/tests/test_conditional_block.py +++ b/python/paddle/v2/fluid/tests/test_conditional_block.py @@ -9,7 +9,7 @@ import numpy class ConditionalBlock(unittest.TestCase): def test_forward(self): - data = layers.data(name='X', shape=[1], data_type='float32') + data = layers.data(name='X', shape=[1], dtype='float32') data.stop_gradient = False cond = layers.ConditionalBlock(inputs=[data]) out = layers.create_tensor(dtype='float32') diff --git a/python/paddle/v2/fluid/tests/test_dropout_op.py b/python/paddle/v2/fluid/tests/test_dropout_op.py index b14a366fcad7f4bf6968b6013c6cfbb57090071d..4f5ea836b44102e5599a2302efd669291ebe920b 100644 --- a/python/paddle/v2/fluid/tests/test_dropout_op.py +++ b/python/paddle/v2/fluid/tests/test_dropout_op.py @@ -7,7 +7,7 @@ class TestDropoutOp(OpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'is_training': True} + self.attrs = {'dropout_prob': 0.0, 'is_test': False} self.outputs = { 'Out': self.inputs['X'], 'Mask': np.ones((32, 64)).astype('float32') @@ -24,7 +24,7 @@ class TestDropoutOp2(TestDropoutOp): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 1.0, 'is_training': True} + self.attrs = {'dropout_prob': 1.0, 'is_test': False} self.outputs = { 'Out': np.zeros((32, 64)).astype('float32'), 'Mask': np.zeros((32, 64)).astype('float32') @@ -35,7 +35,7 @@ class TestDropoutOp3(TestDropoutOp): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = {'dropout_prob': 0.0, 'is_training': True} + self.attrs = {'dropout_prob': 0.0, 'is_test': False} self.outputs = { 'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2)).astype('float32') @@ -46,7 +46,7 @@ class TestDropoutOp4(OpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 0.35, 'is_training': False} + self.attrs = {'dropout_prob': 0.35, 'is_test': True} self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} def test_check_output(self): @@ -57,7 +57,7 @@ class TestDropoutOp5(OpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} - self.attrs = {'dropout_prob': 0.75, 'is_training': False} + self.attrs = {'dropout_prob': 0.75, 'is_test': True} self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} def test_check_output(self): diff --git a/python/paddle/v2/fluid/tests/test_executor_and_mul.py b/python/paddle/v2/fluid/tests/test_executor_and_mul.py index 709250d0c86dde84ac22c37d8e2385ca4a80a40a..da64739de5eb4eca8db8ac8370276c41692a7242 100644 --- a/python/paddle/v2/fluid/tests/test_executor_and_mul.py +++ b/python/paddle/v2/fluid/tests/test_executor_and_mul.py @@ -8,11 +8,11 @@ import numpy class TestExecutor(unittest.TestCase): def test_mul(self): - a = data(name='a', shape=[784], data_type='float32') + a = data(name='a', shape=[784], dtype='float32') b = data( name='b', shape=[784, 100], - data_type='float32', + dtype='float32', append_batch_size=False) out = mul(x=a, y=b) place = core.CPUPlace() diff --git a/python/paddle/v2/fluid/tests/test_image_classification_layer.py b/python/paddle/v2/fluid/tests/test_image_classification_layer.py index bf5444107fa1609e67b09823b82e5fb92234b0a4..8e8e1b0a8c07a60cb1404462f976d10fe26e87f6 100644 --- a/python/paddle/v2/fluid/tests/test_image_classification_layer.py +++ b/python/paddle/v2/fluid/tests/test_image_classification_layer.py @@ -32,7 +32,7 @@ class TestLayer(unittest.TestCase): images = layers.data( name='pixel', shape=[3, 48, 48], - data_type='float32', + dtype='float32', main_program=main_program) layers.batch_norm( input=images, @@ -47,7 +47,7 @@ class TestLayer(unittest.TestCase): images = layers.data( name='pixel', shape=[3, 48, 48], - data_type='float32', + dtype='float32', main_program=main_program) layers.dropout( x=images, @@ -64,7 +64,7 @@ class TestLayer(unittest.TestCase): images = layers.data( name='pixel', shape=[3, 48, 48], - data_type='float32', + dtype='float32', main_program=main_program, startup_program=startup_program) conv1 = conv_block(images, 64, 2, [0.3, 0], main_program, @@ -80,13 +80,13 @@ class TestLayer(unittest.TestCase): image1 = layers.data( name='pixel1', shape=[3, 48, 48], - data_type='float32', + dtype='float32', main_program=main_program, startup_program=startup_program) image2 = layers.data( name='pixel2', shape=[3, 48, 48], - data_type='float32', + dtype='float32', main_program=main_program, startup_program=startup_program) out = layers.elementwise_add( diff --git a/python/paddle/v2/fluid/tests/test_inference_model_io.py b/python/paddle/v2/fluid/tests/test_inference_model_io.py index 98b95713b73e8eba93bd6a58eaaed603cfae7952..74f1ce23262bbc969f9544885a7390534c76cdf6 100644 --- a/python/paddle/v2/fluid/tests/test_inference_model_io.py +++ b/python/paddle/v2/fluid/tests/test_inference_model_io.py @@ -19,13 +19,13 @@ class TestBook(unittest.TestCase): x = layers.data( name='x', shape=[2], - data_type='float32', + dtype='float32', main_program=program, startup_program=init_program) y = layers.data( name='y', shape=[1], - data_type='float32', + dtype='float32', main_program=program, startup_program=init_program) diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index f88e0b4e15f7115be21ef136cbb96ce96af9d99e..87dc6d1a6270e0f8425b56601d04049450c73380 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -9,11 +9,11 @@ class TestBook(unittest.TestCase): def test_fit_a_line(self): program = Program() x = layers.data( - name='x', shape=[13], data_type='float32', main_program=program) + name='x', shape=[13], dtype='float32', main_program=program) y_predict = layers.fc(input=x, size=1, act=None, main_program=program) y = layers.data( - name='y', shape=[1], data_type='float32', main_program=program) + name='y', shape=[1], dtype='float32', main_program=program) cost = layers.square_error_cost( input=y_predict, label=y, main_program=program) @@ -28,12 +28,9 @@ class TestBook(unittest.TestCase): # Change g_program, so the rest layers use `g_program` images = layers.data( - name='pixel', - shape=[784], - data_type='float32', - main_program=program) + name='pixel', shape=[784], dtype='float32', main_program=program) label = layers.data( - name='label', shape=[1], data_type='int32', main_program=program) + name='label', shape=[1], dtype='int32', main_program=program) hidden1 = layers.fc(input=images, size=128, act='relu', @@ -58,7 +55,7 @@ class TestBook(unittest.TestCase): images = layers.data( name='pixel', shape=[3, 48, 48], - data_type='int32', + dtype='int32', main_program=program) layers.conv2d( input=images, @@ -74,10 +71,10 @@ class TestBook(unittest.TestCase): images = layers.data( name='pixel', shape=[1, 28, 28], - data_type='float32', + dtype='float32', main_program=program) label = layers.data( - name='label', shape=[1], data_type='int32', main_program=program) + name='label', shape=[1], dtype='int32', main_program=program) conv_pool_1 = nets.simple_img_conv_pool( input=images, filter_size=5, @@ -112,39 +109,39 @@ class TestBook(unittest.TestCase): dict_size = 10000 embed_size = 32 first_word = layers.data( - name='firstw', shape=[1], data_type='int64', main_program=program) + name='firstw', shape=[1], dtype='int64', main_program=program) second_word = layers.data( - name='secondw', shape=[1], data_type='int64', main_program=program) + name='secondw', shape=[1], dtype='int64', main_program=program) third_word = layers.data( - name='thirdw', shape=[1], data_type='int64', main_program=program) + name='thirdw', shape=[1], dtype='int64', main_program=program) forth_word = layers.data( - name='forthw', shape=[1], data_type='int64', main_program=program) + name='forthw', shape=[1], dtype='int64', main_program=program) next_word = layers.data( - name='nextw', shape=[1], data_type='int64', main_program=program) + name='nextw', shape=[1], dtype='int64', main_program=program) embed_first = layers.embedding( input=first_word, size=[dict_size, embed_size], - data_type='float32', + dtype='float32', param_attr={'name': 'shared_w'}, main_program=program) embed_second = layers.embedding( input=second_word, size=[dict_size, embed_size], - data_type='float32', + dtype='float32', param_attr={'name': 'shared_w'}, main_program=program) embed_third = layers.embedding( input=third_word, size=[dict_size, embed_size], - data_type='float32', + dtype='float32', param_attr={'name': 'shared_w'}, main_program=program) embed_forth = layers.embedding( input=forth_word, size=[dict_size, embed_size], - data_type='float32', + dtype='float32', param_attr={'name': 'shared_w'}, main_program=program) @@ -173,12 +170,9 @@ class TestBook(unittest.TestCase): # Change g_program, so the rest layers use `g_program` images = layers.data( - name='pixel', - shape=[784], - data_type='float32', - main_program=program) + name='pixel', shape=[784], dtype='float32', main_program=program) label = layers.data( - name='label', shape=[1], data_type='int32', main_program=program) + name='label', shape=[1], dtype='int32', main_program=program) hidden = layers.fc(input=images, size=128, main_program=program) crf = layers.linear_chain_crf( input=hidden, label=label, main_program=program) diff --git a/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py index b18cb6b49fa41f26e1b6de1128690507c5a2f099..16e64b8cd52d72a3bbc84e43d772b843dad0129a 100644 --- a/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py +++ b/python/paddle/v2/fluid/tests/test_lod_tensor_array_ops.py @@ -132,7 +132,7 @@ class TestCPULoDTensorArrayOpGrad(unittest.TestCase): x = layers.data( name='x', shape=[1], - data_type='float32', + dtype='float32', main_program=program, stop_gradient=False) table = layers.lod_rank_table(x, level=0, main_program=program) diff --git a/python/paddle/v2/fluid/tests/test_mnist_if_else_op.py b/python/paddle/v2/fluid/tests/test_mnist_if_else_op.py index 8af99005dc0b5d50de60ca89c2ddf870b1537edb..e76357a5be07d79eafee4c3a27911efe8a3eaef4 100644 --- a/python/paddle/v2/fluid/tests/test_mnist_if_else_op.py +++ b/python/paddle/v2/fluid/tests/test_mnist_if_else_op.py @@ -11,10 +11,9 @@ import numpy as np class TestMNISTIfElseOp(unittest.TestCase): def test_raw_api(self): kwargs = {'startup_program': Program(), 'main_program': Program()} - image = layers.data( - name='x', shape=[784], data_type='float32', **kwargs) + image = layers.data(name='x', shape=[784], dtype='float32', **kwargs) - label = layers.data(name='y', shape=[1], data_type='int64', **kwargs) + label = layers.data(name='y', shape=[1], dtype='int64', **kwargs) limit = layers.fill_constant_batch_size_like( input=label, dtype='int64', shape=[1], value=5.0, **kwargs) @@ -84,10 +83,9 @@ class TestMNISTIfElseOp(unittest.TestCase): def test_ifelse(self): kwargs = {'startup_program': Program(), 'main_program': Program()} - image = layers.data( - name='x', shape=[784], data_type='float32', **kwargs) + image = layers.data(name='x', shape=[784], dtype='float32', **kwargs) - label = layers.data(name='y', shape=[1], data_type='int64', **kwargs) + label = layers.data(name='y', shape=[1], dtype='int64', **kwargs) limit = layers.fill_constant_batch_size_like( input=label, dtype='int64', shape=[1], value=5.0, **kwargs) diff --git a/python/paddle/v2/fluid/tests/test_parameter.py b/python/paddle/v2/fluid/tests/test_parameter.py index a633d22c2b1db2728b6eb767078ce4aec6cce163..d467e4bbb79b291c442c643158ef6c0d630920dd 100644 --- a/python/paddle/v2/fluid/tests/test_parameter.py +++ b/python/paddle/v2/fluid/tests/test_parameter.py @@ -20,7 +20,7 @@ class TestParameter(unittest.TestCase): self.assertIsNotNone(param) self.assertEqual('fc.w', param.name) self.assertEqual((784, 100), param.shape) - self.assertEqual(core.DataType.FP32, param.data_type) + self.assertEqual(core.DataType.FP32, param.dtype) self.assertEqual(0, param.block.idx) exe = Executor(core.CPUPlace()) p = exe.run(g_main_program, fetch_list=[param])[0] diff --git a/python/paddle/v2/fluid/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py index 098a9802dfc6763ce2a2356b7267a439145b7939..d8abe17606c4ddb2ff51d5f918b1e5d7e110f7fa 100644 --- a/python/paddle/v2/fluid/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -101,13 +101,13 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(src_shape, res_shape) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) - def test_data_type(self): + def test_dtype(self): program_desc = core.ProgramDesc() block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.LOD_TENSOR) - var.set_data_type(core.DataType.INT32) - self.assertEqual(core.DataType.INT32, var.data_type()) + var.set_dtype(core.DataType.INT32) + self.assertEqual(core.DataType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) diff --git a/python/paddle/v2/fluid/tests/test_recurrent_op.py b/python/paddle/v2/fluid/tests/test_recurrent_op.py index b623d1231838faff9e91c9234befb1f647fe8ec2..88bcdc3e6a21881ace2be53c22a62d78df30a974 100644 --- a/python/paddle/v2/fluid/tests/test_recurrent_op.py +++ b/python/paddle/v2/fluid/tests/test_recurrent_op.py @@ -118,14 +118,14 @@ class RecurrentOpTest1(unittest.TestCase): def create_rnn_op(self): x = layers.data( shape=[self.sent_len, self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='x', append_batch_size=False, **self.p_info) x.stop_gradient = False h_boot = layers.data( shape=[self.input_dim], - data_type='float32', + dtype='float32', name='h_boot', **self.p_info) h_boot.stop_gradient = False @@ -251,14 +251,14 @@ class RecurrentOpTest2(RecurrentOpTest1): def create_rnn_op(self): x = layers.data( shape=[self.sent_len, self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='x', append_batch_size=False, **self.p_info) x.stop_gradient = False h_boot = layers.data( shape=[self.input_dim], - data_type='float32', + dtype='float32', name='h_boot', **self.p_info) h_boot.stop_gradient = False @@ -350,21 +350,21 @@ class RecurrentOpMultipleMemoryTest(RecurrentOpTest1): def create_rnn_op(self): x = layers.data( shape=[self.sent_len, self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='x', append_batch_size=False, **self.p_info) x.stop_gradient = False h_boot1 = layers.data( shape=[self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='h_boot1', append_batch_size=False, **self.p_info) h_boot1.stop_gradient = False h_boot2 = layers.data( shape=[self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='h_boot2', append_batch_size=False, **self.p_info) @@ -435,7 +435,7 @@ class RecurrentOpNoMemBootTest(RecurrentOpTest1): def create_rnn_op(self): x = layers.data( shape=[self.sent_len, self.batch_size, self.input_dim], - data_type='float32', + dtype='float32', name='x', append_batch_size=False, **self.p_info) diff --git a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py index 1a3b88e18e38b88d75ad17a0bb6a2965d1e60406..953629d610e183cdddf97081f94a77951fe979d8 100644 --- a/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py +++ b/python/paddle/v2/fluid/tests/test_shrink_rnn_memory.py @@ -9,7 +9,7 @@ import numpy class TestShrinkRNNMemory(unittest.TestCase): def test_shrink_rnn_memory(self): - x = layers.data('x', shape=[100], data_type='float32') + x = layers.data('x', shape=[100], dtype='float32') x.stop_gradient = False table = layers.lod_rank_table(x=x) i = layers.zeros(dtype='int64', shape=[1]) diff --git a/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py index 3aed83b2ea3418c54f9540279ae6e2e0045421fa..a98cb3bbab8442886206b59a2b591fee96deeb9f 100644 --- a/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py +++ b/python/paddle/v2/fluid/tests/test_split_and_merge_lod_tensor_op.py @@ -123,13 +123,13 @@ class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): x = layers.data( name='x', shape=[1], - data_type='float32', + dtype='float32', main_program=program, stop_gradient=False) y = layers.data( name='y', shape=[1], - data_type='bool', + dtype='bool', main_program=program, stop_gradient=False) diff --git a/python/paddle/v2/fluid/tests/test_variable.py b/python/paddle/v2/fluid/tests/test_variable.py index c3e1f9ac0a70e7448fd8d1983b1c04d27af9771c..92ffdceb6c84fb2669f8c1bb556c46fb1c03c411 100644 --- a/python/paddle/v2/fluid/tests/test_variable.py +++ b/python/paddle/v2/fluid/tests/test_variable.py @@ -22,13 +22,13 @@ class TestVariable(unittest.TestCase): w = b.create_var( dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") self.assertNotEqual(str(w), "") - self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual(core.DataType.FP64, w.dtype) self.assertEqual((784, 100), w.shape) self.assertEqual("fc.w", w.name) self.assertEqual(0, w.lod_level) w = b.create_var(name='fc.w') - self.assertEqual(core.DataType.FP64, w.data_type) + self.assertEqual(core.DataType.FP64, w.dtype) self.assertEqual((784, 100), w.shape) self.assertEqual("fc.w", w.name) self.assertEqual(0, w.lod_level) diff --git a/python/paddle/v2/fluid/tests/test_while_op.py b/python/paddle/v2/fluid/tests/test_while_op.py index 84b432333f950f754a97bc1a051b59c16fb22aed..fca0cdcc319ff661ced33b6bcd242c894941576c 100644 --- a/python/paddle/v2/fluid/tests/test_while_op.py +++ b/python/paddle/v2/fluid/tests/test_while_op.py @@ -9,11 +9,11 @@ import numpy class TestWhileOp(unittest.TestCase): def test_simple_forward(self): d0 = layers.data( - "d0", shape=[10], append_batch_size=False, data_type='float32') + "d0", shape=[10], append_batch_size=False, dtype='float32') d1 = layers.data( - "d1", shape=[10], append_batch_size=False, data_type='float32') + "d1", shape=[10], append_batch_size=False, dtype='float32') d2 = layers.data( - "d2", shape=[10], append_batch_size=False, data_type='float32') + "d2", shape=[10], append_batch_size=False, dtype='float32') i = layers.zeros(shape=[1], dtype='int64') i.stop_gradient = True init = layers.zeros(shape=[10], dtype='float32')