提交 13bfee1f 编写于 作者: P peizhilin

Merge branch 'windows/build' into windows/online

test=develop
......@@ -77,7 +77,6 @@ option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface"
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_PREBUILD_OPENBLAS "Make use of the pre-built openblas library" ${WIN32})
# PY_VERSION
if(NOT PY_VERSION)
......
......@@ -31,64 +31,66 @@ IF(NOT ${CBLAS_FOUND})
ADD_DEFINITIONS(-DPADDLE_USE_OPENBLAS)
IF (WITH_PREBUILD_OPENBLAS)
IF (WIN32)
SET(CBLAS_FOUND true)
MESSAGE(STATUS, "Use prebuild openblas, please put it at " ${CBLAS_INSTALL_DIR})
ELSE(WITH_PREBUILD_OPENBLAS)
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable")
SET(OPENBLAS_COMMIT "v0.2.20")
MESSAGE(WARNING, "In windows, openblas only support msvc build, please build it manually and put it at " ${CBLAS_INSTALL_DIR})
ENDIF(WIN32)
IF(CMAKE_CROSSCOMPILING)
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER})
GET_FILENAME_COMPONENT(CROSS_SUFFIX ${CMAKE_C_COMPILER} DIRECTORY)
SET(CROSS_SUFFIX ${CROSS_SUFFIX}/)
IF(ANDROID)
IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
# use softfp
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0)
ENDIF()
ELSEIF(IOS)
IF(CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
SET(OPENBLAS_CC "${OPENBLAS_CC} ${CMAKE_C_FLAGS} -isysroot ${CMAKE_OSX_SYSROOT}")
SET(OPENBLAS_CC "${OPENBLAS_CC} -arch arm64")
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0 CROSS_SUFFIX=${CROSS_SUFFIX})
ELSE()
MESSAGE(FATAL_ERROR "OpenBLAS only support arm64 architectures on iOS. "
"You can set IOS_USE_VECLIB_FOR_BLAS=ON or USE_EIGEN_FOR_BLAS=ON to use other blas library instead.")
ENDIF()
ELSEIF(RPI)
# use hardfp
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 USE_THREAD=0)
ENDIF()
ELSE()
IF(APPLE)
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}")
IF (NOT WIN32)
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -Wno-unused-but-set-variable -Wno-unused-variable")
SET(OPENBLAS_COMMIT "v0.2.20")
IF(CMAKE_CROSSCOMPILING)
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER})
GET_FILENAME_COMPONENT(CROSS_SUFFIX ${CMAKE_C_COMPILER} DIRECTORY)
SET(CROSS_SUFFIX ${CROSS_SUFFIX}/)
IF(ANDROID)
IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
# use softfp
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0)
ENDIF()
SET(OPTIONAL_ARGS "")
IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64)
ELSEIF(IOS)
IF(CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
SET(OPENBLAS_CC "${OPENBLAS_CC} ${CMAKE_C_FLAGS} -isysroot ${CMAKE_OSX_SYSROOT}")
SET(OPENBLAS_CC "${OPENBLAS_CC} -arch arm64")
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV8 BINARY=64 USE_THREAD=0 CROSS_SUFFIX=${CROSS_SUFFIX})
ELSE()
MESSAGE(FATAL_ERROR "OpenBLAS only support arm64 architectures on iOS. "
"You can set IOS_USE_VECLIB_FOR_BLAS=ON or USE_EIGEN_FOR_BLAS=ON to use other blas library instead.")
ENDIF()
ELSEIF(RPI)
# use hardfp
SET(OPTIONAL_ARGS ${OPTIONAL_ARGS} TARGET=ARMV7 USE_THREAD=0)
ENDIF()
ELSE()
IF(APPLE)
SET(OPENBLAS_CC "${CMAKE_C_COMPILER} -isysroot ${CMAKE_OSX_SYSROOT}")
ENDIF()
SET(OPTIONAL_ARGS "")
IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64)
ENDIF()
ENDIF()
SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs)
ExternalProject_Add(
extern_openblas
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG ${OPENBLAS_COMMIT}
PREFIX ${CBLAS_SOURCES_DIR}
INSTALL_DIR ${CBLAS_INSTALL_DIR}
BUILD_IN_SOURCE 1
BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS}
INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR>
&& rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
)
ENDIF (WITH_PREBUILD_OPENBLAS)
SET(COMMON_ARGS CC=${OPENBLAS_CC} NO_SHARED=1 NO_LAPACK=1 libs)
ExternalProject_Add(
extern_openblas
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/xianyi/OpenBLAS.git
GIT_TAG ${OPENBLAS_COMMIT}
PREFIX ${CBLAS_SOURCES_DIR}
INSTALL_DIR ${CBLAS_INSTALL_DIR}
BUILD_IN_SOURCE 1
BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS}
INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR>
&& rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
)
ELSE()
ENDIF(NOT WIN32)
SET(CBLAS_PROVIDER openblas)
IF(WITH_C_API)
INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas)
......
../../../CONTRIBUTING.md
\ No newline at end of file
../../../CONTRIBUTING.md
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
data_type_transform.cc
\ No newline at end of file
......@@ -211,12 +211,12 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
VLOG(30) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors =
std::array<const float*, 4> tensors{
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
W_output_w0.data<float>(), W_cell_w0.data<float>()};
std::array<const float*, 4> tensors1 =
W_output_w0.data<float>(), W_cell_w0.data<float>()}};
std::array<const float*, 4> tensors1{
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
W_output_w1.data<float>(), W_cell_w1.data<float>()};
W_output_w1.data<float>(), W_cell_w1.data<float>()}};
for (int row = 0; row < D; row++) {
for (int col = 0; col < 4; col++) {
......@@ -238,9 +238,9 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out) {
std::array<const float*, 4> tensors =
std::array<const float*, 4> tensors{
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()};
B_cell.data<float>()}};
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
int D = B_forget.dims()[0];
......
......@@ -207,7 +207,7 @@ struct PassRegistrar : public Registrar {
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __UNUSED__() = \
&__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
......@@ -215,7 +215,7 @@ struct PassRegistrar : public Registrar {
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __UNUSED__() = \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
} // namespace ir
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
tensor_util.cc
\ No newline at end of file
......@@ -113,7 +113,9 @@ void Analyzer::Run(Argument* argument) {
passes.push_back("infer_clean_graph_pass");
passes.push_back("graph_viz_pass"); // add graphviz for debug.
for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) {
// skip mkldnn pass when use_mkldnn_ = false;
bool skip_pass = (!use_mkldnn_) && pass.find("mkldnn") != std::string::npos;
if (!disabled_ir_passes_.count(pass) && !skip_pass) {
passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug.
}
......
......@@ -150,4 +150,4 @@ struct NCCLContextMap {
} // namespace platform
} // namespace paddle
#endif
\ No newline at end of file
#endif
......@@ -24,42 +24,38 @@
#include "glog/logging.h"
#if !defined(_WIN32)
#define UNUSED __attribute__((unused))
#include <dlfcn.h> // dladdr
#include <execinfo.h> // backtrace
#include <sys/stat.h>
#include <algorithm> // std::accumulate
#include <dlfcn.h> // dladdr
#include <execinfo.h> // backtrace
#include <sys/stat.h>
#include <algorithm> // std::accumulate
#else
#include <stdio.h>
#include <io.h> // _popen, _pclose
#include <windows.h>
#include <numeric> // std::accumulate in msvc
// windows version of __attribute__((unused))
#define UNUSED __pragma(warning(suppress : 4100))
#ifndef S_ISDIR // windows port for sys/stat.h
#define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR)
#endif // S_ISDIR
static void *dlsym(void *handle, const char *symbol_name) {
FARPROC found_symbol;
found_symbol = GetProcAddress((HMODULE)handle, symbol_name);
if (found_symbol == NULL) {
throw std::runtime_error(std::string(symbol_name) + " not found.");
}
return reinterpret_cast<void *>(found_symbol);
#include <stdio.h>
#include <io.h> // _popen, _pclose
#include <windows.h>
#include <numeric> // std::accumulate in msvc
#ifndef S_ISDIR // windows port for sys/stat.h
#define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR)
#endif // S_ISDIR
static void *dlsym(void *handle, const char *symbol_name) {
FARPROC found_symbol;
found_symbol = GetProcAddress((HMODULE)handle, symbol_name);
if (found_symbol == NULL) {
throw std::runtime_error(std::string(symbol_name) + " not found.");
}
return reinterpret_cast<void *>(found_symbol);
}
static void *dlopen(const char *filename, int flag) {
std::string file_name(filename);
file_name.replace(0, file_name.size() - 1, '/', '\\');
HMODULE hModule = LoadLibrary(file_name.c_str());
if (!hModule) {
throw std::runtime_error(file_name + " not found.");
}
return reinterpret_cast<void *>(hModule);
static void *dlopen(const char *filename, int flag) {
std::string file_name(filename);
file_name.replace(0, file_name.size() - 1, '/', '\\');
HMODULE hModule = LoadLibrary(file_name.c_str());
if (!hModule) {
throw std::runtime_error(file_name + " not found.");
}
return reinterpret_cast<void *>(hModule);
}
#endif // !_WIN32
......
......@@ -18,8 +18,8 @@
#include <cuda_runtime.h>
#include <functional>
#include <memory>
#include "ThreadPool.h"
#include "paddle/fluid/platform/enforce.h"
#include "third_party/threadpool/src/extern_threadpool/ThreadPool.h"
namespace paddle {
namespace platform {
......
......@@ -45,8 +45,8 @@ limitations under the License. */
// some platform-independent defintion
#if defined(_WIN32)
#define __UNUSED__()
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define __UNUSED__() __attribute__((unused))
#endif
\ No newline at end of file
#define UNUSED __attribute__((unused))
#endif
......@@ -35,6 +35,7 @@ from . import regularizer
from . import average
from . import metrics
from . import transpiler
from . import distribute_lookup_table
from .param_attr import ParamAttr, WeightNormParamAttr
from .data_feeder import DataFeeder
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
......@@ -111,11 +112,10 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads)
read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark',
'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
'dist_threadpool_size', 'eager_delete_tensor_gb',
'reader_queue_speed_test_mode'
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope',
'use_mkldnn', 'initial_cpu_memory_in_mb', 'init_allocated_mem',
'free_idle_memory', 'paddle_num_threads', 'dist_threadpool_size',
'eager_delete_tensor_gb', 'reader_queue_speed_test_mode'
]
if os.name != 'nt':
read_env_flags.append('warpctc_dir')
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LOOKUP_TABLE_TYPE = "lookup_table"
def find_distributed_lookup_table(program):
"""
Find distribute lookup table in program.
We only support one distribute table now.
:param program:
:return: table_name or None
"""
table_name = None
for op in program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if table_name is None:
table_name = op.input("W")[0]
if table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
else:
if table_name is not None:
assert op.input("W")[0] != table_name
return table_name
......@@ -348,6 +348,7 @@ def _copy_reader_create_op_(block, op):
if os.name != 'nt':
@templatedoc(op_type='create_recordio_file_reader')
def open_recordio_file(filename,
shapes,
......@@ -405,8 +406,8 @@ if os.name != 'nt':
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
main_prog_var = _copy_reader_var_(
default_main_program().current_block(), startup_var)
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
......
......@@ -342,6 +342,7 @@ def embedding(input,
if os.name != 'nt':
@templatedoc(op_type="lstm")
def dynamic_lstm(input,
size,
......@@ -961,6 +962,7 @@ def linear_chain_crf(input, label, param_attr=None):
if os.name != 'nt':
@templatedoc()
def crf_decoding(input, param_attr, label=None):
"""
......@@ -988,9 +990,11 @@ if os.name != 'nt':
dtype=helper.input_dtype())
helper.append_op(
type='crf_decoding',
inputs={"Emission": [input],
"Transition": transition,
"Label": label},
inputs={
"Emission": [input],
"Transition": transition,
"Label": label
},
outputs={"ViterbiPath": [viterbi_path]})
return viterbi_path
......@@ -5530,8 +5534,13 @@ def label_smooth(label,
if os.name != 'nt':
@templatedoc()
def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
def roi_pool(input,
rois,
pooled_height=1,
pooled_width=1,
spatial_scale=1.0):
"""
${comment}
......
......@@ -105,7 +105,6 @@ if os.name != 'nt':
_cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None):
locals_var = locals().keys()
kwargs = dict()
......@@ -115,7 +114,6 @@ if os.name != 'nt':
kwargs[name] = val
return _cum_sum_(**kwargs)
cumsum.__doc__ = _cum_sum_.__doc__ + """
Examples:
......
......@@ -13,21 +13,23 @@
# limitations under the License.
from __future__ import print_function
import re
import sys
from collections import defaultdict
from contextlib import contextmanager
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from . import framework
from . import layers
from . import unique_name
from .backward import append_backward
from .clip import append_gradient_clip_ops, error_clip_callback
from .framework import program_guard
from . import unique_name
from .initializer import Constant
from .layer_helper import LayerHelper
from .regularizer import append_regularization_ops
from .clip import append_gradient_clip_ops, error_clip_callback
from contextlib import contextmanager
from .layers import ops
from .regularizer import append_regularization_ops
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
......@@ -85,7 +87,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32' if self._dtype == None else self._dtype,
dtype='float32' if self._dtype is None else self._dtype,
persistable=True)
def _global_learning_rate(self, program=None):
......@@ -245,6 +247,50 @@ class Optimizer(object):
end = len(global_block.ops)
return global_block._slice_ops(start, end)
def _process_distribute_lookuptable(self, param_grads, loss,
startup_program):
"""
Because distribute lookup table only support SGD optimizer for now, not support
other optimizer and regularization, so we should find the table parameter out,
and avoid to add regularization and other op for it, and add sgd optimize op
for it independently.
:param param_grads(list((Var, Var))): list of (param, grad) pair.
:param loss: the loss variable.
:param startup_program: the startup program
"""
program = loss.block.program
table_name = find_distributed_lookup_table(program)
table_param = None
table_grad = None
new_param_grads = []
for p, g in param_grads:
if p.name == table_name:
if table_param is not None:
raise RuntimeError(
"multi dist table var found, only support one now!")
table_param = p
table_grad = g
else:
new_param_grads.append((p, g))
sgd_op = None
if table_param is not None:
with program_guard(program, startup_program):
param_and_grad = [table_param, table_grad]
with table_param.block.program._optimized_guard(param_and_grad), \
framework.name_scope("optimizer"):
self._create_global_learning_rate()
# create the optimize op
sgd_op = loss.block.append_op(
type='sgd',
inputs={
"Param": table_param,
"Grad": table_grad,
"LearningRate":
self._create_param_lr(param_and_grad)
},
outputs={"ParamOut": param_and_grad[0]})
return new_param_grads, (table_param, table_grad), sgd_op
def minimize(self,
loss,
startup_program=None,
......@@ -260,6 +306,9 @@ class Optimizer(object):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads, loss, startup_program)
params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any
......@@ -268,6 +317,9 @@ class Optimizer(object):
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads
......
......@@ -38,7 +38,7 @@ depth = 8
mix_hidden_lr = 1e-3
IS_SPARSE = True
PASS_NUM = 10
PASS_NUM = 1
BATCH_SIZE = 10
embedding_name = 'emb'
......
......@@ -567,7 +567,6 @@ class TestDistLookupTable(TestDistLookupTableBase):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
......@@ -639,7 +638,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config)
trainer, trainer_startup = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1)
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
......@@ -653,6 +652,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
'recv', 'concat'
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random',
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestDistLookupTableSliceSize(TestDistLookupTableBase):
......
......@@ -31,18 +31,17 @@ Steps to transpile pserver:
"""
import math
import sys
import numpy as np
import collections
import six
import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \
default_startup_program, Block, \
Parameter, grad_var_name
from .details import *
from ..distribute_lookup_table import find_distributed_lookup_table
from functools import reduce
LOOKUP_TABLE_TYPE = "lookup_table"
......@@ -292,7 +291,8 @@ class DistributeTranspiler(object):
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads:
......@@ -966,28 +966,6 @@ to transpile() call.")
# ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attr('is_distributed') is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
......@@ -1341,7 +1319,6 @@ to transpile() call.")
"""
create a new block to handle save checkpoint.
"""
import os
pserver_program.global_block().create_var(
name="kLookupTablePath",
......
......@@ -1719,7 +1719,7 @@ def inputs(layers, *args):
if len(args) != 0:
layers.extend(args)
Inputs(*[l.name for l in layers])
Inputs(* [l.name for l in layers])
def outputs(layers, *args):
......@@ -1769,7 +1769,7 @@ def outputs(layers, *args):
assert len(layers) > 0
if HasInputsSet(): # input already set
Outputs(*[l.name for l in layers])
Outputs(* [l.name for l in layers])
return # just return outputs.
if len(layers) != 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册