提交 d214dff1 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_anakin_in_manylinux1

......@@ -204,6 +204,7 @@ include(external/snappy) # download snappy
include(external/snappystream)
include(external/threadpool)
include(flags) # set paddle compile flags
include(cudnn) # set cudnn libraries, must before configure
include(cupti)
include(configure) # add paddle env configuration
......@@ -221,7 +222,6 @@ include(package) # set paddle packages
include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
include(rdma) # set rdma libraries
include(flags) # set paddle compile flags
include(version) # set PADDLE_VERSION
include(coveralls) # set code coverage
include(inference_lib) # add paddle fluid inference libraries
......
......@@ -50,16 +50,16 @@ if(NOT WITH_PROFILER)
endif(NOT WITH_PROFILER)
if(NOT CMAKE_CROSSCOMPILING)
if(WITH_AVX AND AVX_FOUND)
if(WITH_AVX AND AVX512F_FOUND)
set(SIMD_FLAG ${AVX512F_FLAG})
elseif(WITH_AVX AND AVX2_FOUND)
set(SIMD_FLAG ${AVX2_FLAG})
elseif(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG})
elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG})
endif()
endif()
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif(UNIX AND NOT APPLE)
if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG)
......
......@@ -25,8 +25,25 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
$ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
/usr/lib)
find_library(CUDNN_LIBRARY NAMES libcudnn.so libcudnn.dylib # libcudnn_static.a
/usr/lib
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
)
set(CUDNN_LIB_NAME "")
if (LINUX)
set(CUDNN_LIB_NAME "libcudnn.so")
endif(LINUX)
if(WIN32)
# only support cudnn7
set(CUDNN_LIB_NAME "cudnn.lib" "cudnn64_7.dll")
endif(WIN32)
if(Apple)
set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so")
endif(Apple)
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a
PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to cuDNN library.")
......
......@@ -142,6 +142,11 @@ else()
${GPU_COMMON_FLAGS})
endif()
if(UNIX AND NOT APPLE)
# except apple from nix*Os family
set(LINUX TRUE)
endif(UNIX AND NOT APPLE)
foreach(flag ${COMMON_FLAGS})
safe_set_cflag(CMAKE_C_FLAGS ${flag})
......
......@@ -10,6 +10,7 @@ if(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID
set(SSE3_FLAG "-msse3")
set(AVX_FLAG "-mavx")
set(AVX2_FLAG "-mavx2")
set(AVX512F_FLAG "-mavx512f")
elseif(MSVC)
set(MMX_FLAG "/arch:MMX")
set(SSE2_FLAG "/arch:SSE2")
......@@ -81,5 +82,16 @@ int main()
return 0;
}" AVX2_FOUND)
# Check AVX512F
set(CMAKE_REQUIRED_FLAGS ${AVX512F_FLAG})
set(AVX512F_FOUND_EXITCODE 1 CACHE STRING "Result from TRY_RUN" FORCE)
CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h>
int main()
{
__m512i a = _mm512_undefined_epi32();
return 0;
}" AVX512F_FOUND)
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED})
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND)
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND AVX512F_FOUND)
......@@ -99,12 +99,13 @@ else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
endif()
if (NOT WIN32)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
graph graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fast_threaded_ssa_graph_executor)
endif() # NOT WIN32
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -55,11 +55,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto all_ops = blocks_[block_id]->AllOps();
for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
auto &op = all_ops[op_id];
for (const std::string &attr_name : op->AttrNames()) {
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) {
std::vector<int> sub_block_ids =
o.Block(block_id).Op(op_id)->GetBlocksAttrIds(attr_name);
std::vector<BlockDesc *> block_descs;
for (int block_id : sub_block_ids) {
block_descs.push_back(MutableBlock(block_id));
}
op->SetBlocksAttr(attr_name, block_descs);
}
}
}
......@@ -68,24 +77,16 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
InitFromProto();
}
ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
InitFromProto();
}
void ProgramDesc::InitFromProto() {
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
......@@ -95,6 +96,13 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
} else if (attr.type() == proto::AttrType::BLOCKS) {
auto blks_idx = attr.blocks_idx();
std::vector<BlockDesc *> block_descs;
for (int blk_idx : blks_idx) {
block_descs.push_back(this->MutableBlock(blk_idx));
}
op->SetBlocksAttr(attr.name(), block_descs);
}
}
}
......
......@@ -76,6 +76,8 @@ class ProgramDesc {
void SetFetchHolderName(const std::string &fetch_holder_name);
private:
void InitFromProto();
proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDesc>> blocks_;
......
......@@ -42,6 +42,19 @@ TEST(ProgramDesc, copy_ctor) {
out->SetType(proto::VarType::LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
BlockDesc* new_block = program.AppendBlock(*global_block);
op = new_block->AppendOp();
op->SetType("mul");
op = global_block->AppendOp();
op->SetType("op_with_subblock");
op->SetAttr("sub_block", new_block);
std::vector<BlockDesc*> sub_blocks;
sub_blocks.push_back(program.AppendBlock(*global_block));
sub_blocks.push_back(program.AppendBlock(*global_block));
op->SetAttr("sub_blocks", sub_blocks);
ProgramDesc program_copy(program);
auto* global_block_copy = program_copy.MutableBlock(0);
......@@ -64,6 +77,8 @@ TEST(ProgramDesc, copy_ctor) {
assert_same_var("Y", y);
assert_same_var("Out", out);
bool found_sub_block = false;
bool found_sub_blocks = false;
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_copy = global_block_copy->Op(i);
......@@ -74,8 +89,17 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_EQ(op_copy->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
if (op->Type() == "op_with_subblock") {
ASSERT_EQ(1, op->GetBlockAttrId("sub_block"));
found_sub_block = true;
ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size());
found_sub_blocks = true;
}
}
ASSERT_TRUE(found_sub_block);
ASSERT_TRUE(found_sub_blocks);
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
......
......@@ -84,6 +84,15 @@ function(op_library TARGET)
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
endif()
#remove windows unsupported op
if (WIN32)
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op")
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
return()
endif()
endforeach()
endif(WIN32)
list(LENGTH op_library_DEPS op_library_DEPS_len)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
......@@ -181,19 +190,19 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
if (NOT WIN32)
add_subdirectory(nccl)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
else()
set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif()
endif() # NOT WIN32
set(DISTRIBUTE_DEPS "")
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
......@@ -222,7 +231,7 @@ if(WITH_DISTRIBUTE)
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL)
if(WITH_GPU)
if(WITH_GPU AND NOT WIN32)
set_source_files_properties(test_send_nccl_id.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS listen_and_serv_op ${DISTRIBUTE_DEPS} executor SERIAL)
if(WITH_GRPC)
......@@ -233,7 +242,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
endif() # WITH_GPU AND NOT WIN32
else()
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()
......@@ -331,5 +340,7 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
......@@ -29,9 +29,9 @@ class ConditionalOp : public framework::OperatorBase {
protected:
std::vector<const framework::LoDTensor *> InputTensors(
const framework::Scope &scope) const {
const framework::Scope &scope, const std::string &in_name) const {
std::vector<const framework::LoDTensor *> retv;
auto xs = Inputs("X");
auto xs = Inputs(in_name);
retv.resize(xs.size(), nullptr);
std::transform(
xs.begin(), xs.end(), retv.begin(),
......@@ -81,12 +81,18 @@ class ConditionalBlockOp : public ConditionalOp {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
// When is_scalar_condition is True, the conditional variable is a scalar,
// whether need to execute the operators in sub-block depends on the
// conditional variable (Cond).
auto xs = InputTensors(scope, "Cond");
need_run = ScalarCondition(xs);
} else {
// When is_scalar_condition is False, the conditional variable maybe a
// vector or tensor, whether need to execute the operators in sub-block
// depends on the input variables (Input).
auto xs = InputTensors(scope, "Input");
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
......@@ -110,11 +116,11 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
AddInput("Cond",
"The conditional variable of this operator. If Cond is empty, the "
"whole sub-block will not be executed.")
.AsDuplicable();
AddInput("Params", "The input variables of the sub-block.").AsDuplicable();
AddInput("Input", "The input variables of the sub-block.").AsDuplicable();
AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
AddOutput("Scope",
"(std::vector<Scope*>) The step scope of conditional block. To "
......@@ -123,13 +129,18 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator");
AddAttr<bool>("is_scalar_condition",
"the input X is used as scalar "
"condition")
"The conditional variable (Cond) is used as scalar "
"condition.")
.SetDefault(false);
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
outputs of the sub-block.
If `is_scalar_condition` is True, the conditional variable (Cond) is a scalar,
run the operators in sub-block if Cond is True.
If `is_scalar_condition` is False, the conditional variable (Cond) is a vector or
tensor, run the operators in sub-block if all of input variables are not empty.
)DOC");
}
};
......@@ -145,12 +156,12 @@ class ConditionalBlockGradOp : public ConditionalOp {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
auto xs = this->InputTensors(scope, "Cond");
need_run = ScalarCondition(xs);
} else {
auto xs = this->InputTensors(scope, "Input");
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
......@@ -166,11 +177,11 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto *block = Attr<framework::BlockDesc *>("sub_block");
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Params"),
Outputs(framework::GradVarName("Params")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"),
Outputs(framework::GradVarName("Input")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("X"),
Outputs(framework::GradVarName("X")));
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"),
Outputs(framework::GradVarName("Cond")));
}
}
......@@ -199,15 +210,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("X"));
if (context->HasInputs("Params")) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params")));
context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params"));
PADDLE_ENFORCE(context->HasInputs("Cond"));
if (context->HasInputs("Input")) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Input")));
context->SetOutputsDim(framework::GradVarName("Input"),
context->GetInputsDim("Input"));
}
if (context->HasOutputs(framework::GradVarName("X"))) {
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
if (context->HasOutputs(framework::GradVarName("Cond"))) {
context->SetOutputsDim(framework::GradVarName("Cond"),
context->GetInputsDim("Cond"));
}
}
};
......@@ -220,14 +231,15 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad_op = new framework::OpDesc();
grad_op->SetType("conditional_block_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("Params", Input("Params"));
grad_op->SetInput("Cond", Input("Cond"));
grad_op->SetInput("Input", Input("Input"));
grad_op->SetInput("Out", Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetOutput(framework::GradVarName("Cond"),
InputGrad("Cond", false));
grad_op->SetOutput(framework::GradVarName("Input"),
InputGrad("Input", false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op);
......
......@@ -85,6 +85,199 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace());
#ifdef __AVX__
// It use the AVX or AVX512 instruction to deal the data as the vector of 8 or
// 16 elements per iteration. Then it can implement the parallel processing.
// Only optimize for float type.
#ifdef __AVX512F__
size_t step_size = 16;
#else
size_t step_size = 8;
#endif
if (std::is_same<T, float>::value && (tag_num >= step_size)) {
size_t steps = tag_num / step_size;
size_t remain = tag_num % step_size;
int last_offset = static_cast<int>(remain) - static_cast<int>(step_size);
// Setup the alpha initial value.
size_t i_offset = 0;
for (size_t i = 0; i <= steps; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha
// values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps((const float*)(w + i_offset));
x_content = _mm512_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#else
// Declare the variable for the content of weights, input and alpha
// values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps((const float*)(w + i_offset));
x_content = _mm256_loadu_ps((const float*)(x + i_offset));
alpha_content = _mm256_add_ps(w_content, x_content);
// Save the alpha value.
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + i_offset),
alpha_content);
#endif
i_offset += step_size;
if (i == steps - 1) {
if (remain > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
size_t seq_offset = 0;
for (size_t k = 1; k < seq_len; ++k) {
size_t j_offset = 0;
for (size_t j = 0; j <= steps; ++j) {
#ifdef __AVX512F__
// Initialize the variables of maximum score and location.
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max());
__m512i max_j = _mm512_setzero_si512();
#else
// Initialize the variables of maximum score and location.
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<T>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
// Calculate the offset of transition_weights.
size_t trans_offset = state_trans_base_idx * tag_num + j_offset;
for (size_t i = 0; i < tag_num; ++i) {
#ifdef __AVX512F__
// Initalize the content of alpha variable with related offset.
__m512 alpha_content =
_mm512_set1_ps(*(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m512 w_content =
_mm512_loadu_ps((const float*)(w + trans_offset));
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
// Update the max_score value.
max_score = _mm512_max_ps(max_score, score_v);
#else
// Initalize the content of alpha variable with related offset.
__m256 alpha_content = _mm256_broadcast_ss(
(const float*)(alpha_value + seq_offset + i));
// Obtain the content of weights from un-aligned address.
__m256 w_content =
_mm256_loadu_ps((const float*)(w + trans_offset));
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
#ifdef __AVX2__
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask = _mm256_extractf128_si256((__m256i)mask, 0);
__m128i hi_mask = _mm256_extractf128_si256((__m256i)mask, 1);
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
// According to the mask value, it update the index of the max_score
// location.
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
// Update the max_score value.
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
#ifdef __AVX512F__
// Update the alpha and track values.
__m512 x_content = _mm512_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm512_storeu_si512(
reinterpret_cast<__m512i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#else
// Update the alpha and track values.
__m256 x_content = _mm256_loadu_ps(
(const float*)(x + seq_offset + tag_num + j_offset));
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(reinterpret_cast<float*>(alpha_value + seq_offset +
tag_num + j_offset),
max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track_value + seq_offset + tag_num +
j_offset),
max_j);
#endif
// Calculate the offset of next step
j_offset += step_size;
if (j == steps - 1) {
if (remain > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
} else {
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
T score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
}
#else
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
......@@ -105,6 +298,7 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
}
}
#endif
T max_score = -std::numeric_limits<T>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {
......
if(WITH_GPU)
if(WITH_GPU AND NOT WIN32)
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
endif()
......@@ -3,7 +3,7 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc)
# There is no macOS version of NCCL.
if (NOT APPLE)
if (NOT APPLE AND NOT WIN32)
list(APPEND CUDA_SRCS nccl.cc)
endif()
......
......@@ -44,7 +44,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/curand.h"
#ifndef __APPLE__
#if !defined(__APPLE__) and !defined(_WIN32)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif // __APPLE__
#endif // PADDLE_WITH_CUDA
......@@ -205,7 +205,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif
}
#ifndef __APPLE__
#if !defined(__APPLE__) and !defined(_WIN32)
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
ncclResult_t stat, const Args&... args) {
......@@ -221,7 +221,7 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif
}
}
#endif // __APPLE__
#endif // __APPLE__ and windows
#endif // PADDLE_WITH_CUDA
template <typename T>
......
......@@ -1272,8 +1272,8 @@ class ConditionalBlock(object):
parent_block.append_op(
type='conditional_block',
inputs={
'X': self.inputs,
'Params': param_list,
'Cond': self.inputs,
'Input': param_list,
},
outputs={'Out': out_list,
'Scope': [step_scope]},
......
......@@ -30,7 +30,8 @@ import numpy as np
class TestMNISTIfElseOp(unittest.TestCase):
def test_raw_api(self):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_raw_api(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
......@@ -91,7 +92,8 @@ class TestMNISTIfElseOp(unittest.TestCase):
return
self.assertFalse(True)
def test_ifelse(self):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_ifelse(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
......@@ -153,6 +155,13 @@ class TestIfElse(unittest.TestCase):
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def numpy_cal(self):
s1 = self.data[np.where(self.data < self.cond_value)]
res = np.sum(np.exp(s1))
s2 = self.data[np.where(self.data >= self.cond_value)]
res += np.sum(np.tanh(s2))
return res
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
......@@ -166,10 +175,12 @@ class TestIfElse(unittest.TestCase):
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = fluid.layers.exp(true_target)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
false_target = fluid.layers.tanh(false_target)
ie.output(false_target)
if_out = ie()
out = layers.reduce_sum(if_out)
......@@ -180,7 +191,8 @@ class TestIfElse(unittest.TestCase):
o1, = exe.run(fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out])
o2 = np.sum(self.data)
o2 = self.numpy_cal()
self.assertTrue(
np.allclose(
o1, o2, atol=1e-8),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册