未验证 提交 080740b3 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #14300 from jacquesqiao/dist-table-support-optimizer-regular

dist table support other optimize and regular config
...@@ -34,6 +34,7 @@ from . import regularizer ...@@ -34,6 +34,7 @@ from . import regularizer
from . import average from . import average
from . import metrics from . import metrics
from . import transpiler from . import transpiler
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table(program):
"""
Find distribute lookup table in program.
We only support one distribute table now.
:param program:
:return: table_name or None
"""
table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if table_name is None:
table_name = op.input("W")[0]
if table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
else:
if table_name is not None:
assert op.input("W")[0] != table_name
return table_name
...@@ -13,21 +13,23 @@ ...@@ -13,21 +13,23 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import re
import sys
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from . import framework from . import framework
from . import layers from . import layers
from . import unique_name
from .backward import append_backward from .backward import append_backward
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard from .framework import program_guard
from . import unique_name
from .initializer import Constant from .initializer import Constant
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .regularizer import append_regularization_ops
from .clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
from .layers import ops from .layers import ops
from .regularizer import append_regularization_ops
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
...@@ -85,7 +87,7 @@ class Optimizer(object): ...@@ -85,7 +87,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"), name=unique_name.generate("learning_rate"),
shape=[1], shape=[1],
value=float(self._learning_rate), value=float(self._learning_rate),
dtype='float32' if self._dtype == None else self._dtype, dtype='float32' if self._dtype is None else self._dtype,
persistable=True) persistable=True)
def _global_learning_rate(self, program=None): def _global_learning_rate(self, program=None):
...@@ -245,6 +247,50 @@ class Optimizer(object): ...@@ -245,6 +247,50 @@ class Optimizer(object):
end = len(global_block.ops) end = len(global_block.ops)
return global_block._slice_ops(start, end) return global_block._slice_ops(start, end)
def _process_distribute_lookuptable(self, param_grads, loss,
startup_program):
"""
Because distribute lookup table only support SGD optimizer for now, not support
other optimizer and regularization, so we should find the table parameter out,
and avoid to add regularization and other op for it, and add sgd optimize op
for it independently.
:param param_grads(list((Var, Var))): list of (param, grad) pair.
:param loss: the loss variable.
:param startup_program: the startup program
"""
program = loss.block.program
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with program_guard(program, startup_program):
param_and_grad = [table_param, table_grad]
with table_param.block.program._optimized_guard(param_and_grad), \
framework.name_scope("optimizer"):
self._create_global_learning_rate()
# create the optimize op
sgd_op = loss.block.append_op(
type='sgd',
inputs={
"Param": table_param,
"Grad": table_grad,
"LearningRate":
self._create_param_lr(param_and_grad)
},
outputs={"ParamOut": param_and_grad[0]})
return new_param_grads, (table_param, table_grad), sgd_op
def minimize(self, def minimize(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -260,6 +306,9 @@ class Optimizer(object): ...@@ -260,6 +306,9 @@ class Optimizer(object):
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads, loss, startup_program)
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any # Add regularization if any
...@@ -268,6 +317,9 @@ class Optimizer(object): ...@@ -268,6 +317,9 @@ class Optimizer(object):
optimize_ops = self._create_optimization_pass(params_grads, loss, optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program) startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads return optimize_ops, params_grads
......
...@@ -567,7 +567,6 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -567,7 +567,6 @@ class TestDistLookupTable(TestDistLookupTableBase):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init' 'fake_init'
...@@ -639,7 +638,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -639,7 +638,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 5 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config) trainer, trainer_startup = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
...@@ -653,6 +652,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -653,6 +652,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
'recv', 'concat' 'recv', 'concat'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestDistLookupTableSliceSize(TestDistLookupTableBase): class TestDistLookupTableSliceSize(TestDistLookupTableBase):
......
...@@ -31,18 +31,17 @@ Steps to transpile pserver: ...@@ -31,18 +31,17 @@ Steps to transpile pserver:
""" """
import math import math
import sys
import numpy as np import numpy as np
import collections import collections
import six
import logging import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, Block, \ default_startup_program, Block, \
Parameter, grad_var_name Parameter, grad_var_name
from .details import * from .details import *
from ..distribute_lookup_table import find_distributed_lookup_table
from functools import reduce from functools import reduce
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
...@@ -292,7 +291,8 @@ class DistributeTranspiler(object): ...@@ -292,7 +291,8 @@ class DistributeTranspiler(object):
self.optimize_ops, self.params_grads = self._get_optimize_pass() self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints) ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.param_name_to_grad_name = dict() self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict() self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads: for param_var, grad_var in self.params_grads:
...@@ -966,28 +966,6 @@ to transpile() call.") ...@@ -966,28 +966,6 @@ to transpile() call.")
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list, def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads): params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together. # TODO(wuyi): put find a way to put dist lookup table stuff all together.
...@@ -1341,7 +1319,6 @@ to transpile() call.") ...@@ -1341,7 +1319,6 @@ to transpile() call.")
""" """
create a new block to handle save checkpoint. create a new block to handle save checkpoint.
""" """
import os
pserver_program.global_block().create_var( pserver_program.global_block().create_var(
name="kLookupTablePath", name="kLookupTablePath",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册