未验证 提交 7470889f 编写于 作者: K kangguangli 提交者: GitHub

[NewIR]support attribute hook and fix_reduce_all (#55553)

* support attribute hook and fix_reduce_all

* resolve merge conflicts

* fix coverage ci

* trigger CI

* trigger CI

* fix coverage ci
上级 84a56b4a
...@@ -60,12 +60,14 @@ using OpAttributeInfo = dialect::OpAttributeInfo; ...@@ -60,12 +60,14 @@ using OpAttributeInfo = dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>; using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>;
using OpOutputInfo = dialect::OpOutputInfo; using OpOutputInfo = dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<dialect::OpOutputInfo>; using OpOutputInfoList = std::vector<dialect::OpOutputInfo>;
using InputHandleFn = std::function<ir::OpResult(ir::IrContext*, using InputHandlerFn = std::function<ir::OpResult(ir::IrContext*,
TranslationContext*, TranslationContext*,
const OpDesc&, const OpDesc&,
const std::string&, const std::string&,
const OpInputInfo&, const OpInputInfo&,
ir::Program*)>; ir::Program*)>;
using AttributeHandlerFn = std::function<ir::Attribute(
ir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;
constexpr char kTargetDialectPrefix[] = "pd."; constexpr char kTargetDialectPrefix[] = "pd.";
constexpr char kEmptyVarName[] = "@EMPTY@"; constexpr char kEmptyVarName[] = "@EMPTY@";
...@@ -291,7 +293,12 @@ struct OpTranscriber { ...@@ -291,7 +293,12 @@ struct OpTranscriber {
const OpOutputMapping& arg_to_idx); const OpOutputMapping& arg_to_idx);
public: public:
virtual InputHandleFn GetSpecialInputHandlers(std::string input_name) { virtual InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) {
return nullptr;
}
virtual AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) {
return nullptr; return nullptr;
} }
}; };
...@@ -558,6 +565,12 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( ...@@ -558,6 +565,12 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute(
ir::AttributeMap attribute_map = {}; ir::AttributeMap attribute_map = {};
for (const auto& info : op_attr_infos) { for (const auto& info : op_attr_infos) {
if (auto handler = this->GetSpecialAttributeHandlers(info.name)) {
auto new_attr = handler(ctx, op_desc, info);
attribute_map[info.name] = new_attr;
continue;
}
auto legacy_attr_name = auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
VLOG(10) << "[op: " << op_desc.Type() VLOG(10) << "[op: " << op_desc.Type()
...@@ -885,7 +898,8 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx, ...@@ -885,7 +898,8 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx,
// `rnn` has an aditional input in dynamic graph // `rnn` has an aditional input in dynamic graph
struct RnnOpTranscriber : public OpTranscriber { struct RnnOpTranscriber : public OpTranscriber {
InputHandleFn GetSpecialInputHandlers(std::string input_name) override { InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) override {
if (input_name != "dropout_state_in") { if (input_name != "dropout_state_in") {
return nullptr; return nullptr;
} }
...@@ -1207,7 +1221,8 @@ ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, ...@@ -1207,7 +1221,8 @@ ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx,
} }
struct OneHotTranscriber : public OpTranscriber { struct OneHotTranscriber : public OpTranscriber {
InputHandleFn GetSpecialInputHandlers(std::string input_name) override { InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) override {
if (input_name != "num_classes") { if (input_name != "num_classes") {
return nullptr; return nullptr;
} }
...@@ -1215,21 +1230,53 @@ struct OneHotTranscriber : public OpTranscriber { ...@@ -1215,21 +1230,53 @@ struct OneHotTranscriber : public OpTranscriber {
}; };
}; };
ir::Attribute TranslateReduceAll(ir::IrContext* ctx,
const OpDesc& op_desc,
const OpAttributeInfo& attr_info) {
bool reduce_all = false;
if (op_desc.HasAttr("reduce_all")) {
reduce_all = paddle::get<bool>(op_desc.GetAttr("reduce_all"));
}
if (reduce_all) {
return ir::ArrayAttribute::get(ctx, std::vector<ir::Attribute>{});
}
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), attr_info.name);
paddle::framework::Attribute dims = op_desc.GetAttr(legacy_attr_name);
return attribute_translator(attr_info.type_name, dims);
}
struct ReduceOpTranscriber : public OpTranscriber {
AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) override {
if (input_name != "axis") {
return nullptr;
}
return TranslateReduceAll;
}
};
OpTranslator::OpTranslator() { OpTranslator::OpTranslator() {
general_handler = OpTranscriber(); general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["feed"] = FeedOpTranscriber(); special_handlers["feed"] = FeedOpTranscriber();
special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber(); special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["cast"] = CastOpTranscriber(); special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber(); special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["increment"] = IncrementOpTranscriber(); special_handlers["reduce_all"] = ReduceOpTranscriber();
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber(); special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber(); special_handlers["split"] = SplitOpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber();
} }
......
...@@ -160,5 +160,39 @@ class TestOneHotOpTranscriber(unittest.TestCase): ...@@ -160,5 +160,39 @@ class TestOneHotOpTranscriber(unittest.TestCase):
_ = ir.translate_to_new_ir(main_program.desc) _ = ir.translate_to_new_ir(main_program.desc)
class TestReduceOpTranscriber(unittest.TestCase):
def test_reduce_all(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
arr = np.ones([2, 2], dtype="float32")
x = paddle.to_tensor(arr, dtype='int32')
out1 = paddle.all(x)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_array_equal(out[0], np.all(arr))
def test_with_axis(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
arr = np.ones([2, 2], dtype="float32")
x = paddle.to_tensor(arr, dtype='int32')
out1 = paddle.all(x, axis=0)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_array_equal(out[0], np.all(arr, axis=0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -139,6 +139,7 @@ test_prior_box_op ...@@ -139,6 +139,7 @@ test_prior_box_op
test_psroi_pool_op test_psroi_pool_op
test_put_along_axis_op test_put_along_axis_op
test_range test_range
test_reduce_op
test_reverse_op test_reverse_op
test_roi_align_op test_roi_align_op
test_roi_pool_op test_roi_pool_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册