diff --git a/doc/fluid/design/motivation/fluid.md b/doc/fluid/design/motivation/fluid.md index 5e147f8263e685a4665b5793f7127178cbc3cfdd..4b7696cc1bbf57ace72c4d31ffc2bfe6c1071939 100644 --- a/doc/fluid/design/motivation/fluid.md +++ b/doc/fluid/design/motivation/fluid.md @@ -119,7 +119,7 @@ An actual Fluid example is described [here](https://github.com/PaddlePaddle/Pad From the example, the Fluid programs look very similar to their PyTorch equivalent programs, except that Fluid's loop structure, wrapped with Python's `with` statement, could run much faster than just a Python loop. -We have more examples of the [`if-then-else`](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/if_else_op.md) structure of Fluid. +We have more examples of the [`if-then-else`](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/execution/if_else_op.md) structure of Fluid. ## Turing Completeness diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index eecb58e11ef57b550c79c040e6933ed6e52e2e87..cb1927bc0f2eb735f0a3184df5f0f8fada2f9dca 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -114,23 +114,11 @@ class BatchNormKernel const auto *bias = ctx.Input("Bias"); auto *y = ctx.Output("Y"); - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_variance = ctx.Output("SavedVariance"); // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> - functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -159,6 +147,21 @@ class BatchNormKernel // Run training mode. // obtain running mean and running inv var, and see if we need to // initialize them. + + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); + math::SetConstant> + functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); + double this_factor = 1. - momentum; CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index e53183603fec54ceef68873cfd97b4b985b0d437..c28047e6e915280eed6886f99cd6d55704e3f4ad 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -288,9 +288,14 @@ void batched_gemm( // TODO(kexinzhao): add processing code for compute capability < 53 case PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, "cublas Hgemm requires GPU compute capability >= 53"); + +#if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); +#else + PADDLE_ENFORCE(false, "HgemmStridedBatched is not supported on cuda <= 7.5"); +#endif } template <> @@ -310,9 +315,13 @@ void batched_gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; const int strideC = M * N; +#if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasSgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); +#else + PADDLE_ENFORCE(false, "SgemmStridedBatched is not supported on cuda <= 7.5"); +#endif } template <> @@ -332,9 +341,13 @@ void batched_gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; const int strideC = M * N; +#if CUDA_VERSION >= 8000 PADDLE_ENFORCE(platform::dynload::cublasDgemmStridedBatched( context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, strideB, A, lda, strideA, &beta, C, ldc, strideC, batchCount)); +#else + PADDLE_ENFORCE(false, "DgemmStridedBatched is not supported on cuda <= 7.5"); +#endif } template <> diff --git a/paddle/fluid/operators/pad_op.h b/paddle/fluid/operators/pad_op.h index a36abe3789574cb64f05001e34d534cf352a60b2..c93c096575a30dd9344894ead4b81acc16930e21 100644 --- a/paddle/fluid/operators/pad_op.h +++ b/paddle/fluid/operators/pad_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index c88578570c1acdecaa97dd8b12a702778fef2b7e..63eaaedcd5fc3df17902511dc02b25bf43ccd241 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -83,9 +83,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); auto src_memory = - mkldnn::memory({src_md, mkldnn_engine}, (void*)input_data); + mkldnn::memory({src_md, mkldnn_engine}, + static_cast(const_cast(input_data))); auto dst_memory = - mkldnn::memory({dst_md, mkldnn_engine}, (void*)output_data); + mkldnn::memory({dst_md, mkldnn_engine}, + static_cast(const_cast(output_data))); auto pool_prim = mkldnn::pooling_forward(*pool_pd, src_memory, dst_memory, *workspace_memory); @@ -195,9 +197,11 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { pool_bwd_desc, mkldnn_engine, *pool_pd); auto diff_src_memory = - mkldnn::memory({diff_src_md, mkldnn_engine}, (void*)in_x_grad_data); + mkldnn::memory({diff_src_md, mkldnn_engine}, + static_cast(const_cast(in_x_grad_data))); auto diff_dst_memory = - mkldnn::memory({diff_dst_md, mkldnn_engine}, (void*)out_grad_data); + mkldnn::memory({diff_dst_md, mkldnn_engine}, + static_cast(const_cast(out_grad_data))); auto bwd_prim = mkldnn::pooling_backward( pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory); diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 2fec50ef25e0d2621a87963acdf142d24970329d..a48127ea6983d3d4ea12ec4925f30af233002ef2 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/pool_with_index_op.h b/paddle/fluid/operators/pool_with_index_op.h index 83e7bd138ae25c6d3e09c3d01178d6887205bf98..b55fa76eae34c3179d40f31ed6a57d3ecbbaaccf 100644 --- a/paddle/fluid/operators/pool_with_index_op.h +++ b/paddle/fluid/operators/pool_with_index_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 7fb45bd19da3a7f0c51d8e98a52efe62c15c1c55..8eaa12a4a6cfc09fd4e2c3642bc8825fe2af6d6b 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/prelu_op.h" - #include namespace paddle { diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index 82e54139c8c1f42b1d8f74811a6793ec5c66473e..058b13eeb872aaa77a88da37db64a6d59fbdd1cf 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -45,7 +45,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { bool flip = ctx->Attrs().Get("flip"); std::vector aspect_ratios_vec; - ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec); + ExpandAspectRatios(aspect_ratios, flip, &aspect_ratios_vec); size_t num_priors = aspect_ratios_vec.size() * min_sizes.size(); if (max_sizes.size() > 0) { diff --git a/paddle/fluid/operators/prior_box_op.cu b/paddle/fluid/operators/prior_box_op.cu index 76bf2b3b7de7a24c80e927c16199f89c5b7fb794..0ea8909296f8f52d252b0ec258666cf32d69a8bb 100644 --- a/paddle/fluid/operators/prior_box_op.cu +++ b/paddle/fluid/operators/prior_box_op.cu @@ -96,7 +96,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel { auto clip = ctx.Attr("clip"); std::vector aspect_ratios; - ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); T step_w = static_cast(ctx.Attr("step_w")); T step_h = static_cast(ctx.Attr("step_h")); diff --git a/paddle/fluid/operators/prior_box_op.h b/paddle/fluid/operators/prior_box_op.h index 1e4a12aac1c5f1c3b7e2e1bc83170de9ad590fc3..1c62fd8d2c4d4e4deba4ca6442efbaff83e36c35 100644 --- a/paddle/fluid/operators/prior_box_op.h +++ b/paddle/fluid/operators/prior_box_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/transform.h" @@ -22,23 +24,23 @@ namespace operators { inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, bool flip, - std::vector& output_aspect_ratior) { + std::vector* output_aspect_ratior) { constexpr float epsilon = 1e-6; - output_aspect_ratior.clear(); - output_aspect_ratior.push_back(1.0f); + output_aspect_ratior->clear(); + output_aspect_ratior->push_back(1.0f); for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { float ar = input_aspect_ratior[i]; bool already_exist = false; - for (size_t j = 0; j < output_aspect_ratior.size(); ++j) { - if (fabs(ar - output_aspect_ratior[j]) < epsilon) { + for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { + if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { already_exist = true; break; } } if (!already_exist) { - output_aspect_ratior.push_back(ar); + output_aspect_ratior->push_back(ar); if (flip) { - output_aspect_ratior.push_back(1.0f / ar); + output_aspect_ratior->push_back(1.0f / ar); } } } @@ -68,7 +70,7 @@ class PriorBoxOpKernel : public framework::OpKernel { auto clip = ctx.Attr("clip"); std::vector aspect_ratios; - ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); T step_w = static_cast(ctx.Attr("step_w")); T step_h = static_cast(ctx.Attr("step_h")); diff --git a/paddle/fluid/operators/rank_loss_op.cc b/paddle/fluid/operators/rank_loss_op.cc index 767eef56861ea075ec2450b1456e7c5c807ce25d..a1127f11a75e54168ca9682a0189255d37ee8571 100644 --- a/paddle/fluid/operators/rank_loss_op.cc +++ b/paddle/fluid/operators/rank_loss_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/rank_loss_op.h" +#include namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index 083c1fae5e2016ada6309aba78bdfa6ad7fef89c..a4dcf704a63ae3bad6567ddb042ea23513bccff7 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include // NOLINT #include #include "paddle/fluid/framework/data_type.h" @@ -19,7 +20,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#include #include "paddle/fluid/operators/detail/grpc_client.h" namespace paddle { diff --git a/paddle/fluid/operators/roi_pool_op.h b/paddle/fluid/operators/roi_pool_op.h index f38c5a3c0c9952b37f7db468ea00470a00b5ff6f..54e07490319cf1da749bd33449a7b51efd6c3d65 100644 --- a/paddle/fluid/operators/roi_pool_op.h +++ b/paddle/fluid/operators/roi_pool_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index 22c1db82e9f5aff6aa9a311cd1093b33fa7e6db9..7a10218e1556698f3e0a1828db5de8851dd1c90b 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -37,8 +37,8 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, const framework::DDim& src_stride, const framework::DDim& dst_dim, const framework::DDim& dst_stride, T* dst) { - using namespace detail; - StridedCopyDimVisitor func(dev_ctx, src, src_stride, dst_stride, dst); + paddle::operators::detail::StridedCopyDimVisitor func( + dev_ctx, src, src_stride, dst_stride, dst); boost::apply_visitor(func, dst_dim); } diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index bfd26c2f2294f954adc81a1719650c46372098c4..d7f4d383ce0d9e1ff42fc12c96aaf0ceb532e5db 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/assert.h" namespace paddle { @@ -133,71 +134,71 @@ __device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col, } template -__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam, +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, int beam_size, const T* src, - bool& firstStep, bool& is_empty, - Pair& max, int dim, + bool* firstStep, bool* is_empty, + Pair* max, int dim, const int tid) { - if (beam > 0) { - int length = beam < beam_size ? beam : beam_size; - if (firstStep) { - firstStep = false; + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; GetTopK(topk, src, tid, dim, length); } else { for (int k = 0; k < MaxLength; k++) { - if (k < MaxLength - beam) { - topk[k] = topk[k + beam]; + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; } else { topk[k].set(-INFINITY, -1); } } - if (!is_empty) { - GetTopK(topk + MaxLength - beam, src, tid, dim, max, + if (!(*is_empty)) { + GetTopK(topk + MaxLength - *beam, src, tid, dim, *max, length); } } - max = topk[MaxLength - 1]; - if (max.v == -1) is_empty = true; - beam = 0; + *max = topk[MaxLength - 1]; + if ((*max).v == -1) *is_empty = true; + *beam = 0; } } template -__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam, +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, int beam_size, const T* val, - int* col, bool& firstStep, - bool& is_empty, Pair& max, + int* col, bool* firstStep, + bool* is_empty, Pair* max, int dim, const int tid) { - if (beam > 0) { - int length = beam < beam_size ? beam : beam_size; - if (firstStep) { - firstStep = false; + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; GetTopK(topk, val, col, tid, dim, length); } else { for (int k = 0; k < MaxLength; k++) { - if (k < MaxLength - beam) { - topk[k] = topk[k + beam]; + if (k < MaxLength - *beam) { + topk[k] = topk[k + *beam]; } else { topk[k].set(-INFINITY, -1); } } - if (!is_empty) { - GetTopK(topk + MaxLength - beam, val, col, tid, dim, max, + if (!(*is_empty)) { + GetTopK(topk + MaxLength - *beam, val, col, tid, dim, max, length); } } - max = topk[MaxLength - 1]; - if (max.v == -1) is_empty = true; - beam = 0; + *max = topk[MaxLength - 1]; + if ((*max).v == -1) *is_empty = true; + *beam = 0; } } template __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, Pair topk[], T** topVal, - int64_t** topIds, int& beam, int& k, + int64_t** topIds, int* beam, int* k, const int tid, const int warp) { while (true) { __syncthreads(); @@ -225,17 +226,17 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, (*topVal)++; (*topIds)++; } - if (tid == maxid[0]) beam++; - if (--k == 0) break; + if (tid == maxid[0]) (*beam)++; + if (--(*k) == 0) break; __syncthreads(); if (tid == maxid[0]) { - if (beam < MaxLength) { - sh_topk[tid] = topk[beam]; + if (*beam < MaxLength) { + sh_topk[tid] = topk[*beam]; } } if (maxid[0] / 32 == warp) { - if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break; + if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break; } } } @@ -268,13 +269,13 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, topk[k].set(-INFINITY, -1); } while (k) { - ThreadGetTopK(topk, beam, k, - src + blockIdx.x * lds, firststep, - is_empty, max, dim, tid); + ThreadGetTopK(topk, &beam, k, + src + blockIdx.x * lds, &firststep, + &is_empty, &max, dim, tid); sh_topk[tid] = topk[0]; BlockReduce(sh_topk, maxid, topk, &output, - &indices, beam, k, tid, warp); + &indices, &beam, &k, tid, warp); } } @@ -308,9 +309,9 @@ class TopkOpCUDAKernel : public framework::OpKernel { KeMatrixTopK<<< grid, threads, 0, reinterpret_cast( ctx.device_context()) - .stream()>>>(output_data, output->dims()[1], - indices_data, input_data, - input_width, input_width, int(k)); + .stream()>>>( + output_data, output->dims()[1], indices_data, input_data, input_width, + input_width, static_cast(k)); } }; diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index c0a6df831acbfe2654a5941cf95c91343992ef13..4d8bca4d2430a248ccf421572bdafdffc3a3003a 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -37,7 +37,7 @@ depth = 8 mix_hidden_lr = 1e-3 IS_SPARSE = True -PASS_NUM = 10 +PASS_NUM = 100 BATCH_SIZE = 10 embedding_name = 'emb' @@ -77,7 +77,8 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, emb_layers.append(mark_embedding) hidden_0_layers = [ - fluid.layers.fc(input=emb, size=hidden_dim) for emb in emb_layers + fluid.layers.fc(input=emb, size=hidden_dim, act='tanh') + for emb in emb_layers ] hidden_0 = fluid.layers.sums(input=hidden_0_layers) @@ -94,8 +95,8 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, for i in range(1, depth): mix_hidden = fluid.layers.sums(input=[ - fluid.layers.fc(input=input_tmp[0], size=hidden_dim), - fluid.layers.fc(input=input_tmp[1], size=hidden_dim) + fluid.layers.fc(input=input_tmp[0], size=hidden_dim, act='tanh'), + fluid.layers.fc(input=input_tmp[1], size=hidden_dim, act='tanh') ]) lstm = fluid.layers.dynamic_lstm( @@ -109,8 +110,8 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark, input_tmp = [mix_hidden, lstm] feature_out = fluid.layers.sums(input=[ - fluid.layers.fc(input=input_tmp[0], size=label_dict_len), - fluid.layers.fc(input=input_tmp[1], size=label_dict_len) + fluid.layers.fc(input=input_tmp[0], size=label_dict_len, act='tanh'), + fluid.layers.fc(input=input_tmp[1], size=label_dict_len, act='tanh') ]) return feature_out @@ -171,7 +172,7 @@ def train(use_cuda, save_dirname=None, is_local=True): # check other optimizers and check why out will be NAN sgd_optimizer = fluid.optimizer.SGD( learning_rate=fluid.layers.exponential_decay( - learning_rate=0.0001, + learning_rate=0.01, decay_steps=100000, decay_rate=0.5, staircase=True)) @@ -233,7 +234,7 @@ def train(use_cuda, save_dirname=None, is_local=True): print("second per batch: " + str((time.time( ) - start_time) / batch_id)) # Set the threshold low to speed up the CI test - if float(pass_precision) > 0.05: + if float(pass_precision) > 0.01: if save_dirname is not None: # TODO(liuyiqun): Change the target to crf_decode fluid.io.save_inference_model(save_dirname, [