提交 b4b16d1c 编写于 作者: T tangwei12

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into dist_unittest

......@@ -204,6 +204,11 @@ include(external/snappy) # download snappy
include(external/snappystream)
include(external/threadpool)
include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure
include(cupti)
include(configure) # add paddle env configuration
if(WITH_GPU)
include(cuda)
include(tensorrt)
......@@ -212,15 +217,11 @@ elseif()
set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in GPU only now." FORCE)
endif()
include(cudnn) # set cudnn libraries, must before configure
include(cupti)
include(configure) # add paddle env configuration
include(generic) # simplify cmake module
include(package) # set paddle packages
include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
include(rdma) # set rdma libraries
include(flags) # set paddle compile flags
include(version) # set PADDLE_VERSION
include(coveralls) # set code coverage
include(inference_lib) # add paddle fluid inference libraries
......
......@@ -50,16 +50,16 @@ if(NOT WITH_PROFILER)
endif(NOT WITH_PROFILER)
if(NOT CMAKE_CROSSCOMPILING)
if(WITH_AVX AND AVX_FOUND)
if(WITH_AVX AND AVX512F_FOUND)
set(SIMD_FLAG ${AVX512F_FLAG})
elseif(WITH_AVX AND AVX2_FOUND)
set(SIMD_FLAG ${AVX2_FLAG})
elseif(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG})
elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG})
endif()
endif()
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif(UNIX AND NOT APPLE)
if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG)
......@@ -112,8 +112,11 @@ if(WITH_GPU)
endif()
endif()
if(WITH_ANAKIN)
set(ENV{CUDNN_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR})
set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY})
# NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR
# is a softlink to real cudnn.h directory
set(ENV{CUDNN_INCLUDE_DIR} "${CUDNN_INCLUDE_DIR}/")
get_filename_component(CUDNN_LIBRARY_DIR ${CUDNN_LIBRARY} DIRECTORY)
set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY_DIR})
endif()
elseif(WITH_AMD_GPU)
add_definitions(-DPADDLE_WITH_HIP)
......
......@@ -25,8 +25,25 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
$ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
/usr/lib)
find_library(CUDNN_LIBRARY NAMES libcudnn.so libcudnn.dylib # libcudnn_static.a
/usr/lib
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
)
set(CUDNN_LIB_NAME "")
if (LINUX)
set(CUDNN_LIB_NAME "libcudnn.so")
endif(LINUX)
if(WIN32)
# only support cudnn7
set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll")
endif(WIN32)
if(Apple)
set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so")
endif(Apple)
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a
PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to cuDNN library.")
......
......@@ -19,17 +19,17 @@ execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-cer
include_directories(${ANAKIN_INCLUDE})
include_directories(${ANAKIN_INCLUDE}/saber/)
set(ANAKIN_COMPILE_EXTRA_FLAGS
set(ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-error=unused-but-set-variable -Wno-unused-but-set-variable
-Wno-error=unused-variable -Wno-unused-variable
-Wno-error=unused-variable -Wno-unused-variable
-Wno-error=format-extra-args -Wno-format-extra-args
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=switch -Wno-switch
-Wno-error=return-type -Wno-return-type
-Wno-error=return-type -Wno-return-type
-Wno-error=non-virtual-dtor -Wno-non-virtual-dtor
-Wno-sign-compare
-Wno-reorder
-Wno-reorder
-Wno-error=cpp)
ExternalProject_Add(
......@@ -47,6 +47,7 @@ ExternalProject_Add(
-DPROTOBUF_ROOT=${THIRD_PARTY_PATH}/install/protobuf
-DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml
-DCUDNN_ROOT=${CUDNN_ROOT}
-DCUDNN_INCLUDE_DIR=${CUDNN_INCLUDE_DIR}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR}
)
......
......@@ -142,6 +142,11 @@ else()
${GPU_COMMON_FLAGS})
endif()
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif(UNIX AND NOT APPLE)
foreach(flag ${COMMON_FLAGS})
safe_set_cflag(CMAKE_C_FLAGS ${flag})
......
......@@ -10,6 +10,7 @@ if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID
set(SSE3_FLAG "-msse3")
set(AVX_FLAG "-mavx")
set(AVX2_FLAG "-mavx2")
set(AVX512F_FLAG "-mavx512f")
elseif(MSVC)
set(MMX_FLAG "/arch:MMX")
set(SSE2_FLAG "/arch:SSE2")
......@@ -81,5 +82,16 @@ int main()
return 0;
}" AVX2_FOUND)
# Check AVX512F
set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG})
set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h>
int main()
{
__m512i a = _mm512_undefined_epi32();
return 0;
}" AVX512F_FOUND)
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED})
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND)
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND)
......@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', '
paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
......@@ -153,6 +153,7 @@ paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'n
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
......@@ -250,7 +251,6 @@ paddle.fluid.layers.logical_not ArgSpec(args=[], varargs='args', keywords='kwarg
paddle.fluid.layers.uniform_random_batch_size_like ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.gaussian_random ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.gaussian_random_batch_size_like ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.sum ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.slice ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.shape ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
......
......@@ -99,12 +99,13 @@ else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fast_threaded_ssa_graph_executor)
endif() # NOT WIN32
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -200,9 +200,11 @@ TEST(GraphTest, WriteAfterWrite) {
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2);
ASSERT_EQ(control_dep1, control_dep2);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
} // namespace framework
} // namespace paddle
......@@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto all_ops = blocks_[block_id]->AllOps();
for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
auto &op = all_ops[op_id];
for (const std::string &attr_name : op->AttrNames()) {
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) {
std::vector<int> sub_block_ids =
o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name);
std::vector<BlockDesc *> block_descs;
for (int block_id : sub_block_ids) {
block_descs.push_back(MutableBlock(block_id));
}
op->SetBlocksAttr(attr_name, block_descs);
}
}
}
......@@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
InitFromProto();
}
ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
InitFromProto();
}
void ProgramDesc::InitFromProto() {
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
......@@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
} else if (attr.type() == proto::AttrType::BLOCKS) {
auto blks_idx = attr.blocks_idx();
std::vector<BlockDesc *> block_descs;
for (int blk_idx : blks_idx) {
block_descs.push_back(this->MutableBlock(blk_idx));
}
op->SetBlocksAttr(attr.name(), block_descs);
}
}
}
......
......@@ -76,6 +76,8 @@ class ProgramDesc {
void SetFetchHolderName(const std::string &fetch_holder_name);
private:
void InitFromProto();
proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDesc>> blocks_;
......
......@@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) {
out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
BlockDesc* new_block = program.AppendBlock(*global_block);
op = new_block->AppendOp();
op->SetType("mul");
op = global_block->AppendOp();
op->SetType("op_with_subblock");
op->SetAttr("sub_block", new_block);
std::vector<BlockDesc*> sub_blocks;
sub_blocks.push_back(program.AppendBlock(*global_block));
sub_blocks.push_back(program.AppendBlock(*global_block));
op->SetAttr("sub_blocks", sub_blocks);
ProgramDesc program_copy(program);
auto* global_block_copy = program_copy.MutableBlock(0);
......@@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) {
assert_same_var("Y", y);
assert_same_var("Out", out);
bool found_sub_block = false;
bool found_sub_blocks = false;
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_copy = global_block_copy->Op(i);
......@@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_EQ(op_copy->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
if (op->Type() == "op_with_subblock") {
ASSERT_EQ(1, op->GetBlockAttrId("sub_block"));
found_sub_block = true;
ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size());
found_sub_blocks = true;
}
}
ASSERT_TRUE(found_sub_block);
ASSERT_TRUE(found_sub_blocks);
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
......
......@@ -65,13 +65,13 @@ config.model_dir = "xxx";
config.use_gpu = false;
// 创建一个原生的 PaddlePredictor
auto predictor =
paddle::CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
paddle::CreatePaddlePredictor<paddle::NativeConfig, paddle::PaddleEngineKind::kNative>(config);
// 创建输入 tensor
int64_t data[4] = {1, 2, 3, 4};
paddle::PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}),
.data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64};
.data = paddle::PaddleBuf(data, sizeof(data)),
.dtype = paddle::PaddleDType::INT64};
// 创建输出 tensor,输出 tensor 的内存可以复用
std::vector<paddle::PaddleTensor> outputs;
// 执行预测
......
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(convert)
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
activation_op.cc softmax_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......@@ -24,3 +24,6 @@ nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc
nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL)
nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class BatchNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
LOG(INFO) << "convert a fluid batch norm op to tensorrt batch_norm";
framework::OpDesc op_desc(op, nullptr);
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
1); // Variance is a weight
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
// Declare weights
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Mean_v = scope.FindVar(op_desc.Input("Mean").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front());
const float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
PADDLE_ENFORCE_NOT_NULL(Bias_v);
PADDLE_ENFORCE_NOT_NULL(Mean_v);
PADDLE_ENFORCE_NOT_NULL(Scale_v);
PADDLE_ENFORCE_NOT_NULL(Variance_v);
// get tensor
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Mean_t = Mean_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto* Variance_t = Variance_v->GetMutable<framework::LoDTensor>();
// create temp tensor for weights
framework::LoDTensor bias_tensor;
framework::LoDTensor mean_tensor;
framework::LoDTensor scale_tensor;
framework::LoDTensor variance_tensor;
bias_tensor.Resize(Bias_t->dims());
mean_tensor.Resize(Mean_t->dims());
scale_tensor.Resize(Scale_t->dims());
variance_tensor.Resize(Variance_t->dims());
platform::CPUPlace cpu_place;
// copy data from gpu to cpu
TensorCopySync((*Bias_t), cpu_place, &bias_tensor);
TensorCopySync((*Mean_t), cpu_place, &mean_tensor);
TensorCopySync((*Scale_t), cpu_place, &scale_tensor);
TensorCopySync((*Variance_t), cpu_place, &variance_tensor);
auto* bias_data = bias_tensor.mutable_data<float>(platform::CPUPlace());
auto* mean_data = mean_tensor.mutable_data<float>(platform::CPUPlace());
auto* scale_data = scale_tensor.mutable_data<float>(platform::CPUPlace());
auto* variance_data =
variance_tensor.mutable_data<float>(platform::CPUPlace());
std::unique_ptr<framework::LoDTensor> combile_scale_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> combile_bias_tensor(
new framework::LoDTensor());
combile_scale_tensor->Resize(scale_tensor.dims());
combile_bias_tensor->Resize(bias_tensor.dims());
auto* combile_scale_data =
combile_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto* combile_bias_data =
combile_bias_tensor->mutable_data<float>(platform::CPUPlace());
size_t ele_num = combile_scale_tensor->memory_size() / sizeof(float);
for (size_t i = 0; i < ele_num; i++) {
float scale = scale_data[i];
float bias = bias_data[i];
float mean = mean_data[i];
float variance = variance_data[i];
combile_scale_data[i] = scale / sqrtf(variance + eps);
combile_bias_data[i] = bias - mean * combile_scale_data[i];
}
TensorRTEngine::Weight scale_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_scale_data),
combile_scale_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight shift_weights{
nvinfer1::DataType::kFLOAT, static_cast<void*>(combile_bias_data),
combile_bias_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::IScaleLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Scale, *const_cast<nvinfer1::ITensor*>(X),
nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(),
scale_weights.get(), power_weights.get());
auto output_name = op_desc.Output("Y").front();
engine_->weight_map[op_desc.Input("Bias").front()] =
std::move(combile_bias_tensor);
engine_->weight_map[op_desc.Input("Scale").front()] =
std::move(combile_scale_tensor);
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(batch_norm, BatchNormOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(batch_norm_op, test) {
std::unordered_set<std::string> parameters(
{"batch_norm_scale", "batch_norm_bias", "batch_norm_mean",
"batch_norm_variance"});
framework::Scope scope;
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
std::vector<int> param_shape{2};
validator.DeclInputVar("batch_norm_X", nvinfer1::DimsCHW(2, 5, 5));
validator.DeclParamVar("batch_norm_scale", param_shape);
validator.DeclParamVar("batch_norm_bias", param_shape);
validator.DeclParamVar("batch_norm_mean", param_shape);
validator.DeclParamVar("batch_norm_variance", param_shape);
validator.DeclOutputVar("batch_norm_Y", nvinfer1::DimsCHW(2, 5, 5));
validator.DeclOutputVar("batch_norm_save_mean", param_shape);
validator.DeclOutputVar("batch_norm_save_variance", param_shape);
// Prepare Op description
framework::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"batch_norm_X"});
desc.SetInput("Scale", {"batch_norm_scale"});
desc.SetInput("Bias", {"batch_norm_bias"});
desc.SetInput("Mean", {"batch_norm_mean"});
desc.SetInput("Variance", {"batch_norm_variance"});
desc.SetOutput("Y", {"batch_norm_Y"});
desc.SetOutput("MeanOut", {"batch_norm_mean"});
desc.SetOutput("VarianceOut", {"batch_norm_variance"});
desc.SetOutput("SavedMean", {"batch_norm_save_mean"});
desc.SetOutput("SavedVariance", {"batch_norm_save_variance"});
float eps = 1e-5f;
bool is_test = true;
desc.SetAttr("epsilon", eps);
desc.SetAttr("is_test", is_test);
validator.SetOp(*desc.Proto());
std::unordered_set<std::string> neglected_output = {
"batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean",
"batch_norm_variance"};
validator.Execute(3, neglected_output);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(batch_norm);
......@@ -98,11 +98,19 @@ class TRTConvertValidation {
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
}
void DeclParamVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
// Declare a parameter varaible in the scope.
void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims, true);
}
void DeclOutputVar(const std::string& name, const std::vector<int> dim_vec) {
DeclVar(name, dim_vec);
}
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}
......@@ -155,7 +163,11 @@ class TRTConvertValidation {
}
}
void Execute(int batch_size) {
// We use the set 'neglected_output' here, because some Ops like batch norm,
// the outputs specified in the op des are only used during training,
// so we should neglect those output during inference.
void Execute(int batch_size,
std::unordered_set<std::string> neglected_output = {}) {
// Execute Fluid Op
PADDLE_ENFORCE_LE(batch_size, max_batch_size_);
platform::CUDAPlace place;
......@@ -168,6 +180,7 @@ class TRTConvertValidation {
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
const size_t output_space_size = 3000;
for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
std::vector<float> fluid_out;
std::vector<float> trt_out(output_space_size);
engine_->GetOutputInCPU(output, &trt_out[0], output_space_size);
......
......@@ -84,6 +84,15 @@ function(op_library TARGET)
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
endif()
#remove windows unsupported op
if (WIN32)
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
endforeach()
endif(WIN32)
list(LENGTH op_library_DEPS op_library_DEPS_len)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
......@@ -181,19 +190,19 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
if (NOT WIN32)
add_subdirectory(nccl)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
else()
set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif()
endif() # NOT WIN32
set(DISTRIBUTE_DEPS "")
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
......@@ -222,7 +231,7 @@ if(WITH_DISTRIBUTE)
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
if(WITH_GPU)
if(WITH_GPU AND NOT WIN32)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op ${DISTRIBUTE_DEPS} executor SERIAL)
if(WITH_GRPC)
......@@ -233,7 +242,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
endif() # WITH_GPU AND NOT WIN32
else()
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()
......@@ -331,5 +340,7 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
......@@ -29,9 +29,9 @@ class ConditionalOp : public framework::OperatorBase {
protected:
std::vector<const framework::LoDTensor *> InputTensors(
const framework::Scope &scope) const {
const framework::Scope &scope, const std::string &in_name) const {
std::vector<const framework::LoDTensor *> retv;
auto xs = Inputs("X");
auto xs = Inputs(in_name);
retv.resize(xs.size(), nullptr);
std::transform(
xs.begin(), xs.end(), retv.begin(),
......@@ -81,12 +81,18 @@ class ConditionalBlockOp : public ConditionalOp {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
// When is_scalar_condition is True, the conditional variable is a scalar,
// whether need to execute the operators in sub-block depends on the
// conditional variable (Cond).
auto xs = InputTensors(scope, "Cond");
need_run = ScalarCondition(xs);
} else {
// When is_scalar_condition is False, the conditional variable maybe a
// vector or tensor, whether need to execute the operators in sub-block
// depends on the input variables (Input).
auto xs = InputTensors(scope, "Input");
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
......@@ -110,11 +116,11 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
AddInput("Cond",
"The conditional variable of this operator. If Cond is empty, the "
"whole sub-block will not be executed.")
.AsDuplicable();
AddInput("Params", "The input variables of the sub-block.").AsDuplicable();
AddInput("Input", "The input variables of the sub-block.").AsDuplicable();
AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
AddOutput("Scope",
"(std::vector<Scope*>) The step scope of conditional block. To "
......@@ -123,13 +129,18 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator");
AddAttr<bool>("is_scalar_condition",
"the input X is used as scalar "
"condition")
"The conditional variable (Cond) is used as scalar "
"condition.")
.SetDefault(false);
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
outputs of the sub-block.
If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar,
run the operators in sub-block if Cond is True.
If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or
tensor, run the operators in sub-block if all of input variables are not empty.
)DOC");
}
};
......@@ -145,12 +156,12 @@ class ConditionalBlockGradOp : public ConditionalOp {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
auto xs = this->InputTensors(scope, "Cond");
need_run = ScalarCondition(xs);
} else {
auto xs = this->InputTensors(scope, "Input");
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
......@@ -166,11 +177,11 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto *block = Attr<framework::BlockDesc *>("sub_block");
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Params"),
Outputs(framework::GradVarName("Params")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"),
Outputs(framework::GradVarName("Input")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("X"),
Outputs(framework::GradVarName("X")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"),
Outputs(framework::GradVarName("Cond")));
}
}
......@@ -199,15 +210,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("X"));
if (context->HasInputs("Params")) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params")));
context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params"));
PADDLE_ENFORCE(context->HasInputs("Cond"));
if (context->HasInputs("Input")) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input")));
context->SetOutputsDim(framework::GradVarName("Input"),
context->GetInputsDim("Input"));
}
if (context->HasOutputs(framework::GradVarName("X"))) {
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
if (context->HasOutputs(framework::GradVarName("Cond"))) {
context->SetOutputsDim(framework::GradVarName("Cond"),
context->GetInputsDim("Cond"));
}
}
};
......@@ -220,14 +231,15 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad_op = new framework::OpDesc();
grad_op->SetType("conditional_block_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("Params", Input("Params"));
grad_op->SetInput("Cond", Input("Cond"));
grad_op->SetInput("Input", Input("Input"));
grad_op->SetInput("Out", Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetOutput(framework::GradVarName("Cond"),
InputGrad("Cond", false));
grad_op->SetOutput(framework::GradVarName("Input"),
InputGrad("Input", false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op);
......
......@@ -85,6 +85,199 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
#ifdef __AVX__
// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or
// 16 elements per iteration. Then it can implement the parallel processing.
// Only optimize for float type.
#ifdef __AVX512F__
size_t step_size = 16;
#else
size_t step_size = 8;
#endif
if (std::is_same<T, float>::value && (tag_num >= step_size)) {
size_t steps = tag_num / step_size;
size_t remain = tag_num % step_size;
int last_offset = static_cast<int>(remain) - static_cast<int>(step_size);
// Setup the alpha initial value.
size_t i_offset = 0;
for (size_t i = 0; i <= steps; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha
// values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps((const float*)(w + i_offset));
x_content = _mm512_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#else
// Declare the variable for the content of weights, input and alpha
// values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps((const float*)(w + i_offset));
x_content = _mm256_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm256_add_ps(w_content, x_content);
// Save the alpha value.
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#endif
i_offset += step_size;
if (i == steps - 1) {
if (remain > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
size_t seq_offset = 0;
for (size_t k = 1; k < seq_len; ++k) {
size_t j_offset = 0;
for (size_t j = 0; j <= steps; ++j) {
#ifdef __AVX512F__
// Initialize the variables of maximum score and location.
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max());
__m512i max_j = _mm512_setzero_si512();
#else
// Initialize the variables of maximum score and location.
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<T>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
// Calculate the offset of transition_weights.
size_t trans_offset = state_trans_base_idx * tag_num + j_offset;
for (size_t i = 0; i < tag_num; ++i) {
#ifdef __AVX512F__
// Initalize the content of alpha variable with related offset.
__m512 alpha_content =
_mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m512 w_content =
_mm512_loadu_ps((const float*)(w + trans_offset));
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
// Update the max_score value.
max_score = _mm512_max_ps(max_score, score_v);
#else
// Initalize the content of alpha variable with related offset.
__m256 alpha_content = _mm256_broadcast_ss(
(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m256 w_content =
_mm256_loadu_ps((const float*)(w + trans_offset));
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
#ifdef __AVX2__
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0);
__m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1);
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
// Update the max_score value.
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
#ifdef __AVX512F__
// Update the alpha and track values.
__m512 x_content = _mm512_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm512_storeu_si512(
reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#else
// Update the alpha and track values.
__m256 x_content = _mm256_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#endif
// Calculate the offset of next step
j_offset += step_size;
if (j == steps - 1) {
if (remain > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
} else {
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
T score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
}
#else
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
......@@ -105,6 +298,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
}
}
#endif
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {
......
......@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
// TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true;
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
......@@ -23,6 +24,37 @@ struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename DeviceContext, typename T>
void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(),
z->mutable_data<T>(ctx.GetPlace()));
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
}
template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
......@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
if (x->numel() == y->numel()) {
elementwise_mul<DeviceContext, T>(ctx, x, y, z);
} else {
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
}
}
};
......
......@@ -80,6 +80,9 @@ inline framework::DDim trim_trailing_singular_dims(
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return framework::DDim(framework::make_dim());
}
framework::DDim actual_dims = framework::make_ddim(trim_dims);
return actual_dims;
}
......
......@@ -15,8 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
#include "paddle/fluid/operators/math/fc_compute.h"
namespace paddle {
namespace operators {
......@@ -110,13 +109,8 @@ void FCOpMaker::Make() {
AddComment(R"DOC(
Fully Connected Operator.
The fully connected operation calculates the output based on the input, weights and bias attribute.
The fully connected operation calculates the output based on the input, weights and bias.
The size of each dimension of the parameters checked in the infer-shape.
The matrix of bias is generated by the mkldnn framework, when the bias_attr is True.
Additional parametrs are use_mkldnn and bias_attr.
The input(X) size and output(Out) size may be diffrent.
The fully connected layer only supports MKLDNN version
)DOC");
}
......@@ -133,26 +127,15 @@ class FCOpKernel : public framework::OpKernel<T> {
auto in_dims = input->dims();
auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
math::FCCompute<platform::CPUDeviceContext, T>(
blas, in_dims[0], w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL);
blas.GEMM(CblasNoTrans, CblasNoTrans, in_dims[0], w_dims[1], w_dims[0],
static_cast<T>(1), input_data, w_data, static_cast<T>(0),
output_data);
if (bias) {
const T* bias_data = bias->data<T>();
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dims[1]);
}
}
// TODO(TJ): fuse act
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle {
namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchedGate) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 4;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
"The second dimension of Input(WeightH) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"),
"Do not support peephole yet.");
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
void FusionLSTMOpMaker::Make() {
AddInput("X",
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x.");
AddInput("WeightX",
"(Tensor) the learnable weights of X."
" - The shape is (M x 4D), where M is the dim size of x, D is the "
"hidden size. "
" - Weight = {W_cx, W_ix, W_fx, W_ox}");
AddInput("WeightH",
"(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}");
AddInput("Bias",
"(Tensor) the learnable weights. Almost same as LSTMOp"
"Note: we should add the fc bias into this (1x4D) in bias."
"input-hidden bias weight and peephole connections weight if "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddInput("H0",
"(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size and D is the hidden size.")
.AsDispensable();
AddInput("C0",
"(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time.")
.AsDispensable();
AddOutput("Hidden",
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("XX",
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.")
.AsIntermediate();
AddOutput("BatchedGate", "(LoDTensor) (same as LSTMOp).").AsIntermediate();
AddOutput("BatchCellPreAct", "(LoDTensor) (same as LSTMOp).")
.AsIntermediate();
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC");
}
template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
// TODO(TJ): check mem copy perf
row_shuffle(ctx, src, index_lod, dst, indexed_src);
}
template <typename DeviceContext, typename T>
class FuisonLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
cell_out->mutable_data<T>(ctx.GetPlace());
const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>();
auto x_dims = x->dims();
auto wx_dims = wx->dims();
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_dims[1] > wx_dims[1]) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
x_data, wx_data, xx_data,
bias->data<T>());
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
} else {
to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_gate->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
xx_data, wx_data, batched_gate_data,
bias->data<T>());
}
int frame_size = static_cast<int>(wx_dims[1] / 4);
framework::DDim out_dims({x_dims[0], frame_size});
math::LstmMetaValue<T> lstm_value;
// no peephole
lstm_value.check_ig = nullptr;
lstm_value.check_fg = nullptr;
lstm_value.check_og = nullptr;
lstm_value.prev_state_value = nullptr;
Tensor ordered_c0;
framework::Vector<size_t> order(batched_gate->lod()[2]);
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<DeviceContext, T>(dev_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prev_state_value = ordered_c0.data<T>();
}
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(out_dims, ctx.GetPlace());
auto batch_starts = batched_gate->lod()[0];
size_t max_seq_len = batch_starts.size() - 1;
auto gate_act = math::detail::GetActivationType(
ctx.Attr<std::string>("gate_activation"));
auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
for (size_t n = 0; n < max_seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batched_gate->Slice(bstart, bend);
Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
int cur_batch_size = bend - bstart;
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
// TODO(TJ): use gemm directly
blas.MatMul(pre_hidden_t, false, *wh, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0));
} else if (hidden_t0) {
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ReorderInitState<DeviceContext, T>(dev_ctx, *hidden_t0, order,
&ordered_h0, true);
// TODO(TJ): use gemm directly
blas.MatMul(ordered_h0, false, *wh, false, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
lstm_value.gate_value = gate_t.data<T>();
lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute(
dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act,
cand_act);
lstm_value.prev_state_value = lstm_value.state_value;
}
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
batch_hidden.set_lod(batched_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(dev_ctx, batch_hidden, hidden_out);
batch_cell.set_lod(batched_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(dev_ctx, batch_cell, cell_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL(
fusion_lstm,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
class FusionLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
......@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
};
......
......@@ -134,6 +134,9 @@ class Blas {
template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VMUL(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VCOPY(int n, const T* x, T* y) const;
......@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VADD<T>(args...);
}
template <typename... ARGS>
void VMUL(ARGS... args) const {
Base()->template VMUL<T>(args...);
}
template <typename... ARGS>
void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...);
......
......@@ -82,6 +82,11 @@ struct CBlas<float> {
static void VADD(ARGS... args) {
platform::dynload::vsAdd(args...);
}
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
};
template <>
......@@ -142,6 +147,11 @@ struct CBlas<double> {
static void VADD(ARGS... args) {
platform::dynload::vdAdd(args...);
}
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
};
#else
......@@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
}
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
#ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
......@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMUL(n, x, y, z);
#else
// try to find if openblas support vmul
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
const int N, const int K, const T* X, const T* W, T* Y,
const T* B = NULL) {
blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast<T>(1), X, W,
static_cast<T>(0), Y);
if (B) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
#endif
for (int i = 0; i < M; i++) {
blas.AXPY(N, static_cast<T>(1), B, Y + i * N);
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle
if(WITH_GPU)
if(WITH_GPU AND NOT WIN32)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
endif()
......@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
......
......@@ -81,8 +81,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The source input of scatter op");
AddInput("Ids", "The index input of scatter op where X will be updated");
AddInput("Updates", "The updated value of updates op");
AddOutput("Out", "The output of add op");
AddInput("Updates", "The updated value of scatter op");
AddOutput("Out", "The output of scatter op");
AddComment(R"DOC(
Scatter Operator.
......@@ -90,7 +90,7 @@ This operator obtains output by updating the input on selected indices on the fi
$$
Out = X \\
Out[Ids] = X[Ids] + Updates
Out[Ids] = Updates
$$
)DOC");
......
......@@ -34,9 +34,9 @@ class ScatterOpKernel : public framework::OpKernel<T> {
auto *Updates = ctx.Input<Tensor>("Updates");
auto *Out = ctx.Output<Tensor>("Out");
// In place output: Out = X, Out[Ids] += Updates
// In place output: Out = X, Out[Ids] = Updates
framework::TensorCopySync(*X, ctx.GetPlace(), Out);
// Apply ScatterUpdate: Out[index] += Updates[:]
// Apply ScatterUpdate: Out[index] = Updates[:]
ScatterAssign<T>(ctx.device_context(), *Updates, *Ids, Out);
}
};
......@@ -55,7 +55,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Ids]
// Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
};
......
......@@ -23,9 +23,9 @@ class SqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SqueezeOp should not be null.");
"Input(X) of Squeeze operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SqueezeOp should not be null.");
"Output(Out) of Squeeze operator should not be null.");
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
......@@ -107,7 +107,6 @@ class SqueezeOp : public framework::OperatorBase {
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape Op
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
......@@ -125,12 +124,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({});
AddAttr<bool>("inplace",
"(default: false) Squeeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Squeeze Operator.
......@@ -180,7 +173,6 @@ class SqueezeGradOp : public framework::OperatorBase {
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
......
......@@ -23,9 +23,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UnsqueezeOp should not be null.");
"Input(X) of Unsqueeze operator should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnsqueezeOp should not be null.");
"Output(Out) of Unsqueeze operator should not be null.");
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
const auto &x_dims = ctx->GetInputDim("X");
......@@ -95,7 +95,6 @@ class UnsqueezeOp : public framework::OperatorBase {
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(out_dims);
attrs["inplace"] = Attr<bool>("inplace");
// Invoke Reshape op.
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
......@@ -126,13 +125,6 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
" within [1, 6] dimensions (Eigen limit).");
}
});
AddAttr<bool>(
"inplace",
"(default: false) Unsqueeze the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x).")
.SetDefault(false);
AddComment(R"DOC(
Unsqueeze Operator.
......@@ -168,7 +160,6 @@ class UnsqueezeGradOp : public framework::OperatorBase {
framework::AttributeMap attrs;
attrs["shape"] = framework::vectorize2int(x_dims);
attrs["inplace"] = Attr<bool>("inplace");
auto reshape_op = framework::OpRegistry::CreateOp(
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
......
......@@ -3,7 +3,7 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc)
# There is no macOS version of NCCL.
if (NOT APPLE)
if (NOT APPLE AND NOT WIN32)
list(APPEND CUDA_SRCS nccl.cc)
endif()
......
......@@ -49,25 +49,27 @@ extern void* mklml_dso_handle;
#define MKLML_ROUTINE_EACH(__macro) \
__macro(cblas_sgemm); \
__macro(cblas_saxpy); \
__macro(cblas_scopy); \
__macro(cblas_sgemv); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm); \
__macro(cblas_saxpy); \
__macro(cblas_daxpy); \
__macro(cblas_scopy); \
__macro(cblas_dcopy); \
__macro(cblas_sgemv); \
__macro(cblas_dgemv); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_dgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_dgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
......@@ -44,7 +44,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/curand.h"
#ifndef __APPLE__
#if !defined(__APPLE__) and !defined(_WIN32)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif // __APPLE__
#endif // PADDLE_WITH_CUDA
......@@ -205,7 +205,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif
}
#ifndef __APPLE__
#if !defined(__APPLE__) and !defined(_WIN32)
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
ncclResult_t stat, const Args&... args) {
......@@ -221,7 +221,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif
}
}
#endif // __APPLE__
#endif // __APPLE__ and windows
#endif // PADDLE_WITH_CUDA
template <typename T>
......
......@@ -1363,6 +1363,13 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
# for distribute
self._is_distributed = False
self._is_chief = False
self._slice_vars_and_attrs = []
self._endpoints = []
self._distributed_lookup_table = None
@property
def op_role(self):
"""
......
......@@ -372,6 +372,7 @@ def load_vars(executor,
load_vars(
executor,
dirname=dirname,
main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())),
filename=filename)
else:
......@@ -403,9 +404,12 @@ def load_vars(executor,
inputs={},
outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None):
"""
......@@ -659,11 +663,19 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
def load_inference_model(dirname,
executor,
model_filename=None,
params_filename=None):
params_filename=None,
pserver_endpoints=None):
"""
Load inference model from a directory
......@@ -679,6 +691,10 @@ def load_inference_model(dirname,
parameters were saved in a single binary
file. If parameters were saved in separate
files, set it as 'None'.
pserver_endpoints(list|None): This only need by distributed inference.
When use distributed look up table in training,
We also need it in inference.The parameter is
a list of pserver endpoints.
Returns:
tuple: The return of this function is a tuple with three elements:
......@@ -697,12 +713,16 @@ def load_inference_model(dirname,
exe = fluid.Executor(fluid.CPUPlace())
path = "./infer_model"
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[inference_program, feed_target_names, fetch_targets] =
fluid.io.load_inference_model(dirname=path, executor=exe)
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# if we need lookup table, we will use:
fluid.io.load_inference_model(dirname=path, executor=exe, pserver_endpoints=endpoints)
# In this exsample, the inference program was saved in the
# "./infer_model/__model__" and parameters were saved in
# separate files in ""./infer_model".
......@@ -729,6 +749,9 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints)
feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [
......@@ -738,6 +761,61 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.has_attr(ENDPOINT_MAP):
op.set_attr(ENDPOINT_MAP, endpoints)
program._sync_with_cpp()
return program
def get_parameter_value(para, executor):
"""
Get the LoDTensor value of the given parameter.
......@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
......@@ -1272,8 +1272,8 @@ class ConditionalBlock(object):
parent_block.append_op(
type='conditional_block',
inputs={
'X': self.inputs,
'Params': param_list,
'Cond': self.inputs,
'Input': param_list,
},
outputs={'Out': out_list,
'Scope': [step_scope]},
......
......@@ -94,6 +94,7 @@ __all__ = [
'image_resize_short',
'resize_bilinear',
'gather',
'scatter',
'random_crop',
'mean_iou',
'relu',
......@@ -5036,6 +5037,47 @@ def gather(input, index):
return out
def scatter(input, index, updates, name=None):
"""
**Scatter Layer**
Output is obtained by updating the input on selected indices on the first
axis.
.. math::
Out = X
Out[Ids] = Updates
Args:
input (Variable): The source input with rank>=1.
index (Variable): The index input with rank=1. Its dtype should be
int32 or int64 as it is used as indexes.
updates (Variable): The updated value of scatter op.
name (str|None): The output variable name. Default None.
Returns:
output (Variable): The output is a tensor with the same shape as input.
Examples:
.. code-block:: python
output = fluid.layers.scatter(input, index, updates)
"""
helper = LayerHelper('scatter', **locals())
dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype)
helper.append_op(
type="scatter",
inputs={"X": input,
"Ids": index,
"Updates": updates},
outputs={"Out": out})
return out
@templatedoc()
def random_crop(x, shape, seed=None):
"""
......
......@@ -65,7 +65,6 @@ __all__ = [
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'scatter',
'sum',
'slice',
'shape',
......
......@@ -30,7 +30,8 @@ import numpy as np
class TestMNISTIfElseOp(unittest.TestCase):
def test_raw_api(self):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_raw_api(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
......@@ -91,7 +92,8 @@ class TestMNISTIfElseOp(unittest.TestCase):
return
self.assertFalse(True)
def test_ifelse(self):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_ifelse(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
......@@ -153,6 +155,13 @@ class TestIfElse(unittest.TestCase):
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def numpy_cal(self):
s1 = self.data[np.where(self.data < self.cond_value)]
res = np.sum(np.exp(s1))
s2 = self.data[np.where(self.data >= self.cond_value)]
res += np.sum(np.tanh(s2))
return res
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
......@@ -166,10 +175,12 @@ class TestIfElse(unittest.TestCase):
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = fluid.layers.exp(true_target)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
false_target = fluid.layers.tanh(false_target)
ie.output(false_target)
if_out = ie()
out = layers.reduce_sum(if_out)
......@@ -180,7 +191,8 @@ class TestIfElse(unittest.TestCase):
o1, = exe.run(fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out])
o2 = np.sum(self.data)
o2 = self.numpy_cal()
self.assertTrue(
np.allclose(
o1, o2, atol=1e-8),
......
......@@ -46,7 +46,8 @@ def cnn_model(data):
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant()))
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.3)))
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
......@@ -54,7 +55,8 @@ def cnn_model(data):
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant()))
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.2)))
SIZE = 10
input_shape = conv_pool_2.shape
......@@ -66,8 +68,7 @@ def cnn_model(data):
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale, seed=1)))
initializer=fluid.initializer.Constant(value=0.1)))
return predict
......
......@@ -129,7 +129,12 @@ class SE_ResNeXt():
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
out = fluid.layers.fc(
input=drop,
size=class_dim,
act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.2)))
return out
def shortcut(self, input, ch_out, stride):
......@@ -179,7 +184,7 @@ class SE_ResNeXt():
act=None,
# avoid pserver CPU init differs from GPU
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant()),
initializer=fluid.initializer.Constant(value=0.2)),
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
......@@ -228,10 +233,8 @@ class DistSeResneXt2x2(TestDistRunnerBase):
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
# FIXME(typhoonzero): add back LR decay once ParallelExecutor fixed.
#learning_rate=fluid.layers.piecewise_decay(
# boundaries=bd, values=lr),
learning_rate=base_lr,
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
......
......@@ -265,9 +265,9 @@ def main(role="pserver",
if __name__ == "__main__":
if len(sys.argv) != 7:
if len(sys.argv) != 8:
print(
"Usage: python dist_transformer.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]"
"Usage: python dist_transformer.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist] [sync_mode]"
)
role = sys.argv[1]
endpoints = sys.argv[2]
......@@ -275,6 +275,8 @@ if __name__ == "__main__":
current_endpoint = sys.argv[4]
trainers = int(sys.argv[5])
is_dist = True if sys.argv[6] == "TRUE" else False
# FIXME(typhoonzero): refine this test.
is_async = True if sys.argv[7] == "TRUE" else False
main(
role=role,
endpoints=endpoints,
......
......@@ -27,6 +27,7 @@ import unittest
from multiprocessing import Process
import os
import signal
import six
import collections
SEED = 1
......@@ -55,7 +56,8 @@ def cnn_model(data):
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)
] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
......@@ -108,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
def operator_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, fluid.framework.Program) or \
isinstance(v, fluid.framework.Block):
continue
......@@ -118,8 +120,8 @@ def operator_equal(a, b):
raise ValueError("In operator_equal not equal:{0}\n".format(k))
elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1:
raise ValueError("In operator_equal not equal:{0}\n".format(k))
......@@ -131,7 +133,7 @@ def operator_equal(a, b):
def block_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc) or isinstance(
v, fluid.framework.Program) or isinstance(v, core.BlockDesc):
continue
......@@ -143,8 +145,8 @@ def block_equal(a, b):
assert (len(a.ops) == len(b.ops))
elif isinstance(v, collections.OrderedDict):
v0 = sorted(v.iteritems(), key=lambda x: x[0])
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
if v0 != v1:
raise ValueError("In block_equal not equal:{0}\n".format(k))
......@@ -156,7 +158,7 @@ def block_equal(a, b):
def program_equal(a, b):
for k, v in a.__dict__.iteritems():
for k, v in six.iteritems(a.__dict__):
if isinstance(v, core.ProgramDesc):
continue
......
......@@ -30,7 +30,7 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes.")
def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
trainers):
trainers, sync_mode):
# NOTE: import fluid until runtime, or else forking processes will cause error.
import paddle
import paddle.fluid as fluid
......@@ -39,17 +39,22 @@ class TestDistRunnerBase(object):
trainer_id=trainer_id,
program=main_program,
pservers=pserver_endpoints,
trainers=trainers)
trainers=trainers,
sync_mode=sync_mode)
return t
def run_pserver(self, pserver_endpoints, trainers, current_endpoint,
trainer_id):
def run_pserver(self,
pserver_endpoints,
trainers,
current_endpoint,
trainer_id,
sync_mode=True):
import paddle
import paddle.fluid as fluid
self.get_model(batch_size=2)
t = self.get_transpiler(trainer_id,
fluid.default_main_program(), pserver_endpoints,
trainers)
trainers, sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint)
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
place = fluid.CPUPlace()
......@@ -57,7 +62,13 @@ class TestDistRunnerBase(object):
exe.run(startup_prog)
exe.run(pserver_prog)
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
def run_trainer(self,
place,
endpoints,
trainer_id,
trainers,
is_dist=True,
sync_mode=True):
import paddle
import paddle.fluid as fluid
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
......@@ -65,7 +76,7 @@ class TestDistRunnerBase(object):
if is_dist:
t = self.get_transpiler(trainer_id,
fluid.default_main_program(), endpoints,
trainers)
trainers, sync_mode)
trainer_prog = t.get_trainer_program()
else:
trainer_prog = fluid.default_main_program()
......@@ -106,9 +117,9 @@ def runtime_main(test_class):
import paddle.fluid as fluid
import paddle.fluid.core as core
if len(sys.argv) != 7:
if len(sys.argv) != 8:
print(
"Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]"
"Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist] [sync_mode]"
)
role = sys.argv[1]
endpoints = sys.argv[2]
......@@ -116,34 +127,43 @@ def runtime_main(test_class):
current_endpoint = sys.argv[4]
trainers = int(sys.argv[5])
is_dist = True if sys.argv[6] == "TRUE" else False
sync_mode = True if sys.argv[7] == "TRUE" else False
model = test_class()
if role == "pserver":
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id)
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id,
sync_mode)
else:
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist,
sync_mode)
import paddle.compat as cpt
class TestDistBase(unittest.TestCase):
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
def setUp(self):
self._trainers = 2
self._pservers = 2
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
self._python_interp = "python"
self._sync_mode = True
self._setup_config()
def start_pserver(self, model_file, check_error_log):
sync_mode_str = "TRUE" if self._sync_mode else "FALSE"
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps0_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
ps0_cmd = "%s %s pserver %s 0 %s %d TRUE %s" % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers)
ps1_cmd = "%s %s pserver %s 0 %s %d TRUE" % \
self._trainers, sync_mode_str)
ps1_cmd = "%s %s pserver %s 0 %s %d TRUE %s" % \
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
self._trainers)
self._trainers, sync_mode_str)
ps0_pipe = subprocess.PIPE
ps1_pipe = subprocess.PIPE
......@@ -195,9 +215,10 @@ class TestDistBase(unittest.TestCase):
# Run local to get a base line
env_local = {"CUDA_VISIBLE_DEVICES": "0"}
env_local.update(required_envs)
local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \
sync_mode_str = "TRUE" if self._sync_mode else "FALSE"
local_cmd = "%s %s trainer %s 0 %s %d FLASE %s" % \
(self._python_interp, model_file,
"127.0.0.1:1234", "127.0.0.1:1234", 1)
"127.0.0.1:1234", "127.0.0.1:1234", 1, sync_mode_str)
if not check_error_log:
local_proc = subprocess.Popen(
local_cmd.split(" "),
......@@ -226,12 +247,12 @@ class TestDistBase(unittest.TestCase):
self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr0_cmd = "%s %s trainer %s 0 %s %d TRUE" % \
tr0_cmd = "%s %s trainer %s 0 %s %d TRUE %s" % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers)
tr1_cmd = "%s %s trainer %s 1 %s %d TRUE" % \
self._trainers, sync_mode_str)
tr1_cmd = "%s %s trainer %s 1 %s %d TRUE %s" % \
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
self._trainers)
self._trainers, sync_mode_str)
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
......
......@@ -17,10 +17,21 @@ import unittest
from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase):
class TestDistMnist2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=1e-7)
class TestDistMnistAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200)
if __name__ == "__main__":
unittest.main()
......@@ -18,9 +18,20 @@ from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_se_resnext(self):
self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistSeResneXt2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_se_resnext(self):
self.check_with_place("dist_se_resnext.py", delta=100)
if __name__ == "__main__":
unittest.main()
......@@ -19,6 +19,9 @@ from test_dist_base import TestDistBase
class TestDistTransformer2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_transformer(self):
# TODO(paddle-dev): check if the delta is OK.
# Usually start around ~8000 and converge to ~5000
......
......@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase):
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
return
def get_main_program(self):
main = fluid.Program()
......@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase):
def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
class TestBasicModel(TranspilerTest):
......@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest):
decay_rate=0.1,
staircase=True))
sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest):
learning_rate=fluid.layers.piecewise_decay([10000, 20000],
[1.0, 0.5, 1.0]))
sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest):
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest):
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, is_sparse, is_distributed):
self.table_size = 1000
self.emb_size = 64
self.lookup_table_name = 'shared_w'
def emb_pool(ids):
emb = fluid.layers.embedding(
input=ids,
size=[self.table_size, self.emb_size],
dtype='float32',
param_attr='shared_w', # share parameter
param_attr=self.lookup_table_name, # share parameter
is_sparse=is_sparse,
is_distributed=is_distributed)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
......@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
pserver1, _ = self.get_pserver(self.pserver1_ep, config)
self.assertTrue(self.transpiler.has_distributed_lookup_table)
lookup_table_var = pserver1.global_block().vars[
......@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
self.assertEqual(row_size, calc_row_size)
class TestDistArgsInProgram(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=True)
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
self.assertTrue(trainer._is_distributed)
self.assertTrue(trainer._is_chief)
self.assertEqual(trainer._distributed_lookup_table,
self.lookup_table_name)
self.assertEqual(trainer._endpoints,
[self.pserver1_ep, self.pserver2_ep])
class TestRMSPropOptimizer(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
......@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest):
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest):
self.assertEqual(moment_var.shape, (500, 1000))
class TestLoadSliceVar(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) + reduce(
lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__":
unittest.main()
......@@ -18,9 +18,20 @@ from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
def test_se_resnext(self):
self.check_with_place("dist_word2vec.py", delta=1e-7)
class TestDistSeResneXt2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
def test_se_resnext(self):
self.check_with_place("dist_word2vec.py", delta=1)
if __name__ == "__main__":
unittest.main()
......@@ -64,27 +64,47 @@ class TestFCOp(OpTest):
self.check_output()
class TestFCOpBiasBoth(TestFCOp):
class TestFCOpNoBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
for with_bias in {True, False}:
self.with_bias = with_bias
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
self.with_bias = False
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpBiasBoth):
class TestFCOpWithBias(TestFCOp):
def init_shapes(self, mb, ic, oc, h, w):
self.with_bias = True
self.matrix = MatrixGenerate(mb, ic, oc, h, w)
class TestFCOp1(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(2, 8, 10, 1, 1)
class TestFCOp2(TestFCOpBiasBoth):
class TestFCOp2(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOp4(TestFCOpBiasBoth):
class TestFCOp4(TestFCOpNoBias):
def init_op_type(self):
self.init_shapes(1, 32, 64, 3, 3)
class TestFCOpWithBias1(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(3, 8, 10, 2, 1)
class TestFCOpWithBias2(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(4, 5, 6, 2, 2)
class TestFCOpWithBias3(TestFCOpWithBias):
def init_op_type(self):
self.init_shapes(1, 64, 32, 3, 3)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from test_lstm_op import lstm, ACTIVATION
def fc(x, w, b):
return np.dot(x, w) + b
def fusion_lstm(
x, # T x M
lod, # 1 x N
wx=None, # M x 4D
bx=None, # 1 x 4D
h0=None, # N x D
c0=None, # N x D
w_h=None, # D x 4D
w_b=None, # 1 x 4D
w_c=None, # 1 x 3D
is_reverse=False,
act_gate=None,
act_cell=None,
act_cand=None):
return lstm(
fc(x, wx, bx), lod, h0, c0, w_h, w_b, w_c, is_reverse, act_gate,
act_cell, act_cand)
class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[2, 3, 2]]
def setUp(self):
self.op_type = 'fusion_lstm'
self.lod = [[2, 3, 2]]
self.M = 8
self.D = 16
self.has_initial_state = False
self.is_reverse = False
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_peepholes = False
self.set_argument()
T = sum(self.lod[0])
bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float64')
if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float64')
c0 = np.random.normal(size=(bs, self.D)).astype('float64')
else:
h0 = np.zeros((bs, self.D)).astype('float64')
c0 = np.zeros((bs, self.D)).astype('float64')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64')
# this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64')
b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cell], ACTIVATION[self.act_cand])
self.inputs = {
'X': (x, self.lod),
'WeightX': wx,
'WeightH': wh,
'Bias': b
}
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
}
self.attrs = {
'use_peepholes': self.use_peepholes,
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output(atol=1e-8)
class TestLstmOpInitReverse(TestLstmOp):
def set_argument(self):
self.has_initial_state = True
self.is_reverse = True
class TestLstmOpMD1(TestLstmOp):
def set_argument(self):
self.M = 36
self.D = 8
class TestLstmOpMD2(TestLstmOp):
def set_argument(self):
self.M = 8
self.D = 8
class TestLstmOpMD3(TestLstmOp):
def set_argument(self):
self.M = 15
self.D = 3
class TestLstmOpBS1(TestLstmOp):
def set_argument(self):
self.lod = [[3]]
self.D = 16
if __name__ == '__main__':
unittest.main()
......@@ -347,6 +347,25 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(loss)
print(str(program))
def test_scatter(self):
program = Program()
with program_guard(program):
x = layers.data(
name='x',
shape=[3, 3],
append_batch_size=False,
dtype='float32')
idx = layers.data(
name='idx', shape=[2], append_batch_size=False, dtype='int32')
updates = layers.data(
name='updates',
shape=[2, 3],
append_batch_size=False,
dtype='float32')
out = layers.scatter(input=x, index=idx, updates=updates)
self.assertIsNotNone(out)
print(str(program))
def test_lod_reset(self):
program = Program()
with program_guard(program):
......
......@@ -41,7 +41,7 @@ class TestSqueezeOp(OpTest):
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
self.attrs = {"axes": self.axes}
# Correct: There is mins axis.
......@@ -68,49 +68,5 @@ class TestSqueezeOp3(TestSqueezeOp):
self.new_shape = (3, 5, 1, 4)
# Correct: Inplace.
class TestSqueezeOpInplace1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, 2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins axis.
class TestSqueezeOpInplace2(TestSqueezeOp):
def inti_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = (0, -2)
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. No axes input.
class TestSqueezeOpInplace3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 5)
self.axes = ()
self.new_shape = (3, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inpalce. Just part of axes be squeezed.
class TestSqueezeOpInplace4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (3, 5, 1, 4)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
......@@ -41,7 +41,7 @@ class TestUnsqueezeOp(OpTest):
self.new_shape = (3, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": False}
self.attrs = {"axes": self.axes}
# Correct: Single input index.
......@@ -76,38 +76,5 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
self.new_shape = (3, 1, 1, 2, 5, 1)
# Correct: Inplace.
class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, 2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is mins index.
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 5)
self.axes = (0, -2)
self.new_shape = (1, 3, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
# Correct: Inplace. There is duplicated axis.
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (3, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 3, 2, 1, 1, 5)
def init_attrs(self):
self.attrs = {"axes": self.axes, "inplace": True}
if __name__ == "__main__":
unittest.main()
......@@ -285,11 +285,12 @@ class Trainer(object):
self._load_checkpoint()
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
io.load_persistables(
executor=exe,
dirname=param_path,
main_program=self.startup_program)
with self._prog_and_scope_guard():
# load params from param_path into scope
io.load_persistables(
executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......
......@@ -215,6 +215,13 @@ class DistributeTranspiler(object):
for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._is_chief = self.trainer_id == 0
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
# split and create vars, then put splited vars in dicts for later use.
# step 1: split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars()
......@@ -369,7 +376,7 @@ class DistributeTranspiler(object):
# FIXME(gongwb): delete not need ops.
# note that: some parameter is not trainable and those ops can't be deleted.
for varname, splited_var in self.param_var_mapping.iteritems():
for varname, splited_var in six.iteritems(self.param_var_mapping):
# Get the eplist of recv vars
eps = []
for var in splited_var:
......@@ -406,7 +413,7 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in self.param_var_mapping.iteritems():
for varname, splited_var in six.iteritems(self.param_var_mapping):
#add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1:
continue
......@@ -590,6 +597,8 @@ class DistributeTranspiler(object):
checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
......@@ -616,6 +625,10 @@ class DistributeTranspiler(object):
outputs={},
attrs=attrs)
# add distributed attrs
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
endpoint)
pserver_program._sync_with_cpp()
return pserver_program
......@@ -689,8 +702,31 @@ class DistributeTranspiler(object):
inputs=new_inputs,
outputs=new_outputs,
attrs=op.all_attrs())
# add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
return s_prog
def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = []
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name)
if not block_name:
continue
block_idx = int(block_name.split(block_suffix)[1])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_attrs.append([orig_var, skip_numel, param])
return slice_vars_and_attrs
# ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
......@@ -1209,8 +1245,8 @@ class DistributeTranspiler(object):
elif op_type == "momentum":
if varkey == "Velocity":
return param_shape
elif op_type == "":
if varkey == "Moment":
elif op_type == "rmsprop":
if varkey in ["Moment", "MeanSquare"]:
return param_shape
elif op_type == "sgd":
pass
......@@ -1289,8 +1325,6 @@ class DistributeTranspiler(object):
pserver_block = program.global_block()
new_inputs = collections.OrderedDict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
def _get_param_block(opt_op):
# param is already created on global program
param_block = None
......@@ -1303,22 +1337,6 @@ class DistributeTranspiler(object):
for key in opt_op.input_names:
if key == "Grad":
new_inputs[key] = merged_var
# For RMSProp optimizer
elif key == "Moment" or key == "MeanSquare":
param_block = _get_param_block(opt_op)
if not param_block:
return
moment_var = origin_program.global_block().vars[opt_op.input(
key)[0]]
tmpvar = pserver_block.create_var(
name=moment_var.name,
persistable=moment_var.persistable,
dtype=moment_var.dtype,
# change to use same shape as param
# TODO(typhoonzero): didn't append .block in the var name,
# may affect checkpoint saving? Need to verify.
shape=param_block.shape)
new_inputs[key] = tmpvar
elif key == "Param":
param_block = _get_param_block(opt_op)
if not param_block:
......@@ -1346,7 +1364,7 @@ class DistributeTranspiler(object):
for key in opt_op.input_names:
new_shape = None
if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]:
if key in ["Param", "Grad", "LearningRate"]:
continue
var = self.origin_program.global_block().vars[opt_op.input(key)[0]]
# update accumulator variable shape
......
......@@ -159,18 +159,20 @@ if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_LIB}', libs_path)
shutil.copy('${MKLML_IOMP_LIB}', libs_path)
package_data['paddle.libs']+=['libmklml_intel.so','libiomp5.so']
if '${WITH_MKLDNN}' == 'ON':
# TODO(typhoonzero): use install_name_tool to patch mkl libs once
# we can support mkl on mac.
#
# change rpath of libmkldnn.so.0, add $ORIGIN/ to it.
# The reason is that all thirdparty libraries in the same directory,
# thus, libmkldnn.so.0 will find libmklml_intel.so and libiomp5.so.
command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}"
if os.system(command) != 0:
raise Exception("patch libmkldnn.so failed, command: %s" % command)
package_data['paddle.libs']+=['libmkldnn.so.0']
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
if '${CMAKE_BUILD_TYPE}' == 'Release':
# only change rpath in Release mode.
if '${WITH_MKLDNN}' == 'ON':
# TODO(typhoonzero): use install_name_tool to patch mkl libs once
# we can support mkl on mac.
#
# change rpath of libmkldnn.so.0, add $ORIGIN/ to it.
# The reason is that all thirdparty libraries in the same directory,
# thus, libmkldnn.so.0 will find libmklml_intel.so and libiomp5.so.
command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}"
if os.system(command) != 0:
raise Exception("patch libmkldnn.so failed, command: %s" % command)
package_data['paddle.libs']+=['libmkldnn.so.0']
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
# remove unused paddle/libs/__init__.py
os.remove(libs_path+'/__init__.py')
package_dir['paddle.libs']=libs_path
......@@ -179,20 +181,22 @@ package_dir['paddle.libs']=libs_path
# The reason is that libwarpctc.so, libiomp5.so etc are in paddle.libs, and
# core.so is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries.
# This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213
if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/../libs/\" ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
else:
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
if os.system(command) != 0:
raise Exception("patch core.so failed, command: %s" % command)
if '${WITH_FLUID_ONLY}'== 'OFF':
# change rpath of _swig_paddle.so.
if '${CMAKE_BUILD_TYPE}' == 'Release':
# only change rpath in Release mode, since in Debug mode, core.so is too large to be changed.
if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/../paddle/libs/\" ${PADDLE_BINARY_DIR}/python/py_paddle/_swig_paddle.so"
command = "install_name_tool -id \"@loader_path/../libs/\" ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
else:
command = "patchelf --set-rpath '$ORIGIN/../paddle/libs/' ${PADDLE_BINARY_DIR}/python/py_paddle/_swig_paddle.so"
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
if os.system(command) != 0:
raise Exception("patch _swig_paddle.so failed, command: %s" % command)
raise Exception("patch core.so failed, command: %s" % command)
if '${WITH_FLUID_ONLY}'== 'OFF':
# change rpath of _swig_paddle.so.
if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/../paddle/libs/\" ${PADDLE_BINARY_DIR}/python/py_paddle/_swig_paddle.so"
else:
command = "patchelf --set-rpath '$ORIGIN/../paddle/libs/' ${PADDLE_BINARY_DIR}/python/py_paddle/_swig_paddle.so"
if os.system(command) != 0:
raise Exception("patch _swig_paddle.so failed, command: %s" % command)
setup(name='${PACKAGE_NAME}',
version='${PADDLE_VERSION}',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册