未验证 提交 2b5466f9 编写于 作者: C Charles-hit 提交者: GitHub

[PRIM][IR]Support prim in new ir (#56342)

* support ir api form prim

* convert vector of int to intarray

* support ir api for prim

* support vjp prim mode in new ir

* remove useless code

* add test for prim

* modify utils

* remove useless code

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 a533dae3
......@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
namespace phi {
......@@ -41,8 +41,7 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::Strings>;
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
template class TypeInfoTraits<phi::TensorBase,
paddle::primitive::experimental::StaticTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::primitive::LazyTensor>;
template class TypeInfoTraits<phi::TensorBase,
paddle::framework::VariableRefArray>;
......
......@@ -95,6 +95,8 @@ API_LIST = [
'expand',
'tile',
'add_grad',
'divide_grad',
'sum_grad',
]
OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType'
......
......@@ -21,10 +21,13 @@ from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
gen_op_vjp_str,
vjp_interface_gen_op_list,
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str
from vjp_interface_gen_op_list import (
vjp_interface_declare_gen_op_list,
vjp_interface_implementation_gen_op_list,
)
# =====================================
# String Template for h file code gen
......@@ -112,7 +115,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/ir/core/op_base.h"
{input}
......@@ -756,7 +759,7 @@ def OpGenerator(
if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_gen_op_list
and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list
):
op_interfaces += ["VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info)
......@@ -1055,7 +1058,8 @@ def OpGenerator(
# TODO(chenzhiyang) add vjp gen code
if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_gen_op_list
and op_info.op_phi_name[0]
in vjp_interface_implementation_gen_op_list
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# generator interfaces
from vjp_interface_gen_op_list import vjp_interface_gen_op_list
from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list
OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
......@@ -23,13 +23,13 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
"""
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::experimental::StaticTensor>(op_obj.{input_name}()));"""
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));"""
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>(out_grads[{idx1}][{idx2}]));"""
Tensor {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}]));"""
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>(out_grads[{idx1}]));"""
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}]));"""
OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();"""
......@@ -39,7 +39,7 @@ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """
OP_VJP_CALL_VJP_TEMPLATE = """ std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::{op_phi_name}_vjp(
primitive::{op_phi_name}_vjp(
{inputs_list}stop_gradients);"""
OP_VJP_STOPGRADIENT_TEMPLATE = """
......@@ -48,7 +48,7 @@ OP_VJP_STOPGRADIENT_TEMPLATE = """
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {{
if(tensor_res[i][j].defined()){{
res[i][j] = std::static_pointer_cast<primitive::experimental::StaticTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(tensor_res[i][j].impl())->getValue().dyn_cast<ir::OpResult>();
}}
}}
}}"""
......@@ -166,6 +166,6 @@ def gen_exclusive_interface_str(op_info):
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_gen_op_list:
if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
......@@ -21,4 +21,5 @@
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_gen_op_list = ["tanh", "mean", "add"]
vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"]
vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"]
......@@ -15,7 +15,7 @@
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
......@@ -23,5 +23,32 @@
// this file will be generated in pd_op.cc
namespace paddle {
namespace dialect {} // namespace dialect
namespace dialect {
std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
SumOp op_obj = op->dyn_cast<SumOp>();
Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));
IntArray axis = op_obj.axis()
.GetDefiningOp()
->attribute("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = primitive::sum_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) {
res[0][0] =
std::static_pointer_cast<primitive::LazyTensor>(tensor_res[0][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
return res;
}
} // namespace dialect
} // namespace paddle
add_subdirectory(utils)
add_subdirectory(backend)
add_subdirectory(rule)
......@@ -19,8 +19,6 @@
namespace paddle {
namespace primitive {
namespace backend {
namespace experimental {} // namespace experimental
} // namespace backend
namespace backend {} // namespace backend
} // namespace primitive
} // namespace paddle
......@@ -21,8 +21,6 @@
namespace paddle {
namespace primitive {
namespace backend {
namespace experimental {} // namespace experimental
} // namespace backend
namespace backend {} // namespace backend
} // namespace primitive
} // namespace paddle
......@@ -15,65 +15,177 @@
#include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
namespace paddle {
namespace primitive {
namespace backend {
namespace experimental {
using StaticTensor = paddle::primitive::experimental::StaticTensor;
using LazyTensor = paddle::primitive::LazyTensor;
template <>
Tensor tanh_grad<StaticTensor>(const Tensor& out, const Tensor& grad_out) {
ir::OpResult out_res = std::static_pointer_cast<StaticTensor>(out.impl())
Tensor tanh_grad<LazyTensor>(const Tensor& out, const Tensor& grad_out) {
ir::OpResult out_res = std::static_pointer_cast<LazyTensor>(out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult grad_out_res =
std::static_pointer_cast<StaticTensor>(grad_out.impl())
std::static_pointer_cast<LazyTensor>(grad_out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res);
return Tensor(
std::make_shared<primitive::experimental::StaticTensor>(op_res));
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor mean_grad<StaticTensor>(const Tensor& x,
Tensor mean_grad<LazyTensor>(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<StaticTensor>(x.impl())
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<StaticTensor>(out_grad.impl())
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::mean_grad(
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(
std::make_shared<primitive::experimental::StaticTensor>(op_res));
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
std::tuple<Tensor, Tensor> add_grad<StaticTensor>(const Tensor& x,
Tensor divide<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::divide(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor add<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::add(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor multiply<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::multiply(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor elementwise_pow<LazyTensor>(const Tensor& x, const Tensor& y) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::elementwise_pow(x_res, y_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor scale<LazyTensor>(const Tensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res =
paddle::dialect::scale(x_res, scale.to<float>(), bias, bias_after_scale);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor sum<LazyTensor>(const Tensor& x,
const IntArray& axis,
phi::DataType dtype,
bool keepdim) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res =
paddle::dialect::sum(x_res, axis.GetData(), dtype, keepdim);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor full<LazyTensor>(const IntArray& shape,
const Scalar& value,
phi::DataType dtype,
phi::Place place) {
ir::OpResult op_res =
paddle::dialect::full(shape.GetData(), value.to<float>(), dtype, place);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
std::tuple<Tensor, Tensor> reshape<LazyTensor>(const Tensor& x,
const IntArray& shape) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::reshape(x_res, shape.GetData());
return std::make_tuple(
Tensor(std::make_shared<primitive::LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<primitive::LazyTensor>(std::get<1>(op_res))));
}
template <>
Tensor expand<LazyTensor>(const Tensor& x, const IntArray& shape) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::expand(x_res, shape.GetData());
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
Tensor tile<LazyTensor>(const Tensor& x, const IntArray& repeat_times) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::tile(x_res, repeat_times.GetData());
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
template <>
std::tuple<Tensor, Tensor> add_grad<LazyTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<StaticTensor>(x.impl())
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<StaticTensor>(y.impl())
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<StaticTensor>(out_grad.impl())
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
......@@ -81,12 +193,55 @@ std::tuple<Tensor, Tensor> add_grad<StaticTensor>(const Tensor& x,
paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis);
return std::make_tuple(
Tensor(std::make_shared<primitive::experimental::StaticTensor>(
std::get<0>(op_res))),
Tensor(std::make_shared<primitive::experimental::StaticTensor>(
std::get<1>(op_res))));
Tensor(std::make_shared<primitive::LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<primitive::LazyTensor>(std::get<1>(op_res))));
}
template <>
std::tuple<Tensor, Tensor> divide_grad<LazyTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<LazyTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_res = std::static_pointer_cast<LazyTensor>(out.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::divide_grad(x_res, y_res, out_res, out_grad_res, axis);
return std::make_tuple(
Tensor(std::make_shared<LazyTensor>(std::get<0>(op_res))),
Tensor(std::make_shared<LazyTensor>(std::get<1>(op_res))));
}
template <>
Tensor sum_grad<LazyTensor>(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<LazyTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::sum_grad(
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<LazyTensor>(op_res));
}
} // namespace experimental
} // namespace backend
} // namespace primitive
} // namespace paddle
......@@ -23,7 +23,6 @@
namespace paddle {
namespace primitive {
namespace backend {
namespace experimental {
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
......@@ -43,7 +42,60 @@ std::tuple<Tensor, Tensor> add_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis);
} // namespace experimental
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y);
template <typename T>
Tensor add(const Tensor& x, const Tensor& y);
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
template <typename T>
Tensor elementwise_pow(const Tensor& x, const Tensor& y);
template <typename T>
Tensor scale(const Tensor& x,
const Scalar& scale = 1.0,
float bias = 0.0,
bool bias_after_scale = true);
template <typename T>
Tensor sum(const Tensor& x,
const IntArray& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false);
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace());
template <typename T>
std::tuple<Tensor, Tensor> reshape(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape);
template <typename T>
Tensor tile(const Tensor& x, const IntArray& repeat_times = {});
template <typename T>
std::tuple<Tensor, Tensor> divide_grad(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis);
template <typename T>
Tensor sum_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all);
} // namespace backend
} // namespace primitive
} // namespace paddle
......@@ -18,12 +18,71 @@
namespace paddle {
namespace primitive {
namespace experimental {
// why exist this file?
// We provide this file to divide
// the primitive ops set in the backend.
// It will be called by the vjp composite
// rules and composite ops rules.
} // namespace experimental
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y) {
return backend::divide<T>(x, y);
}
template <typename T>
Tensor add(const Tensor& x, const Tensor& y) {
return backend::add<T>(x, y);
}
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y) {
return backend::multiply<T>(x, y);
}
template <typename T>
Tensor elementwise_pow(const Tensor& x, const Tensor& y) {
return backend::elementwise_pow<T>(x, y);
}
template <typename T>
Tensor scale(const Tensor& x,
const Scalar& scale = 1.0,
float bias = 0.0,
bool bias_after_scale = true) {
return backend::scale<T>(x, scale, bias, bias_after_scale);
}
template <typename T>
Tensor sum(const Tensor& x,
const IntArray& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false) {
return backend::sum<T>(x, axis, dtype, keepdim);
}
template <typename T>
Tensor full(const IntArray& shape,
const Scalar& value,
phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace()) {
return backend::full<T>(shape, value, dtype, place);
}
template <typename T>
std::tuple<Tensor, Tensor> reshape(const Tensor& x, const IntArray& shape) {
return backend::reshape<T>(x, shape);
}
template <typename T>
Tensor expand(const Tensor& x, const IntArray& shape) {
return backend::expand<T>(x, shape);
}
template <typename T>
Tensor tile(const Tensor& x, const IntArray& repeat_times = {}) {
return backend::tile<T>(x, repeat_times);
}
} // namespace primitive
} // namespace paddle
file(GLOB VJP_SRCS "*.cc")
file(GLOB VJP_SRCS "vjp.cc")
cc_library(
primitive_vjp_experimental
SRCS ${VJP_SRCS}
DEPS primitive_backend_static_experimental)
DEPS primitive_backend_static_experimental static_global_utils
primitive_static_utils_experimental)
add_dependencies(primitive_vjp_experimental pd_dialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include <vector>
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
namespace paddle {
namespace primitive {
namespace details {
template <typename T>
void divide_grad(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto denominator =
elementwise_pow<T>(y, full<T>(y.shape(), 2.0, y.dtype(), y.place()));
auto dy_res = scale<T>(
multiply<T>(divide<T>(x, denominator), out_grad), -1.0, 0.0, true);
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto reshape_res = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
auto dy_tmp = std::get<0>(reshape_res);
set_output<T>(dy_tmp, dy);
}
} else {
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto dx_res = multiply<T>(divide<T>(one_tensor, y), out_grad);
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_reduce_reshape_res =
reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
auto dx_tmp = std::get<0>(dx_reduce_reshape_res);
set_output<T>(dx_tmp, dx);
}
} else {
set_output<T>(dx_res, dx);
}
} // indicate we will compute dx
}
template <typename T>
void sum_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (x_dim_size == 1) {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else {
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_reshape_res = reshape<T>(out_grad, out_grad_shape);
auto out_grad_ = std::get<0>(out_grad_reshape_res);
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
}
set_output<T>(x_grad_tmp, x_grad);
}
} // namespace details
} // namespace primitive
} // namespace paddle
......@@ -14,15 +14,17 @@
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/prim/utils/static/static_global_utils.h"
#include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/fluid/primitive/rule/vjp/details.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
#include "paddle/ir/core/operation.h"
// TODO(wanghao107):
// op's vjp will be auto generated.
namespace paddle {
namespace primitive {
namespace experimental {
std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out,
......@@ -31,16 +33,13 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp(
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get tanh_grad res.
Tensor op_res =
backend::experimental::tanh_grad<primitive::experimental::StaticTensor>(
out, grad_out);
Tensor op_res = backend::tanh_grad<primitive::LazyTensor>(out, grad_out);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::StaticTensor>(
op_res.impl())
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
......@@ -76,16 +75,14 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get mean_grad res.
Tensor op_res =
backend::experimental::mean_grad<primitive::experimental::StaticTensor>(
Tensor op_res = backend::mean_grad<primitive::LazyTensor>(
x, out_grad, axis, keepdim, reduce_all);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::StaticTensor>(
op_res.impl())
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
......@@ -119,16 +116,14 @@ std::vector<std::vector<paddle::Tensor>> add_vjp(
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
// get mean_grad res.
// get add_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::experimental::add_grad<primitive::experimental::StaticTensor>(
x, y, out_grad, axis);
backend::add_grad<primitive::LazyTensor>(x, y, out_grad, axis);
// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::StaticTensor>(
ir::Operation* grad_op = std::static_pointer_cast<primitive::LazyTensor>(
std::get<0>(op_res).impl())
->getValue()
.dyn_cast<ir::OpResult>()
......@@ -152,6 +147,57 @@ std::vector<std::vector<paddle::Tensor>> add_vjp(
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
return vjp_res;
}
} // namespace experimental
std::vector<std::vector<paddle::Tensor>> divide_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
// get divide_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::divide_grad<primitive::LazyTensor>(x, y, out, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0];
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
} else {
// get divide_grad prim mode res.
Tensor* dx = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr;
Tensor* dy = !stop_gradients[1][0] ? &vjp_res[1][0] : nullptr;
details::divide_grad<LazyTensor>(x, y, out, out_grad, axis, dx, dy);
}
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> sum_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
if (!paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
// get sum_grad res.
Tensor op_res = backend::sum_grad<primitive::LazyTensor>(
x, out_grad, axis, keepdim, reduce_all);
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
} else {
// get divide_grad prim mode res.
Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr;
details::sum_grad<LazyTensor>(
x, out_grad, axis, keepdim, reduce_all, x_grad);
}
return vjp_res;
}
} // namespace primitive
} // namespace paddle
......@@ -14,13 +14,6 @@
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include <vector>
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h"
......@@ -28,7 +21,6 @@
namespace paddle {
namespace primitive {
namespace experimental {
using IntArray = paddle::experimental::IntArray;
// TODO(wanghao107):
......@@ -53,11 +45,21 @@ std::vector<std::vector<paddle::Tensor>> add_vjp(
int axis,
const std::vector<std::vector<bool>>& stop_gradients);
namespace details {
// NOTE: this namespace will store
// primitive ops grad composite rules.
std::vector<std::vector<paddle::Tensor>> divide_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> sum_vjp(
const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace details
} // namespace experimental
} // namespace primitive
} // namespace paddle
......@@ -22,34 +22,36 @@
namespace paddle {
namespace primitive {
namespace experimental {
class StaticTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, StaticTensor> {
class LazyTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, LazyTensor> {
public:
explicit StaticTensor(ir::Value value)
explicit LazyTensor(ir::Value value)
: value_(value),
dims_(value.type().dyn_cast<dialect::DenseTensorType>().dims()) {}
static const char* name() { return "StaticTensor"; }
static const char* name() { return "LazyTensor"; }
const phi::DDim& dims() const override { return dims_; }
int64_t numel() const override { return product(dims()); }
DataType dtype() const override {
return paddle::dialect::TransToPhiDataType(value_.type());
return paddle::dialect::TransToPhiDataType(
value_.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype());
}
ir::Value getValue() const { return value_; }
const phi::Place& place() const override { return place_; }
bool initialized() const override { return value_.impl() != nullptr; }
private:
ir::Value value_;
mutable phi::DDim dims_;
phi::Place place_;
};
} // namespace experimental
} // namespace primitive
} // namespace paddle
if(WITH_PYTHON OR NOT ON_INFER)
cc_library(
primitive_eager_utils_experimental
SRCS eager_utils.cc
DEPS phi common_infer_shape_functions)
endif()
cc_library(
primitive_static_utils_experimental
SRCS static_utils.cc
DEPS phi common_infer_shape_functions)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/primitive/utils/utils.h"
namespace paddle {
namespace primitive {
template <>
void set_output<Tensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
x->set_impl(x_tmp.impl());
x->set_autograd_meta(x_tmp.mutable_autograd_meta());
}
} // namespace primitive
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
namespace paddle {
namespace primitive {
template <>
void set_output<LazyTensor>(const paddle::Tensor& x_tmp, paddle::Tensor* x) {
x->set_impl(x_tmp.impl());
}
} // namespace primitive
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace primitive {
template <typename T>
void set_output(const Tensor& x_tmp, Tensor* x);
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
static std::vector<int64_t> get_unsqueeze_dims(
const Tensor& origin, const std::vector<int64_t>& axis) {
auto origin_dims = origin.shape();
auto total_shape_size = origin_dims.size() + axis.size();
std::vector<int64_t> result;
size_t j = 0, k = 0;
for (size_t i = 0; i < total_shape_size; ++i) {
if (j < axis.size() && axis[j] == int64_t(i)) {
result.push_back(1);
j++;
} else {
PADDLE_ENFORCE_LT(
k,
origin_dims.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in origin_dims[%lu].",
k,
origin_dims.size()));
result.push_back(origin_dims[k]);
k++;
}
}
return result;
}
// These method don't need to be specified
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
std::vector<int64_t> result;
int bat = dout_dims.size() - in_dims.size();
for (int i = 0; i < bat; ++i) {
result.push_back(i);
}
for (int i = 0; i < in_dims.size(); ++i) {
if (in_dims[i] == 1) {
result.push_back(i + bat);
} else {
PADDLE_ENFORCE_EQ(
in_dims[i],
dout_dims[i + bat],
platform::errors::InvalidArgument(
"ReduceDims dimension mismatch. Operands could "
"not be broadcast together with the shape of dout = [%s] and "
"the shape of in_dims = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d.",
dout_dims,
in_dims,
dout_dims[i + bat],
in_dims[i],
i));
}
}
return phi::make_ddim(result);
}
static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
const phi::DDim& y_dims) {
auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims);
return get_reduce_dims_from_out(out_dims, x_dims);
}
} // namespace primitive
} // namespace paddle
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import ir
from paddle.fluid.core import call_vjp
paddle.enable_static()
def get_ir_program_0():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.tensor.fill_constant(
shape=[1, 4], dtype='float32', value=2.0
)
x.stop_gradient = False
y = paddle.tensor.fill_constant(shape=[4], dtype='float32', value=1.0)
y.stop_gradiable = False
dout = paddle.tensor.fill_constant(
shape=[1, 4], dtype='float32', value=1.0
)
dout.stop_gradiable = False
out = paddle.divide(x, y)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
def get_ir_program_1():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.tensor.fill_constant(
shape=[4, 5], dtype='float32', value=2.0
)
x.stop_gradient = False
dout = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=1.0
)
dout.stop_gradiable = False
out = paddle.sum(x)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TestVjpPrim(unittest.TestCase):
def test_divide_grad_prim_case1(self):
newir_program = get_ir_program_0()
paddle.fluid.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [False]]
divide_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(divide_op, out_grads, stop_gradients)
reshape_op2 = newir_program.block().ops[-1]
reshape_op1 = newir_program.block().ops[-8]
self.assertEqual(len(grad_outs), 2)
self.assertEqual(len(newir_program.block().ops), 21)
self.assertEqual(reshape_op2.result(0), grad_outs[0][0])
self.assertEqual(reshape_op1.result(0), grad_outs[1][0])
all_op_names = [
"pd.full",
"pd.full",
"pd.full",
"pd.divide",
"pd.full",
"pd.elementwise_pow",
"pd.divide",
"pd.multiply",
"pd.full",
"pd.scale",
"pd.full_int_array",
"pd.sum",
"pd.full_int_array",
"pd.reshape",
"pd.full",
"pd.divide",
"pd.multiply",
"pd.full_int_array",
"pd.sum",
"pd.full_int_array",
"pd.reshape",
]
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
def test_divide_grad_no_prim(self):
newir_program = get_ir_program_0()
paddle.fluid.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [False]]
divide_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(divide_op, out_grads, stop_gradients)
self.assertEqual(len(grad_outs), 2)
self.assertEqual(
grad_outs[0][0].get_defining_op().name(), "pd.divide_grad"
)
self.assertEqual(
grad_outs[1][0].get_defining_op().name(), "pd.divide_grad"
)
self.assertEqual(len(newir_program.block().ops), 5)
def test_sum_grad_prim(self):
newir_program = get_ir_program_1()
paddle.fluid.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False]]
sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
expand_op = newir_program.block().ops[-1]
self.assertEqual(len(grad_outs), 1)
self.assertEqual(len(newir_program.block().ops), 8)
self.assertEqual(expand_op.result(0), grad_outs[0][0])
all_op_names = [
"pd.full",
"pd.full",
"pd.full_int_array",
"pd.sum",
"pd.full_int_array",
"pd.reshape",
"pd.full_int_array",
"pd.expand",
]
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
def test_sum_grad_no_prim(self):
newir_program = get_ir_program_1()
paddle.fluid.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False]]
sum_op = newir_program.block().ops[-1]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(sum_op, out_grads, stop_gradients)
self.assertEqual(len(grad_outs), 1)
self.assertEqual(
grad_outs[0][0].get_defining_op().name(), "pd.sum_grad"
)
self.assertEqual(len(newir_program.block().ops), 6)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册