提交 494a633c 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/baidu/Paddle into bug_fix

......@@ -14,11 +14,11 @@ limitations under the License. */
#include "hl_cuda_cudnn.h"
#include <cudnn.h>
#include <gflags/gflags.h>
#include <mutex>
#include "hl_cuda_cudnn.ph"
#include "hl_dso_loader.h"
#include "hl_thread.ph"
#include "paddle/utils/CommandLineParser.h"
#include "paddle/utils/Logging.h"
DEFINE_int32(cudnn_conv_workspace_limit_in_mb,
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_dso_loader.h"
#include "paddle/utils/CommandLineParser.h"
#include <gflags/gflags.h>
#include "paddle/utils/Logging.h"
DEFINE_string(cudnn_dir,
......
......@@ -50,7 +50,7 @@ public:
class ResponseNormLayer : public NormLayer {
protected:
size_t channels_, size_, outputX_, imgSize_, outputY_, imgSizeY_;
float scale_, pow_;
real scale_, pow_;
MatrixPtr denoms_;
public:
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include "Layer.h"
#include "SequenceToBatch.h"
#include "paddle/utils/CommandLineParser.h"
#include "paddle/utils/Stat.h"
DEFINE_bool(rnn_use_batch, false, "Using the batch method for calculation.");
......
......@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "TestUtil.h"
#include <gflags/gflags.h>
#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/CommandLineParser.h"
DEFINE_int32(fixed_seq_length, 0, "Produce some sequence of fixed length");
......
......@@ -17,7 +17,7 @@ import random
from paddle.trainer.PyDataProvider2 import *
@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
def test_dense_no_seq(setting, filename):
for i in xrange(200):
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]
......
......@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <string.h>
#include <algorithm>
#include "Matrix.h"
#include "paddle/utils/CommandLineParser.h"
#include "paddle/utils/Util.h"
DECLARE_bool(allow_inefficient_sparse_update);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "Parameter.h"
#include <gflags/gflags.h>
#include <fstream>
#include "AverageOptimizer.h"
#include "FirstOrderOptimizer.h"
......@@ -23,7 +24,6 @@ limitations under the License. */
#include "paddle/math/CpuSparseMatrix.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/SparseRowMatrix.h"
#include "paddle/utils/CommandLineParser.h"
#include "paddle/utils/Logging.h"
DEFINE_int32(enable_grad_share,
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "BaseClient.h"
#include <gflags/gflags.h>
#include <string.h>
#include <vector>
#include "paddle/utils/CommandLineParser.h"
#include "paddle/utils/Stat.h"
DECLARE_string(pservers);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CommandLineParser.h"
namespace paddle {
#ifndef GFLAGS_NS
#define GFLAGS_NS google
#endif
namespace gflags_ns = GFLAGS_NS;
void ParseCommandLineFlags(int* argc, char** argv, bool withHelp) {
if (withHelp) {
gflags_ns::ParseCommandLineFlags(argc, &argv, true);
} else {
gflags_ns::ParseCommandLineNonHelpFlags(argc, &argv, true);
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
namespace paddle {
void ParseCommandLineFlags(int* argc, char** argv, bool withHelp = true);
} // namespace paddle
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "CustomStackTrace.h"
#include <gflags/gflags.h>
#include <iostream>
#include "CommandLineParser.h"
DEFINE_bool(
layer_stack_error_only_current_thread,
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "CommandLineParser.h"
#include <gflags/gflags.h>
DECLARE_bool(parallel_nn);
DECLARE_int32(async_count);
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "ThreadLocal.h"
#include "CommandLineParser.h"
#include <gflags/gflags.h>
#include "Util.h"
DEFINE_bool(thread_local_rand_use_global_seed,
......
......@@ -24,10 +24,10 @@ limitations under the License. */
#include <fstream>
#include <mutex>
#include "paddle/utils/Logging.h"
#include <gflags/gflags.h>
#include "CommandLineParser.h"
#include "CustomStackTrace.h"
#include "Logging.h"
#include "StringUtil.h"
#include "Thread.h"
#include "ThreadLocal.h"
......@@ -152,7 +152,12 @@ void initMain(int argc, char** argv) {
line += ' ';
}
LOG(INFO) << "commandline: " << line;
ParseCommandLineFlags(&argc, argv, true);
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::ParseCommandLineFlags(&argc, &argv, true);
CHECK_EQ(argc, 1) << "Unknown commandline argument: " << argv[1];
installProfilerSwitch();
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "CommandLineParser.h"
#include "DisableCopy.h"
#include "Logging.h"
#include "TrainerConfig.pb.h"
......
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <chrono>
#include "paddle/utils/CommandLineParser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/utils/CustomStackTrace.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/Util.h"
......
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <vector>
#include "paddle/utils/CommandLineParser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/utils/Locks.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Util.h"
......
......@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <set>
#include <vector>
#include "paddle/utils/CommandLineParser.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/utils/Locks.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Util.h"
......
......@@ -232,7 +232,7 @@ def provider(input_types=None,
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
**outter_kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
......@@ -318,11 +318,6 @@ def provider(input_types=None,
self.logger = logging.getLogger("")
self.logger.setLevel(logging.INFO)
self.input_types = None
if 'slots' in kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = kwargs['slots']
self.slots = input_types
self.should_shuffle = should_shuffle
true_table = [1, 't', 'true', 'on']
......@@ -358,9 +353,19 @@ def provider(input_types=None,
self.check = check
if init_hook is not None:
init_hook(self, file_list=file_list, **kwargs)
if 'slots' in outter_kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = outter_kwargs['slots']
if input_types is not None:
self.slots = input_types
if self.input_types is not None:
self.slots = self.input_types
assert self.slots is not None
assert self.slots is not None, \
"Data Provider's input_types must be set"
assert self.generator is not None
use_dynamic_order = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册