未验证 提交 cc8a7858 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Blacklist bwd comp (#50148)

* refactor dir for prim

* support blacklist for bwd comp

* fix type error

* remove additional file

* fix git ignore

* add more test

* merge develop
上级 b76594a0
......@@ -330,7 +330,7 @@ NODE_CC_FILE_TEMPLATE = """
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/utils/utils.h"
DECLARE_bool(check_nan_inf);
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
......
......@@ -5,7 +5,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
......
......@@ -17,7 +17,7 @@
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
......
generated/prim_api/eager_prim_api.cc
generated/prim_api/tmp_eager_prim_api.cc
generated/prim_api/*.h
generated_prim/*.cc
generated_prim/*.h
add_subdirectory(auto_code_generated)
add_subdirectory(manual)
add_subdirectory(generated)
add_subdirectory(manual_prim)
add_subdirectory(generated_prim)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
......
......@@ -13,6 +13,6 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
......@@ -5,16 +5,17 @@ set(legacy_api_yaml_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
set(tmp_eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_eager_prim_api.cc"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc"
)
set(tmp_prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_prim_api.h"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h"
)
set(eager_prim_api_cc_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc"
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc"
)
set(prim_api_h_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/prim_api.h")
"${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
)
set(prim_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py)
......
......@@ -28,11 +28,11 @@ def header_include():
"""
def eager_source_include(header_file_path):
def eager_source_include():
return """
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
"""
......@@ -73,10 +73,7 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path):
header_file.write(header_include())
header_file.write(namespace[0])
header_file.write(namespace[1])
include_header_file = (
"#include paddle/fluid/prim/api/generated/prim_api/prim_api.h"
)
eager_prim_source_file.write(eager_source_include(include_header_file))
eager_prim_source_file.write(eager_source_include())
eager_prim_source_file.write(namespace[0])
for api in apis:
......@@ -106,13 +103,13 @@ def main():
parser.add_argument(
'--prim_api_header_path',
help='output of generated prim_api header code file',
default='paddle/fluid/prim/api/generated/prim_api/prim_api.h',
default='paddle/fluid/prim/api/generated_prim/prim_generated_api.h',
)
parser.add_argument(
'--eager_prim_api_source_path',
help='output of generated eager_prim_api source code file',
default='paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc',
default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc',
)
options = parser.parse_args()
......
......@@ -13,9 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/all.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"
......
cc_library(
static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
eager_prim_api
......
add_subdirectory(utils)
cc_library(
static_prim_api
SRCS static_prim_api.cc
DEPS proto_desc static_utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,15 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// prim api which can't be generated
#pragma once
#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"
// TODO(jiabin): Make this Header only for handwritten api, instead of include
// prim_generated_api.h
namespace paddle {
namespace prim {} // namespace prim
} // namespace paddle
......@@ -26,9 +26,8 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/phi/api/include/tensor.h"
......
......@@ -19,7 +19,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/prim/utils/utils.h"
#include "paddle/phi/core/enforce.h"
......
......@@ -69,6 +69,18 @@ class StaticCompositeContext {
enable_bwd_prim_ = enable_prim;
}
size_t CheckSkipCompOps(const std::string& op_type) const {
return skip_comp_ops_.count(op_type);
}
void AddSkipCompOps(const std::string& op_type) {
skip_comp_ops_.insert(op_type);
}
void RemoveSkipCompOps(const std::string& op_type) {
skip_comp_ops_.erase(op_type);
}
void SetTargetGradName(const std::map<std::string, std::string>& m) {
target_grad_name_ = m;
}
......@@ -79,10 +91,13 @@ class StaticCompositeContext {
private:
StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {}
: current_block_desc_(nullptr),
generator_(new UniqueNameGenerator()),
skip_comp_ops_({"matmul_v2"}) {}
framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
std::unordered_set<std::string> skip_comp_ops_;
std::map<std::string, std::string> target_grad_name_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
......
......@@ -24,7 +24,7 @@ bool PrimCommonUtils::IsBwdPrimEnabled() {
}
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
......@@ -32,11 +32,23 @@ bool PrimCommonUtils::IsFwdPrimEnabled() {
}
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
size_t PrimCommonUtils::CheckSkipCompOps(const std::string& op_type) {
return StaticCompositeContext::Instance().CheckSkipCompOps(op_type);
}
void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().AddSkipCompOps(op_type);
}
void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().RemoveSkipCompOps(op_type);
}
void PrimCommonUtils::SetTargetGradName(
......
......@@ -13,9 +13,9 @@
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <unordered_set>
namespace paddle {
namespace prim {
......@@ -26,6 +26,9 @@ class PrimCommonUtils {
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
static size_t CheckSkipCompOps(const std::string& op_type);
static void AddSkipCompOps(const std::string& op_type);
static void RemoveSkipCompOps(const std::string& op_type);
static void SetTargetGradName(const std::map<std::string, std::string>& m);
};
} // namespace prim
......
......@@ -1246,6 +1246,9 @@ All parameter, weight, gradient are variables in Paddle.
return static_cast<paddle::framework::proto::AttrType>(
defalut_val.index() - 1);
});
m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps);
m.def("_remove_skip_comp_ops",
&paddle::prim::PrimCommonUtils::RemoveSkipCompOps);
m.def("get_grad_op_desc",
[](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
......@@ -1277,8 +1280,11 @@ All parameter, weight, gradient are variables in Paddle.
// priority of CompGradOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
auto need_skip =
paddle::prim::PrimCommonUtils::CheckSkipCompOps(op_desc.Type());
VLOG(3) << "need skip: " << need_skip << std::endl;
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
if ((grad_comp_op_maker != nullptr) && (!need_skip)) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
......
......@@ -306,6 +306,8 @@ try:
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
from .libpaddle import _add_skip_comp_ops
from .libpaddle import _remove_skip_comp_ops
# prim controller flags
from .libpaddle import __set_bwd_prim_enabled
......@@ -409,7 +411,7 @@ def __sync_stat_with_flag(flag):
__set_fwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
elif flag is "FLAGS_prim_backward":
flag_value = os.getenv("FLAGS_prim_backward")
assert flag_value is not None
......@@ -420,7 +422,7 @@ def __sync_stat_with_flag(flag):
__set_bwd_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
elif flag is "FLAGS_prim_all":
flag_value = os.getenv("FLAGS_prim_all")
assert flag_value is not None
......@@ -431,7 +433,7 @@ def __sync_stat_with_flag(flag):
__set_all_prim_enabled(True)
else:
raise TypeError(f"flag {flag} should be true or false.")
logging.debug(
print(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
......@@ -441,19 +443,24 @@ def __sync_stat_with_flag(flag):
)
# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors.
def _test_use_sync(value):
__sync_stat_with_flag(value)
def _set_prim_backward_enabled(value):
__set_bwd_prim_enabled(bool(value))
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
def _set_prim_forward_enabled(value):
__set_fwd_prim_enabled(bool(value))
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
def _set_prim_all_enabled(value):
__set_all_prim_enabled(bool(value))
logging.debug(
print(
"all prim enabled: ",
bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()),
)
......@@ -462,7 +469,7 @@ def _set_prim_all_enabled(value):
def __sync_prim_backward_status():
flag_value = os.getenv("FLAGS_prim_backward")
if flag_value is None:
logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
print("backward prim enabled: ", bool(_is_bwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_backward")
......@@ -470,7 +477,7 @@ def __sync_prim_backward_status():
def __sync_prim_forward_status():
flag_value = os.getenv("FLAGS_prim_forward")
if flag_value is None:
logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
print("forward prim enabled: ", bool(_is_fwd_prim_enabled()))
else:
__sync_stat_with_flag("FLAGS_prim_forward")
......
......@@ -47,6 +47,16 @@ class TestPrimFlags(unittest.TestCase):
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
del os.environ['FLAGS_prim_backward']
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_bwd_prim_enabled())
del os.environ['FLAGS_prim_forward']
core.check_and_set_prim_all_enabled()
self.assertFalse(core._is_fwd_prim_enabled())
with self.assertRaises(TypeError):
core._test_use_sync("aaaa")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle.fluid import core, framework
class TestGetGradOpDescPrimEnabled(unittest.TestCase):
def setUp(self):
self.fwd_type = 'tanh'
self.inputs = {'X': ['x']}
self.outputs = {'Out': ['y']}
self.no_grad_var = set()
self.grad_sub_block = tuple()
self.desired_ops = 'tanh_grad'
self.desired_ops_no_skip = ('pow', 'scale', 'elementwise_mul')
paddle.enable_static()
block = framework.Block(framework.Program(), 0)
block.append_op(
type=self.fwd_type,
inputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in self.inputs.items()
},
outputs={
n: [block.create_var(name=v, stop_gradient=False) for v in vs]
for n, vs in self.outputs.items()
},
)
self.fwd = block.ops[0].desc
def tearDown(self):
paddle.disable_static()
def test_get_grad_op_desc_without_skip(self):
core._set_prim_backward_enabled(True)
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
self.assertEquals(actual, self.desired_ops_no_skip)
core._set_prim_backward_enabled(False)
def test_get_grad_op_desc_with_skip(self):
core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("tanh")
actual = tuple(
desc.type()
for desc in core.get_grad_op_desc(
self.fwd, self.no_grad_var, self.grad_sub_block
)[0]
)
core._remove_skip_comp_ops("tanh")
self.assertEquals(actual[0], self.desired_ops)
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册