未验证 提交 4abea956 编写于 作者: C Charles-hit 提交者: GitHub

[PRIM][IR]Complete IR vjp code gen for more vjp code gen (#56798)

* Fix attr type error like concat axis

* Fix None input error

* Fix intermediate output

* support vjp code gen

---------
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 0b608393
...@@ -156,13 +156,26 @@ class CodeGen: ...@@ -156,13 +156,26 @@ class CodeGen:
def _gen_ret_type(self, op_info): def _gen_ret_type(self, op_info):
type_list = op_info.output_type_list type_list = op_info.output_type_list
if len(type_list) > 1: intermediate_list = op_info.output_intermediate_list
assert len(type_list) == len(intermediate_list)
output_num = len(type_list) - intermediate_list.count('true')
if output_num > 1:
return 'std::tuple<{}>'.format( return 'std::tuple<{}>'.format(
', '.join([self._type_map[type] for type in type_list]) ', '.join(
[
self._type_map[type]
for type, intermediate in zip(
type_list, intermediate_list
)
if intermediate == 'false'
]
)
) )
elif len(type_list) == 1: elif output_num == 1:
return self._type_map[type_list[0]] index = intermediate_list.index('false')
elif len(type_list) == 0: return self._type_map[type_list[index]]
elif output_num == 0:
return 'void' return 'void'
def _gen_one_declare(self, op_info, op_name, is_mutable_attr): def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
...@@ -252,10 +265,16 @@ class CodeGen: ...@@ -252,10 +265,16 @@ class CodeGen:
def _gen_out_split_and_ret_list(self, op_info, op_inst_name): def _gen_out_split_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list name_list = op_info.output_name_list
type_list = op_info.output_type_list type_list = op_info.output_type_list
intermediate_list = op_info.output_intermediate_list
assert len(name_list) == len(type_list) == len(intermediate_list)
split_op_str = '' split_op_str = ''
ret_list = [] ret_list = []
for i, (name, type) in enumerate(zip(name_list, type_list)): for i, (name, type, intermediate) in enumerate(
zip(name_list, type_list, intermediate_list)
):
if intermediate == 'true':
continue
if VECTOR_TYPE in type: if VECTOR_TYPE in type:
split_op_name = f'{name}_split_op' split_op_name = f'{name}_split_op'
split_op_str += SPLIT_OP_TEMPLATE.format( split_op_str += SPLIT_OP_TEMPLATE.format(
......
...@@ -129,6 +129,7 @@ mutable_attribute_phi_type_maps = { ...@@ -129,6 +129,7 @@ mutable_attribute_phi_type_maps = {
'float': 'phi::DataType::FLOAT32', 'float': 'phi::DataType::FLOAT32',
'std::vector<int64_t>': 'phi::DataType::INT64', 'std::vector<int64_t>': 'phi::DataType::INT64',
'const std::vector<int64_t>&': 'phi::DataType::INT64', 'const std::vector<int64_t>&': 'phi::DataType::INT64',
'bool': 'phi::DataType::BOOL',
} }
......
...@@ -598,6 +598,19 @@ class OpInfoParser: ...@@ -598,6 +598,19 @@ class OpInfoParser:
if 'Scalar' in temp_type: if 'Scalar' in temp_type:
if 'data_type' in attribute_info: if 'data_type' in attribute_info:
temp_type = attribute_info['data_type'] temp_type = attribute_info['data_type']
op_name = self.op_yaml_item['name']
attr_name = attribute_info['name']
if (
op_name not in ["isclose", "allclose"]
and self.op_compat_item is not None
and 'scalar' in self.op_compat_item.keys()
and attr_name in self.op_compat_item['scalar'].keys()
and 'data_type'
in self.op_compat_item['scalar'][attr_name].keys()
):
temp_type = self.op_compat_item['scalar'][attr_name][
'data_type'
]
if 'IntArray' in temp_type: if 'IntArray' in temp_type:
if 'data_type' in attribute_info: if 'data_type' in attribute_info:
temp_type = "const " + attribute_info['data_type'] + "&" temp_type = "const " + attribute_info['data_type'] + "&"
......
...@@ -29,5 +29,15 @@ ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, ...@@ -29,5 +29,15 @@ ir::OpResult split_grad(std::vector<ir::OpResult> out_grads,
return split_grad_op.x_grad(); return split_grad_op.x_grad();
} }
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, int axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(out_grads);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
combine_op.out(), axis);
return split_grad_op.x_grad();
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -25,5 +25,6 @@ namespace dialect { ...@@ -25,5 +25,6 @@ namespace dialect {
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, ir::OpResult axis); ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, ir::OpResult axis);
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, int axis);
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -35,11 +35,8 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp( ...@@ -35,11 +35,8 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x())); Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0])); Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));
IntArray axis = op_obj.axis() Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
.GetDefiningOp()
->attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data(); bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false; bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp( std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
......
...@@ -84,3 +84,21 @@ def supports_no_need_buffer(op): ...@@ -84,3 +84,21 @@ def supports_no_need_buffer(op):
def is_tensor_list(s): def is_tensor_list(s):
return s == 'Tensor[]' return s == 'Tensor[]'
def exist_mutable_attribute(attrs):
for attr in attrs:
if (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
):
return True
else:
return False
def is_mutable_attribute(attr):
return (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
)
...@@ -28,14 +28,6 @@ using Scalar = paddle::experimental::Scalar; ...@@ -28,14 +28,6 @@ using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray; using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType; using DataType = phi::DataType;
template <typename T>
std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis);
template <typename T>
Tensor split_grad(const std::vector<Tensor>& out_grads, const Tensor& axis);
} // namespace backend } // namespace backend
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -23,54 +23,6 @@ namespace backend { ...@@ -23,54 +23,6 @@ namespace backend {
using LazyTensor = paddle::primitive::LazyTensor; using LazyTensor = paddle::primitive::LazyTensor;
template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);
std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}
template <>
Tensor split_grad<LazyTensor>(const std::vector<Tensor>& out_grads,
const Tensor& axis) {
std::vector<ir::OpResult> out_grads_res;
for (uint64_t idx = 0; idx < out_grads.size(); idx++) {
out_grads_res.emplace_back(
std::static_pointer_cast<LazyTensor>(out_grads[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
} // namespace backend } // namespace backend
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -7,6 +7,7 @@ set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml) ...@@ -7,6 +7,7 @@ set(rev_legacy_path ${parsed_yaml_path}/legacy_backward_ops.parsed.yaml)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml") set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/") "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
set(compat_path "${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml")
set(destination_dir "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/") set(destination_dir "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/")
set(scripts "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/gen.py") set(scripts "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/gen.py")
...@@ -17,7 +18,7 @@ execute_process( ...@@ -17,7 +18,7 @@ execute_process(
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path ${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path ${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir} ${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--destination_dir ${destination_dir} --compat_path ${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result) RESULT_VARIABLE _result)
if(${_result}) if(${_result})
message( message(
......
...@@ -16,6 +16,7 @@ import argparse ...@@ -16,6 +16,7 @@ import argparse
import hashlib import hashlib
import pathlib import pathlib
import sys import sys
from typing import Dict, List
import jinja2 import jinja2
import yaml import yaml
...@@ -36,7 +37,15 @@ sys.path.append( ...@@ -36,7 +37,15 @@ sys.path.append(
# fmt: on # fmt: on
VJPS = ['tanh_grad', 'mean_grad', 'add_grad', 'divide_grad', 'sum_grad'] VJPS = [
'tanh_grad',
'mean_grad',
'add_grad',
'divide_grad',
'sum_grad',
'concat_grad',
'split_grad',
]
VJP_COMPS = ['divide_grad', 'sum_grad'] VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [ BACKENDS = [
'add_n', 'add_n',
...@@ -57,6 +66,8 @@ BACKENDS = [ ...@@ -57,6 +66,8 @@ BACKENDS = [
'add_grad', 'add_grad',
'divide_grad', 'divide_grad',
'sum_grad', 'sum_grad',
'concat_grad',
'split_grad',
] ]
...@@ -99,6 +110,7 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): ...@@ -99,6 +110,7 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs):
'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type, 'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type,
'to_paddle_input_type': op_gen_filters.to_paddle_input_type, 'to_paddle_input_type': op_gen_filters.to_paddle_input_type,
'to_paddle_output_type': op_gen_filters.to_paddle_output_type, 'to_paddle_output_type': op_gen_filters.to_paddle_output_type,
'trip_intermediate': op_gen_filters.filter_intermediate,
} }
) )
env.tests.update( env.tests.update(
...@@ -106,6 +118,8 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): ...@@ -106,6 +118,8 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs):
'scalar': op_gen_tests.is_scalar, 'scalar': op_gen_tests.is_scalar,
'intarray': op_gen_tests.is_intarray, 'intarray': op_gen_tests.is_intarray,
'datatype': op_gen_tests.is_datatype, 'datatype': op_gen_tests.is_datatype,
'exist_mutable_attribute': op_gen_tests.exist_mutable_attribute,
'mutable_attribute': op_gen_tests.is_mutable_attribute,
} }
) )
for tpl in env.list_templates( for tpl in env.list_templates(
...@@ -143,12 +157,114 @@ def save(content: str, path: pathlib.Path): ...@@ -143,12 +157,114 @@ def save(content: str, path: pathlib.Path):
print(f"Generate source file {path}") print(f"Generate source file {path}")
def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)
def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]:
compat_dict = {}
for item in items:
name = item["op"]
compat_dict[name] = item
return compat_dict
def to_apis_dict(apis):
apis_dict = {}
for api in apis:
apis_dict[api['name']] = api
return apis_dict
def get_inplace_api(apis):
inplace_apis = []
for api in apis:
if (
'inplace' in api
and api['inplace'] is not None
and not api['name'].endswith('_')
):
inplace_api = api.copy()
inplace_api['name'] = api['name'] + '_'
inplace_apis.append(inplace_api)
return inplace_apis
def extend_compat_info(apis, compats):
for api in apis:
attrs = api["attrs"]
for attr in attrs:
if attr['typename'] in ["Scalar", "IntArray"]:
attr["support_tensor"] = False
apis_dict = to_apis_dict(apis)
for compat_item in compats:
fwd_op_name = compat_item["op"]
if fwd_op_name not in apis_dict:
continue
fwd_api = apis_dict[fwd_op_name]
backward_op_names = []
while fwd_op_name is not None and fwd_op_name in apis_dict:
backward_op_names.append(apis_dict[fwd_op_name]['backward'])
fwd_op_name = apis_dict[fwd_op_name]['backward']
backward_apis = []
for backward_op_name in backward_op_names:
if backward_op_name in apis_dict:
backward_apis.append(apis_dict[backward_op_name])
support_tensor_attrs_names = []
compat_attrs_data_type = {}
if 'scalar' in compat_item:
for attr_name, attr_info in compat_item['scalar'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if 'data_type' in attr_info:
compat_attrs_data_type.update(
{attr_name: attr_info['data_type']}
)
if 'int_array' in compat_item:
for attr_name, attr_info in compat_item['int_array'].items():
if (
'support_tensor' in attr_info
and attr_info['support_tensor'] is True
or 'tensor_name' in attr_info
or 'tensors_name' in attr_info
):
support_tensor_attrs_names.append(attr_name)
if len(support_tensor_attrs_names) > 0:
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in support_tensor_attrs_names:
attr['support_tensor'] = True
for api in [fwd_api] + backward_apis:
attrs = api["attrs"]
for attr in attrs:
if attr['name'] in compat_attrs_data_type:
attr['data_type'] = compat_attrs_data_type[attr['name']]
return apis
def gen( def gen(
prim_path: pathlib.Path, prim_path: pathlib.Path,
fwd_path: pathlib.Path, fwd_path: pathlib.Path,
fwd_legacy_path: pathlib.Path, fwd_legacy_path: pathlib.Path,
rev_path: pathlib.Path, rev_path: pathlib.Path,
rev_legacy_path: pathlib.Path, rev_legacy_path: pathlib.Path,
compat_path: pathlib.Path,
templates_dir: pathlib.Path, templates_dir: pathlib.Path,
destination_dir: pathlib.Path, destination_dir: pathlib.Path,
): ):
...@@ -163,20 +279,22 @@ def gen( ...@@ -163,20 +279,22 @@ def gen(
rev_path (pathlib.Path): The YAML file path of the backward API. rev_path (pathlib.Path): The YAML file path of the backward API.
rev_legacy_path (pathlib.Path): The YAML file path of the legacy rev_legacy_path (pathlib.Path): The YAML file path of the legacy
backward API. backward API.
compat_path: (pathlib.Path): The YAML file path of the ops compat.
templates_dir (pathlib.Path): The directory of the templates. templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file. destination_dir (pathlib.Path): The Directory of the generated file.
Returns: Returns:
None None
""" """
prims, fwds, legacy_fwds, revs, legacy_revs = ( prims, fwds, legacy_fwds, revs, legacy_revs, compats = (
load(prim_path), load(prim_path),
load(fwd_path), load(fwd_path),
load(fwd_legacy_path), load(fwd_legacy_path),
load(rev_path), load(rev_path),
load(rev_legacy_path), load(rev_legacy_path),
load(compat_path),
) )
filter_compat_info(compats)
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds] apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs] apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs]
apis = [ apis = [
...@@ -185,7 +303,8 @@ def gen( ...@@ -185,7 +303,8 @@ def gen(
else {**api, **{'is_prim': False}} else {**api, **{'is_prim': False}}
for api in apis for api in apis
] ]
apis = extend_compat_info(apis, compats)
apis = apis + get_inplace_api(apis)
render( render(
templates_dir, templates_dir,
destination_dir, destination_dir,
...@@ -221,6 +340,11 @@ if __name__ == "__main__": ...@@ -221,6 +340,11 @@ if __name__ == "__main__":
type=str, type=str,
help='The parsed ops yaml file.', help='The parsed ops yaml file.',
) )
parser.add_argument(
'--compat_path',
type=str,
help='The parsed ops compat yaml file.',
)
parser.add_argument( parser.add_argument(
'--templates_dir', '--templates_dir',
type=str, type=str,
...@@ -239,6 +363,7 @@ if __name__ == "__main__": ...@@ -239,6 +363,7 @@ if __name__ == "__main__":
pathlib.Path(args.fwd_legacy_path), pathlib.Path(args.fwd_legacy_path),
pathlib.Path(args.rev_path), pathlib.Path(args.rev_path),
pathlib.Path(args.rev_legacy_path), pathlib.Path(args.rev_legacy_path),
pathlib.Path(args.compat_path),
pathlib.Path(args.templates_dir), pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir), pathlib.Path(args.destination_dir),
) )
...@@ -20,7 +20,11 @@ using DataType = phi::DataType; ...@@ -20,7 +20,11 @@ using DataType = phi::DataType;
{% for api in apis %} {% for api in apis %}
{%- if api.name in backend_white_list -%} {%- if api.name in backend_white_list -%}
{{common.sig(api.name, api.inputs, api.outputs, api.attrs, True)}}; {% if api.attrs is exist_mutable_attribute %}
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, True, True)}};
{% endif %}
{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, False, True)}};
{% endif %} {% endif %}
{% endfor %} {% endfor %}
......
...@@ -18,7 +18,7 @@ namespace backend { ...@@ -18,7 +18,7 @@ namespace backend {
{%- macro sig(name, inputs, attrs, outputs) -%} {%- macro sig(name, inputs, attrs, outputs) -%}
template <> template <>
{{common.ret(outputs)}} {{name}}<Tensor>({{common.params(inputs, attrs)}}) {{common.ret(outputs)}} {{name}}<Tensor>({{common.params(inputs, attrs, False)}})
{%- endmacro -%} {%- endmacro -%}
{% macro body(name, inputs, attrs, outputs) %} {% macro body(name, inputs, attrs, outputs) %}
...@@ -34,10 +34,9 @@ return ::{{name}}_ad_func({{common.args(input_names, attr_names)}}); ...@@ -34,10 +34,9 @@ return ::{{name}}_ad_func({{common.args(input_names, attr_names)}});
{% for api in apis %} {% for api in apis %}
{#- TODO(cxxly): codegen for reshape -#} {%- if api.is_prim and api.name in backend_white_list -%}
{%- if api.is_prim and api.name in backend_white_list and api.name != 'reshape' -%} {{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} {
{{sig(api.name, api.inputs, api.attrs, api.outputs)}} { {{body(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}}
{{body(api.name, api.inputs, api.attrs, api.outputs)}}
} }
{% endif %} {% endif %}
......
...@@ -12,12 +12,12 @@ namespace backend { ...@@ -12,12 +12,12 @@ namespace backend {
using LazyTensor = paddle::primitive::LazyTensor; using LazyTensor = paddle::primitive::LazyTensor;
{%- macro sig(name, inputs, outputs, attrs) -%} {%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) -%}
template <> template <>
{{common.ret(outputs)}} {{name}}<LazyTensor>({{common.params(inputs, attrs)}}) {{common.ret(outputs)}} {{name}}<LazyTensor>({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}})
{%- endmacro -%} {%- endmacro -%}
{% macro body(name, inputs, outputs, attrs) %} {% macro body(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) %}
{%- set output_names = [] -%} {%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%} {%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{%- for input in inputs -%} {%- for input in inputs -%}
...@@ -30,20 +30,55 @@ template <> ...@@ -30,20 +30,55 @@ template <>
ir::OpResult {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->getValue().dyn_cast<ir::OpResult>(); ir::OpResult {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->getValue().dyn_cast<ir::OpResult>();
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{%- for attr in attrs -%}
{% if mutable_attribute_as_inputs and attr is mutable_attribute %}
ir::OpResult {{attr.name}}_res = std::static_pointer_cast<LazyTensor>({{attr.name~'_'}}.impl())->getValue().dyn_cast<ir::OpResult>();
{% endif %}
{% endfor %}
{%- set input_names = [] -%} {%- set input_names = [] -%}
{%- for i in inputs -%} {%- do input_names.append(i.name~'_res') -%} {%- endfor -%} {%- for i in inputs -%}
{%- do input_names.append(i.name~'_res') -%}
{%- endfor -%}
{%- if mutable_attribute_as_inputs -%}
{%- for i in attrs -%}
{%- if i is mutable_attribute -%}
{%- do input_names.append(i.name~'_res') -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set attr_names = [] -%} {%- set attr_names = [] -%}
{%- for i in attrs -%} {%- do attr_names.append(common.phi2ir_attr(i)) -%} {% endfor %} {%- for i in attrs -%}
{%- if not mutable_attribute_as_inputs or mutable_attribute_as_inputs and i is not mutable_attribute -%}{#- do nothing -#}
{%- do attr_names.append(common.phi2ir_attr(i)) -%}
{%- endif -%}
{% endfor %}
auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}); auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}});
{% if outputs|length > 1 %} {% if outputs|length == 1 %}
return std::make_tuple( {% if outputs[0].typename == 'Tensor' %}
Tensor {{outputs[0].name}}(std::make_shared<LazyTensor>(op_res));
return {{outputs[0].name}};
{% elif outputs[0].typename == 'Tensor[]' %}
std::vector<Tensor> {{outputs[0].name}}(op_res.size());
std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
return {{outputs[0].name}};
{% else %} {#- render nothing -#}
{% endif %}
{% elif outputs|length > 1 %}
{% for i in range(outputs|length) %} {% for i in range(outputs|length) %}
Tensor(std::make_shared<LazyTensor>(std::get<{{i}}>(op_res))){%- if i!=outputs|length - 1 -%}, {% endif %} auto op_res_{{i}} = std::get<{{i}}>(op_res);
{% if outputs[i].typename == 'Tensor' %}
Tensor {{outputs[i].name}}(std::make_shared<LazyTensor>(op_res_{{i}}));
{% elif outputs[i].typename == 'Tensor[]' %}
std::vector<Tensor> {{outputs[i].name}}(op_res_{{i}}.size());
std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const ir::OpResult& res) {
return Tensor(std::make_shared<LazyTensor>(res));
});
{% else %} {#- render nothing -#}
{% endif %}
{% endfor %} {% endfor %}
); return std::make_tuple({% for i in range(outputs|length) %}{{outputs[i].name}}{%- if i!=outputs|length - 1 -%}, {% endif %}{% endfor %});
{% elif outputs|length == 1 %}
return Tensor(std::make_shared<LazyTensor>(op_res));
{% else %} {#- render nothing -#} {% else %} {#- render nothing -#}
{% endif %} {% endif %}
{% endmacro %} {% endmacro %}
...@@ -51,10 +86,17 @@ template <> ...@@ -51,10 +86,17 @@ template <>
{% for api in apis %} {% for api in apis %}
{% if api.name in backend_white_list %} {% if api.name in backend_white_list %}
{{sig(api.name, api.inputs, api.outputs, api.attrs)}} { {% set api_outputs = api.outputs | trip_intermediate %}
{{body(api.name, api.inputs, api.outputs, api.attrs)}} {{sig(api.name, api.inputs, api_outputs, api.attrs)}} {
{{body(api.name, api.inputs, api_outputs, api.attrs)}}
} }
{% if api.attrs is exist_mutable_attribute %}
{{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} {
{{body(api.name, api.inputs, api_outputs, api.attrs, True)}}
}
{% endif %}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
......
{%- macro sig(name, inputs, outputs, attrs, default=False) -%} {%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False, default=False) -%}
template <typename T> template <typename T>
{{ret(outputs)}} {{name}}({{params(inputs, attrs, default)}}) {{ret(outputs)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}})
{%- endmacro %} {%- endmacro %}
{%- macro params(inputs, attrs, default=False) -%} {%- macro params(inputs, attrs, mutable_attribute_as_inputs=False, default=False) -%}
{%- set input_params = [] -%} {%- set input_params = [] -%}
{%- for i in inputs -%} {%- do input_params.append(i.typename|to_paddle_input_type(i.optional)~' '~i.name) -%} {%- endfor -%} {%- for i in inputs -%} {%- do input_params.append(i.typename|to_paddle_input_type(i.optional)~' '~i.name) -%} {%- endfor -%}
{%- set attr_params = [] -%} {%- set attr_params = [] -%}
{%- for i in attrs -%} {%- for i in attrs -%}
{%- if not mutable_attribute_as_inputs or i is not mutable_attribute -%}
{%- if default -%} {%- if default -%}
{%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name~default_value(i)) -%} {%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name~default_value(i)) -%}
{%- else -%} {%- else -%}
{%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name) -%} {%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name) -%}
{%- endif -%} {%- endif -%}
{%- else -%}
{%- do input_params.append('const Tensor&'~' '~i.name~'_') -%}
{%- endif -%}
{%- endfor -%} {%- endfor -%}
{{sequence('', '', ', ', input_params)}} {{sequence('', '', ', ', input_params)}}
{%- if input_params|length>0 and attr_params|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#} {%- if input_params|length>0 and attr_params|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
......
...@@ -13,12 +13,12 @@ using Tensor = paddle::Tensor; ...@@ -13,12 +13,12 @@ using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray; using IntArray = paddle::experimental::IntArray;
{% for api in apis %} {% for api in apis %}
{%- if api.is_prim and api.name in backend_white_list -%} {%- if api.is_prim and api.name in backend_white_list and api.name[-1] != '_' -%}
{%- set input_names = [] -%} {%- set input_names = [] -%}
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%} {%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names = [] -%} {%- set attr_names = [] -%}
{%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %} {%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %}
{{common.sig(api.name, api.inputs, api.outputs, api.attrs, True)}} { {{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, False, True)}} {
return backend::{{api.name}}<T>({{common.args(input_names, attr_names)}}); return backend::{{api.name}}<T>({{common.args(input_names, attr_names)}});
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h" #include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/fluid/primitive/backend/backend.h" #include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/rule/vjp/details.h" #include "paddle/fluid/primitive/rule/vjp/details.h"
...@@ -14,7 +15,7 @@ ...@@ -14,7 +15,7 @@
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
{% macro sig(fwd_name, name, inputs, attrs, outputs) -%} {% macro sig(fwd_name, name, inputs, attrs, outputs) -%}
std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs)}}, const std::vector<std::vector<bool>>& stop_gradients) std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs, attrs is exist_mutable_attribute)}}, const std::vector<std::vector<bool>>& stop_gradients)
{%- endmacro -%} {%- endmacro -%}
{% macro body(api) %} {% macro body(api) %}
...@@ -34,16 +35,46 @@ if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { ...@@ -34,16 +35,46 @@ if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
return vjp_res; return vjp_res;
{% endmacro %} {% endmacro %}
{% macro get_mutable_attribute(attrs, api_name) %}
{% for i in attrs %}
{%- if i is mutable_attribute -%}
auto* {{i.name}}_define_op = std::static_pointer_cast<primitive::LazyTensor>({{i.name~'_'}}.impl())->getValue().dyn_cast<ir::OpResult>().GetDefiningOp();
{% if i.typename is scalar %}
if({{i.name}}_define_op->name() != "pd.full") {
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite "
"for now. "));
}
auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast<paddle::dialect::ScalarAttribute>().data();
{% elif i.typename is intarray %}
if({{i.name}}_define_op->name() != "pd.full_int_array"){
PADDLE_THROW(platform::errors::Unimplemented(
"We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite "
"for now. "));
}
auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data();
{% endif %}
{% endif %}
{% endfor %}
{% endmacro %}
{% macro body_unprim(api) %} {% macro body_unprim(api) %}
{%- set input_names=[] -%} {%- set input_names=[] -%}
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%} {%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names=[] -%} {%- set attr_names=[] -%}
{%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {%- endfor %} {%- for i in api.attrs -%}
{%- if i is mutable_attribute -%}
{%- do input_names.append(i.name~'_') -%}
{%- else -%}
{%- do attr_names.append(i.name) -%}
{%- endif -%}
{%- endfor %}
auto op_res = backend::{{api.name}}<LazyTensor>({{common.args(input_names, attr_names)}}); auto op_res = backend::{{api.name}}<LazyTensor>({{common.args(input_names, attr_names)}});
{% if api.outputs|length > 1 %} {% set outputs = api.outputs|trip_intermediate %} {#- ignore intermediate output -#}
{% for i in range(api.outputs|length) %} {% if outputs|length > 1 %}
{% for i in range(outputs|length) %}
auto out{{i}} = std::get<{{i}}>(op_res); auto out{{i}} = std::get<{{i}}>(op_res);
{% if api.outputs[i].typename=='Tensor' %} {% if outputs[i].typename=='Tensor' %}
vjp_res[{{i}}][0] = !stop_gradients[{{i}}][0] ? out{{i}} : vjp_res[{{i}}][0]; vjp_res[{{i}}][0] = !stop_gradients[{{i}}][0] ? out{{i}} : vjp_res[{{i}}][0];
{% else %} {% else %}
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
...@@ -51,8 +82,8 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { ...@@ -51,8 +82,8 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
} }
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% elif api.outputs|length == 1 %} {% elif outputs|length == 1 %}
{% if api.outputs[0].typename=='Tensor' %} {% if outputs[0].typename=='Tensor' %}
vjp_res[0][0] = !stop_gradients[0][0] ? op_res : vjp_res[0][0]; vjp_res[0][0] = !stop_gradients[0][0] ? op_res : vjp_res[0][0];
{% else %} {% else %}
for (size_t i=0; i< stop_gradients[0].size(); i++ ) { for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
...@@ -74,6 +105,7 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { ...@@ -74,6 +105,7 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
} }
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{{get_mutable_attribute(api.attrs, api.name)}}
details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}}); details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}});
{% endmacro %} {% endmacro %}
......
...@@ -14,7 +14,7 @@ namespace primitive { ...@@ -14,7 +14,7 @@ namespace primitive {
using IntArray = paddle::experimental::IntArray; using IntArray = paddle::experimental::IntArray;
{% macro sig(fwd_name, name, inputs, attrs, outputs) %} {% macro sig(fwd_name, name, inputs, attrs, outputs) %}
std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs)}}, const std::vector<std::vector<bool>>& stop_gradients); std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs, attrs is exist_mutable_attribute)}}, const std::vector<std::vector<bool>>& stop_gradients);
{% endmacro %} {% endmacro %}
{%- set api_map = {} -%} {%- set api_map = {} -%}
......
...@@ -5,4 +5,4 @@ cc_library( ...@@ -5,4 +5,4 @@ cc_library(
primitive_vjp_experimental primitive_vjp_experimental
SRCS ${VJP_SRCS} SRCS ${VJP_SRCS}
DEPS primitive_backend_static_experimental static_global_utils DEPS primitive_backend_static_experimental static_global_utils
primitive_static_utils_experimental) primitive_static_utils_experimental pd_dialect_core)
...@@ -51,8 +51,7 @@ void divide_grad(const Tensor& x, ...@@ -51,8 +51,7 @@ void divide_grad(const Tensor& x,
} else { } else {
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto reshape_res = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
auto dy_tmp = std::get<0>(reshape_res);
set_output<T>(dy_tmp, dy); set_output<T>(dy_tmp, dy);
} }
} else { } else {
...@@ -71,9 +70,7 @@ void divide_grad(const Tensor& x, ...@@ -71,9 +70,7 @@ void divide_grad(const Tensor& x,
} else { } else {
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_reduce_reshape_res = auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
auto dx_tmp = std::get<0>(dx_reduce_reshape_res);
set_output<T>(dx_tmp, dx); set_output<T>(dx_tmp, dx);
} }
...@@ -121,9 +118,8 @@ void sum_grad(const Tensor& x, ...@@ -121,9 +118,8 @@ void sum_grad(const Tensor& x,
} }
} }
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_reshape_res = reshape<T>(out_grad, out_grad_shape); auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
auto out_grad_ = std::get<0>(out_grad_reshape_res); x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else { } else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim)); x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} }
......
...@@ -24,49 +24,5 @@ ...@@ -24,49 +24,5 @@
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {} // namespace primitive
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(2, std::vector<Tensor>());
// get concat_grad res.
std::vector<Tensor> op_res =
backend::concat_grad<primitive::LazyTensor>(x, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0].resize(op_res.size());
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
if (!stop_gradients[0][idx]) {
vjp_res[0][idx] = op_res[idx];
}
}
// vjp_res[1] is axis's grad which is attribute (no grad).
vjp_res[1].resize(1);
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> split_vjp(
const std::vector<Tensor>& out_grads,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(3, std::vector<Tensor>(1));
// get concat_grad res.
Tensor op_res = backend::split_grad<primitive::LazyTensor>(out_grads, axis);
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
// vjp_res[1] is sections's grad which is attribute (no grad).
// vjp_res[2] is axis's grad which is attribute (no grad).
vjp_res[1].resize(stop_gradients[1].size());
vjp_res[2].resize(stop_gradients[2].size());
return vjp_res;
}
} // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -23,17 +23,5 @@ namespace paddle { ...@@ -23,17 +23,5 @@ namespace paddle {
namespace primitive { namespace primitive {
using IntArray = paddle::experimental::IntArray; using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> split_vjp(
const std::vector<Tensor>& out_grads,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -1490,6 +1490,8 @@ ir::OpResult CastPyArg2OpResult(PyObject* obj, ...@@ -1490,6 +1490,8 @@ ir::OpResult CastPyArg2OpResult(PyObject* obj,
size_t arg_pos) { size_t arg_pos) {
if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) { if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) {
return ::pybind11::handle(obj).cast<ir::OpResult>(); return ::pybind11::handle(obj).cast<ir::OpResult>();
} else if (obj == nullptr || obj == Py_None) {
return ir::OpResult();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
......
...@@ -32,11 +32,11 @@ def get_ir_program_0(): ...@@ -32,11 +32,11 @@ def get_ir_program_0():
) )
x.stop_gradient = False x.stop_gradient = False
y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0) y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0)
y.stop_gradiable = False y.stop_gradient = False
dout = paddle.tensor.fill_constant( dout = paddle.tensor.fill_constant(
shape=[1, 4], dtype='float32', value=1.0 shape=[1, 4], dtype='float32', value=1.0
) )
dout.stop_gradiable = False dout.stop_gradient = False
out = paddle.divide(x, y) out = paddle.divide(x, y)
newir_program = ir.translate_to_new_ir(main_program.desc) newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program return newir_program
...@@ -52,10 +52,8 @@ def get_ir_program_1(): ...@@ -52,10 +52,8 @@ def get_ir_program_1():
shape=[4, 5], dtype='float32', value=2.0 shape=[4, 5], dtype='float32', value=2.0
) )
x.stop_gradient = False x.stop_gradient = False
dout = paddle.tensor.fill_constant( dout = paddle.tensor.fill_constant(shape=[], dtype='float32', value=1.0)
shape=[1], dtype='float32', value=1.0 dout.stop_gradient = False
)
dout.stop_gradiable = False
out = paddle.sum(x) out = paddle.sum(x)
newir_program = ir.translate_to_new_ir(main_program.desc) newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program return newir_program
...@@ -124,7 +122,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -124,7 +122,7 @@ class TestVjpPrim(unittest.TestCase):
def test_sum_grad_prim(self): def test_sum_grad_prim(self):
newir_program = get_ir_program_1() newir_program = get_ir_program_1()
paddle.fluid.core._set_prim_backward_enabled(True) paddle.fluid.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-3].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [True]] stop_gradients = [[False], [True]]
sum_op = newir_program.block().ops[-1] sum_op = newir_program.block().ops[-1]
...@@ -162,7 +160,7 @@ class TestVjpPrim(unittest.TestCase): ...@@ -162,7 +160,7 @@ class TestVjpPrim(unittest.TestCase):
grad_outs[0][0].get_defining_op().name(), "pd.sum_grad" grad_outs[0][0].get_defining_op().name(), "pd.sum_grad"
) )
self.assertEqual(grad_outs[1][0], None) self.assertEqual(grad_outs[1][0], None)
self.assertEqual(len(newir_program.block().ops), 6) self.assertEqual(len(newir_program.block().ops), 5)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册