diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 4e98e4bf889bc13938931be7f6cb204c83250a5c..95ca16f57f2704eaded85aa5f5c0546310fba3a7 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -63,6 +63,15 @@ ADD_DEPENDENCIES(gflags extern_gflags) LIST(APPEND external_project_dependencies gflags) +# On Windows (including MinGW), the Shlwapi library is used by gflags if available. +if (WIN32) + include(CheckIncludeFileCXX) + check_include_file_cxx("shlwapi.h" HAVE_SHLWAPI) + if (HAVE_SHLWAPI) + set_property(GLOBAL PROPERTY OS_DEPENDENCY_MODULES shlwapi.lib) + endif(HAVE_SHLWAPI) +endif (WIN32) + IF(WITH_C_API) INSTALL(DIRECTORY ${GFLAGS_INCLUDE_DIR} DESTINATION third_party/gflags) IF(ANDROID) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 05293b8b06b55bb0b83a30c7eb059efe0b61e57e..63820fd4f0ad1718beda71048e4333596de80dbe 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -359,6 +359,8 @@ function(cc_binary TARGET_NAME) add_dependencies(${TARGET_NAME} ${cc_binary_DEPS}) common_link(${TARGET_NAME}) endif() + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) endfunction(cc_binary) function(cc_test TARGET_NAME) @@ -367,18 +369,15 @@ function(cc_test TARGET_NAME) set(oneValueArgs "") set(multiValueArgs SRCS DEPS ARGS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cc_test_SRCS}) if(WIN32) - list(APPEND win32_deps shlwapi) if("${cc_test_DEPS};" MATCHES "python;") list(REMOVE_ITEM cc_test_DEPS python) - list(APPEND win32_deps ${PYTHON_LIBRARIES}) + target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) endif() endif(WIN32) - add_executable(${TARGET_NAME} ${cc_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) - if(WIN32) - target_link_libraries(${TARGET_NAME} ${win32_deps}) - endif(WIN32) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) common_link(${TARGET_NAME}) add_test(NAME ${TARGET_NAME} @@ -451,7 +450,8 @@ function(nv_test TARGET_NAME) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog ${os_dependency_modules}) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) @@ -538,7 +538,8 @@ function(hip_test TARGET_NAME) endif() add_executable(${TARGET_NAME} ${_cmake_options} ${_generated_files} ${_sources}) set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) - target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags ${os_dependency_modules}) add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 16d43f82d6e6fa398afde33d168a92b9916d5b83..50ffef72baa1c5f210fd6e92de05d24a39ac86b4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -88,6 +88,7 @@ paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'poo paddle.fluid.layers.adaptive_pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)) paddle.fluid.layers.adaptive_pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)) paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)) +paddle.fluid.layers.data_norm ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'use_mkldnn', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, False, None, None, None, False)) paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)) @@ -210,6 +211,7 @@ paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], va paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)) paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a595a8ab4299298f625b8322a0adbed6d0b4fda3..42fb6a1aa5375bfbb266454cfbc7f0fb756f779c 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -48,6 +48,17 @@ pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference) +pass_library(transpose_flatten_concat_fuse_pass inference) + +# There may be many transpose-flatten structures in a model, and the output of +# these structures will be used as inputs to the concat Op. This pattern will +# be detected by our pass. The index here represents the number of structures in the +# pattern. We use index 3 ~ 6, because these quantities of structures are +# common in the models. +foreach (index RANGE 3 6) + file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n") +endforeach() + if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c513fe2dd8f5733c87802f6fa9980ad885dfd865..6282ced1e47329915bb3626b410e55ad8251071d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1306,6 +1306,69 @@ PDNode *patterns::ConvAffineChannel::operator()( return ac_out_var; } +// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a +// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b +// ... +// z -> transpose_op(n) -> transpose_out_z -> flatten_op(n) -> flatten_out_z +// flatten_out_a -> concat_op flatten_out_b -> concat_op ... flatten_out_z -> +// concat_op +PDNode *patterns::TransposeFlattenConcat::operator()( + std::vector conv_in, int times) { + // The times represents the repeat times of the + // {trans, trans_out, flatten, flatten_out} + const int kNumFields = 4; + const int kTransOutOffset = 1; + const int kFlattenOffset = 2; + const int kFlattenOutOffset = 3; + + std::vector nodes; + + for (int i = 0; i < times; i++) { + nodes.push_back( + pattern->NewNode(GetNodeName("transpose" + std::to_string(i))) + ->assert_is_op("transpose2")); + nodes.push_back( + pattern->NewNode(GetNodeName("transpose_out" + std::to_string(i))) + ->assert_is_op_output("transpose2") + ->assert_is_op_input("flatten2", "X") + ->AsIntermediate()); + nodes.push_back(pattern->NewNode(GetNodeName("flatten" + std::to_string(i))) + ->assert_is_op("flatten2")); + + nodes.push_back( + pattern->NewNode(GetNodeName("flatten_out" + std::to_string(i))) + ->assert_is_op_output("flatten2") + ->assert_is_op_nth_input("concat", "X", i) + ->AsIntermediate()); + } + + auto concat_op = pattern->NewNode(GetNodeName("concat")) + ->assert_is_op("concat") + ->assert_op_has_n_inputs("concat", times); + auto concat_out = pattern->NewNode(GetNodeName("concat_out")) + ->assert_is_op_output("concat") + ->AsOutput(); + + std::vector flatten_outs; + for (int i = 0; i < times; i++) { + conv_in[i]->AsInput(); + // trans + nodes[i * kNumFields]->LinksFrom({conv_in[i]}); + // trans_out + nodes[i * kNumFields + kTransOutOffset]->LinksFrom({nodes[i * kNumFields]}); + // flatten + nodes[i * kNumFields + kFlattenOffset]->LinksFrom( + {nodes[i * kNumFields + kTransOutOffset]}); + // flatten_out + nodes[i * kNumFields + kFlattenOutOffset]->LinksFrom( + {nodes[i * kNumFields + kFlattenOffset]}); + flatten_outs.push_back(nodes[i * kNumFields + kFlattenOutOffset]); + } + + concat_op->LinksFrom(flatten_outs).LinksTo({concat_out}); + return concat_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 61a53003449710da2a52c90197c9f2f3ac56c7bb..c8be586f546dc604375401b13a801841efbf08d2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -766,6 +766,21 @@ struct ConvAffineChannel : public PatternBase { PATTERN_DECL_NODE(ac_out); // Out }; +struct TransposeFlattenConcat : public PatternBase { + TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "transpose_flatten_concat") {} + + PDNode* operator()(std::vector conv_inputs, int times); + + std::string GetNodeName(const std::string& op_type) { + return PDNodeName(name_scope_, repr_, id_, op_type); + } + + PDNode* GetPDNode(const std::string& op_type) { + return pattern->RetrieveNode(GetNodeName(op_type)); + } +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 96a3b7ee058647156258b946c1301138c185fa31..fa75e3b4aa7feb7ff856dc26338d089f90efa2e2 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -50,7 +50,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, // the other one should be unused empty var. if (is_nth_input_var_of_concat(x->outputs[0], idx)) { satisfied_all = satisfied_all && x->outputs[1]->IsVar() && - x->outputs[1]->outputs.size() == 0; + x->outputs[1]->outputs.empty(); } else { satisfied_all = satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) && diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..fda43948d567689103815e3ad7ba285719dae80f --- /dev/null +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +std::unique_ptr TransposeFlattenConcatFusePass::ApplyImpl( + std::unique_ptr graph) const { + const std::string pattern_name = + "transpose_flatten" + std::to_string(times) + "_concat_fuse"; + FusePassBase::Init(pattern_name, graph.get()); + + GraphPatternDetector gpd; + std::vector input_nodes; + for (int i = 0; i < times; i++) { + input_nodes.push_back(gpd.mutable_pattern() + ->NewNode("x" + std::to_string(i)) + ->assert_is_op_input("transpose2", "X") + ->AsInput()); + } + + patterns::TransposeFlattenConcat pattern(gpd.mutable_pattern(), pattern_name); + pattern(input_nodes, times); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + const int kNumFields = 5; + const int kTransOffset = 1; + const int kTransOutOffset = 2; + const int kFlattenOffset = 3; + const int kFlattenOutOffset = 4; + std::vector nodes; + + for (int i = 0; i < times; i++) { + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); + PADDLE_ENFORCE(subgraph.at(input_nodes[i])); + + nodes.push_back(subgraph.at(input_nodes[i])); + nodes.push_back( + subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); + } + + Node *concat_op = subgraph.at(pattern.GetPDNode("concat")); + Node *concat_out = subgraph.at(pattern.GetPDNode("concat_out")); + std::vector input_names; + std::vector trans_axis = boost::get>( + nodes[kTransOffset]->Op()->GetAttr("axis")); + int flatten_axis = + boost::get(nodes[kFlattenOffset]->Op()->GetAttr("axis")); + int concat_axis = boost::get(concat_op->Op()->GetAttr("axis")); + std::string output_name = concat_out->Name(); + + for (int i = 0; i < times; i++) { + input_names.push_back(nodes[i * kNumFields]->Name()); + } + + framework::OpDesc new_op_desc; + new_op_desc.SetType("fusion_transpose_flatten_concat"); + new_op_desc.SetInput("X", input_names); + new_op_desc.SetAttr("trans_axis", trans_axis); + new_op_desc.SetAttr("flatten_axis", flatten_axis); + new_op_desc.SetAttr("concat_axis", concat_axis); + new_op_desc.SetOutput("Out", {output_name}); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto *new_conv_op = graph->CreateOpNode(&new_op_desc); + + std::unordered_set delete_nodes; + + for (int i = 0; i < times; i++) { + nodes[i * kNumFields]->outputs.push_back(new_conv_op); + new_conv_op->inputs.push_back(nodes[i * kNumFields]); + delete_nodes.insert(nodes[i * kNumFields + kTransOffset]); + delete_nodes.insert(nodes[i * kNumFields + kTransOutOffset]); + delete_nodes.insert(nodes[i * kNumFields + kFlattenOffset]); + delete_nodes.insert(nodes[i * kNumFields + kFlattenOutOffset]); + } + delete_nodes.insert(concat_op); + + new_conv_op->outputs.push_back(concat_out); + concat_out->inputs.push_back(new_conv_op); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph.get(), delete_nodes); + }; + + gpd(graph.get(), handler); + return graph; +} + +template class TransposeFlattenConcatFusePass<1>; +template class TransposeFlattenConcatFusePass<3>; +template class TransposeFlattenConcatFusePass<4>; +template class TransposeFlattenConcatFusePass<5>; +template class TransposeFlattenConcatFusePass<6>; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(transpose_flatten_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<1>); + +REGISTER_PASS(transpose_flatten3_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<3>); + +REGISTER_PASS(transpose_flatten4_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<4>); + +REGISTER_PASS(transpose_flatten5_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<5>); + +REGISTER_PASS(transpose_flatten6_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<6>); diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..fb0f0ae9efdc5a25a799d6123fa658a99860cd86 --- /dev/null +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// There may be many transpose-flatten structures in a model, and the output of +// these structures will be used as inputs to the concat Op. This pattern will +// be detected by our pass. The times here represents the repeat times of this +// structure. +template +class TransposeFlattenConcatFusePass : public FusePassBase { + public: + virtual ~TransposeFlattenConcatFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 041187665af6ad0d75a7c55fe6ed451fe6f45b73..4d29564aeed74558b7f0ec580568f70dad0b40cc 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -391,7 +391,7 @@ class ExecutionContext { PADDLE_ENFORCE( dynamic_cast(allocation_ptr) != nullptr, "The AllocationPtr must be TemporaryAllocation."); - PADDLE_ENFORCE_GE(allocation_ptr->size(), + PADDLE_ENFORCE_EQ(allocation_ptr->size(), framework::product(dim) * sizeof(T)); paddle::framework::Tensor temp_tensor( diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index a79f501673d0d55aa2c5f2a40333a900d3415bb4..7594670cd2608802bdf41682ef5724a7a965d754 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -27,6 +27,8 @@ namespace paddle { namespace imperative { +std::map py_funcs_; + using framework::Variable; void AddTo(Variable* src, Variable* dst) { @@ -55,6 +57,7 @@ class Autograd { if (var->stop_gradient_) { return; } + VLOG(3) << "start autograd"; std::deque ready; ready.push_back(var->pre_op_); @@ -120,51 +123,57 @@ framework::LoDTensor& VarBase::GradValue() { } std::map> OpBase::ApplyGrad() { - if (!grad_op_desc_) { + if (!grad_op_desc_ && backward_id_ <= 0) { LOG(WARNING) << "op with no grad: " << op_desc_->Type(); return {}; } - VLOG(3) << "op grad " << grad_op_desc_->Type(); - std::vector> tmp_vars; std::map> grad_outputs; - for (auto it : grad_output_vars_) { - auto& outputs = grad_outputs[it.first]; - for (size_t i = 0; i < it.second.size(); ++i) { - // Allocate a new variable - Variable* tmp_var = new framework::Variable(); - tmp_var->GetMutable(); - - tmp_vars.emplace_back(tmp_var); - outputs.push_back(tmp_var); + if (backward_id_ > 0) { + VLOG(3) << "py_layer_grad"; + grad_outputs["Out@GRAD"] = + PyLayer::ApplyGrad(backward_id_, grad_input_vars_["X@GRAD"]); + } else { + VLOG(3) << "op grad " << grad_op_desc_->Type(); + for (auto it : grad_output_vars_) { + auto& outputs = grad_outputs[it.first]; + for (size_t i = 0; i < it.second.size(); ++i) { + // Allocate a new variable + Variable* tmp_var = new framework::Variable(); + tmp_var->GetMutable(); + outputs.push_back(tmp_var); + } } - } - framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); + framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); - // No need to do compile time infer shape here. - // grad_op_desc_->InferShape(*block_); - grad_op_desc_->InferVarType(block_); + // No need to do compile time infer shape here. + // grad_op_desc_->InferShape(*block_); + grad_op_desc_->InferVarType(block_); - std::unique_ptr opbase = - framework::OpRegistry::CreateOp(*grad_op_desc_); - framework::OperatorWithKernel* op_kernel = - dynamic_cast(opbase.get()); - PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); + std::unique_ptr opbase = + framework::OpRegistry::CreateOp(*grad_op_desc_); + framework::OperatorWithKernel* op_kernel = + dynamic_cast(opbase.get()); + PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); - framework::Scope scope; - platform::CPUPlace place; - PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); - p.op.RuntimeInferShape(scope, place, ctx); - p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + framework::Scope scope; + platform::CPUPlace place; + PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place); + p.op.RuntimeInferShape(scope, place, ctx); + p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + } for (auto it : grad_output_vars_) { auto& outputs = grad_outputs[it.first]; auto& origin_outputs = it.second; + PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { + framework::Variable* grad = outputs[i]; framework::Variable* orig_grad = origin_outputs[i]; - AddTo(outputs[i], orig_grad); + AddTo(grad, orig_grad); + delete grad; } } return input_vars_; @@ -173,6 +182,7 @@ std::map> OpBase::ApplyGrad() { void VarBase::RunBackward() { if (!pre_op_) return; + VLOG(3) << "start backward"; auto grads_t = grads_->var_->GetMutable(); float* data = grads_t->mutable_data(platform::CPUPlace()); std::fill(data, data + grads_t->numel(), 1.0); @@ -183,5 +193,65 @@ void VarBase::RunBackward() { Autograd().RunBackward(this); } +void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { + py_funcs_[func_id] = py_func; +} + +int PyLayer::NumFuncs() { return py_funcs_.size(); } + +std::vector PyLayer::Apply(int func_id, + const std::vector& inputs) { + std::vector invars; + for (const VarBase* in : inputs) { + invars.push_back(in->var_); + } + PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); + std::vector outvars = CallPythonFunc(py_funcs_[func_id], invars); + std::vector ret; + for (Variable* v : outvars) { + ret.push_back(new VarBase(v, new VarBase(true))); + } + return ret; +} + +std::vector PyLayer::ApplyGrad( + int func_id, const std::vector& inputs) { + PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); + return CallPythonFunc(py_funcs_[func_id], inputs); +} + +std::vector PyLayer::CallPythonFunc( + const py::object& callable, const std::vector& ins) { + py::gil_scoped_acquire guard; + py::tuple in_args(ins.size()); + for (size_t i = 0; i < ins.size(); ++i) { + const framework::LoDTensor& t = ins[i]->Get(); + in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr); + } + VLOG(3) << "pyfunc in " << py::len(in_args); + + // TODO(panyx0718): Who owns the returned LoDTensor. + auto ret = callable(in_args); + auto ret_tuple = py::cast(ret); + size_t ret_num = py::len(ret_tuple); + std::vector outs; + VLOG(3) << "pyfunc out " << ret_num; + for (size_t i = 0; i < ret_num; ++i) { + try { + auto* py_out_tensor = py::cast(ret_tuple[i]); + PADDLE_ENFORCE_NOT_NULL(py_out_tensor, + "Output tensor %d should not be nullptr", i); + auto* var = new framework::Variable(); + auto* tensor = var->GetMutable(); + tensor->ShareDataWith(*py_out_tensor); + tensor->set_lod(py_out_tensor->lod()); + outs.push_back(var); + } catch (py::cast_error&) { + PADDLE_THROW("The %d-th output must be LoDTensor", i); + } + } + return outs; +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index d441b3445a0f2423e4d8e626765e4d086b2c8d74..86c2dc3fa4a7d03aa8f0a89a25c17656e1cd708c 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -22,12 +22,15 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" +#include "pybind11/pybind11.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { +namespace py = ::pybind11; + class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, @@ -90,16 +93,21 @@ class OpBase; */ class VarBase { public: - VarBase() + VarBase() : VarBase(new framework::Variable(), new VarBase(true)) {} + + // Owns `var` and `grad` + VarBase(framework::Variable* var, VarBase* grad) : pre_op_(nullptr), + pre_op_out_name_(), pre_op_out_idx_(-1), var_desc_(nullptr), - var_(new framework::Variable()), - grads_(new VarBase(true)), + var_(var), + grads_(grad), stop_gradient_(false) {} explicit VarBase(bool stop_gradient) : pre_op_(nullptr), + pre_op_out_name_(), pre_op_out_idx_(-1), var_desc_(nullptr), var_(new framework::Variable()), @@ -144,7 +152,11 @@ class VarBase { */ class OpBase { public: - OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {} + OpBase() + : op_desc_(nullptr), + forward_id_(-1), + grad_op_desc_(nullptr), + backward_id_(-1) {} virtual ~OpBase() { if (grad_op_desc_) delete grad_op_desc_; @@ -152,8 +164,14 @@ class OpBase { std::map> ApplyGrad(); + // One of `op_desc_` or `forward_id_` is set, not both. + // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. framework::OpDesc* op_desc_; + int forward_id_; + // When has backward, one of `grad_op_desc_` or `backward_id_` is set, + // not both. framework::OpDesc* grad_op_desc_; + int backward_id_; VarBasePtrMap input_vars_; VarBasePtrMap output_vars_; @@ -173,8 +191,25 @@ class Layer { std::vector vars; return vars; } +}; + +class PyLayer { + public: + virtual ~PyLayer() {} + + static void RegisterFunc(int func_id, const py::object& py_func); + + static int NumFuncs(); + + static std::vector Apply(int func_id, + const std::vector& inputs); + + static std::vector ApplyGrad( + int func_id, const std::vector& inputs); - virtual void Backward() { LOG(ERROR) << "To support customize"; } + private: + static std::vector CallPythonFunc( + const py::object& callable, const std::vector& ins); }; } // namespace imperative diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index ead1ed5e3f82a1fc46adf93d19db9361ccbae13a..78e95f672232e82332fefd2a9eef30756fa0cae6 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -115,8 +115,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, if (!stop_gradient) { framework::OpDesc* grad_op_desc; - auto grad_to_var = new std::unordered_map(); - CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); + // TODO(panyx): Is this leaked? + std::unique_ptr> grad_to_var( + new std::unordered_map()); + CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get()); op->grad_op_desc_ = grad_op_desc; for (auto it : grad_op_desc->Inputs()) { @@ -127,13 +129,15 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, if (var_it == grad_to_var->end()) { auto fwd_var_it = vars.find(grad_invar); PADDLE_ENFORCE(fwd_var_it != vars.end()); + // Forward inputs or outputs. grad_in_vars.push_back(fwd_var_it->second->var_); } else { VarBase* var = vars[var_it->second]; - if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + if (!var->grads_->IsInitialized()) { + InitVar(var->var_, var->grads_); } - grad_in_vars.push_back(var->grads_->var_); + // Douts. + grad_in_vars.push_back(var->grads_); } } } @@ -145,10 +149,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, auto var_it = grad_to_var->find(grad_outvar); PADDLE_ENFORCE(var_it != grad_to_var->end()); VarBase* var = vars[var_it->second]; - if (!var->grads_->var_->IsInitialized()) { - InitVar(var->var_, var->grads_->var_); + if (!var->grads_->IsInitialized()) { + InitVar(var->var_, var->grads_); } - grad_out_vars.push_back(var->grads_->var_); + grad_out_vars.push_back(var->grads_); } } } @@ -156,5 +160,54 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, op->block_ = block; } +std::vector Tracer::PyTrace(OpBase* op, + const std::vector& inputs, + bool stop_gradient) { + VLOG(3) << "py_trace"; + op->input_vars_["X"] = inputs; + op->output_vars_["Out"] = PyLayer::Apply(op->forward_id_, inputs); + for (VarBase* inp : inputs) { + if (inp->pre_op_) { + op->pre_ops_["X"].push_back(inp->pre_op_); + op->pre_ops_out_idx_["X"].push_back(inp->pre_op_out_idx_); + } else { + op->pre_ops_["X"].push_back(nullptr); + } + } + + auto& outputs = op->output_vars_["Out"]; + for (size_t i = 0; i < outputs.size(); ++i) { + VarBase* out = outputs[i]; + out->stop_gradient_ = stop_gradient; + out->pre_op_ = op; + out->pre_op_out_name_ = "Out"; + out->pre_op_out_idx_ = i; + } + if (!stop_gradient) { + auto& grad_input_vars = op->grad_input_vars_["X@GRAD"]; + auto& grad_output_vars = op->grad_output_vars_["Out@GRAD"]; + + for (const VarBase* inp : inputs) { + grad_input_vars.push_back(inp->var_); + } + for (VarBase* out : outputs) { + grad_input_vars.push_back(out->var_); + } + for (VarBase* out : outputs) { + grad_input_vars.push_back(out->grads_); + if (!grad_input_vars.back()->IsInitialized()) { + InitVar(out->var_, grad_input_vars.back()); + } + } + for (const VarBase* inp : inputs) { + grad_output_vars.push_back(inp->grads_); + if (!grad_output_vars.back()->IsInitialized()) { + InitVar(inp->var_, grad_output_vars.back()); + } + } + } + return outputs; +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 7d484c291fc40eded26f559f01754e662ad167f8..f225d8abe6c0635d2bdd8dba0b12c7fc3a4110db 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -45,6 +45,9 @@ class Tracer { const std::map>& outputs, framework::BlockDesc* block, const bool stop_gradient = false); + std::vector PyTrace(OpBase* op, const std::vector& inputs, + bool stop_gradient = false); + private: framework::BlockDesc* root_block_; }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 211c691504de2c0bd8ff50f34b92cbc01397d5c9..336ab426c21d9de93693c44d8fc6bc5b37b58864 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -127,6 +127,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, use_tensorrt_ = true; tensorrt_workspace_size_ = workspace_size; tensorrt_max_batchsize_ = max_batch_size; + Update(); } void contrib::AnalysisConfig::Update() { diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index fa2752e9158d8136c3c7aee34651b37096f009d1..19ef402d6fd78d6a65bdb0bbd22198f36b872a27 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -128,8 +128,8 @@ else() ${CMAKE_STATIC_LIBRARY_PREFIX}glog ${CMAKE_STATIC_LIBRARY_PREFIX}gflags ${CMAKE_STATIC_LIBRARY_PREFIX}protobuf ${CMAKE_STATIC_LIBRARY_PREFIX}snappy ${CMAKE_STATIC_LIBRARY_PREFIX}z ${CMAKE_STATIC_LIBRARY_PREFIX}xxhash snappystream ${EXTERNAL_LIB}) - # NOTE(dzhwinter) shlwapi is deprecated. - set(DEPS ${DEPS} libcmt shlwapi) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + set(DEPS ${DEPS} libcmt ${os_dependency_modules}) endif(NOT WIN32) if(WITH_GPU) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 1e5712e1638ea802dfa9c3b41ab1d3f7f62f090b..de9650735adfe158e72213d4f6d5d3569aa90d55 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -141,6 +141,10 @@ class GpuPassStrategy : public PassStrategy { "conv_elementwise_add_fuse_pass", // }); + for (int i = 6; i >= 3; i--) { + passes_.push_back("transpose_flatten" + std::to_string(i) + + "_concat_fuse_pass"); + } use_gpu_ = true; } diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 6975086193d991dc9f53b2d9d988f960c8ad118d..79362f9677010247dffa4fbaa155a7a56eed6f85 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -39,6 +39,7 @@ class ElementwiseWeightOpConverter : public OpConverter { const framework::Scope& scope, bool test_mode) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. + nvinfer1::ILayer* layer = nullptr; framework::OpDesc op_desc(op, nullptr); VLOG(3) << "Convert a fluid elementwise op to TensorRT IScaleLayer"; @@ -98,13 +99,21 @@ class ElementwiseWeightOpConverter : public OpConverter { 0}; TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; + if (op_type_ == "add") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get()); + layer = scale_layer; + } else if (op_type_ == "mul") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, scale_weights.get(), + shift_weights.get(), power_weights.get()); + layer = scale_layer; + } - nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *const_cast(X), scale_mode, - shift_weights.get(), scale_weights.get(), power_weights.get()); auto output_name = op_desc.Output("Out")[0]; - - layer->setName(("elementwise_add (Output: " + output_name + ")").c_str()); + layer->setName( + ("elementwise_" + op_type_ + "(Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str()); engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor); engine_->SetITensor(output_name, layer->getOutput(0)); @@ -113,6 +122,9 @@ class ElementwiseWeightOpConverter : public OpConverter { engine_->DeclareOutput(output_name); } } + + protected: + std::string op_type_; }; class ElementwiseTensorOpConverter : public OpConverter { @@ -188,6 +200,16 @@ const std::unordered_map {"max", nvinfer1::ElementWiseOperation::kMAX}, }; +class ElementwiseWeightAddOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightAddOpConverter() { op_type_ = "add"; } +}; + +class ElementwiseWeightMulOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightMulOpConverter() { op_type_ = "mul"; } +}; + class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter { public: ElementwiseTensorAddOpConverter() { op_type_ = "add"; } @@ -227,7 +249,10 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, ElementwiseWeightOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, + ElementwiseWeightAddOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_mul_weight, + ElementwiseWeightMulOpConverter); REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor, ElementwiseTensorAddOpConverter); diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index cfb80fe6ec11a55a887c7552ec4e6a8a0c6a2fce..c43eaf7f9849ee4a88ed95bdb8b6966da8760435 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -2,6 +2,3 @@ cc_library(benchmark SRCS benchmark.cc DEPS enforce) cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark) cc_binary(visualizer SRCS visualizer.cc DEPS analysis paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes) -if(WIN32) - target_link_libraries(visualizer shlwapi) -endif(WIN32) diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index f5208e7a601f4dd33b486e5840178022f66431e5..25a723fc07948888ef3dc61320fb9bec026390de 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -137,6 +137,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { // ------------------- cudnn conv algorithm --------------------- cudnnConvolutionFwdAlgo_t algo; auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); bool half_float = false; #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) @@ -157,8 +158,6 @@ class CUDNNConvOpKernel : public framework::OpKernel { VLOG(5) << "NOT use cudnn_tensor_op_math"; } #endif - Tensor cudnn_workspace; - void* cudnn_workspace_ptr = nullptr; auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); @@ -181,26 +180,21 @@ class CUDNNConvOpKernel : public framework::OpKernel { .Var(kCUDNNFwdAlgoCache) ->GetMutable>(); } - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_limit)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - algo = algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array fwd_perf_stat; - - CUDNN_ENFORCE( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, cudnn_input_desc, input_data, cudnn_filter_desc, - filter_data, cudnn_conv_desc, cudnn_output_desc, - output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - fwd_perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); + auto cudnn_find_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + handle, cudnn_input_desc, input_data, cudnn_filter_desc, + filter_data, cudnn_conv_desc, cudnn_output_desc, + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, + fwd_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); VLOG(3) << "Perf result: (algo: stat, time, memory)"; for (int i = 0; i < returned_algo_count; ++i) { @@ -225,23 +219,17 @@ class CUDNNConvOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); - // Allocate on GPU memory - if (!cudnn_workspace_ptr) { - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_in_bytes)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - } // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_filter_desc, filter_data + i * group_offset_filter, - cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes, - &beta, cudnn_output_desc, output_data + i * group_offset_out)); + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, + cudnn_filter_desc, filter_data + i * group_offset_filter, + cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes, + &beta, cudnn_output_desc, output_data + i * group_offset_out)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } }; @@ -365,20 +353,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size_limit = max_user_size * 1024 * 1024; } - Tensor cudnn_workspace; - void* cudnn_workspace_ptr = nullptr; - if ((input_data || filter_data) && exhaustive_search) { - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_limit)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - } - auto x_dims = framework::vectorize(input->dims()); auto f_dims = framework::vectorize(filter->dims()); auto handle = dev_ctx.cudnn_handle(); + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); if (exhaustive_search) { @@ -396,22 +374,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { ->GetMutable< AlgorithmsCache>(); } - data_algo = data_algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array data_perf_stat; - - CUDNN_ENFORCE(platform::dynload:: - cudnnFindConvolutionBackwardDataAlgorithmEx( - handle, cudnn_filter_desc, filter_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_input_desc, - input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS, - &returned_algo_count, data_perf_stat.data(), - cudnn_workspace_ptr, workspace_size_limit)); + auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardDataAlgorithmEx( + handle, cudnn_filter_desc, filter_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_input_desc, input_grad_data, + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + data_perf_stat.data(), cudnn_workspace, + workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_bd_data_func, + workspace_size_limit); VLOG(3) << "Perf result: (algo: stat, time, memory)"; for (int i = 0; i < returned_algo_count; ++i) { @@ -462,23 +443,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { ->GetMutable< AlgorithmsCache>(); } - filter_algo = f_algo_cache->GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array filter_perf_stat; - - CUDNN_ENFORCE( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithmEx( - handle, cudnn_input_desc, input_data, - cudnn_output_grad_desc, output_grad_data, - cudnn_conv_desc, cudnn_filter_desc, filter_grad_data, - kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, - filter_perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); + auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE( + platform::dynload:: + cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle, cudnn_input_desc, input_data, + cudnn_output_grad_desc, output_grad_data, + cudnn_conv_desc, cudnn_filter_desc, + filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS, + &returned_algo_count, filter_perf_stat.data(), + cudnn_workspace, workspace_size_limit)); + }; + workspace_handle.RunFunc(cudnn_find_bd_f_func, + workspace_size_limit); return filter_perf_stat[0].algo; }); VLOG(3) << "cuDNN backward filter algo " << filter_algo; @@ -499,16 +482,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size); } - // ------------------- cudnn conv workspace --------------------- - if (!cudnn_workspace_ptr) { - cudnn_workspace = - ctx.AllocateTmpTensor( - framework::make_ddim( - {static_cast(workspace_size_in_bytes)}), - dev_ctx); - cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - } - // ------------------- cudnn conv backward data --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { @@ -516,12 +489,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // Because beta is zero, it is unnecessary to reset input_grad. for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, - filter_data + i * group_offset_filter, cudnn_output_grad_desc, - output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo, - cudnn_workspace_ptr, workspace_size_in_bytes, &beta, - cudnn_input_desc, input_grad_data + i * group_offset_in)); + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, + filter_data + i * group_offset_filter, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, + data_algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_input_desc, input_grad_data + i * group_offset_in)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } // ------------------- cudnn conv backward filter --------------------- @@ -529,12 +505,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. for (int i = 0; i < groups; i++) { - CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, - cudnn_output_grad_desc, output_grad_data + i * group_offset_out, - cudnn_conv_desc, filter_algo, cudnn_workspace_ptr, - workspace_size_in_bytes, &beta, cudnn_filter_desc, - filter_grad_data + i * group_offset_filter)); + auto cudnn_func = [&](void* cudnn_workspace) { + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_input_desc, + input_data + i * group_offset_in, cudnn_output_grad_desc, + output_grad_data + i * group_offset_out, cudnn_conv_desc, + filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_filter_desc, filter_grad_data + i * group_offset_filter)); + }; + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); } } } diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5bc25d19cba4de6f059612e3e8c4a65b2edd0f9 --- /dev/null +++ b/paddle/fluid/operators/data_norm_op.cc @@ -0,0 +1,409 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/data_norm_op.h" +#include +#include "paddle/fluid/framework/data_layout.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +using EigenArrayMap = + Eigen::Map>; +template +using ConstEigenArrayMap = + Eigen::Map>; +template +using EigenVectorArrayMap = Eigen::Map>; +template +using ConstEigenVectorArrayMap = + Eigen::Map>; + +class DataNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Means"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Scales"), ""); + PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, + "Input X must have 2 to 5 dimensions."); + + const int64_t C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C); + + ctx->SetOutputDim("Y", x_dims); + ctx->SetOutputDim("Means", {C}); + ctx->SetOutputDim("Scales", {C}); + ctx->ShareLoD("X", "Y"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = ctx.Input("X")->type(); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto dn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + dn_param_type = framework::proto::VarType::FP64; + } + PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSize")->type(), + "BatchSize input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSum")->type(), + "BatchSum input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, + ctx.Input("BatchSquareSum")->type(), + "BatchSquareSum input should be of float type"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class DataNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + // AddAttr("is_test", "").SetDefault(false); + AddAttr("epsilon", "") + .SetDefault(1e-4) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("data_layout", "").SetDefault("NCHW"); + AddInput("X", "The input tensor"); + AddInput("BatchSize", + "BatchSize is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("BatchSum", + "BatchSum is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("BatchSquareSum", + "The global BatchSquareSum (for training) or " + "estimated BatchSquareSum (for testing)"); + AddOutput("Y", "result after normalization"); + AddOutput("Means", + "Mean of the history data batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("Scales", + "Scales of the history data batch, " + "will apply to output when training") + .AsIntermediate(); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddComment(R"DOC( +Data Normalization. + +Can be used as a normalizer function for data +The required data format for this layer is one of the following: +1. NHWC `[batch, in_height, in_width, in_channels]` +2. NCHW `[batch, in_channels, in_height, in_width]` + +)DOC"); + } +}; + +template +class DataNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + // const bool is_test = ctx.Attr("is_test"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("Means"); + auto *scales = ctx.Output("Scales"); + + // alloc memory + y->mutable_data(ctx.GetPlace()); + + Eigen::Array inv_std(C); + ConstEigenVectorArrayMap b_size_arr( + ctx.Input("BatchSize")->data(), C); + ConstEigenVectorArrayMap b_sum_arr( + ctx.Input("BatchSum")->data(), C); + ConstEigenVectorArrayMap b_square_sum_arr( + ctx.Input("BatchSquareSum")->data(), C); + EigenVectorArrayMap means_arr(mean_out->mutable_data(ctx.GetPlace()), + C); + EigenVectorArrayMap scales_arr(scales->mutable_data(ctx.GetPlace()), + C); + means_arr = b_sum_arr / b_size_arr; + scales_arr = (b_size_arr / b_square_sum_arr).sqrt(); + + switch (data_layout) { + case DataLayout::kNCHW: // because it's two dimensions, so make no + // difference + case DataLayout::kNHWC: { + EigenArrayMap(y->mutable_data(ctx.GetPlace()), C, N) = + (ConstEigenArrayMap(x->data(), C, N).colwise() - means_arr) + .colwise() * + scales_arr; + break; + } + default: + PADDLE_THROW("Unknown storage order: %d", data_layout); + } + } +}; + +class DataNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // check input + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSize"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("BatchSquareSum"), ""); + PADDLE_ENFORCE(ctx->HasInput("Means"), ""); + PADDLE_ENFORCE(ctx->HasInput("Scales"), ""); + + // check output + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSize")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSum")), ""); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("BatchSquareSum")), + ""); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->SetOutputDim(framework::GradVarName("BatchSize"), {C}); + ctx->SetOutputDim(framework::GradVarName("BatchSum"), {C}); + ctx->SetOutputDim(framework::GradVarName("BatchSquareSum"), {C}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), layout, library); + } +}; + +template +class DataNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *batch_size = ctx.Input("BatchSize"); + const auto *batch_sum = ctx.Input("BatchSum"); + const auto *batch_square_sum = ctx.Input("BatchSquareSum"); + const auto *scales = ctx.Input("Scales"); + const auto *means = ctx.Input("Means"); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + // Get the size for each dimension. + // NCHW [batch_size, in_channels, in_height, in_width] + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() == 2, "The Input dim size should be 2"); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_batch_size = + ctx.Output(framework::GradVarName("BatchSize")); + auto *d_batch_sum = ctx.Output(framework::GradVarName("BatchSum")); + auto *d_batch_square_sum = + ctx.Output(framework::GradVarName("BatchSquareSum")); + + EigenVectorArrayMap d_batch_size_arr( + d_batch_size->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap d_batch_sum_arr( + d_batch_sum->mutable_data(ctx.GetPlace()), C); + EigenVectorArrayMap d_batch_square_sum_arr( + d_batch_square_sum->mutable_data(ctx.GetPlace()), C); + + d_batch_size_arr.setZero(); + d_batch_sum_arr.setZero(); + d_batch_square_sum_arr.setZero(); + + const float epsilon = ctx.Attr("epsilon"); + switch ( + data_layout) { // because it's two dimensions, so make no difference + case DataLayout::kNCHW: + case DataLayout::kNHWC: { + ConstEigenVectorArrayMap scales_arr(scales->data(), C); + ConstEigenVectorArrayMap means_arr(means->data(), C); + ConstEigenArrayMap x_arr(x->data(), C, N); + ConstEigenArrayMap d_y_arr(d_y->data(), C, N); + EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N); + d_x_arr.setZero(); + for (int nc = 0; nc < N; ++nc) { + d_x_arr.col(nc) = d_y_arr.col(nc) * scales_arr; + } + + // calculate data sum and squre sum + ConstEigenVectorArrayMap batch_size_arr(batch_size->data(), C); + ConstEigenVectorArrayMap batch_sum_arr(batch_sum->data(), C); + ConstEigenVectorArrayMap batch_square_sum_arr( + batch_square_sum->data(), C); + Eigen::Array sample_sum(C); + Eigen::Array sample_square_sum(C); + // calculate data sample sum and square sum + sample_sum.setZero(); + sample_square_sum.setZero(); + for (int nc = 0; nc < N; ++nc) { + sample_sum += x_arr.col(nc); + sample_square_sum += (x_arr.col(nc) - means_arr).square(); + } + // calculate gradient + d_batch_size_arr.setConstant(N); + d_batch_sum_arr = sample_sum; + d_batch_square_sum_arr = sample_square_sum + d_batch_size_arr * epsilon; + break; + } + default: + PADDLE_THROW("Unknown storage order: %s", data_layout_str); + } + } +}; + +class DataNormGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDesc(); + op->SetType("data_norm_grad"); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + + op->SetInput("BatchSize", Input("BatchSize")); + op->SetInput("BatchSum", Input("BatchSum")); + op->SetInput("BatchSquareSum", Input("BatchSquareSum")); + op->SetInput("Scales", Output("Scales")); + op->SetInput("Means", Output("Means")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("BatchSize"), InputGrad("BatchSize")); + op->SetOutput(framework::GradVarName("BatchSum"), InputGrad("BatchSum")); + op->SetOutput(framework::GradVarName("BatchSquareSum"), + InputGrad("BatchSquareSum")); + + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(data_norm, ops::DataNormOp, ops::DataNormOpMaker, + ops::DataNormGradMaker); +REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp); + +REGISTER_OP_CPU_KERNEL( + data_norm, ops::DataNormKernel, + ops::DataNormKernel); +REGISTER_OP_CPU_KERNEL( + data_norm_grad, + ops::DataNormGradKernel, + ops::DataNormGradKernel); diff --git a/paddle/fluid/operators/data_norm_op.h b/paddle/fluid/operators/data_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..63451214bcf649d0a7a949f391db9b651d237d22 --- /dev/null +++ b/paddle/fluid/operators/data_norm_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class DataNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +template +class DataNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index bde2791add4075be6949703dfbea634966d25c1c..4b4ce07fa78b97e636173566fa104cb8a18c914e 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -52,11 +52,11 @@ struct BenchFunc { for (int i = 0; i < FLAGS_burning; ++i) { tgt(args...); } - auto start = paddle::platform::PosixInNsec() / 1e-3; + auto start = paddle::platform::PosixInNsec() * 1e-3; for (int i = 0; i < FLAGS_repeat; ++i) { tgt(args...); } - auto end = paddle::platform::PosixInNsec() / 1e-3; + auto end = paddle::platform::PosixInNsec() * 1e-3; return static_cast(end - start) / FLAGS_repeat; } }; diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8ee13875c5ae772de3c09f97fded8f70c5698e6 --- /dev/null +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/teacher_student_sigmoid_loss_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(label_dims[1], 1UL, + "The 2nd dimension of " + "Input(Label) should be 1."); + ctx->SetOutputDim("Y", {x_dims[0], 1}); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of + // teacher_student_sigmoid_loss + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class TeacherStudentSigmoidLossGradientOp + : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], + "The 1st dimension of Input(X) and Input(Label) should " + "be equal."); + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + PADDLE_ENFORCE_EQ(dy_dims[1], 1, + "The 2nd dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[1], 1, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(Label) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // teacher_student_sigmoid_loss + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class TeacherStudentSigmoidLossOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape [N x 1]," + " where N is the batch size and D is the output. " + "This input is a probability computed by the previous operator, " + "which is almost always the result of a softmax operator."); + AddInput("Label", + "(Tensor), the ground truth which is a 2-D tensor. " + "Label is a Tensor with shape [N x 1]. "); + AddOutput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape " + "[N x 1]. The teacher student sigmoid loss."); + AddAttr( + "soft_max_up_bound", + "fp32, if input > soft_max_up_bound, will be bound, default 15.0") + .SetDefault(15.0); + AddAttr( + "soft_max_lower_bound", + "fp32, if input < soft_max_lower_bound, will be bound, default -15.0") + .SetDefault(-15.0); + AddComment(R"DOC( +TeacherStudentSigmoidLoss Operator. + +It's similarity to SigmoidCrossEntropyWithLogits Operator. The difference is that +we add another label(z') to original. + loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x))) + z is click or not + z' is teacher value + label = {-2, -1, [0, 2]} + when z' is not exist, clk = 0 : label = -2; + when z' is not exist, clk = 1 : label = -1; + when z' is exist , clk = 0 : label = 0 + z'; + when z' is exist , clk = 1 : label = 1 + z'; + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(teacher_student_sigmoid_loss, + ops::TeacherStudentSigmoidLossOp, + ops::TeacherStudentSigmoidLossOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OPERATOR(teacher_student_sigmoid_loss_grad, + ops::TeacherStudentSigmoidLossGradientOp); + +REGISTER_OP_CPU_KERNEL(teacher_student_sigmoid_loss, + ops::TeacherStudentSigmoidLossOpKernel, + ops::TeacherStudentSigmoidLossOpKernel); + +REGISTER_OP_CPU_KERNEL(teacher_student_sigmoid_loss_grad, + ops::TeacherStudentSigmoidLossGradOpKernel, + ops::TeacherStudentSigmoidLossGradOpKernel); diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..41d2662ae2a4d37222323d6a536ed3af1ab7e056 --- /dev/null +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +class TeacherStudentSigmoidLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + Tensor* y = context.Output("Y"); + const Tensor* x = context.Input("X"); + const Tensor* labels = context.Input("Label"); + T* y_data = y->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + const T* label_data = labels->data(); + int64_t batch_size = x->dims()[0]; + // loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + + // log(1 + exp(-abs(x))) + // z is click or not + // z' is value q of feed_fine + // label = {-2, -1, [0, 2]} + // when z' is not exist, clk = 0 : label = -2; + // when z' is not exist, clk = 1 : label = -1; + // when z' is exist , clk = 0 : label = 0 + z'; + // when z' is exist , clk = 1 : label = 1 + z'; + for (int i = 0; i < batch_size; ++i) { + if (label_data[i] < -1.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) + + log(1.0 + exp(-fabs(x_data[i]))); + } else if (label_data[i] < 0.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) - x_data[i] + + log(1.0 + exp(-fabs(x_data[i]))); + } else if (label_data[i] < 1.0) { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) + + log(1.0 + exp(-fabs(x_data[i]))) + + (x_data[i] > 0 ? x_data[i] : 0.0) - + x_data[i] * label_data[i] + + log(1.0 + exp(-fabs(x_data[i]))); + } else { + y_data[i] = (x_data[i] > 0 ? x_data[i] : 0.0) - x_data[i] + + log(1.0 + exp(-fabs(x_data[i]))) + + (x_data[i] > 0 ? x_data[i] : 0.0) - + x_data[i] * (label_data[i] - 1.0) + + log(1.0 + exp(-fabs(x_data[i]))); + } + } + } +}; + +template +class TeacherStudentSigmoidLossGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + const T* x_data = x->data(); + + Tensor* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + const Tensor* labels = context.Input("Label"); + const T* label_data = labels->data(); + + T soft_max_up_bound = + static_cast(context.Attr("soft_max_up_bound")); + T soft_max_lower_bound = + static_cast(context.Attr("soft_max_lower_bound")); + + int64_t batch_size = x->dims()[0]; + + const framework::Tensor* dOut = + context.Input(framework::GradVarName("Y")); + + const T* dout_data = dOut->data(); + + for (int i = 0; i < batch_size; ++i) { + T sum_val = x_data[i]; + if (sum_val > soft_max_up_bound) { + sum_val = soft_max_up_bound; + } else { + if (sum_val < soft_max_lower_bound) { + sum_val = soft_max_lower_bound; + } + } + + T pred = 1.0 / (1.0 + exp(-sum_val)); + if (label_data[i] < -1.0) { + dx_data[i] = 0.0 - pred; + } else if (label_data[i] < 0.0) { + dx_data[i] = 1.0 - pred; + } else { + dx_data[i] = label_data[i] - 2.0 * pred; + } + if (sum_val >= soft_max_up_bound || sum_val <= soft_max_lower_bound) { + dx_data[i] = 0; + } + dx_data[i] *= dout_data[i] * -1; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index 466bf90c63c1496883995819cdcb19f846e4a302..9e3025bf30b8849472e33a71228eb16814157b21 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -15,6 +15,9 @@ #include #include #include +#ifdef _WIN32 +#include +#endif #include #define PADDLE_CUDA_FP16 diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d376f90ad5754d70f3b9f30957eb2e2f584f8da9..c81d17380cf894631d06588c007c2e11ce5c7836 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -61,7 +61,7 @@ namespace platform { * the allocations of temp_allocation_queue: * - when the Stream calls cudaStreamSynchronize; * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_tmp_allocation). + * (defined by FLAGS_limit_of_temporary_allocation). * * */ class DeviceTemporaryAllocator { diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 98afe843c0035ec14ad874508dc02b8d1d3d359c..c203f4e04a28452807a42bbdaf75e89977772a04 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -59,7 +59,7 @@ limitations under the License. */ #if !defined(_WIN32) #define PADDLE_ALIGN(x) __attribute__((aligned(x))) #else -#define PADDLE_ALIGN(x) /*do nothing*/ +#define PADDLE_ALIGN(x) __declspec(align(x)) #endif namespace paddle { diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index b1b51d804e02f233bcd16149005092dc80e9c79d..14cad927f06551ebbfbf1d166ae250c18591dd6b 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -271,11 +271,13 @@ TEST(float16, isinf) { float16 b = float16(INFINITY); // underflow to 0 float16 native_a(5e-40f); - // overflow to inf - float16 native_b(5e40f); EXPECT_EQ(std::isinf(a), true); EXPECT_EQ(std::isinf(b), true); +#ifndef _WIN32 + // overflow to inf + float16 native_b(5e40f); EXPECT_EQ(std::isinf(native_b), true); +#endif EXPECT_EQ(native_a, float16(0)); } diff --git a/paddle/fluid/platform/temporary_allocator.cc b/paddle/fluid/platform/temporary_allocator.cc index 9cbdfe46e78dc84e58eae6929c887221d9562c69..0be017f75bcc8aff5073ebb2c5179cf7250be8b9 100644 --- a/paddle/fluid/platform/temporary_allocator.cc +++ b/paddle/fluid/platform/temporary_allocator.cc @@ -15,15 +15,8 @@ #include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" -DEFINE_int64(limit_of_tmp_allocation, -1, - "The up limit of temporary_allocation size."); -DEFINE_double(times_excess_than_required_tmp_allocation, 2, - "times_excess_than_required_tmp_allocation indicates the " - "max size the TemporaryAllocator can return. For example, " - "if the required memory size is N, and " - "times_excess_than_required_tmp_allocation is 2.0, " - "the TemporaryAllocator will return the available allocation " - "that the range of size is N ~ 2*N."); +DEFINE_double(limit_of_temporary_allocation, -1, + "The up limit of temporary_allocation size."); namespace paddle { namespace platform { @@ -36,25 +29,24 @@ TemporaryAllocation::TemporaryAllocation( underlying_allocation_(std::move(underlying_allocation)) {} TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) { - temp_mem_map_.reset(new std::multimap()); + temp_mem_queue_.reset(new std::deque()); } bool TemporaryAllocator::IsAllocThreadSafe() const { return true; } void TemporaryAllocator::Release(const std::function &callback) { - std::unique_ptr> t_allocations; + std::shared_ptr> t_allocations; { std::unique_lock lock(mtx_); callback(); - t_allocations.swap(temp_mem_map_); - temp_mem_map_.reset(new std::multimap()); + t_allocations = temp_mem_queue_; + temp_mem_queue_.reset(new std::deque()); wait_delete_mem_ = 0; } - for (auto tmp : *t_allocations) { - VLOG(10) << "Delete temporary allocation " << tmp.second->ptr() - << " size: " << tmp.second->size(); - delete tmp.second; + VLOG(10) << "Delete temporary allocation " << tmp->ptr() + << " size: " << tmp->size(); + delete tmp; } } @@ -62,34 +54,28 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) { auto *temp_allocation = dynamic_cast(allocation); PADDLE_ENFORCE_NOT_NULL(temp_allocation); if (platform::is_gpu_place(temp_allocation->place())) { - PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_), - "The place should be the same."); size_t wait_delete_mem = 0; { std::unique_lock lock(mtx_); - temp_mem_map_->emplace(temp_allocation->size(), temp_allocation); + temp_mem_queue_->emplace_back(temp_allocation); wait_delete_mem_ += temp_allocation->size(); wait_delete_mem = wait_delete_mem_; VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr() << " to delete queue: " << temp_allocation->size() << "; " - << "wait_delete_mem: " << wait_delete_mem; + << "wait_delete_mem: " << wait_delete_mem_; } - - if (FLAGS_limit_of_tmp_allocation > 0 && - wait_delete_mem > static_cast(FLAGS_limit_of_tmp_allocation)) { - PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized."); + if (FLAGS_limit_of_temporary_allocation > 0 && + wait_delete_mem > FLAGS_limit_of_temporary_allocation) { Release(callback_); } return; } - VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr() - << " size: " << temp_allocation->size(); delete temp_allocation; } size_t TemporaryAllocator::TemporaryAllocationQueueSize() { std::unique_lock lock(mtx_); - return temp_mem_map_ ? temp_mem_map_->size() : 0; + return temp_mem_queue_ ? temp_mem_queue_->size() : 0; } void TemporaryAllocator::SetCallback(const std::function &callback) { @@ -98,27 +84,6 @@ void TemporaryAllocator::SetCallback(const std::function &callback) { alloc::Allocation *TemporaryAllocator::AllocateImpl( size_t size, alloc::Allocator::Attr attr) { - { - // Find available allocation in temp_mem_map. - std::unique_lock lock(mtx_); - if (temp_mem_map_->size()) { - auto it = temp_mem_map_->lower_bound(size); - // FIXME(zcd): Not sure the best value of excess fraction. - if (it != temp_mem_map_->end() && - it->first < - static_cast( - size * FLAGS_times_excess_than_required_tmp_allocation)) { - auto tmp_ptr = it->second; - temp_mem_map_->erase(it); - wait_delete_mem_ -= tmp_ptr->size(); - VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": " - << tmp_ptr->size(); - return tmp_ptr; - } - } - } - // If not find the the available allocation, get allocation from - // AllocatorFacadeInstance. auto raw_allocation = alloc::AllocatorFacade::Instance().Alloc(place_, size, attr); auto temp_mem = new TemporaryAllocation(std::move(raw_allocation)); diff --git a/paddle/fluid/platform/temporary_allocator.h b/paddle/fluid/platform/temporary_allocator.h index d657a14223326aa1e2cb5b154a10a56ae742f95c..812c4a333189d8c432be398ca0ebbce11f957561 100644 --- a/paddle/fluid/platform/temporary_allocator.h +++ b/paddle/fluid/platform/temporary_allocator.h @@ -15,7 +15,6 @@ #pragma once #include // NOLINT #include -#include #include // NOLINT #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/platform/lock_guard_ptr.h" @@ -40,7 +39,7 @@ class TemporaryAllocation : public memory::allocation::Allocation { * * There is one opportunity to free the allocations of temp_allocation_queue: * - when the allocation size of opportunities exceeds a certain threshold - * (defined by FLAGS_limit_of_tmp_allocation). + * (defined by FLAGS_limit_of_temporary_allocation). * * */ class TemporaryAllocator : public memory::allocation::Allocator { @@ -63,10 +62,11 @@ class TemporaryAllocator : public memory::allocation::Allocator { private: platform::Place place_; + // When the allocation is not held by any variable, it should be placed - // to temp_mem_map immediately. - std::unique_ptr> temp_mem_map_{ - nullptr}; + // to temp_mem_queue immediately. + std::shared_ptr> temp_mem_queue_{nullptr}; + std::mutex mtx_; size_t wait_delete_mem_{0}; std::function callback_; diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc index 3879cd540017ea22b0cf4eee794a172e56716b74..35d1d929819c41b213bc51ec24ac725021a76c88 100644 --- a/paddle/fluid/platform/temporary_allocator_test.cc +++ b/paddle/fluid/platform/temporary_allocator_test.cc @@ -18,8 +18,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" -DECLARE_int64(limit_of_tmp_allocation); -DECLARE_double(times_excess_than_required_tmp_allocation); +DECLARE_double(limit_of_temporary_allocation); namespace paddle { namespace platform { @@ -36,7 +35,7 @@ class DummyOp : public framework::OperatorBase { const platform::Place& place) const override {} }; -TEST(temporary_allocator, test_base_function) { +TEST(temporary_allocator, temporary_allocator) { platform::CPUPlace cpu_place; TemporaryAllocator alloc(cpu_place); alloc.Allocate(100); @@ -60,10 +59,10 @@ TEST(temporary_allocator, test_base_function) { #endif } -TEST(temporary_allocator, test_flags_function) { +TEST(temporary_allocator, add_callback) { #ifdef PADDLE_WITH_CUDA - const int64_t limit = FLAGS_limit_of_tmp_allocation; - FLAGS_limit_of_tmp_allocation = 10; + const double limit = FLAGS_limit_of_temporary_allocation; + FLAGS_limit_of_temporary_allocation = 10; platform::CUDAPlace gpu_place(0); TemporaryAllocator gpu_alloc(gpu_place); @@ -79,52 +78,7 @@ TEST(temporary_allocator, test_flags_function) { }); { gpu_alloc.Allocate(100); } PADDLE_ENFORCE(deleted); - FLAGS_limit_of_tmp_allocation = limit; -#endif -} - -TEST(temporary_allocator, test_reuse_tmp_allocation) { -#ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - gpu_alloc.SetCallback([]() {}); - - void* tmp_allocation_ptr1 = nullptr; - { - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - auto tmp_allocation1 = gpu_alloc.Allocate(100); - tmp_allocation_ptr1 = tmp_allocation1->ptr(); - } - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - auto tmp_allocation2 = gpu_alloc.Allocate(100); - void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); - - auto tmp_allocation3 = gpu_alloc.Allocate(100); - void* tmp_allocation_ptr3 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3); -#endif -} - -TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) { -#ifdef PADDLE_WITH_CUDA - platform::CUDAPlace gpu_place(0); - TemporaryAllocator gpu_alloc(gpu_place); - gpu_alloc.SetCallback([]() {}); - double excess_fraction = FLAGS_times_excess_than_required_tmp_allocation; - void* tmp_allocation_ptr1 = nullptr; - { - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - auto tmp_allocation1 = - gpu_alloc.Allocate(static_cast(100 * excess_fraction - 1)); - tmp_allocation_ptr1 = tmp_allocation1->ptr(); - } - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1); - auto tmp_allocation2 = gpu_alloc.Allocate(100); - void* tmp_allocation_ptr2 = tmp_allocation2->ptr(); - PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0); - PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2); + FLAGS_limit_of_temporary_allocation = limit; #endif } diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 7388a2d0e94bfe805cbc339fb171ab16611b107b..9a91ea38caef50b9a7ad970a3d08ca28c497e419 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -22,9 +22,8 @@ if(WITH_PYTHON) endif(NOT APPLE AND NOT ANDROID AND NOT WIN32) endif(WITH_AMD_GPU) - if(WIN32) - target_link_libraries(paddle_pybind shlwapi) - endif(WIN32) + get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(paddle_pybind ${os_dependency_modules}) cc_test(tensor_py_test SRCS tensor_py_test.cc DEPS python) endif(WITH_PYTHON) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5c1c7478f4dbe4c78f5ac2c19f4eae09abbf1c8b..dbc7843caa0c0a39a32cda6050fa99a3ab4c3e22 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -26,7 +26,9 @@ void BindTracer(pybind11::module *m) { [](imperative::Tracer &self, framework::BlockDesc *root_block) { new (&self) imperative::Tracer(root_block); }) - .def("trace", &imperative::Tracer::Trace); + .def("trace", &imperative::Tracer::Trace) + .def("py_trace", &imperative::Tracer::PyTrace, + pybind11::return_value_policy::take_ownership); } } // namespace pybind diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h index 7a9d3a01ea81f11ac85000c3d0153f20e108789a..f947b743f99d5d4994b1a87f89fd6815357d8125 100644 --- a/paddle/fluid/pybind/imperative.h +++ b/paddle/fluid/pybind/imperative.h @@ -22,7 +22,7 @@ limitations under the License. */ namespace paddle { namespace pybind { -class PyLayer : public imperative::Layer { +class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors @@ -31,10 +31,6 @@ class PyLayer : public imperative::Layer { PYBIND11_OVERLOAD(std::vector, Layer, Forward, inputs); // NOLINT } - - void Backward() override { - PYBIND11_OVERLOAD(void, Layer, Backward, ); // NOLINT - } }; class PyOpBase : public imperative::OpBase { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a15a5b8ea3c502a9320ffde45f53b042f940f2af..f3f4854a9efbcf5ab325e7f6aec81135c018dcd5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -161,16 +161,44 @@ PYBIND11_MODULE(core, m) { self.op_desc_ = op_desc; } }, + py::return_value_policy::reference) + .def_property( + "forward_id", + [](const imperative::OpBase &self) { return self.forward_id_; }, + [](imperative::OpBase &self, int forward_id) { + self.forward_id_ = forward_id; + }, + py::return_value_policy::reference) + .def_property( + "backward_id", + [](const imperative::OpBase &self) { return self.backward_id_; }, + [](imperative::OpBase &self, int backward_id) { + self.backward_id_ = backward_id; + }, py::return_value_policy::reference); - py::class_ layer(m, "Layer"); + py::class_ layer(m, "Layer"); layer.def(py::init<>()) - .def("forward", - [](imperative::Layer &self, - const std::vector &inputs) { - return self.Forward(inputs); - }) - .def("backward", &imperative::Layer::Backward); + .def("forward", [](imperative::Layer &self, + const std::vector &inputs) { + return self.Forward(inputs); + }); + + py::class_(m, "PyLayer") + .def(py::init<>()) + .def_static( + "apply", + [](int func_id, const std::vector &inputs) + -> std::vector { + return imperative::PyLayer::Apply(func_id, inputs); + }, + py::return_value_policy::take_ownership) + .def_static("register_func", + [](int func_id, const py::object &callable) { + imperative::PyLayer::RegisterFunc(func_id, callable); + }) + .def_static("num_funcs", &imperative::PyLayer::NumFuncs); + BindTracer(&m); py::class_(m, "Tensor", py::buffer_protocol()) diff --git a/python/paddle/dataset/mnist.py b/python/paddle/dataset/mnist.py index 38addd0cfd9bd0afde7eefc57f2111b717b7e636..847ca187206f8932e5454ddad881a94910efb55f 100644 --- a/python/paddle/dataset/mnist.py +++ b/python/paddle/dataset/mnist.py @@ -21,10 +21,9 @@ parse training set and test set into paddle reader creators. from __future__ import print_function import paddle.dataset.common -import subprocess +import gzip import numpy -import platform -import tempfile +import struct from six.moves import range __all__ = ['train', 'test', 'convert'] @@ -41,51 +40,47 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' def reader_creator(image_filename, label_filename, buffer_size): def reader(): - if platform.system() == 'Darwin': - zcat_cmd = 'gzcat' - elif platform.system() == 'Linux': - zcat_cmd = 'zcat' - else: - raise NotImplementedError() - - # According to http://stackoverflow.com/a/38061619/724872, we - # cannot use standard package gzip here. - tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset') - m = subprocess.Popen( - [zcat_cmd, image_filename], stdout=tmp_image_file).communicate() - tmp_image_file.seek(16) # skip some magic bytes - - # Python3 will not take stdout as file - tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset') - l = subprocess.Popen( - [zcat_cmd, label_filename], stdout=tmp_label_file).communicate() - tmp_label_file.seek(8) # skip some magic bytes - - try: # reader could be break. - while True: - labels = numpy.fromfile( - tmp_label_file, 'ubyte', count=buffer_size).astype("int") - - if labels.size != buffer_size: - break # numpy.fromfile returns empty slice after EOF. - - images = numpy.fromfile( - tmp_image_file, 'ubyte', count=buffer_size * 28 * - 28).reshape((buffer_size, 28 * 28)).astype('float32') - - images = images / 255.0 * 2.0 - 1.0 - - for i in range(buffer_size): - yield images[i, :], int(labels[i]) - finally: - try: - m.terminate() - except: - pass - try: - l.terminate() - except: - pass + with gzip.GzipFile(image_filename, 'rb') as image_file: + img_buf = image_file.read() + with gzip.GzipFile(label_filename, 'rb') as label_file: + lab_buf = label_file.read() + + step_label = 0 + + offset_img = 0 + # read from Big-endian + # get file info from magic byte + # image file : 16B + magic_byte_img = '>IIII' + magic_img, image_num, rows, cols = struct.unpack_from( + magic_byte_img, img_buf, offset_img) + offset_img += struct.calcsize(magic_byte_img) + + offset_lab = 0 + # label file : 8B + magic_byte_lab = '>II' + magic_lab, label_num = struct.unpack_from(magic_byte_lab, + lab_buf, offset_lab) + offset_lab += struct.calcsize(magic_byte_lab) + + while True: + if step_label >= label_num: + break + fmt_label = '>' + str(buffer_size) + 'B' + labels = struct.unpack_from(fmt_label, lab_buf, offset_lab) + offset_lab += struct.calcsize(fmt_label) + step_label += buffer_size + + fmt_images = '>' + str(buffer_size * rows * cols) + 'B' + images_temp = struct.unpack_from(fmt_images, img_buf, + offset_img) + images = numpy.reshape(images_temp, ( + buffer_size, rows * cols)).astype('float32') + offset_img += struct.calcsize(fmt_images) + + images = images / 255.0 * 2.0 - 1.0 + for i in range(buffer_size): + yield images[i, :], int(labels[i]) return reader diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 686550a3c8d7d55f06b03132124621c5d0db342f..2c17716500ababfab3216a5ec47fecca30065ff1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,8 +155,7 @@ def __bootstrap__(): 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', - 'sync_nccl_allreduce', 'limit_of_tmp_allocation', - 'times_excess_than_required_tmp_allocation' + 'sync_nccl_allreduce' ] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py index d78d61eb3f02c27ec44806ae52e134068c2cb9be..8027d9ba3bcf4d37f3573bc928faf574dcde1038 100644 --- a/python/paddle/fluid/imperative/layers.py +++ b/python/paddle/fluid/imperative/layers.py @@ -20,10 +20,12 @@ from paddle.fluid import core from paddle.fluid import framework from paddle.fluid.imperative import base -__all__ = ['PyLayer'] +__all__ = ['Layer', 'PyLayer'] -class PyLayer(core.Layer): +class Layer(core.Layer): + """Layers composed of operators.""" + def __init__(self, dtype=core.VarDesc.VarType.FP32, name=None): self._once_built = False self._dtype = dtype @@ -37,8 +39,56 @@ class PyLayer(core.Layer): self._once_built = True outputs = self.forward(*inputs) - return outputs def forward(self, *inputs): raise NotImplementedError + + def backward(self, *inputs): + raise ValueError("Layer shouldn't implement backward") + + +class PyLayer(core.PyLayer): + """Layers composed of user-defined python codes.""" + + def __init__(self): + super(PyLayer, self).__init__() + + @staticmethod + def forward(inputs): + raise NotImplementedError + + @staticmethod + def backward(douts): + raise NotImplementedError + + @classmethod + def __call__(cls, inputs): + tracer = framework._imperative_tracer() + block = framework.default_main_program().current_block() + inputs = [x._ivar for x in inputs] + + if not hasattr(cls, 'forward_id'): + cls.forward_id = core.PyLayer.num_funcs() + 1 + PyLayer.register_func(cls.forward_id, cls.forward) + cls.backward_id = core.PyLayer.num_funcs() + 1 + PyLayer.register_func(cls.backward_id, cls.backward) + + iop = core.OpBase() + iop.forward_id = cls.forward_id + iop.backward_id = cls.backward_id + block.ops.append(iop) + ivars = tracer.py_trace(iop, inputs, False) + # ivars = core.PyLayer.apply(cls.forward, inputs) + ret = [] + for ivar in ivars: + tensor = ivar.value.get_tensor() + py_var = framework.Variable( + block, + type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=tensor.shape(), + dtype=tensor._dtype(), + ivar=ivar) + ret.append(py_var) + return ret diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index 4f30417e99d21bcb66dacaab0257816c4d77f932..8754e5d4d0c8c829303f1fe9cd39ead36619ac3b 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -30,7 +30,7 @@ __all__ = [ ] -class Conv2D(layers.PyLayer): +class Conv2D(layers.Layer): def __init__(self, num_channels, num_filters, @@ -143,7 +143,7 @@ class Conv2D(layers.PyLayer): return self._helper.append_activation(pre_act) -class Pool2D(layers.PyLayer): +class Pool2D(layers.Layer): def __init__(self, pool_size=-1, pool_type="max", @@ -205,7 +205,7 @@ class Pool2D(layers.PyLayer): return pool_out -class FC(layers.PyLayer): +class FC(layers.Layer): def __init__(self, size, param_attr=None, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 615a35ba916f813399dc21a87646884b3d01081e..a4787e769f62ebbefd3ea6b70b402e660c02b576 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -58,6 +58,7 @@ __all__ = [ 'adaptive_pool2d', 'adaptive_pool3d', 'batch_norm', + 'data_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose', @@ -180,6 +181,7 @@ __all__ = [ 'lstm', 'py_func', 'psroi_pool', + 'teacher_student_sigmoid_loss', 'huber_loss', ] @@ -2896,6 +2898,133 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +def data_norm(input, + act=None, + epsilon=1e-05, + param_attr=None, + data_layout='NCHW', + in_place=False, + use_mkldnn=False, + name=None, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=False): + """ + **Data Normalization Layer** + + Can be used as a normalizer function for conv2d and fully_connected operations. + The required data format for this layer is one of the following: + + 1. NHWC `[batch, in_height, in_width, in_channels]` + + 2. NCHW `[batch, in_channels, in_height, in_width]` + + :math:`input` is the input features over a mini-batch. + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + Args: + input(variable): The input variable which is a LoDTensor. + act(string, Default None): Activation type, linear|relu|prelu|... + epsilon(float, Default 1e-05): + param_attr(ParamAttr): The parameter attribute for Parameter `scale`. + data_layout(string, default NCHW): NCHW|NHWC + in_place(bool, Default False): Make the input and output of batch norm reuse memory. + use_mkldnn(bool, Default false): ${use_mkldnn_comment} + name(string, Default None): A name for this layer(optional). If set None, the layer + will be named automatically. + moving_mean_name(string, Default None): The name of moving_mean which store the global Mean. + moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance. + do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not. + + Returns: + Variable: A tensor variable which is the result after applying data normalization on the input. + + Examples: + + .. code-block:: python + + data = fluid.layers.data(input=x, size=200, param_attr='fc1.w') + hidden2 = fluid.layers.data_norm(input=hidden1) + """ + helper = LayerHelper('data_norm', **locals()) + dtype = helper.input_dtype() + + input_shape = input.shape + if data_layout == 'NCHW': + channel_num = input_shape[1] + else: + if data_layout == 'NHWC': + channel_num = input_shape[-1] + else: + raise ValueError("unsupported data layout:" + data_layout) + + param_shape = [channel_num] + + batch_size_default = 1e4 + batch_sum_default = 0.0 + batch_square_sum_default = 1e4 + + if param_attr and isinstance(param_attr, dict): + batch_size_default = param_attr.get("batch_size", 1e4) + batch_sum_default = param_attr.get("batch_sum", 0.0) + batch_square_sum_default = param_attr.get("batch_square", 1e4) + + # create parameter + batch_size = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_size', + initializer=Constant(value=float(batch_size_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + batch_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_sum', + initializer=Constant(value=float(batch_sum_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + batch_square_sum = helper.create_parameter( + attr=ParamAttr( + name=name + '.batch_square_sum', + initializer=Constant(value=float(batch_square_sum_default)), + trainable=True), + shape=param_shape, + dtype=input.dtype) + + means = helper.create_variable(dtype=dtype, stop_gradient=True) + scales = helper.create_variable(dtype=dtype, stop_gradient=True) + + data_norm_out = input if in_place else helper.create_variable(dtype=dtype) + + helper.append_op( + type="data_norm", + inputs={ + "X": input, + "BatchSize": batch_size, + "BatchSum": batch_sum, + "BatchSquareSum": batch_square_sum + }, + outputs={"Y": data_norm_out, + "Means": means, + "Scales": scales}, + attrs={"epsilon": epsilon, + "use_mkldnn": use_mkldnn}) + + return helper.append_activation(data_norm_out) + + @templatedoc() def layer_norm(input, scale=True, @@ -3064,9 +3193,9 @@ def group_norm(input, inputs['Bias'] = bias # create output - mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) - group_norm_out = helper.create_tmp_variable(dtype) + mean_out = helper.create_variable(dtype=dtype, stop_gradient=True) + variance_out = helper.create_variable(dtype=dtype, stop_gradient=True) + group_norm_out = helper.create_variable(dtype) helper.append_op( type="group_norm", @@ -9264,6 +9393,47 @@ def log_loss(input, label, epsilon=1e-4, name=None): return loss +def teacher_student_sigmoid_loss(input, + label, + soft_max_up_bound=15.0, + soft_max_lower_bound=-15.0): + """ + **Teacher Student Log Loss Layer** + + This layer accepts input predictions and target label and returns the + teacher_student loss. + + .. math:: + loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + max(x, 0) - x * z' + log(1 + exp(-abs(x))) + + Args: + input (Variable|list): a 2-D tensor with shape [N x 1], where N is the + batch size. This input is a probability computed + by the previous operator. + label (Variable|list): the ground truth which is a 2-D tensor with + shape [N x 1], where N is the batch size. + soft_max_up_bound (float): if input > soft_max_up_bound, will be bound + soft_max_lower_bound (float): if input < soft_max_lower_bound, will be bound + + Returns: + Variable: A 2-D tensor with shape [N x 1], the teacher_student_sigmoid_loss. + + Examples: + .. code-block:: python + cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label) + """ + helper = LayerHelper('teacher_student_sigmoid_loss', **locals()) + out = helper.create_variable(dtype=input.dtype) + helper.append_op( + type='teacher_student_sigmoid_loss', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out]}, + attrs={"soft_max_lower_bound": float(soft_max_lower_bound), \ + "soft_max_up_bound": float(soft_max_up_bound)}) + return out + + def add_position_encoding(input, alpha, beta, name=None): """ **Add Position Encoding Layer** diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index 1dc13ec74e8da1f13d447950b3c7822bbbecb2a7..e3e1ce7ca3127969e9c4430649a18b08e0e71889 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -15,6 +15,7 @@ import contextlib import unittest import numpy as np +import sys import paddle.fluid as fluid from paddle.fluid import core @@ -22,7 +23,7 @@ from paddle.fluid.imperative.nn import FC from test_imperative_base import new_program_scope -class MyLayer(fluid.imperative.PyLayer): +class MyLayer(fluid.imperative.Layer): def __init__(self): super(MyLayer, self).__init__() @@ -34,7 +35,35 @@ class MyLayer(fluid.imperative.PyLayer): return [x] -class MLP(fluid.imperative.PyLayer): +class MyPyLayer(fluid.imperative.PyLayer): + def __init__(self): + super(MyPyLayer, self).__init__() + + @staticmethod + def forward(inputs): + sys.stderr.write('before forward\n') + ret = np.tanh(inputs[0]) + sys.stderr.write('after forward: %s\n' % ret) + tensor = core.LoDTensor() + tensor.set(ret, core.CPUPlace()) + return tuple([tensor]) + + @staticmethod + def backward(inputs): + sys.stderr.write('calling into backward: %s\n' % str(inputs)) + inp, out, dout = inputs + inp = np.array(inp) + out = np.array(out) + dout = np.array(dout) + sys.stderr.write('calling into backward: %s, %s, %s\n' % + (inp, out, dout)) + ret = np.array(dout) * (1 - np.square(np.array(out))) + tensor = core.LoDTensor() + tensor.set(ret, core.CPUPlace()) + return tuple([tensor]) + + +class MLP(fluid.imperative.Layer): def __init__(self): super(MLP, self).__init__() self._fc1 = FC(3, @@ -56,9 +85,77 @@ class TestImperative(unittest.TestCase): with fluid.imperative.guard(): cl = core.Layer() cl.forward([]) - l = fluid.imperative.PyLayer() + l = fluid.imperative.Layer() self.assertRaises(NotImplementedError, l.forward, []) + def test_pylayer_func_id(self): + + with fluid.imperative.guard(): + + class PyLayer1(fluid.imperative.PyLayer): + def __init__(self): + super(PyLayer1, self).__init__() + + @staticmethod + def forward(inputs): + return inputs + + @staticmethod + def backward(inputs): + return inputs + + class PyLayer2(fluid.imperative.PyLayer): + def __init__(self): + super(PyLayer2, self).__init__() + + @staticmethod + def forward(inputs): + return inputs + + @staticmethod + def backward(inputs): + return inputs + + py_layer_1 = PyLayer1() + py_layer_2 = PyLayer2() + py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) + py_layer_2([fluid.imperative.base.to_variable(np.ones([2, 2]))]) + id = py_layer_1.forward_id + self.assertGreater(id, 0) + self.assertEqual(py_layer_1.backward_id, id + 1) + self.assertEqual(py_layer_2.forward_id, id + 2) + self.assertEqual(py_layer_2.backward_id, id + 3) + py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) + self.assertEqual(py_layer_1.forward_id, id) + + def test_pylayer(self): + np_inp = np.ones([2, 2], np.float32) + with fluid.imperative.guard(): + my_py_layer = MyPyLayer() + var_inp = fluid.imperative.base.to_variable(np_inp) + outs = my_py_layer([var_inp]) + dy_out = np.sum(outs[0]._numpy()) + outs[0]._backward() + dy_grad = var_inp._gradient() + + with new_program_scope(): + inp = fluid.layers.data( + name="inp", shape=[2, 2], append_batch_size=False) + # TODO(panyx0718): Paddle doesn't diff against data `inp`. + x1 = inp * 1 + # TODO(panyx0718): If reduce_sum is skipped, the result is wrong. + x = fluid.layers.reduce_sum(fluid.layers.tanh(x1)) + param_grads = fluid.backward.append_backward( + x, parameter_list=[x1.name])[0] + exe = fluid.Executor(fluid.CPUPlace()) + + static_out, static_grad = exe.run( + feed={inp.name: np_inp}, + fetch_list=[x.name, param_grads[1].name]) + + self.assertTrue(np.allclose(dy_out, static_out)) + self.assertTrue(np.allclose(dy_grad, static_grad)) + def test_layer_in_out(self): np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32) with fluid.imperative.guard(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py index 42896336b5971d54e31ef77a092c8d2d15b5e318..63eeae4b712c2064309b664b91d5f0347b67817d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer.py @@ -26,7 +26,7 @@ from paddle.fluid.imperative.base import to_variable from test_imperative_base import new_program_scope -class SimpleImgConvPool(fluid.imperative.PyLayer): +class SimpleImgConvPool(fluid.imperative.Layer): def __init__(self, num_channels, num_filters, @@ -72,7 +72,7 @@ class SimpleImgConvPool(fluid.imperative.PyLayer): return x -class MNIST(fluid.imperative.PyLayer): +class MNIST(fluid.imperative.Layer): def __init__(self, param_attr=None, bias_attr=None): super(MNIST, self).__init__() diff --git a/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py b/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..26bf0fd88368ed27e142e8515ec57a6c6bebd6fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_teacher_student_sigmoid_loss_op.py @@ -0,0 +1,59 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from math import log +from math import exp +from op_test import OpTest +from scipy.special import logit +from scipy.special import expit +import unittest + + +class TestTeacherStudentSigmoidLossOp(OpTest): + """ + Test teacher_student_sigmoid_loss with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "teacher_student_sigmoid_loss" + batch_size = 16 + num_classes = 1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Label': np.random.uniform(0, 2, (batch_size, num_classes)) + .astype("float32") + } + outs = [] + for index, label in enumerate(self.inputs["Label"]): + x = self.inputs["X"][index] + if label < -1.0: + outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x)))) + elif label < 0.0: + outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x)))) + elif label < 1.0: + outs.append(max(x, 0.0) + log(1.0 + exp(-abs(x))) + \ + max(x, 0.0) - x * label + log(1.0 + exp(-abs(x)))) + else: + outs.append(max(x, 0.0) - x + log(1.0 + exp(-abs(x))) + \ + max(x, 0.0) - x * (label - 1.0) + log(1.0 + exp(-abs(x)))) + self.outputs = {'Y': np.array(outs)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.005)