提交 866d6bfe 编写于 作者: Q Qiao Longfei

dist table support other optimize and regular config

上级 6449faec
......@@ -13,21 +13,23 @@
# limitations under the License.
from __future__ import print_function
import re
import sys
from collections import defaultdict
from contextlib import contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
import paddle.fluid.transpiler.details.distribute_lookuptable_utils as distribute_lookuptable_utils
from . import framework
from . import layers
from . import unique_name
from .backward import append_backward
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard
from . import unique_name
from .initializer import Constant
from .layer_helper import LayerHelper
from .regularizer import append_regularization_ops
from .clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
from .layers import ops
from .regularizer import append_regularization_ops
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
......@@ -260,6 +262,9 @@ class Optimizer(object):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
distribute_lookuptable_utils.process_distribute_lookuptable(loss.block.program, params_grads, self._learning_rate)
params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any
......@@ -268,6 +273,8 @@ class Optimizer(object):
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.optimizer as optimizer
import paddle.fluid.framework as framework
LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table(program):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if table_name is None:
table_name = op.input("W")[0]
if table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if table_name is not None:
assert op.input("W")[0] != table_name
return table_name
def process_distribute_lookuptable(program, param_grads, learning_rate):
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with table_param.block.program._optimized_guard(
[table_param, table_grad]), framework.name_scope("optimizer"):
sgd_optimizer = optimizer.SGD(learning_rate)
sgd_op = sgd_optimizer._append_optimize_op(table_param.block, (
table_param, table_grad))
return new_param_grads, (table_param, table_grad), sgd_op
......@@ -31,18 +31,17 @@ Steps to transpile pserver:
"""
import math
import sys
import numpy as np
import collections
import six
import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \
default_startup_program, Block, \
Parameter, grad_var_name
from .details import *
from .details.distribute_lookuptable_utils import find_distributed_lookup_table
from functools import reduce
LOOKUP_TABLE_TYPE = "lookup_table"
......@@ -292,7 +291,8 @@ class DistributeTranspiler(object):
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads:
......@@ -966,28 +966,6 @@ to transpile() call.")
# ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
......@@ -1259,9 +1237,8 @@ to transpile() call.")
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
if 'Param' in op.input_names and op.input("Param")[0] ==
self.table_name
op for op in self.optimize_ops if 'Param' in op.input_names and
op.input("Param")[0] == self.table_name
][0]
origin_param_var = self.origin_program.global_block().vars[
......@@ -1341,7 +1318,6 @@ to transpile() call.")
"""
create a new block to handle save checkpoint.
"""
import os
pserver_program.global_block().create_var(
name="kLookupTablePath",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册