未验证 提交 cc8a7858 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Blacklist bwd comp (#50148)

* refactor dir for prim

* support blacklist for bwd comp

* fix type error

* remove additional file

* fix git ignore

* add more test

* merge develop
上级 b76594a0
...@@ -330,7 +330,7 @@ NODE_CC_FILE_TEMPLATE = """ ...@@ -330,7 +330,7 @@ NODE_CC_FILE_TEMPLATE = """
#include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/phi/api/include/sparse_api.h" #include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" #include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/utils/utils.h" #include "paddle/fluid/prim/utils/utils.h"
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle { namespace paddle {
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
......
generated/prim_api/eager_prim_api.cc generated_prim/*.cc
generated/prim_api/tmp_eager_prim_api.cc generated_prim/*.h
generated/prim_api/*.h
add_subdirectory(auto_code_generated) add_subdirectory(auto_code_generated)
add_subdirectory(manual) add_subdirectory(manual_prim)
add_subdirectory(generated) add_subdirectory(generated_prim)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library( cc_library(
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" #include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
...@@ -5,16 +5,17 @@ set(legacy_api_yaml_path ...@@ -5,16 +5,17 @@ set(legacy_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml" "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
) )
set(tmp_eager_prim_api_cc_path set(tmp_eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_eager_prim_api.cc" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc"
) )
set(tmp_prim_api_h_path set(tmp_prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_prim_api.h" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h"
) )
set(eager_prim_api_cc_path set(eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc" "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
) )
set(prim_api_h_path set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/prim_api.h") "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
)
set(prim_api_gen_file set(prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py) ${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py)
......
...@@ -28,11 +28,11 @@ def header_include(): ...@@ -28,11 +28,11 @@ def header_include():
""" """
def eager_source_include(header_file_path): def eager_source_include():
return """ return """
#include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
""" """
...@@ -73,10 +73,7 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path): ...@@ -73,10 +73,7 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
header_file.write(header_include()) header_file.write(header_include())
header_file.write(namespace[0]) header_file.write(namespace[0])
header_file.write(namespace[1]) header_file.write(namespace[1])
include_header_file = ( eager_prim_source_file.write(eager_source_include())
"#include paddle/fluid/prim/api/generated/prim_api/prim_api.h"
)
eager_prim_source_file.write(eager_source_include(include_header_file))
eager_prim_source_file.write(namespace[0]) eager_prim_source_file.write(namespace[0])
for api in apis: for api in apis:
...@@ -106,13 +103,13 @@ def main(): ...@@ -106,13 +103,13 @@ def main():
parser.add_argument( parser.add_argument(
'--prim_api_header_path', '--prim_api_header_path',
help='output of generated prim_api header code file', help='output of generated prim_api header code file',
default='paddle/fluid/prim/api/generated/prim_api/prim_api.h', default='paddle/fluid/prim/api/generated_prim/prim_generated_api.h',
) )
parser.add_argument( parser.add_argument(
'--eager_prim_api_source_path', '--eager_prim_api_source_path',
help='output of generated eager_prim_api source code file', help='output of generated eager_prim_api source code file',
default='paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc', default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc',
) )
options = parser.parse_args() options = parser.parse_args()
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
......
cc_library(
static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library( cc_library(
eager_prim_api eager_prim_api
......
add_subdirectory(utils)
cc_library(
static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,15 +12,16 @@ ...@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// prim api which can't be generated
#pragma once #pragma once
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
// TODO(jiabin): Make this Header only for handwritten api, instead of include
// prim_generated_api.h
namespace paddle { namespace paddle {
namespace prim {} // namespace prim namespace prim {} // namespace prim
} // namespace paddle } // namespace paddle
...@@ -26,9 +26,8 @@ ...@@ -26,9 +26,8 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/utils.h" #include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
......
...@@ -69,6 +69,18 @@ class StaticCompositeContext { ...@@ -69,6 +69,18 @@ class StaticCompositeContext {
enable_bwd_prim_ = enable_prim; enable_bwd_prim_ = enable_prim;
} }
size_t CheckSkipCompOps(const std::string& op_type) const {
return skip_comp_ops_.count(op_type);
}
void AddSkipCompOps(const std::string& op_type) {
skip_comp_ops_.insert(op_type);
}
void RemoveSkipCompOps(const std::string& op_type) {
skip_comp_ops_.erase(op_type);
}
void SetTargetGradName(const std::map<std::string, std::string>& m) { void SetTargetGradName(const std::map<std::string, std::string>& m) {
target_grad_name_ = m; target_grad_name_ = m;
} }
...@@ -79,10 +91,13 @@ class StaticCompositeContext { ...@@ -79,10 +91,13 @@ class StaticCompositeContext {
private: private:
StaticCompositeContext() StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {} : current_block_desc_(nullptr),
generator_(new UniqueNameGenerator()),
skip_comp_ops_({"matmul_v2"}) {}
framework::BlockDesc* current_block_desc_; framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
std::unordered_set<std::string> skip_comp_ops_;
std::map<std::string, std::string> target_grad_name_; std::map<std::string, std::string> target_grad_name_;
static thread_local bool enable_bwd_prim_; static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_; static thread_local bool enable_fwd_prim_;
......
...@@ -24,7 +24,7 @@ bool PrimCommonUtils::IsBwdPrimEnabled() { ...@@ -24,7 +24,7 @@ bool PrimCommonUtils::IsBwdPrimEnabled() {
} }
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
} }
bool PrimCommonUtils::IsFwdPrimEnabled() { bool PrimCommonUtils::IsFwdPrimEnabled() {
...@@ -32,11 +32,23 @@ bool PrimCommonUtils::IsFwdPrimEnabled() { ...@@ -32,11 +32,23 @@ bool PrimCommonUtils::IsFwdPrimEnabled() {
} }
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
} }
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
size_t PrimCommonUtils::CheckSkipCompOps(const std::string& op_type) {
return StaticCompositeContext::Instance().CheckSkipCompOps(op_type);
}
void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().AddSkipCompOps(op_type);
}
void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().RemoveSkipCompOps(op_type);
} }
void PrimCommonUtils::SetTargetGradName( void PrimCommonUtils::SetTargetGradName(
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace prim { namespace prim {
...@@ -26,6 +26,9 @@ class PrimCommonUtils { ...@@ -26,6 +26,9 @@ class PrimCommonUtils {
static bool IsFwdPrimEnabled(); static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled); static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled);
static size_t CheckSkipCompOps(const std::string& op_type);
static void AddSkipCompOps(const std::string& op_type);
static void RemoveSkipCompOps(const std::string& op_type);
static void SetTargetGradName(const std::map<std::string, std::string>& m); static void SetTargetGradName(const std::map<std::string, std::string>& m);
}; };
} // namespace prim } // namespace prim
......
...@@ -1246,6 +1246,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1246,6 +1246,9 @@ All parameter, weight, gradient are variables in Paddle.
return static_cast<paddle::framework::proto::AttrType>( return static_cast<paddle::framework::proto::AttrType>(
defalut_val.index() - 1); defalut_val.index() - 1);
}); });
m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps);
m.def("_remove_skip_comp_ops",
&paddle::prim::PrimCommonUtils::RemoveSkipCompOps);
m.def("get_grad_op_desc", m.def("get_grad_op_desc",
[](const OpDesc &op_desc, [](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set, const std::unordered_set<std::string> &no_grad_set,
...@@ -1277,8 +1280,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1277,8 +1280,11 @@ All parameter, weight, gradient are variables in Paddle.
// priority of CompGradOpMaker is less than GradCompMaker for better // priority of CompGradOpMaker is less than GradCompMaker for better
// performance. // performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs; std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
auto need_skip =
paddle::prim::PrimCommonUtils::CheckSkipCompOps(op_desc.Type());
VLOG(3) << "need skip: " << need_skip << std::endl;
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) { if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) { if ((grad_comp_op_maker != nullptr) && (!need_skip)) {
VLOG(3) << "Runing composite fun for " << op_desc.Type(); VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc, grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set, no_grad_set,
......
...@@ -306,6 +306,8 @@ try: ...@@ -306,6 +306,8 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name from .libpaddle import _get_phi_kernel_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _remove_skip_comp_ops
# prim controller flags # prim controller flags
from .libpaddle import __set_bwd_prim_enabled from .libpaddle import __set_bwd_prim_enabled
...@@ -409,7 +411,7 @@ def __sync_stat_with_flag(flag): ...@@ -409,7 +411,7 @@ def __sync_stat_with_flag(flag):
__set_fwd_prim_enabled(True) __set_fwd_prim_enabled(True)
else: else:
raise TypeError(f"flag {flag} should be true or false.") raise TypeError(f"flag {flag} should be true or false.")
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
elif flag is "FLAGS_prim_backward": elif flag is "FLAGS_prim_backward":
flag_value = os.getenv("FLAGS_prim_backward") flag_value = os.getenv("FLAGS_prim_backward")
assert flag_value is not None assert flag_value is not None
...@@ -420,7 +422,7 @@ def __sync_stat_with_flag(flag): ...@@ -420,7 +422,7 @@ def __sync_stat_with_flag(flag):
__set_bwd_prim_enabled(True) __set_bwd_prim_enabled(True)
else: else:
raise TypeError(f"flag {flag} should be true or false.") raise TypeError(f"flag {flag} should be true or false.")
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
elif flag is "FLAGS_prim_all": elif flag is "FLAGS_prim_all":
flag_value = os.getenv("FLAGS_prim_all") flag_value = os.getenv("FLAGS_prim_all")
assert flag_value is not None assert flag_value is not None
...@@ -431,7 +433,7 @@ def __sync_stat_with_flag(flag): ...@@ -431,7 +433,7 @@ def __sync_stat_with_flag(flag):
__set_all_prim_enabled(True) __set_all_prim_enabled(True)
else: else:
raise TypeError(f"flag {flag} should be true or false.") raise TypeError(f"flag {flag} should be true or false.")
logging.debug( print(
"all prim enabled: ", "all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
) )
...@@ -441,19 +443,24 @@ def __sync_stat_with_flag(flag): ...@@ -441,19 +443,24 @@ def __sync_stat_with_flag(flag):
) )
# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors.
def _test_use_sync(value):
__sync_stat_with_flag(value)
def _set_prim_backward_enabled(value): def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value)) __set_bwd_prim_enabled(bool(value))
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def _set_prim_forward_enabled(value): def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value)) __set_fwd_prim_enabled(bool(value))
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def _set_prim_all_enabled(value): def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value)) __set_all_prim_enabled(bool(value))
logging.debug( print(
"all prim enabled: ", "all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
) )
...@@ -462,7 +469,7 @@ def _set_prim_all_enabled(value): ...@@ -462,7 +469,7 @@ def _set_prim_all_enabled(value):
def __sync_prim_backward_status(): def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward") flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None: if flag_value is None:
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else: else:
__sync_stat_with_flag("FLAGS_prim_backward") __sync_stat_with_flag("FLAGS_prim_backward")
...@@ -470,7 +477,7 @@ def __sync_prim_backward_status(): ...@@ -470,7 +477,7 @@ def __sync_prim_backward_status():
def __sync_prim_forward_status(): def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward") flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None: if flag_value is None:
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else: else:
__sync_stat_with_flag("FLAGS_prim_forward") __sync_stat_with_flag("FLAGS_prim_forward")
......
...@@ -47,6 +47,16 @@ class TestPrimFlags(unittest.TestCase): ...@@ -47,6 +47,16 @@ class TestPrimFlags(unittest.TestCase):
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled()) self.assertFalse(core._is_fwd_prim_enabled())
del os.environ['FLAGS_prim_backward']
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
del os.environ['FLAGS_prim_forward']
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
with self.assertRaises(TypeError):
core._test_use_sync("aaaa")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core, framework
class TestGetGradOpDescPrimEnabled(unittest.TestCase):
def setUp(self):
self.fwd_type = 'tanh'
self.inputs = {'X': ['x']}
self.outputs = {'Out': ['y']}
self.no_grad_var = set()
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = ('pow', 'scale', 'elementwise_mul')
paddle.enable_static()
block = framework.Block(framework.Program(), 0)
block.append_op(
type=self.fwd_type,
inputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in self.inputs.items()
},
outputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in self.outputs.items()
},
)
self.fwd = block.ops[0].desc
def tearDown(self):
paddle.disable_static()
def test_get_grad_op_desc_without_skip(self):
core._set_prim_backward_enabled(True)
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
self.assertEquals(actual, self.desired_ops_no_skip)
core._set_prim_backward_enabled(False)
def test_get_grad_op_desc_with_skip(self):
core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("tanh")
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
core._remove_skip_comp_ops("tanh")
self.assertEquals(actual[0], self.desired_ops)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册