From b0ea134638049b546e57f885d42ab52d8bfccfcf Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 31 Jul 2023 10:35:37 +0800 Subject: [PATCH] [NewIR] support elementwise operations with axis!=-1 (#55699) * support elementwise with axis!=-1 * fix coverage ci * fix bug * remove print --- .../ir_adaptor/translator/op_translator.cc | 217 +++++++++++++++++- paddle/ir/core/builtin_op.h | 2 + test/cpp/ir/core/ir_program_test.cc | 2 + test/ir/new_ir/test_special_op_translator.py | 64 ++++++ 4 files changed, 280 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 086a224d5b9..497b002f41e 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -287,7 +287,8 @@ struct OpTranscriber { const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc); - virtual void RecordOpResultMapping(TranslationContext* param_map, + virtual void RecordOpResultMapping(ir::IrContext* ctx, + TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, const OpOutputMapping& arg_to_idx); @@ -597,7 +598,8 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( return attribute_map; } -void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map, +void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx, + TranslationContext* param_map, const OpDesc& op_desc, ir::Operation* operation, const OpOutputMapping& arg_to_idx) { @@ -605,7 +607,7 @@ void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map, auto& name = n.first; VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << name; - auto& args = n.second; + const auto& args = n.second; size_t idx_in_vector = 0; for (const auto& arg_name : args) { if (arg_name == kEmptyVarName) { @@ -674,7 +676,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, program->block()->push_back(operation); VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end."; - this->RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); + this->RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); return operation; } @@ -843,7 +845,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { ir::Operation* operation = ir::Operation::Create( op_inputs, attribute_map, op_output_types, op_info); program->block()->push_back(operation); - RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx); + RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); VLOG(10) << "[op assign_value] translation finished"; @@ -1260,6 +1262,192 @@ struct ReduceOpTranscriber : public OpTranscriber { } }; +struct ElementwiseTranscriber : public OpTranscriber { + std::vector GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) override { + int axis = paddle::get(op_desc.GetAttr("axis")); + + if (axis == -1) { + return OpTranscriber::GenerateOperationInput( + ctx, param_map, op_desc, normalized_op_name, input_infos, program); + } + + auto x_names = op_desc.Input("X", true); + IR_ENFORCE(x_names.size() == 1, + "Expected op[%s]'s input X has only 1 variable, but got %d", + op_desc.Type(), + x_names.size()); + auto x_name = x_names[0]; + IR_ENFORCE(param_map->count(x_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + x_name); + auto x_defining_info = param_map->at(x_name); + if (x_defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, x_defining_info, x_name); + x_defining_info = param_map->at(x_name); + } + ir::OpResult x_value = x_defining_info.value; + IR_ENFORCE(x_value, + "Expected op[%s]'s input %s is not null", + op_desc.Type(), + x_name); + ir::Type x_type = x_value.type(); + IR_ENFORCE(x_type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + x_name, + x_type); + dialect::DenseTensorType x_tensor_type = + x_type.dyn_cast(); + std::vector x_shape = phi::vectorize(x_tensor_type.dims()); + + auto y_names = op_desc.Input("Y", true); + IR_ENFORCE(y_names.size() == 1, + "Expected op[%s]'s input Y has only 1 variable, but got %d", + op_desc.Type(), + y_names.size()); + auto y_name = y_names[0]; + IR_ENFORCE(param_map->count(y_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + y_name); + auto y_defining_info = param_map->at(y_name); + if (y_defining_info.generated_by_vector) { + InsertSliceOperationForTarget( + ctx, param_map, program, y_defining_info, y_name); + y_defining_info = param_map->at(y_name); + } + ir::OpResult y_value = y_defining_info.value; + IR_ENFORCE(y_value, + "Expected op[%s]'s input %s is not null", + op_desc.Type(), + y_name); + ir::Type y_type = y_value.type(); + IR_ENFORCE(y_type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + y_name, + y_type); + dialect::DenseTensorType y_tensor_type = + y_type.dyn_cast(); + std::vector y_shape = phi::vectorize(y_tensor_type.dims()); + + if (axis < 0) { + axis += x_shape.size(); + } + + int append_size = x_shape.size() - axis - 1 - y_shape.size(); + if (append_size < 0) { // which means x.rank <= y.rank, mostly + // x.rank=y.rank + return {x_value, y_value}; + } + IR_ENFORCE(append_size >= 0, + "Expected op[%s] have append size >= 0 with axis=%d but got %d", + op_desc.Type(), + axis, + append_size); + + ir::Builder builder(ctx, program->block()); + ir::OpResult y_new; + if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) { + std::vector y_new_shape(y_shape); + for (int i = 0; i <= append_size; i++) { + y_new_shape.push_back(1); + } + dialect::Reshape_Op reshape_op = + builder.Build(y_value, y_new_shape); + y_new = reshape_op.out(); + VLOG(6) << "[" << op_desc.Type() << "] y_shape change from " + << y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape); + } else { + auto shape_op = builder.Build(y_value); + auto append_shape_op = builder.Build( + std::vector(append_size, 1), + phi::DataType::INT64, + phi::CPUPlace()); + auto y_true_shape_op = builder.Build( + std::vector{shape_op.out(), append_shape_op.out()}); + auto concat_op = + builder.Build(y_true_shape_op.out(), 0); + auto y_new_shape = concat_op.out(); + auto reshape_op = + builder.Build(y_value, y_new_shape); + y_new = reshape_op.out(); + } + return {x_value, y_new}; + } +}; + +struct ElementwiseGradTranscriber : public OpTranscriber { + void RecordOpResultMapping(ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + ir::Operation* operation, + const OpOutputMapping& arg_to_idx) override { + OpTranscriber::RecordOpResultMapping( + ctx, param_map, op_desc, operation, arg_to_idx); + + int axis = paddle::get(op_desc.GetAttr("axis")); + if (axis == -1) { + return; + } + + const auto& y_grad_output = op_desc.Output("Y@GRAD"); + if (y_grad_output.size() < 1) { + return; + } + IR_ENFORCE( + y_grad_output.size() == 1, + "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d", + op_desc.Type(), + y_grad_output.size()); + const auto& y_grad_var_name = y_grad_output[0]; + + auto idx_iter = arg_to_idx.find(y_grad_var_name); + if (idx_iter == arg_to_idx.end()) { + IR_THROW("op[%s] should have got its y_grad", op_desc.Type()); + } + auto idx = idx_iter->second; + VLOG(10) << "[output recording]" + << "[" << op_desc.Type() << "]" << y_grad_var_name << " " << idx; + + auto y_names = op_desc.Input("Y", true); + auto y_name = y_names[0]; + IR_ENFORCE(param_map->count(y_name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + y_name); + auto y_defining_info = param_map->at(y_name); + ir::OpResult y_value = y_defining_info.value; + IR_ENFORCE(y_value, + "Expected op[%s]'s input %s is not null", + op_desc.Type(), + y_name); + ir::Type y_type = y_value.type(); + IR_ENFORCE(y_type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + y_name, + y_type); + dialect::DenseTensorType y_tensor_type = + y_type.dyn_cast(); + std::vector y_shape = phi::vectorize(y_tensor_type.dims()); + + ir::OpResult value = operation->result(idx); + ir::Builder builder(ctx, operation->GetParent()); + auto reshape_op = builder.Build(value, y_shape); + (*param_map)[y_grad_var_name] = + VariableDefiningInfo(reshape_op.out(), false, -1); + } +}; + OpTranslator::OpTranslator() { general_handler = OpTranscriber(); special_handlers["add_n"] = AddNOpTranscriber(); @@ -1278,6 +1466,25 @@ OpTranslator::OpTranslator() { special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); + + // special handler for elementwise ops with axis != -1 + // note(lyk): maybe we should do this by a pass, which seems more reasonable + special_handlers["elementwise_add"] = ElementwiseTranscriber(); + special_handlers["elementwise_sub"] = ElementwiseTranscriber(); + special_handlers["elementwise_mul"] = ElementwiseTranscriber(); + special_handlers["elementwise_div"] = ElementwiseTranscriber(); + special_handlers["elementwise_max"] = ElementwiseTranscriber(); + special_handlers["elementwise_min"] = ElementwiseTranscriber(); + special_handlers["elementwise_mod"] = ElementwiseTranscriber(); + special_handlers["elementwise_floordiv"] = ElementwiseTranscriber(); + special_handlers["elementwise_add_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_sub_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_mul_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_div_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_max_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_min_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber(); + special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber(); } } // namespace translator diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 0ab058f5aac..fe5b7116a29 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -93,6 +93,7 @@ class IR_API CombineOp : public ir::Op { const std::vector &inputs); void Verify() const; + ir::OpResult out() { return result(0); } }; /// @@ -108,6 +109,7 @@ class IR_API SliceOp : public ir::Op { static const char *attributes_name[attributes_num]; void Verify() const; + ir::OpResult out() { return result(0); } }; class IR_API ConstantLikeTrait : public OpTraitBase { diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index d94fb6026c1..2af40feaeec 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -234,6 +234,8 @@ TEST(program_test, slice_combine_test) { ir::VectorType::get(ctx, std::vector({fp32_dtype, fp32_dtype})); ir::Operation *combine_op = ir::Operation::Create( {op1->result(0), op2->result(0)}, {}, {output_type}, combine_op_info); + ir::CombineOp combine_op_type = combine_op->dyn_cast(); + EXPECT_TRUE(combine_op_type.out()); program.block()->push_back(combine_op); // (7) Def slice_op = SliceOp(combine_op, 0) diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 22b0a82c2ce..ea4837eb478 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -19,6 +19,7 @@ import numpy as np import paddle from paddle import ir from paddle.fluid import core +from paddle.framework import LayerHelper paddle.enable_static() @@ -37,6 +38,69 @@ class TestCastOpTranscriber(unittest.TestCase): _ = ir.translate_to_new_ir(main_program.desc) +class TestElementwiseOpTranscriber(unittest.TestCase): + def test_elementwise_without_y_grad(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + exe = paddle.static.Executor(place) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x_data = np.random.rand(100, 2, 3) + y_data = np.random.rand(100) + x = paddle.to_tensor(x_data, dtype='float32') + x.stop_gradient = False + y = paddle.to_tensor(y_data, dtype='float32') + + out1 = paddle.tensor.math._elementwise_op( + LayerHelper('elementwise_add', x=x, y=y, axis=0) + ) + out1.stop_gradient = False + mean = paddle.mean(out1) + paddle.static.append_backward(mean) + + out = exe.run(main_program, {}, fetch_list=[out1.name]) + np.testing.assert_allclose( + out[0], + x_data + y_data.reshape(100, 1, 1), + rtol=1e-6, + atol=1e-6, + ) + + def test_elementwise_with_y_grad(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + exe = paddle.static.Executor(place) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x_data = np.random.rand(100, 2, 3) + y_data = np.random.rand(100) + x = paddle.to_tensor(x_data, dtype='float32') + x.stop_gradient = False + y = paddle.to_tensor(y_data, dtype='float32') + y.stop_gradient = False + + out1 = paddle.tensor.math._elementwise_op( + LayerHelper('elementwise_add', x=x, y=y, axis=0) + ) + out1.stop_gradient = False + mean = paddle.mean(out1) + paddle.static.append_backward(mean) + + out = exe.run(main_program, {}, fetch_list=[out1.name]) + np.testing.assert_allclose( + out[0], + x_data + y_data.reshape(100, 1, 1), + rtol=1e-6, + atol=1e-6, + ) + + class TestEmbeddingOpTranscriber(unittest.TestCase): def test_op(self): place = core.Place() -- GitLab