未验证 提交 a90b8dc1 编写于 作者: J Jiabin Yang 提交者: GitHub

Support broadcast tensor in phi system (#44590)

上级 acf07c74
......@@ -31,14 +31,14 @@ paddle::optional<phi::DenseTensor> TensorToDenseTensor(
return nullptr;
}
std::unique_ptr<std::vector<phi::DenseTensor>> TensorToDenseTensor(
std::unique_ptr<std::vector<phi::DenseTensor*>> TensorToDenseTensor(
const std::vector<Tensor>& tensors) {
auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor*>>();
pt_tensors->reserve(tensors.size());
for (const auto& t : tensors) {
pt_tensors->push_back(
*std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()));
std::dynamic_pointer_cast<phi::DenseTensor>(t.impl()).get());
}
return pt_tensors;
......
......@@ -35,7 +35,7 @@ std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(const Tensor& tensor);
paddle::optional<phi::DenseTensor> TensorToDenseTensor(
const paddle::optional<Tensor>& tensor);
std::unique_ptr<std::vector<phi::DenseTensor>> TensorToDenseTensor(
std::unique_ptr<std::vector<phi::DenseTensor*>> TensorToDenseTensor(
const std::vector<Tensor>& tensors);
std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(const Tensor& tensor);
......
......@@ -582,18 +582,18 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
trans_flag = "{false, true}"
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
else:
if self.inputs['input_info'][
input_name] == "const Tensor&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
elif self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
......@@ -612,7 +612,13 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} paddle::optional<phi::TensorBase> {PREFIX_TENSOR_NAME}{input_name} = {input_name} ? paddle::optional<phi::TensorBase>(*{input_name}->impl()) : paddle::none;"""
else:
input_tensor_code = input_tensor_code + f"""
if self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&":
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_uq_ptr = TensorToDenseTensor({input_name});
{code_indent} const auto& {PREFIX_TENSOR_NAME}{input_name} = *{PREFIX_TENSOR_NAME}{input_name}_uq_ptr;"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
kernel_args = ["*dev_ctx"]
......
......@@ -2513,6 +2513,15 @@
output : Tensor
invoke : full_like(x, 0, dtype, place)
- api: broadcast_tensors
args: (Tensor[] x)
output: Tensor[]{x.size()}
infer_meta:
func: BroadcastTensorsInferMeta
kernel:
func: broadcast_tensors
backward: broadcast_tensors_grad
# eig
- api: eig
args: (Tensor x)
......
......@@ -280,6 +280,18 @@
func : brelu_grad
inplace : (out_grad -> x_grad)
- backward_api : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] x) -> Tensor[](out)
args : (Tensor[] x, Tensor[] out_grad)
output : Tensor[](x_grad)
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
kernel :
func : broadcast_tensors_grad
param : [out_grad]
no_need_buffer : x
- backward_api : cast_grad
forward : cast (Tensor x, DataType out_dtype) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -99,26 +99,49 @@ class TestCPUBroadcastTensorsOp(OpTest):
]
self.set_place()
self.set_dtypes()
self.python_api = paddle.broadcast_tensors
def run_test(self, test_func, args):
def run_dual_test(self, test_func, args):
for dtype in self.dtypes:
for gen_func in self.test_gen_func_list:
self.inputs, self.outputs = gen_func(dtype)
test_func(**args)
if len(self.outputs["Out"]) < 3:
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def run_triple_in_test(self, test_func, args):
for dtype in self.dtypes:
self.inputs, self.outputs = self.test_gen_func_list[2](dtype)
self.python_out_sig = [
f"out{i}" for i in range(len(self.outputs["Out"]))
]
test_func(**args)
def test_check_output(self):
self.run_test(self.check_output_with_place, {
self.run_dual_test(self.check_output_with_place, {
"place": self.place,
"atol": 1e-1
"atol": 1e-1,
"check_eager": True
})
def test_check_grad_normal(self):
self.run_test(
self.run_dual_test(
self.check_grad_with_place, {
"place": self.place,
"inputs_to_check": ['x0', 'x1'],
"output_names": ['out0', 'out1'],
"max_relative_error": 0.05,
"check_eager": True
})
self.run_triple_in_test(
self.check_grad_with_place, {
"place": self.place,
"inputs_to_check": ['x0', 'x1', 'x2'],
"output_names": ['out0', 'out1', "out2"],
"max_relative_error": 0.05,
"check_eager": True
})
......
......@@ -1132,7 +1132,9 @@ def broadcast_tensors(input, name=None):
"""
num_inputs = len(input)
if paddle.in_dynamic_mode():
if paddle.framework.in_dygraph_mode():
return _C_ops.final_state_broadcast_tensors(input)
if paddle.framework._non_static_mode():
return _C_ops.broadcast_tensors(input, num_inputs)
check_type(input, 'input', (list, tuple), 'broadcast_tensors')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册