提交 6717925d 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into fix_cross_entropy_doc

...@@ -252,6 +252,11 @@ first_seq ...@@ -252,6 +252,11 @@ first_seq
.. autoclass:: paddle.v2.layer.first_seq .. autoclass:: paddle.v2.layer.first_seq
:noindex: :noindex:
sub_seq
---------
.. autoclass:: paddle.v2.layer.sub_seq
:noindex:
concat concat
------ ------
.. autoclass:: paddle.v2.layer.concat .. autoclass:: paddle.v2.layer.concat
......
...@@ -68,12 +68,6 @@ scale ...@@ -68,12 +68,6 @@ scale
:noindex: :noindex:
reshape
---------
.. autofunction:: paddle.v2.fluid.layers.reshape
:noindex:
transpose transpose
--------- ---------
.. autofunction:: paddle.v2.fluid.layers.transpose .. autofunction:: paddle.v2.fluid.layers.transpose
......
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using DataTransformFN = using DataTransformFn =
std::function<void(const std::vector<platform::DeviceContext*> ctx, std::function<void(const std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out)>; const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
...@@ -47,7 +47,7 @@ struct KernelTypePairHash { ...@@ -47,7 +47,7 @@ struct KernelTypePairHash {
}; };
using DataTransformMap = using DataTransformMap =
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>; std::unordered_map<KernelTypePair, DataTransformFn, KernelTypePairHash>;
class DataTransformFnMap { class DataTransformFnMap {
public: public:
...@@ -58,25 +58,25 @@ class DataTransformFnMap { ...@@ -58,25 +58,25 @@ class DataTransformFnMap {
} }
void Insert(const OpKernelType& left, const OpKernelType& right, void Insert(const OpKernelType& left, const OpKernelType& right,
const DataTransformFN& data_tranform_fn) { const DataTransformFn& data_tranform_fn) {
Insert(std::make_pair(left, right), data_tranform_fn); Insert(std::make_pair(left, right), data_tranform_fn);
} }
void Insert(const KernelTypePair& kernel_type_pair, void Insert(const KernelTypePair& kernel_type_pair,
const DataTransformFN& data_tranform_fn) { const DataTransformFn& data_tranform_fn) {
PADDLE_ENFORCE(!Has(kernel_type_pair), PADDLE_ENFORCE(!Has(kernel_type_pair),
"KernelTypePair %s has been registered", ""); "KernelTypePair %s has been registered", "");
map_.insert({kernel_type_pair, data_tranform_fn}); map_.insert({kernel_type_pair, data_tranform_fn});
} }
const DataTransformFN& Get(const KernelTypePair& key_pair) const { const DataTransformFn& Get(const KernelTypePair& key_pair) const {
auto data_transformer = GetNullable(key_pair); auto data_transformer = GetNullable(key_pair);
PADDLE_ENFORCE_NOT_NULL(data_transformer, PADDLE_ENFORCE_NOT_NULL(data_transformer,
"DataTransformFN should not be NULL"); "DataTransformFn should not be NULL");
return *data_transformer; return *data_transformer;
} }
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const { const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
auto it = map_.find(key_pair); auto it = map_.find(key_pair);
if (it == map_.end()) { if (it == map_.end()) {
return nullptr; return nullptr;
......
...@@ -68,6 +68,8 @@ struct OpKernelType { ...@@ -68,6 +68,8 @@ struct OpKernelType {
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_; library_type_ == o.library_type_;
} }
bool operator!=(const OpKernelType& o) const { return !(*this == o); }
}; };
inline std::ostream& operator<<(std::ostream& os, inline std::ostream& operator<<(std::ostream& os,
...@@ -78,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os, ...@@ -78,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os; return os;
} }
inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
std::ostringstream stream;
stream << kernel_key;
return stream.str();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) { ...@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW, OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
std::ostringstream stream;
stream << op_kernel_type;
ASSERT_EQ( ASSERT_EQ(
stream.str(), paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]"); "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
} }
......
...@@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
if (actual_kernel_key == expected_kernel_key) { if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx); PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
"Currently, model parallelism is only supported between "
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else { } else {
Scope& op_scope = scope.NewScope(); const DataTransformFn* trans_fun =
auto input_vars = this->InputVars(); DataTransformFnMap::Instance().GetNullable(
for (auto var_name : input_vars) { std::make_pair(actual_kernel_key, expected_kernel_key));
op_scope.Var(var_name); if (trans_fun) {
} auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr; // filter vars that has been transformed
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx}; std::vector<std::string> need_trans;
for (auto var_name : input_vars) {
auto var_name_trans =
var_name + framework::KernelTypeToString(expected_kernel_key);
if (!scope.FindVar(var_name_trans)) {
const_cast<Scope&>(scope).Var(var_name_trans);
need_trans.push_back(var_name);
}
}
// TODO(qijun) get appropriate DataTransformFN from global map if (!need_trans.empty()) {
framework::DataTransformFN trans_fun = nullptr; // TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// Wait for transform starting // Wait for transform starting
dev_ctx->Wait(); dev_ctx->Wait();
for (auto var_name : input_vars) { for (auto var_name : need_trans) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)), (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name)); scope.FindVar(var_name + framework::KernelTypeToString(
} expected_kernel_key)));
// Wait for data transform finishing }
for (auto ctx : trans_dev_ctx_vec) { // Wait for data transform finishing
ctx->Wait(); for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
}
} }
// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
} }
kernel_iter->second->Compute(ctx);
} }
OpKernelType OperatorWithKernel::GetActualKernelType( OpKernelType OperatorWithKernel::GetActualKernelType(
......
file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}")
set(DEPS_OPS "")
set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/pybind/pybind.h) set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/pybind/pybind.h)
file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n") file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n")
function(op_library TARGET) function(op_library TARGET)
...@@ -48,6 +49,11 @@ function(op_library TARGET) ...@@ -48,6 +49,11 @@ function(op_library TARGET)
message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file") message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
endif() endif()
list(LENGTH op_library_DEPS op_library_DEPS_len)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif()
if (WITH_GPU) if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps}) ${op_common_deps})
...@@ -181,55 +187,26 @@ endfunction() ...@@ -181,55 +187,26 @@ endfunction()
add_subdirectory(math) add_subdirectory(math)
add_subdirectory(nccl) add_subdirectory(nccl)
set(DEPS_OPS if(WITH_GPU)
cond_op op_library(nccl_op DEPS nccl_common)
cross_entropy_op else()
recurrent_op set(DEPS_OPS ${DEPS_OPS} nccl_op)
softmax_with_cross_entropy_op endif()
softmax_op
sequence_softmax_op
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
conv_op
conv_transpose_op
nccl_op
sequence_conv_op
sequence_pool_op
lod_rank_table_op
lod_tensor_to_array_op
array_to_lod_tensor_op
max_sequence_len_op
lstm_op
tensor_array_read_write_op
gru_op
adagrad_op
sgd_op
save_op
load_op
send_op
recv_op)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(detail) add_subdirectory(detail)
op_library(send_op SRCS send_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf)
set_source_files_properties( set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
send_op.cc op_library(send_op DEPS ${DISTRIBUTE_DEPS})
PROPERTIES set_source_files_properties(send_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(recv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(recv_op SRCS recv_op.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib_target protobuf) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor)
set_source_files_properties( else()
recv_op.cc set(DEPS_OPS ${DEPS_OPS} send_op recv_op)
PROPERTIES
COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op recv_op sum_op executor)
endif() endif()
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cond_op DEPS framework_proto tensor net_op)
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax) op_library(softmax_op DEPS softmax)
...@@ -242,21 +219,16 @@ op_library(pool_op DEPS pooling) ...@@ -242,21 +219,16 @@ op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting) op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling) op_library(unpool_op DEPS unpooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_rank_table_op DEPS lod_rank_table)
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
op_library(array_to_lod_tensor_op SRCS array_to_lod_tensor_op.cc DEPS lod_rank_table_op) op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
op_library(max_sequence_len_op SRCS max_sequence_len_op.cc DEPS lod_rank_table) op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(tensor_array_read_write_op SRCS tensor_array_read_write_op.cc)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project) op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling) op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col) op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor) op_library(recurrent_op DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
...@@ -269,13 +241,12 @@ endforeach() ...@@ -269,13 +241,12 @@ endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
if(WITH_GPU) if(WITH_GPU)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif() endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册