未验证 提交 530b6b66 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【newir】Concat optest (#57043)

* add reference of lbfgs

* add reference of lbfgs

* modify concat kernel choose

* modify ci
上级 e98875b0
...@@ -569,6 +569,29 @@ phi::KernelKey GetKernelKey( ...@@ -569,6 +569,29 @@ phi::KernelKey GetKernelKey(
kernel_key_parser.key_set.backend_set = kernel_key_parser.key_set.backend_set =
kernel_key_parser.key_set.backend_set | kernel_key_parser.key_set.backend_set |
paddle::experimental::BackendSet(data_op_backend); paddle::experimental::BackendSet(data_op_backend);
} else if (op->operand_source(i).GetDefiningOp()->name() ==
"builtin.combine") {
auto combine_op = op->operand_source(i).GetDefiningOp();
for (size_t j = 0; j < combine_op->num_operands(); ++j) {
if (combine_op->operand_source(j).GetDefiningOp()->name() ==
"pd.data") {
auto data_op = combine_op->operand_source(j).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
auto data_op_backend =
paddle::experimental::ParseBackend(data_place);
if (data_op_backend == phi::Backend::UNDEFINED) {
data_op_backend = paddle::experimental::ParseBackend(place);
}
kernel_key_parser.key_set.backend_set =
kernel_key_parser.key_set.backend_set |
paddle::experimental::BackendSet(data_op_backend);
break;
}
}
} }
} }
......
...@@ -16,7 +16,7 @@ from paddle.base import core ...@@ -16,7 +16,7 @@ from paddle.base import core
__all__ = [] __all__ = []
UNIFIED_APIS = ['mean'] UNIFIED_APIS = ['mean', 'concat', 'add_n', 'scale']
for name in dir(core.eager.ops): for name in dir(core.eager.ops):
globals()[name] = getattr(core.eager.ops, name) globals()[name] = getattr(core.eager.ops, name)
......
...@@ -32,12 +32,14 @@ class IrGuard: ...@@ -32,12 +32,14 @@ class IrGuard:
paddle.base.framework.set_flags(old_flag) paddle.base.framework.set_flags(old_flag)
def __enter__(self): def __enter__(self):
paddle.enable_static()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
self._switch_to_new_ir() self._switch_to_new_ir()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
self._switch_to_old_ir() self._switch_to_old_ir()
paddle.disable_static()
def _switch_to_new_ir(self): def _switch_to_new_ir(self):
if paddle.ir.core._use_new_ir_api(): if paddle.ir.core._use_new_ir_api():
......
...@@ -34,6 +34,7 @@ from ..framework import ( ...@@ -34,6 +34,7 @@ from ..framework import (
core, core,
dygraph_only, dygraph_only,
in_dynamic_mode, in_dynamic_mode,
in_new_ir_mode,
) )
from .creation import _complex_to_real_dtype, _real_to_complex_dtype, zeros from .creation import _complex_to_real_dtype, _real_to_complex_dtype, zeros
...@@ -1131,11 +1132,11 @@ def concat(x, axis=0, name=None): ...@@ -1131,11 +1132,11 @@ def concat(x, axis=0, name=None):
if not isinstance(input, Variable): if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0] input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.concat(input, axis) return _C_ops.concat(input, axis)
elif in_new_ir_mode():
if not isinstance(input, paddle.ir.Value):
input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.concat(input, axis)
else: else:
if paddle.ir.core._use_new_ir_api():
if not isinstance(input, paddle.ir.Value):
input = [t for t in input if t.shape.count(0) == 0]
return paddle._ir_ops.concat(input, axis)
check_type(input, 'input', (list, tuple, Variable), 'concat') check_type(input, 'input', (list, tuple, Variable), 'concat')
if not isinstance(input, Variable): if not isinstance(input, Variable):
for id, x in enumerate(input): for id, x in enumerate(input):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
math functions math functions
""" """
import numpy as np import numpy as np
import paddle import paddle
...@@ -34,6 +35,8 @@ from ..framework import ( ...@@ -34,6 +35,8 @@ from ..framework import (
convert_np_dtype_to_dtype_, convert_np_dtype_to_dtype_,
core, core,
in_dynamic_mode, in_dynamic_mode,
in_dynamic_or_new_ir_mode,
in_new_ir_mode,
) )
from .creation import _complex_to_real_dtype from .creation import _complex_to_real_dtype
from .layer_function_generator import generate_layer_fn, templatedoc from .layer_function_generator import generate_layer_fn, templatedoc
...@@ -264,6 +267,10 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -264,6 +267,10 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
return _C_ops.scale(x, scale, float(bias), bias_after_scale) return _C_ops.scale(x, scale, float(bias), bias_after_scale)
out = _C_ops.scale(x, scale, float(bias), bias_after_scale) out = _C_ops.scale(x, scale, float(bias), bias_after_scale)
return dygraph_utils._append_activation_in_dygraph(out, act) return dygraph_utils._append_activation_in_dygraph(out, act)
elif in_new_ir_mode():
if act is None:
return _C_ops.scale(x, scale, float(bias), bias_after_scale)
raise ValueError("act is not implement in new ir of scale api.")
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, x,
...@@ -1956,14 +1963,11 @@ def add_n(inputs, name=None): ...@@ -1956,14 +1963,11 @@ def add_n(inputs, name=None):
[[8. , 10., 12.], [[8. , 10., 12.],
[14., 16., 18.]]) [14., 16., 18.]])
""" """
if in_dynamic_mode(): if in_dynamic_or_new_ir_mode():
if isinstance(inputs, Variable): if isinstance(inputs, Variable):
inputs = [inputs] inputs = [inputs]
return _C_ops.add_n(inputs) return _C_ops.add_n(inputs)
else: else:
if paddle.ir.core._use_new_ir_api():
return paddle._ir_ops.add_n(inputs)
helper = LayerHelper('add_n', **locals()) helper = LayerHelper('add_n', **locals())
check_type(inputs, 'inputs', (Variable, tuple, list), 'add_n') check_type(inputs, 'inputs', (Variable, tuple, list), 'add_n')
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
......
...@@ -2921,7 +2921,6 @@ class OpTest(unittest.TestCase): ...@@ -2921,7 +2921,6 @@ class OpTest(unittest.TestCase):
if user_defined_grads is None and self.is_compared_with_fp32(): if user_defined_grads is None and self.is_compared_with_fp32():
self.enable_cal_ref_output() self.enable_cal_ref_output()
numeric_grads = self._get_gradient( numeric_grads = self._get_gradient(
inputs_to_check, inputs_to_check,
place, place,
...@@ -3164,7 +3163,6 @@ class OpTest(unittest.TestCase): ...@@ -3164,7 +3163,6 @@ class OpTest(unittest.TestCase):
# delete the inputs which no need to calculate grad # delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set: for no_grad_val in no_grad_set:
del inputs[no_grad_val] del inputs[no_grad_val]
grad_inputs = paddle.grad( grad_inputs = paddle.grad(
outputs=paddle.utils.flatten(outputs), outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(inputs), inputs=paddle.utils.flatten(inputs),
...@@ -3433,7 +3431,7 @@ class OpTest(unittest.TestCase): ...@@ -3433,7 +3431,7 @@ class OpTest(unittest.TestCase):
( (
static_inputs, static_inputs,
attrs, attrs,
input_dict, inputs_dict,
feed, feed,
) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False) ) = self.get_ir_input_attr_dict_and_feed(stop_gradient=False)
# prepare args # prepare args
...@@ -3480,14 +3478,27 @@ class OpTest(unittest.TestCase): ...@@ -3480,14 +3478,27 @@ class OpTest(unittest.TestCase):
fetch_list = getattr(self, "fetch_list", []) fetch_list = getattr(self, "fetch_list", [])
outputs_valid = outputs outputs_valid = outputs
grad_inputs = inputs_to_check loss_inputs = []
for input_name in inputs_to_check:
loss_inputs.append(inputs_dict[input_name])
if user_defined_grad_outputs is None: if user_defined_grad_outputs is None:
if len(outputs_valid) == 1: if len(outputs_valid) == 1:
for outputs_valid_key in outputs_valid: for outputs_valid_key in outputs_valid:
loss = paddle.mean(outputs_valid[outputs_valid_key][0]) loss = paddle.mean(outputs_valid[outputs_valid_key][0])
else:
avg_sum = []
for cur_loss in outputs_valid:
cur_avg_loss = paddle.mean(outputs_valid[cur_loss][0])
avg_sum.append(cur_avg_loss)
loss_sum = paddle.add_n(avg_sum)
loss = paddle.scale(
loss_sum, scale=1.0 / float(len(avg_sum))
)
grad_inputs = ir_grad( grad_inputs = ir_grad(
outputs=paddle.utils.flatten(loss), outputs=paddle.utils.flatten(loss),
inputs=paddle.utils.flatten(static_inputs), inputs=paddle.utils.flatten(loss_inputs),
grad_outputs=None, grad_outputs=None,
) )
else: else:
......
...@@ -75,6 +75,7 @@ class TestMeanOp_ZeroDim(OpTest): ...@@ -75,6 +75,7 @@ class TestMeanOp_ZeroDim(OpTest):
class TestMeanOpError(unittest.TestCase): class TestMeanOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The input type of mean_op must be Variable. # The input type of mean_op must be Variable.
input1 = 12 input1 = 12
...@@ -88,6 +89,7 @@ class TestMeanOpError(unittest.TestCase): ...@@ -88,6 +89,7 @@ class TestMeanOpError(unittest.TestCase):
name='input3', shape=[-1, 4], dtype="float16" name='input3', shape=[-1, 4], dtype="float16"
) )
paddle.nn.functional.softmax(input3) paddle.nn.functional.softmax(input3)
paddle.disable_static()
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册