未验证 提交 14b81d5a 编写于 作者: C Charles-hit 提交者: GitHub

[Prim] [NewIR]Automatic code generation for vjp rules (#56512)

* support ir api form prim

* support ir api for prim

* support vjp prim mode in new ir

* remove useless code

* remove useless code

* auto code generator for primitive vjp methods

* add vjp and backend manual and fix segment fault

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 f02261b0
backend/generated/*.cc
backend/generated/*.h
primitive/primitive.h
rule/vjp/generated/generated_vjp.h
rule/vjp/generated/generated_vjp.cc
add_subdirectory(utils)
add_subdirectory(backend)
add_subdirectory(rule)
add_subdirectory(codegen)
set(eager_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_eager_backend.cc
)
if(WITH_PYTHON OR NOT ON_INFER)
cc_library(
primitive_backend_eager_experimental
SRCS eager_backend.cc
SRCS ${eager_backend_files}
DEPS final_dygraph_function eager_utils phi)
endif()
set(static_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_static_backend.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/manual/manual_static_backend.cc
)
cc_library(
primitive_backend_static_experimental
SRCS static_backend.cc
SRCS ${static_backend_files}
DEPS pd_dialect_api)
......@@ -14,13 +14,5 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace primitive {
namespace backend {} // namespace backend
} // namespace primitive
} // namespace paddle
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
......@@ -18,89 +18,21 @@
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
namespace backend {
using Tensor = paddle::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
template <typename T>
Tensor tanh_grad(const Tensor& out, const Tensor& grad_out);
template <typename T>
Tensor mean_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis = {},
bool keepdim = false,
bool reduce_all = false);
using DataType = phi::DataType;
template <typename T>
std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis);
template <typename T>
std::tuple<Tensor, Tensor> add_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis);
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y);
template <typename T>
Tensor add(const Tensor& x, const Tensor& y);
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
template <typename T>
Tensor elementwise_pow(const Tensor& x, const Tensor& y);
template <typename T>
Tensor scale(const Tensor& x,
const Scalar& scale = 1.0,
float bias = 0.0,
bool bias_after_scale = true);
template <typename T>
Tensor sum(const Tensor& x,
const IntArray& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false);
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace());
template <typename T>
std::tuple<Tensor, Tensor> reshape(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor tile(const Tensor& x, const IntArray& repeat_times = {});
template <typename T>
std::tuple<Tensor, Tensor> divide_grad(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis);
template <typename T>
Tensor sum_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all);
} // namespace backend
} // namespace primitive
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
namespace paddle {
namespace primitive {
namespace backend {
using LazyTensor = paddle::primitive::LazyTensor;
template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);
std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}
} // namespace backend
} // namespace primitive
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
namespace paddle {
namespace primitive {
namespace backend {
using LazyTensor = paddle::primitive::LazyTensor;
template <>
Tensor tanh_grad<LazyTensor>(const Tensor& out, const Tensor& grad_out) {
ir::OpResult out_res = std::static_pointer_cast<LazyTensor>(out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult grad_out_res =
std::static_pointer_cast<LazyTensor>(grad_out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor mean_grad<LazyTensor>(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::mean_grad(
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor divide<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::divide(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor add<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::add(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor multiply<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::multiply(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor elementwise_pow<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::elementwise_pow(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor scale<LazyTensor>(const Tensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res =
paddle::dialect::scale(x_res, scale.to<float>(), bias, bias_after_scale);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor sum<LazyTensor>(const Tensor& x,
const IntArray& axis,
phi::DataType dtype,
bool keepdim) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res =
paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor full<LazyTensor>(const IntArray& shape,
const Scalar& value,
phi::DataType dtype,
phi::Place place) {
ir::OpResult op_res =
paddle::dialect::full(shape.GetData(), value.to<float>(), dtype, place);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
std::tuple<Tensor, Tensor> reshape<LazyTensor>(const Tensor& x,
const IntArray& shape) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::reshape(x_res, shape.GetData());
return std::make_tuple(
Tensor(std::make_shared<primitive::LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<primitive::LazyTensor>(std::get<1>(op_res))));
}
template <>
Tensor expand<LazyTensor>(const Tensor& x, const IntArray& shape) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData());
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor tile<LazyTensor>(const Tensor& x, const IntArray& repeat_times) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData());
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
std::tuple<Tensor, Tensor> add_grad<LazyTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis);
return std::make_tuple(
Tensor(std::make_shared<primitive::LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<primitive::LazyTensor>(std::get<1>(op_res))));
}
template <>
std::tuple<Tensor, Tensor> divide_grad<LazyTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_res = std::static_pointer_cast<LazyTensor>(out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::divide_grad(x_res, y_res, out_res, out_grad_res, axis);
return std::make_tuple(
Tensor(std::make_shared<LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<LazyTensor>(std::get<1>(op_res))));
}
template <>
Tensor sum_grad<LazyTensor>(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::sum_grad(
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<LazyTensor>(op_res));
}
template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);
std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}
} // namespace backend
} // namespace primitive
} // namespace paddle
set(fwd_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml"
)
set(fwd_legacy_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml"
)
set(rev_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml"
)
set(rev_legacy_path
"${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml"
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
set(destination_dir "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/")
set(scripts "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/gen.py")
message("Automatic code generation for paddle/fluid/primitive")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/primitive/codegen
COMMAND
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --fwd_legacy_path
${fwd_legacy_path} --rev_path ${rev_path} --rev_legacy_path
${rev_legacy_path} --prim_path ${prim_path} --templates_dir ${templates_dir}
--destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
message(
FATAL_ERROR
"Automatic code generation for paddle/fluid/primitive failed, exiting.")
endif()
message("Automatic code generation for paddle/fluid/primitive succeed.")
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import hashlib
import pathlib
import sys
import jinja2
import yaml
# fmt: off
# import from paddle/fluid/operators/generator
sys.path.append(
str(pathlib.Path(__file__).resolve().parents[2] / 'operators/generator')
)
import filters as op_gen_filters
import tests_utils as op_gen_tests
# import from paddle/fluid/ir/dialect/op_generator/api_gen.py
sys.path.append(
str(pathlib.Path(__file__).resolve().parents[2] / 'ir/dialect/op_generator')
)
# fmt: on
VJPS = ['tanh_grad', 'mean_grad', 'add_grad', 'divide_grad', 'sum_grad']
VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [
'add_n',
'mean',
'sum',
'divide',
'full',
'tanh_grad',
'mean_grad',
'concat',
'add',
'multiply',
'elementwise_pow',
'scale',
'reshape',
'expand',
'tile',
'add_grad',
'divide_grad',
'sum_grad',
]
def load(path: pathlib.Path):
"""Load config from yaml file.
Args:
path (pathlib.Path): The path of yaml config.
Returns:
dict: The config info.
"""
with open(path, 'rt') as f:
return yaml.safe_load(f)
def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs):
"""Render and save Jinja2 templates to the destination directory.
Args:
src_dir (pathlib.Path): The source directory containing Jinja2 templates.
dst_dir (pathlib.Path): The destination directory to save rendered files.
*args: Additional positional arguments passed to the `render` function.
**kwargs: Additional keyword arguments passed to the `render` function.
Returns:
None
"""
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(src_dir),
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=jinja2.StrictUndefined,
extensions=['jinja2.ext.do'],
)
env.filters.update(
{
'to_paddle_attr_type': op_gen_filters.to_paddle_attr_type,
'to_paddle_input_type': op_gen_filters.to_paddle_input_type,
'to_paddle_output_type': op_gen_filters.to_paddle_output_type,
}
)
env.tests.update(
{
'scalar': op_gen_tests.is_scalar,
'intarray': op_gen_tests.is_intarray,
'datatype': op_gen_tests.is_datatype,
}
)
for tpl in env.list_templates(
filter_func=lambda name: ".h" in name or ".cc" in name
):
save(
env.get_template(tpl).render(*args, **kwargs),
dst_dir / tpl.rstrip('.j2'),
)
def save(content: str, path: pathlib.Path):
"""Saves the given string contents to a file in the specified path.
Args:
content (str): The string content that needs to be saved.
path (pathlib.Path): The path to save the file, a Pathlib path object
Returns:
None
"""
path.parent.mkdir(parents=True, exist_ok=True)
dst_content = ''
if path.is_file():
with open(path, 'r') as f:
dst_content = f.read()
if (
hashlib.md5(content.encode("UTF-8")).hexdigest()
!= hashlib.md5(dst_content.encode("UTF-8")).hexdigest()
):
with open(path, 'w') as f:
f.write(content)
print(f"Generate source file {path}")
def gen(
prim_path: pathlib.Path,
fwd_path: pathlib.Path,
fwd_legacy_path: pathlib.Path,
rev_path: pathlib.Path,
rev_legacy_path: pathlib.Path,
templates_dir: pathlib.Path,
destination_dir: pathlib.Path,
):
"""The `gen` load jinja2 templates and relative config info, use jinja2
templating engine to generate c++ code, and save the code into destination.
Args:
prim_path (pathlib.Path): The YAML file path of the primitive API.
fwd_path (pathlib.Path): The YAML file path of the forwad API.
fwd_legacy_path (pathlib.Path): The YAML file path of the legacy
forwad API.
rev_path (pathlib.Path): The YAML file path of the backward API.
rev_legacy_path (pathlib.Path): The YAML file path of the legacy
backward API.
templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file.
Returns:
None
"""
prims, fwds, legacy_fwds, revs, legacy_revs = (
load(prim_path),
load(fwd_path),
load(fwd_legacy_path),
load(rev_path),
load(rev_legacy_path),
)
apis = [{**api, **{'is_fwd': True}} for api in fwds + legacy_fwds]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + legacy_revs]
apis = [
{**api, **{'is_prim': True}}
if api['name'] in prims
else {**api, **{'is_prim': False}}
for api in apis
]
render(
templates_dir,
destination_dir,
apis=apis,
backend_white_list=BACKENDS,
vjp_white_list=VJPS,
vjp_comp_white_list=VJP_COMPS,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Generate Static Primitive API'
)
parser.add_argument(
'--prim_path',
type=str,
help='The primitive API yaml file.',
)
parser.add_argument(
'--fwd_path', type=str, help='The parsed ops yaml file.'
)
parser.add_argument(
'--fwd_legacy_path',
type=str,
help='The parsed ops yaml file.',
)
parser.add_argument(
'--rev_path', type=str, help='The parsed ops yaml file.'
)
parser.add_argument(
'--rev_legacy_path',
type=str,
help='The parsed ops yaml file.',
)
parser.add_argument(
'--templates_dir',
type=str,
help='JinJa2 templates base directory.',
)
parser.add_argument(
'--destination_dir',
type=str,
help='Destination base directory for generated file.',
)
args = parser.parse_args()
gen(
pathlib.Path(args.prim_path),
pathlib.Path(args.fwd_path),
pathlib.Path(args.fwd_legacy_path),
pathlib.Path(args.rev_path),
pathlib.Path(args.rev_legacy_path),
pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir),
)
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
namespace primitive {
namespace backend {
using Tensor = paddle::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;
{% for api in apis %}
{%- if api.name in backend_white_list -%}
{{common.sig(api.name, api.inputs, api.outputs, api.attrs, True)}};
{% endif %}
{% endfor %}
} // namespace backend
} // namespace primitive
} // namespace paddle
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
namespace paddle {
namespace primitive {
namespace backend {
{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{common.sequence('', '', ', ', inputs)}}
{%- if attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{common.sequence('', '', ', ', attrs)}}
{%- endmacro -%}
{%- macro sig(name, inputs, attrs, outputs) -%}
template <>
{{common.ret(outputs)}} {{name}}<Tensor>({{common.params(inputs, attrs)}})
{%- endmacro -%}
{% macro body(name, inputs, attrs, outputs) %}
{%- set input_names = [] -%}
{%- for i in inputs -%} {%- do input_names.append(i.name) -%} {%-endfor-%}
{%- set attr_names = [] -%}
{%- for i in attrs -%} {%- do attr_names.append(i.name) -%} {%-endfor-%}
{% filter indent(2, True) %}
VLOG(4) << "Eager Prim API {name}_ad_func call";
return ::{{name}}_ad_func({{common.args(input_names, attr_names)}});
{% endfilter %}
{% endmacro %}
{% for api in apis %}
{#- TODO(cxxly): codegen for reshape -#}
{%- if api.is_prim and api.name in backend_white_list and api.name != 'reshape' -%}
{{sig(api.name, api.inputs, api.attrs, api.outputs)}} {
{{body(api.name, api.inputs, api.attrs, api.outputs)}}
}
{% endif %}
{% endfor %}
} // namespace backend
} // namespace primitive
} // namespace paddle
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
namespace paddle {
namespace primitive {
namespace backend {
using LazyTensor = paddle::primitive::LazyTensor;
{%- macro sig(name, inputs, outputs, attrs) -%}
template <>
{{common.ret(outputs)}} {{name}}<LazyTensor>({{common.params(inputs, attrs)}})
{%- endmacro -%}
{% macro body(name, inputs, outputs, attrs) %}
{%- set output_names = [] -%}
{%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%}
{%- for input in inputs -%}
{% if input.typename=='Tensor[]' %}
std::vector<ir::OpResult> {{input.name}}_res({{input.name}}.size());
std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.begin(), [](const Tensor& t) {
return std::static_pointer_cast<LazyTensor>(t.impl())->getValue().dyn_cast<ir::OpResult>();
});
{% else %}
ir::OpResult {{input.name}}_res = std::static_pointer_cast<LazyTensor>({{input.name}}.impl())->getValue().dyn_cast<ir::OpResult>();
{% endif %}
{% endfor %}
{%- set input_names = [] -%}
{%- for i in inputs -%} {%- do input_names.append(i.name~'_res') -%} {%- endfor -%}
{%- set attr_names = [] -%}
{%- for i in attrs -%} {%- do attr_names.append(common.phi2ir_attr(i)) -%} {% endfor %}
auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}});
{% if outputs|length > 1 %}
return std::make_tuple(
{% for i in range(outputs|length) %}
Tensor(std::make_shared<LazyTensor>(std::get<{{i}}>(op_res))){%- if i!=outputs|length - 1 -%}, {% endif %}
{% endfor %}
);
{% elif outputs|length == 1 %}
return Tensor(std::make_shared<LazyTensor>(op_res));
{% else %} {#- render nothing -#}
{% endif %}
{% endmacro %}
{% for api in apis %}
{% if api.name in backend_white_list %}
{{sig(api.name, api.inputs, api.outputs, api.attrs)}} {
{{body(api.name, api.inputs, api.outputs, api.attrs)}}
}
{% endif %}
{% endfor %}
} // namespace backend
} // namespace primitive
} // namespace paddle
{%- macro sig(name, inputs, outputs, attrs, default=False) -%}
template <typename T>
{{ret(outputs)}} {{name}}({{params(inputs, attrs, default)}})
{%- endmacro %}
{%- macro params(inputs, attrs, default=False) -%}
{%- set input_params = [] -%}
{%- for i in inputs -%} {%- do input_params.append(i.typename|to_paddle_input_type(i.optional)~' '~i.name) -%} {%- endfor -%}
{%- set attr_params = [] -%}
{%- for i in attrs -%}
{%- if default -%}
{%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name~default_value(i)) -%}
{%- else -%}
{%- do attr_params.append(i.typename|to_paddle_attr_type~' '~i.name) -%}
{%- endif -%}
{%- endfor -%}
{{sequence('', '', ', ', input_params)}}
{%- if input_params|length>0 and attr_params|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{sequence('', '', ', ', attr_params)}}
{%- endmacro -%}
{%- macro default_value(attr) -%}
{%- if 'default_value' in attr %}
= {{attr.default_value}}
{%- else -%} {#- render nothing -#}
{%- endif -%}
{%- endmacro -%}
{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{sequence('', '', ', ', inputs)}}
{%- if inputs|length>0 and attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{sequence('', '', ', ', attrs)}}
{%- endmacro -%}
{%- macro ret(outputs) -%}
{%- set names = [] -%}
{%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type) -%} {%- endfor -%}
{%- if names|length > 1 -%}
std::tuple<{{sequence('', '', ', ', names)}}>
{%- else -%}
{{names[0]}}
{%- endif -%}
{%- endmacro -%}
{%- macro sequence(lsymbol, rsymbol, delimiter, items) -%}
{{lsymbol}}{%- for item in items -%}{{item}}{{delimiter if not loop.last else "" }}{%- endfor -%}{{rsymbol}}
{%- endmacro -%}
{%- macro phi2ir_attr(attr) -%}
{%- if attr.typename is intarray -%}
{{intarray2ir(attr.name)}}
{%- elif attr.typename is scalar -%}
{{scalar2ir(attr.name, attr.data_type)}}
{%- else -%}
{{attr.name}}
{%- endif -%}
{%- endmacro %}
{%- macro intarray2ir(name) -%}
{{name}}.GetData()
{%- endmacro -%}
{%- macro scalar2ir(name, data_type) -%}
{{name}}.to<{{data_type}}>()
{%- endmacro -%}
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#pragma once
#include "paddle/fluid/primitive/backend/backend.h"
namespace paddle {
namespace primitive {
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
{% for api in apis %}
{%- if api.is_prim and api.name in backend_white_list -%}
{%- set input_names = [] -%}
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names = [] -%}
{%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %}
{{common.sig(api.name, api.inputs, api.outputs, api.attrs, True)}} {
return backend::{{api.name}}<T>({{common.args(input_names, attr_names)}});
}
{% endif %}
{% endfor %}
} // namespace primitive
} // namespace paddle
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/rule/vjp/details.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
namespace paddle {
namespace primitive {
{% macro sig(fwd_name, name, inputs, attrs, outputs) -%}
std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs)}}, const std::vector<std::vector<bool>>& stop_gradients)
{%- endmacro -%}
{% macro body(api) %}
std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg: stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
{% if 'composite' in api and api.name in vjp_comp_white_list %}
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
{% filter indent(2, True) %}{{body_prim(api)}}{% endfilter %}
} else {
{% filter indent(2, True) %}{{body_unprim(api)}}{% endfilter %}
}
{% else %}
{{body_unprim(api)}}
{% endif %}
return vjp_res;
{% endmacro %}
{% macro body_unprim(api) %}
{%- set input_names=[] -%}
{%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%}
{%- set attr_names=[] -%}
{%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {%- endfor %}
auto op_res = backend::{{api.name}}<LazyTensor>({{common.args(input_names, attr_names)}});
{% if api.outputs|length > 1 %}
{% for i in range(api.outputs|length) %}
auto out{{i}} = std::get<{{i}}>(op_res);
{% if api.outputs[i].typename=='Tensor' %}
vjp_res[{{i}}][0] = !stop_gradients[{{i}}][0] ? out{{i}} : vjp_res[{{i}}][0];
{% else %}
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
vjp_res[{{i}}][i] = !stop_gradients[{{i}}][i] ? out{{i}}[i] : vjp_res[{{i}}][i];
}
{% endif %}
{% endfor %}
{% elif api.outputs|length == 1 %}
{% if api.outputs[0].typename=='Tensor' %}
vjp_res[0][0] = !stop_gradients[0][0] ? op_res : vjp_res[0][0];
{% else %}
for (size_t i=0; i< stop_gradients[0].size(); i++ ) {
vjp_res[0][i] = !stop_gradients[0][i] ? op_res[i] : vjp_res[0][i];
}
{% endif %}
{% else %} {#- render nothing -#}
{% endif %}
{% endmacro %}
{% macro body_prim(api) %}
{% for i in range(api.outputs|length) %}
{% if api.outputs[i].typename=='Tensor' %}
paddle::Tensor* {{api.outputs[i].name}} = !stop_gradients[{{i}}][0] ? &vjp_res[{{i}}][0] : nullptr;
{% else %}
std::vector<paddle::Tensor*> {{api.outputs[i].name}}(stop_gradients[{{i}}].size(), nullptr);
for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
{{api.outputs[i].name}} = !stop_gradients[{{i}}][i] ? &vjp_res[{{i}}][i] : nullptr;
}
{% endif %}
{% endfor %}
details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}});
{% endmacro %}
{%- set api_map = {} -%}
{%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%}
{%- for api in apis %}
{%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%}
{%- set backward_api = api_map[api.backward] %}
{{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} {
{% filter indent(2, True) %}
{{body(backward_api)}}
{% endfilter %}
}
{% endif %}
{% endfor %}
} // namespace primitive
} // namespace paddle
{% import "common.j2" as common %}
// Auto Generated, DO NOT EDIT!
#pragma once
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
using IntArray = paddle::experimental::IntArray;
{% macro sig(fwd_name, name, inputs, attrs, outputs) %}
std::vector<std::vector<paddle::Tensor>> {{fwd_name}}_vjp({{common.params(inputs, attrs)}}, const std::vector<std::vector<bool>>& stop_gradients);
{% endmacro %}
{%- set api_map = {} -%}
{%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%}
{% for api in apis %}
{%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%}
{%- set backward_api = api_map[api.backward] -%}
{{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}}
{% endif %}
{% endfor %}
} // namespace primitive
} // namespace paddle
- add
- subtract
- multiply
- divide
- less_equal
- less_than
- equal
- not_equal
- greater_equal
- greater_than
- bitwise_and
- bitwise_not
- bitwise_or
- bitwise_xor
- exp
- scale
- matmul
- expand
- sum
- abs
- assign
- concat
- elementwise_pow
- floor
- gather
- gather_nd
- log
- max
- maximum
- minimum
- prod
- roll
- scatter
- scatter_nd_add
- tile
- transpose
- pad
- sqrt
- cumsum
- put_along_axis
- equal
- greater_than
- less_equal
- sin
- cos
- where
- split
- reshape
- erf
- tanh
- full
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/primitive/backend/eager_backend.h"
#include "paddle/fluid/primitive/backend/static_backend.h"
namespace paddle {
namespace primitive {
// why exist this file?
// We provide this file to divide
// the primitive ops set in the backend.
// It will be called by the vjp composite
// rules and composite ops rules.
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y) {
return backend::divide<T>(x, y);
}
template <typename T>
Tensor add(const Tensor& x, const Tensor& y) {
return backend::add<T>(x, y);
}
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y) {
return backend::multiply<T>(x, y);
}
template <typename T>
Tensor elementwise_pow(const Tensor& x, const Tensor& y) {
return backend::elementwise_pow<T>(x, y);
}
template <typename T>
Tensor scale(const Tensor& x,
const Scalar& scale = 1.0,
float bias = 0.0,
bool bias_after_scale = true) {
return backend::scale<T>(x, scale, bias, bias_after_scale);
}
template <typename T>
Tensor sum(const Tensor& x,
const IntArray& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false) {
return backend::sum<T>(x, axis, dtype, keepdim);
}
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace()) {
return backend::full<T>(shape, value, dtype, place);
}
template <typename T>
std::tuple<Tensor, Tensor> reshape(const Tensor& x, const IntArray& shape) {
return backend::reshape<T>(x, shape);
}
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape) {
return backend::expand<T>(x, shape);
}
template <typename T>
Tensor tile(const Tensor& x, const IntArray& repeat_times = {}) {
return backend::tile<T>(x, repeat_times);
}
} // namespace primitive
} // namespace paddle
file(GLOB VJP_SRCS "vjp.cc")
set(VJP_SRCS
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/rule/vjp/generated/generated_vjp.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc)
cc_library(
primitive_vjp_experimental
SRCS ${VJP_SRCS}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Auto Generated, DO NOT EDIT!
#include "paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/fluid/primitive/backend/backend.h"
#include "paddle/fluid/primitive/rule/vjp/details.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
namespace paddle {
namespace primitive {
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(2, std::vector<Tensor>());
// get concat_grad res.
std::vector<Tensor> op_res =
backend::concat_grad<primitive::LazyTensor>(x, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0].resize(op_res.size());
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
if (!stop_gradients[0][idx]) {
vjp_res[0][idx] = op_res[idx];
}
}
// vjp_res[1] is axis's grad which is attribute (no grad).
vjp_res[1].resize(1);
return vjp_res;
}
} // namespace primitive
} // namespace paddle
......@@ -12,13 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/primitive/backend/eager_backend.h"
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#pragma once
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
namespace backend {} // namespace backend
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/primitive/rule/vjp/details.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
// TODO(wanghao107):
// op's vjp will be auto generated.
namespace paddle {
namespace primitive {
std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out,
const Tensor& grad_out,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get tanh_grad res.
Tensor op_res = backend::tanh_grad<primitive::LazyTensor>(out, grad_out);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
uint32_t num_res = grad_op->num_results();
std::vector<ir::Attribute> ir_stop_gradients(num_res);
for (size_t i = 0; i < num_res; i++) {
if (stop_gradients[0][i]) {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), true);
} else {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), false);
}
}
grad_op->set_attribute(
"stop_gradient",
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients));
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get mean_grad res.
Tensor op_res = backend::mean_grad<primitive::LazyTensor>(
x, out_grad, axis, keepdim, reduce_all);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
uint32_t num_res = grad_op->num_results();
std::vector<ir::Attribute> ir_stop_gradients(num_res);
for (size_t i = 0; i < num_res; i++) {
if (stop_gradients[0][i]) {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), true);
} else {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), false);
}
}
grad_op->set_attribute(
"stop_gradient",
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients));
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
// get add_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::add_grad<primitive::LazyTensor>(x, y, out_grad, axis);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op = std::static_pointer_cast<primitive::LazyTensor>(
std::get<0>(op_res).impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
std::vector<ir::Attribute> ir_stop_gradients(2);
for (size_t i = 0; i < 2; i++) {
if (stop_gradients[i][0]) {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), true);
} else {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), false);
}
}
grad_op->set_attribute(
"stop_gradient",
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients));
// construct vjp result by op result and stop_gradients info
vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0];
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(2, std::vector<Tensor>());
// get concat_grad res.
std::vector<Tensor> op_res =
backend::concat_grad<primitive::LazyTensor>(x, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0].resize(op_res.size());
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
if (!stop_gradients[0][idx]) {
vjp_res[0][idx] = op_res[idx];
}
}
// vjp_res[1] is axis's grad which is attribute (no grad).
vjp_res[1].resize(1);
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> divide_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
// get divide_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::divide_grad<primitive::LazyTensor>(x, y, out, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0];
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
} else {
// get divide_grad prim mode res.
Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr;
Tensor* dy = !stop_gradients[1][0] ? &vjp_res[1][0] : nullptr;
details::divide_grad<LazyTensor>(x, y, out, out_grad, axis, dx, dy);
}
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> sum_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
// get sum_grad res.
Tensor op_res = backend::sum_grad<primitive::LazyTensor>(
x, out_grad, axis, keepdim, reduce_all);
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
} else {
// get divide_grad prim mode res.
Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr;
details::sum_grad<LazyTensor>(
x, out_grad, axis, keepdim, reduce_all, x_grad);
}
return vjp_res;
}
} // namespace primitive
} // namespace paddle
......@@ -14,58 +14,5 @@
#pragma once
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
using IntArray = paddle::experimental::IntArray;
// TODO(wanghao107):
// op's vjp will be auto generated.
std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out,
const Tensor& grad_out,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> divide_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> sum_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive
} // namespace paddle
#include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h"
#include "paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册