提交 64bfd814 编写于 作者: Q qiaolongfei

fix style probelm

上级 bad503ff
......@@ -19,9 +19,9 @@ limitations under the License. */
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/utils/Common.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
......@@ -470,7 +470,8 @@ private:
enum GradientMatchineCreateMode {
CREATE_MODE_NORMAL = paddle::GradientMachine::kNormal,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING = paddle::GradientMachine::kSgdSparseCpuTraining,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING =
paddle::GradientMachine::kSgdSparseCpuTraining,
CREATE_MODE_TESTING = paddle::GradientMachine::kTesting
};
......@@ -819,7 +820,8 @@ private:
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount, bool userSparseUpdater);
int passCount,
bool userSparseUpdater);
~ParameterUpdater();
/**
......
......@@ -32,12 +32,16 @@ ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount, bool userSparseUpdater) {
auto updater = new ParameterUpdater();
auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr);
config->m->getConfig(), passCount, nullptr);
if (userSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater = new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(), passCount, false, std::move(remoteUpdaterPtr));
auto sparseRemoteUpdater =
new paddle::SparseRemoteParameterUpdaterComposite(
config->m->getConfig(),
passCount,
false,
std::move(remoteUpdaterPtr));
updater->m->updater.reset(sparseRemoteUpdater);
} else {
updater->m->updater.reset(remoteUpdater);
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BufferArg.h"
#include <gtest/gtest.h>
#include "BufferArg.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Function.h"
#include <gtest/gtest.h>
#include "Function.h"
#include "paddle/math/SparseMatrix.h"
namespace paddle {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorShape.h"
#include <gtest/gtest.h>
#include "TensorShape.h"
namespace paddle {
......
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorType.h"
#include <gtest/gtest.h>
#include "TensorType.h"
namespace paddle {
......
......@@ -42,8 +42,8 @@ class Optimizer(object):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num, use_sparse_updater):
return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__,
pass_num, use_sparse_updater)
return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater)
class Momentum(Optimizer):
......
......@@ -42,7 +42,12 @@ class SGD(object):
:type extra_layers: paddle.v2.config_base.Layer
"""
def __init__(self, cost, parameters, update_equation, extra_layers=None, is_local=True):
def __init__(self,
cost,
parameters,
update_equation,
extra_layers=None,
is_local=True):
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
......@@ -97,8 +102,8 @@ class SGD(object):
if self.__is_local__:
updater = self.__optimizer__.create_local_updater()
else:
updater = self.__optimizer__.create_remote_updater(num_passes,
self.__use_sparse_updater__)
updater = self.__optimizer__.create_remote_updater(
num_passes, self.__use_sparse_updater__)
updater.init(self.__gradient_machine__)
self.__gradient_machine__.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册