未验证 提交 b0ea1346 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] support elementwise operations with axis!=-1 (#55699)

* support elementwise with axis!=-1

* fix coverage ci

* fix bug

* remove print
上级 91040569
......@@ -287,7 +287,8 @@ struct OpTranscriber {
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc);
virtual void RecordOpResultMapping(TranslationContext* param_map,
virtual void RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
const OpOutputMapping& arg_to_idx);
......@@ -597,7 +598,8 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute(
return attribute_map;
}
void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map,
void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
const OpOutputMapping& arg_to_idx) {
......@@ -605,7 +607,7 @@ void OpTranscriber::RecordOpResultMapping(TranslationContext* param_map,
auto& name = n.first;
VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
const auto& args = n.second;
size_t idx_in_vector = 0;
for (const auto& arg_name : args) {
if (arg_name == kEmptyVarName) {
......@@ -674,7 +676,7 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx,
program->block()->push_back(operation);
VLOG(4) << "[general op][" << op_desc.Type() << "] opearation insertion end.";
this->RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
this->RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);
return operation;
}
......@@ -843,7 +845,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {
ir::Operation* operation = ir::Operation::Create(
op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);
VLOG(10) << "[op assign_value] translation finished";
......@@ -1260,6 +1262,192 @@ struct ReduceOpTranscriber : public OpTranscriber {
}
};
struct ElementwiseTranscriber : public OpTranscriber {
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) override {
int axis = paddle::get<int>(op_desc.GetAttr("axis"));
if (axis == -1) {
return OpTranscriber::GenerateOperationInput(
ctx, param_map, op_desc, normalized_op_name, input_infos, program);
}
auto x_names = op_desc.Input("X", true);
IR_ENFORCE(x_names.size() == 1,
"Expected op[%s]'s input X has only 1 variable, but got %d",
op_desc.Type(),
x_names.size());
auto x_name = x_names[0];
IR_ENFORCE(param_map->count(x_name) > 0,
"Expected op[%s]'s input %s has been parsed",
op_desc.Type(),
x_name);
auto x_defining_info = param_map->at(x_name);
if (x_defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, x_defining_info, x_name);
x_defining_info = param_map->at(x_name);
}
ir::OpResult x_value = x_defining_info.value;
IR_ENFORCE(x_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
x_name);
ir::Type x_type = x_value.type();
IR_ENFORCE(x_type.isa<dialect::DenseTensorType>(),
"Expected op[%s]'s input %s is DenseTensor but got %s",
op_desc.Type(),
x_name,
x_type);
dialect::DenseTensorType x_tensor_type =
x_type.dyn_cast<dialect::DenseTensorType>();
std::vector<int64_t> x_shape = phi::vectorize(x_tensor_type.dims());
auto y_names = op_desc.Input("Y", true);
IR_ENFORCE(y_names.size() == 1,
"Expected op[%s]'s input Y has only 1 variable, but got %d",
op_desc.Type(),
y_names.size());
auto y_name = y_names[0];
IR_ENFORCE(param_map->count(y_name) > 0,
"Expected op[%s]'s input %s has been parsed",
op_desc.Type(),
y_name);
auto y_defining_info = param_map->at(y_name);
if (y_defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, y_defining_info, y_name);
y_defining_info = param_map->at(y_name);
}
ir::OpResult y_value = y_defining_info.value;
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
y_name);
ir::Type y_type = y_value.type();
IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(),
"Expected op[%s]'s input %s is DenseTensor but got %s",
op_desc.Type(),
y_name,
y_type);
dialect::DenseTensorType y_tensor_type =
y_type.dyn_cast<dialect::DenseTensorType>();
std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims());
if (axis < 0) {
axis += x_shape.size();
}
int append_size = x_shape.size() - axis - 1 - y_shape.size();
if (append_size < 0) { // which means x.rank <= y.rank, mostly
// x.rank=y.rank
return {x_value, y_value};
}
IR_ENFORCE(append_size >= 0,
"Expected op[%s] have append size >= 0 with axis=%d but got %d",
op_desc.Type(),
axis,
append_size);
ir::Builder builder(ctx, program->block());
ir::OpResult y_new;
if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) {
std::vector<int64_t> y_new_shape(y_shape);
for (int i = 0; i <= append_size; i++) {
y_new_shape.push_back(1);
}
dialect::Reshape_Op reshape_op =
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape);
y_new = reshape_op.out();
VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
<< y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape);
} else {
auto shape_op = builder.Build<dialect::ShapeOp>(y_value);
auto append_shape_op = builder.Build<dialect::FullIntArrayOp>(
std::vector<int64_t>(append_size, 1),
phi::DataType::INT64,
phi::CPUPlace());
auto y_true_shape_op = builder.Build<ir::CombineOp>(
std::vector<ir::OpResult>{shape_op.out(), append_shape_op.out()});
auto concat_op =
builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0);
auto y_new_shape = concat_op.out();
auto reshape_op =
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape);
y_new = reshape_op.out();
}
return {x_value, y_new};
}
};
struct ElementwiseGradTranscriber : public OpTranscriber {
void RecordOpResultMapping(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Operation* operation,
const OpOutputMapping& arg_to_idx) override {
OpTranscriber::RecordOpResultMapping(
ctx, param_map, op_desc, operation, arg_to_idx);
int axis = paddle::get<int>(op_desc.GetAttr("axis"));
if (axis == -1) {
return;
}
const auto& y_grad_output = op_desc.Output("Y@GRAD");
if (y_grad_output.size() < 1) {
return;
}
IR_ENFORCE(
y_grad_output.size() == 1,
"Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d",
op_desc.Type(),
y_grad_output.size());
const auto& y_grad_var_name = y_grad_output[0];
auto idx_iter = arg_to_idx.find(y_grad_var_name);
if (idx_iter == arg_to_idx.end()) {
IR_THROW("op[%s] should have got its y_grad", op_desc.Type());
}
auto idx = idx_iter->second;
VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << y_grad_var_name << " " << idx;
auto y_names = op_desc.Input("Y", true);
auto y_name = y_names[0];
IR_ENFORCE(param_map->count(y_name) > 0,
"Expected op[%s]'s input %s has been parsed",
op_desc.Type(),
y_name);
auto y_defining_info = param_map->at(y_name);
ir::OpResult y_value = y_defining_info.value;
IR_ENFORCE(y_value,
"Expected op[%s]'s input %s is not null",
op_desc.Type(),
y_name);
ir::Type y_type = y_value.type();
IR_ENFORCE(y_type.isa<dialect::DenseTensorType>(),
"Expected op[%s]'s input %s is DenseTensor but got %s",
op_desc.Type(),
y_name,
y_type);
dialect::DenseTensorType y_tensor_type =
y_type.dyn_cast<dialect::DenseTensorType>();
std::vector<int64_t> y_shape = phi::vectorize(y_tensor_type.dims());
ir::OpResult value = operation->result(idx);
ir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::Reshape_Op>(value, y_shape);
(*param_map)[y_grad_var_name] =
VariableDefiningInfo(reshape_op.out(), false, -1);
}
};
OpTranslator::OpTranslator() {
general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
......@@ -1278,6 +1466,25 @@ OpTranslator::OpTranslator() {
special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
// special handler for elementwise ops with axis != -1
// note(lyk): maybe we should do this by a pass, which seems more reasonable
special_handlers["elementwise_add"] = ElementwiseTranscriber();
special_handlers["elementwise_sub"] = ElementwiseTranscriber();
special_handlers["elementwise_mul"] = ElementwiseTranscriber();
special_handlers["elementwise_div"] = ElementwiseTranscriber();
special_handlers["elementwise_max"] = ElementwiseTranscriber();
special_handlers["elementwise_min"] = ElementwiseTranscriber();
special_handlers["elementwise_mod"] = ElementwiseTranscriber();
special_handlers["elementwise_floordiv"] = ElementwiseTranscriber();
special_handlers["elementwise_add_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_sub_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_mul_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_div_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_max_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_min_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber();
special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber();
}
} // namespace translator
......
......@@ -93,6 +93,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
const std::vector<ir::OpResult> &inputs);
void Verify() const;
ir::OpResult out() { return result(0); }
};
///
......@@ -108,6 +109,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static const char *attributes_name[attributes_num];
void Verify() const;
ir::OpResult out() { return result(0); }
};
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
......
......@@ -234,6 +234,8 @@ TEST(program_test, slice_combine_test) {
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::Create(
{op1->result(0), op2->result(0)}, {}, {output_type}, combine_op_info);
ir::CombineOp combine_op_type = combine_op->dyn_cast<ir::CombineOp>();
EXPECT_TRUE(combine_op_type.out());
program.block()->push_back(combine_op);
// (7) Def slice_op = SliceOp(combine_op, 0)
......
......@@ -19,6 +19,7 @@ import numpy as np
import paddle
from paddle import ir
from paddle.fluid import core
from paddle.framework import LayerHelper
paddle.enable_static()
......@@ -37,6 +38,69 @@ class TestCastOpTranscriber(unittest.TestCase):
_ = ir.translate_to_new_ir(main_program.desc)
class TestElementwiseOpTranscriber(unittest.TestCase):
def test_elementwise_without_y_grad(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x_data = np.random.rand(100, 2, 3)
y_data = np.random.rand(100)
x = paddle.to_tensor(x_data, dtype='float32')
x.stop_gradient = False
y = paddle.to_tensor(y_data, dtype='float32')
out1 = paddle.tensor.math._elementwise_op(
LayerHelper('elementwise_add', x=x, y=y, axis=0)
)
out1.stop_gradient = False
mean = paddle.mean(out1)
paddle.static.append_backward(mean)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_allclose(
out[0],
x_data + y_data.reshape(100, 1, 1),
rtol=1e-6,
atol=1e-6,
)
def test_elementwise_with_y_grad(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x_data = np.random.rand(100, 2, 3)
y_data = np.random.rand(100)
x = paddle.to_tensor(x_data, dtype='float32')
x.stop_gradient = False
y = paddle.to_tensor(y_data, dtype='float32')
y.stop_gradient = False
out1 = paddle.tensor.math._elementwise_op(
LayerHelper('elementwise_add', x=x, y=y, axis=0)
)
out1.stop_gradient = False
mean = paddle.mean(out1)
paddle.static.append_backward(mean)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_allclose(
out[0],
x_data + y_data.reshape(100, 1, 1),
rtol=1e-6,
atol=1e-6,
)
class TestEmbeddingOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册