未验证 提交 aceb25e1 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Support optional param for C++ API (#39760)

* fix selected_rows bug in C++ API

* add optional for C++ APIO

* data transform support optional

* remove data transform for optional vector<Tensor>

* adjust some format of funtcion

* fix empyt bug
上级 bd9b9460
......@@ -34,7 +34,7 @@ namespace experimental {
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);
......@@ -67,7 +67,7 @@ std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
......
......@@ -31,6 +31,14 @@ inline std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(
return std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
}
inline std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(
const paddle::optional<Tensor>& tensor) {
if (tensor) {
return std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl());
}
return nullptr;
}
inline std::unique_ptr<std::vector<phi::DenseTensor>> TensorToDenseTensor(
const std::vector<Tensor>& tensors) {
auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
......@@ -49,12 +57,28 @@ inline std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(
return std::dynamic_pointer_cast<phi::SelectedRows>(tensor.impl());
}
inline std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(
const paddle::optional<Tensor>& tensor) {
if (tensor) {
return std::dynamic_pointer_cast<phi::SelectedRows>(tensor->impl());
}
return nullptr;
}
/* ----------------- for infer_meta --------------------- */
inline phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) {
return phi::MetaTensor(tensor);
}
inline paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::DenseTensor&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}
inline std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
......@@ -69,6 +93,14 @@ inline phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}
inline paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::SelectedRows&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}
/* ------------------ for output ----------------------- */
inline phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
......
......@@ -199,6 +199,16 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
return std::make_shared<phi::DenseTensor>(out);
}
std::shared_ptr<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
if (input) {
return PrepareData(*input, target_args_def, transform_flag);
}
return {nullptr};
}
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
......
......@@ -66,6 +66,11 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
std::shared_ptr<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);
std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
......
......@@ -51,7 +51,7 @@ struct KernelKeySet {
DataType dtype{DataType::UNDEFINED};
// TODO(chenweihang): iterate all kernelkey for kernel selection
phi::KernelKey GetHigestPriorityKernelKey() {
phi::KernelKey GetHighestPriorityKernelKey() {
return phi::KernelKey(static_cast<Backend>(64 - detail::CountLeadingZeros(
backend_set.bitset())),
layout,
......
......@@ -51,7 +51,7 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x,
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_coo";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_coo";
......@@ -112,7 +112,7 @@ PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_csr";
if (x.layout() == phi::DataLayout::SPARSE_COO) {
kernel_name = "sparse_coo_to_csr";
......@@ -179,7 +179,7 @@ PADDLE_API Tensor to_dense(const Tensor& x, Backend backend) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "sparse_coo_to_dense";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_dense";
......
......@@ -76,6 +76,23 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
}
}
void GeneralTernaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
}
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
......
......@@ -34,6 +34,13 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
MetaTensor* dx,
MetaTensor* dy);
void GeneralTernaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz);
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
......
......@@ -25,7 +25,8 @@ void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
DenseTensor* out) {
out->ResizeAndAllocate(phi::make_ddim(shape.GetData()));
out->Resize(phi::make_ddim(shape.GetData()));
dev_ctx.template Alloc<T>(out);
}
template <typename T, typename Context>
......
......@@ -596,7 +596,6 @@ void MatmulDoubleGradKernel(const Context& dev_ctx,
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = ddy.get();
if (ddy_mat.dims() != y_help.dims()) {
......
......@@ -42,7 +42,7 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
......@@ -215,7 +215,7 @@ Tensor scale_switch_case(const Tensor& x,
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
......@@ -161,5 +162,31 @@ TEST(API, matmul_cuda) {
#endif
TEST(API, matmul_double_grad) {
// 1. create tensor
auto x = paddle::experimental::full({3, 3}, 1.0);
auto y = paddle::experimental::full({3, 3}, 2.0);
auto out_grad = paddle::experimental::full({3, 3}, 2.0);
auto dx_grad = paddle::experimental::full({3, 3}, 2.0);
// 2. test API
const auto out = paddle::experimental::matmul_double_grad(
x, y, out_grad, dx_grad, {}, false, false);
// 3. check result
ASSERT_EQ(out.size(), 3UL);
ASSERT_EQ(out[0].size(), 1UL);
ASSERT_EQ(out[1].size(), 1UL);
ASSERT_EQ(out[2].size(), 1UL);
ASSERT_EQ(out[0][0].dims()[1], 3);
ASSERT_EQ(out[0][0].numel(), 9);
ASSERT_EQ(out[1][0].numel(), 9);
ASSERT_EQ(out[2][0].numel(), 9);
ASSERT_EQ(out[0][0].type(), phi::DataType::FLOAT32);
ASSERT_EQ(out[0][0].layout(), phi::DataLayout::NCHW);
ASSERT_EQ(out[1][0].initialized(), true);
ASSERT_EQ(out[2][0].initialized(), true);
}
} // namespace tests
} // namespace paddle
......@@ -20,6 +20,7 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <functional>
#include <new>
#include <type_traits>
......
......@@ -35,7 +35,7 @@ class BaseAPI(object):
# args_str:
# args_declare : "str" // str of function params with default value. Example: (..., bool flag=false)
# args_define : "str" // str of function params without default value. Example: (..., bool flag)
self.inputs, self.attrs, self.outputs, self.args_str = self.parse_args(
self.inputs, self.attrs, self.outputs, self.args_str, self.optional_vars = self.parse_args(
self.api, api_item_yaml)
self.is_base_api = True
......@@ -57,17 +57,22 @@ class BaseAPI(object):
return self.api
def parse_args(self, api_name, api_item_yaml):
optional_vars = []
if 'optional' in api_item_yaml:
optional_vars = [
item.strip() for item in api_item_yaml['optional'].split(',')
]
inputs, attrs, args_str = self.parse_input_and_attr(
api_name, api_item_yaml['args'])
api_name, api_item_yaml['args'], optional_vars)
output_type_list, output_names, return_type = self.parse_output(
api_name, api_item_yaml['output'])
return inputs, attrs, {
'names': output_names,
'types': output_type_list,
'return_type': return_type
}, args_str
}, args_str, optional_vars
def parse_input_and_attr(self, api_name, args_config):
def parse_input_and_attr(self, api_name, args_config, optional_vars=[]):
inputs = {'names': [], 'input_info': {}}
attrs = {'names': [], 'attr_info': {}}
args_str = args_config.strip()
......@@ -79,11 +84,43 @@ class BaseAPI(object):
'Tensor': 'const Tensor&',
'Tensor[]': 'const std::vector<Tensor>&'
}
attr_types_map = {'ScalarArray' : 'const ScalarArray&', 'Scalar' : 'const Scalar&', \
'int' : 'int', 'int32_t' : 'int32_t', 'int64_t' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'Backend' : 'Backend', 'DataLayout' : 'DataLayout', 'DataType' : 'DataType', \
'int64_t[]' : 'const std::vector<int64_t>&', 'int[]' : 'const std::vector<int>&'}
attr_types_map = {
'ScalarArray': 'const ScalarArray&',
'Scalar': 'const Scalar&',
'int': 'int',
'int32_t': 'int32_t',
'int64_t': 'int64_t',
'long': 'long',
'size_t': 'size_t',
'float': 'float',
'double': 'double',
'bool': 'bool',
'Backend': 'Backend',
'DataLayout': 'DataLayout',
'DataType': 'DataType',
'int64_t[]': 'const std::vector<int64_t>&',
'int[]': 'const std::vector<int>&',
'long[]': 'const std::vector<int64_t>&'
}
optional_types_trans = {
'Tensor': 'const paddle::optional<Tensor>&',
'Tensor[]': 'const paddle::optional<std::vector<Tensor>>&',
'ScalarArray': 'const paddle::optional<ScalarArray>&',
'Scalar': 'const paddle::optional<Scalar>&',
'int': 'paddle::optional<int>',
'int32_t': 'paddle::optional<int32_t>',
'int64_t': 'paddle::optional<int64_t>',
'size_t': 'paddle::optional<size_t>',
'float': 'paddle::optional<float>',
'double': 'paddle::optional<double>',
'bool': 'paddle::optional<bool>',
'Backend': 'paddle::optional<Backend>',
'DataLayout': 'paddle::optional<DataLayout>',
'DataType': 'paddle::optional<DataType>',
'int64_t[]': 'paddle::optional<std::vector<int64_t>>',
'int[]': 'paddle::optional<std::vector<int>>'
}
args_declare_str = ""
args_define_str = ""
......@@ -100,6 +137,9 @@ class BaseAPI(object):
assert len(attrs['names']) == 0, \
f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
if input_name in optional_vars:
in_type = optional_types_trans[in_type_symbol]
inputs['names'].append(input_name)
inputs['input_info'][input_name] = in_type
args_declare_str = args_declare_str + in_type + ' ' + input_name + ', '
......@@ -121,6 +161,9 @@ class BaseAPI(object):
attr_name = attr_infos[0].strip()
default_value = attr_infos[1].strip()
if attr_name in optional_vars:
attr_type = optional_types_trans[attr_type_symbol]
default_value_str = "" if default_value is None else '=' + default_value
args_declare_str = args_declare_str + attr_type + ' ' + attr_name + default_value_str + ', '
args_define_str = args_define_str + attr_type + ' ' + attr_name + ', '
......@@ -381,7 +424,7 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
......@@ -408,7 +451,17 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
param_code = ""
for param in infer_meta_params:
if param in input_names:
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
if param in self.optional_vars:
meta_tensor_code = meta_tensor_code + f"""
{code_indent} paddle::optional<const phi::MetaTensor&> {PREFIX_TENSOR_NAME}meta_ref_{param}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}meta_{param} = MakeMetaTensor({PREFIX_TENSOR_NAME}{param});
{code_indent} if ({PREFIX_TENSOR_NAME}meta_{param}) {{
{code_indent} {PREFIX_TENSOR_NAME}meta_ref_{param} = paddle::make_optional<const phi::MetaTensor&>(*{PREFIX_TENSOR_NAME}meta_{param});
{code_indent} }}"""
param_code = param_code + f"{PREFIX_TENSOR_NAME}meta_ref_{param}, "
else:
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
elif param in kernel_output_names:
meta_tensor_code = meta_tensor_code + code_indent + " phi::MetaTensor " + param.replace(
'kernel_', PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
......@@ -435,7 +488,11 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
'const std::vector<Tensor>&':
'const std::vector<phi::DenseTensor>&',
'const std::vector<Tensor> &':
'const std::vector<phi::DenseTensor>&'
'const std::vector<phi::DenseTensor>&',
'const paddle::optional<Tensor>&':
'paddle::optional<const phi::DenseTensor&>',
'const paddle::optional<std::vector<Tensor>>&':
'paddle::optional<const std::vector<phi::DenseTensor>&>'
}
out_trans_map = {
'Tensor': 'phi::DenseTensor*',
......@@ -459,19 +516,40 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
trans_flag = "{true}"
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {input_trans_map[input_infos[input_name]]} {PREFIX_TENSOR_NAME}{input_name}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_ptr = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_ptr) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::make_optional<const phi::DenseTensor&>(*{PREFIX_TENSOR_NAME}{input_name}_ptr);
{code_indent} }}"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else:
input_tensor_code = input_tensor_code + f"""
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {input_trans_map[input_infos[input_name]]} {PREFIX_TENSOR_NAME}{input_name}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_ptr = TensorToDenseTensor({input_name});
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_ptr) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::make_optional<const phi::DenseTensor&>(*{PREFIX_TENSOR_NAME}{input_name}_ptr);
{code_indent} }}"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[
param]])
if param in self.optional_vars:
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", "
else:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_in_type = input_trans_map[input_infos[param]]
kernel_args_type_list.append(kernel_in_type)
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in self.attrs['attr_info'][param][0]:
......@@ -499,21 +577,16 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
def get_selected_rows_kernel_args(self, code_indent):
input_trans_map = {
'const Tensor&': 'const phi::SelectedRows&',
'const Tensor &': 'const phi::SelectedRows&'
'const Tensor &': 'const phi::SelectedRows&',
'const paddle::optional<Tensor>&':
'paddle::optional<const phi::SelectedRows&>'
}
out_trans_map = {'Tensor': 'phi::SelectedRows*'}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
......@@ -521,15 +594,28 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self.
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
input_tensor_code = input_tensor_code + f"""
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {input_trans_map[input_infos[input_name]]} {PREFIX_TENSOR_NAME}{input_name}(paddle::none);
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_ptr = TensorToSelectedRows({input_name});
{code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_ptr) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name} = paddle::make_optional<const phi::SelectedRows&>(*{PREFIX_TENSOR_NAME}{input_name}_ptr);
{code_indent} }}"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[
param]])
if param in self.optional_vars:
kernel_args = kernel_args + PREFIX_TENSOR_NAME + param + ", "
else:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_in_type = input_trans_map[input_infos[param]]
kernel_args_type_list.append(kernel_in_type)
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in self.attrs['attr_info'][param][0]:
......
......@@ -92,6 +92,7 @@ def header_include():
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/utils/optional.h"
"""
......
......@@ -8,6 +8,17 @@
kernel :
func : matmul_grad
- backward_api : matmul_double_grad
forward : matmul_grad (Tensor x, Tensor y, Tensor out_grad, bool transpose_x, bool transpose_y) -> Tensor(dx), Tensor(dy)
args : (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y)
output : Tensor(d2x), Tensor(d2y), Tensor(dout_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, y, out_grad]
kernel :
func : matmul_double_grad
optional : dx_grad, dy_grad
- backward_api : scale_grad
forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out)
args : (Tensor out_grad, Scalar scale, float bias=0.0, bool bias_after_scale=true)
......@@ -15,15 +26,6 @@
invoke : scale(out_grad, scale, bias, bias_after_scale)
# TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future.
#
# - backward_api : matmul_double_grad
# forward : matmul_grad (Tensor x, Tensor y, Tensor out_grad, bool transpose_x, bool transpose_y) -> Tensor(dx), Tensor>(dy)
# args : (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y)
# output : Tensor(d2x), Tensor(d2y), Tensor(dout_grad)
# infer_meta :
# func : MatmulDoubleGradInferMeta
# kernel :
# func : matmul_double_grad
# - backward_api : matmul_triple_grad
# forward : matmul_double_grad (Tensor x, Tensor y, Tensor out_grad, Tensor dx_grad, Tensor dy_grad, bool transpose_x, bool transpose_y) -> Tensor(d2x), Tensor(d2y), Tensor(dout_grad)
......
......@@ -31,10 +31,10 @@ class BackwardAPI(BaseAPI):
def parse_forward_config(self, forward_config):
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search(
r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->[^\(]*\((?P<outputs>[^\)]+)\)",
r"(?P<api>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config)
api = result.group('api')
outputs = [item.strip() for item in result.group('outputs').split(',')]
_, outputs, _ = self.parse_output(self.api, result.group('outputs'))
fw_inputs, fw_attrs, _, = self.parse_input_and_attr(
api, result.group('args'))
......@@ -47,7 +47,7 @@ class BackwardAPI(BaseAPI):
# check the inputs of backward
for input in self.inputs['names']:
if input not in fw_inputs and input not in fw_outputs:
if input not in fw_inputs['names'] and input not in fw_outputs:
if input.endswith('_grad'):
original_name = input[:-5]
assert original_name in fw_outputs, \
......@@ -132,6 +132,7 @@ def header_include():
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/utils/optional.h"
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册