未验证 提交 6a42ddc6 编写于 作者: C Charles-hit 提交者: GitHub

Rename desctensor (#56334)

* add flags for cinn test

* rename DescTensor

* remove useless code

* modify code style

* modify code style

* modify code style
上级 2a378ff5
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/raw_tensor.h"
#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/string_array.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/static_tensor.h"
namespace phi { namespace phi {
...@@ -42,7 +42,7 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>; ...@@ -42,7 +42,7 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>; template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>; template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
template class TypeInfoTraits<phi::TensorBase, template class TypeInfoTraits<phi::TensorBase,
paddle::primitive::experimental::DescTensor>; paddle::primitive::experimental::StaticTensor>;
template class TypeInfoTraits<phi::TensorBase, template class TypeInfoTraits<phi::TensorBase,
paddle::framework::VariableRefArray>; paddle::framework::VariableRefArray>;
......
...@@ -111,7 +111,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g ...@@ -111,7 +111,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
{input} {input}
......
...@@ -23,15 +23,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ ...@@ -23,15 +23,15 @@ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
""" """
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """
{input_type} {input_name}(std::make_shared<primitive::experimental::DescTensor>(op_obj.{input_name}())); {input_type} {input_name}(std::make_shared<primitive::experimental::StaticTensor>(op_obj.{input_name}()));
""" """
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """
Tensor {output_grad_name}(std::make_shared<primitive::experimental::DescTensor>((out_grads[{idx1}][{idx2}]); Tensor {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>((out_grads[{idx1}][{idx2}]);
""" """
OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """
std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::DescTensor>((out_grads[{idx1}]); std::vector<Tensor> {output_grad_name}(std::make_shared<primitive::experimental::StaticTensor>((out_grads[{idx1}]);
""" """
OP_VJP_CALL_VJP_TEMPLATE = """ OP_VJP_CALL_VJP_TEMPLATE = """
...@@ -41,7 +41,7 @@ OP_VJP_CALL_VJP_TEMPLATE = """ ...@@ -41,7 +41,7 @@ OP_VJP_CALL_VJP_TEMPLATE = """
OP_VJP_STOPGRADIENT_TEMPLATE = """ OP_VJP_STOPGRADIENT_TEMPLATE = """
if(!stop_gradients[{idx1}][{idx2}]){{ if(!stop_gradients[{idx1}][{idx2}]){{
res[{idx1}][{idx2}] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[{idx1}][{idx2}] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[idx1][idx2].impl()) tensor_res[idx1][idx2].impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
...@@ -32,14 +32,14 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp( ...@@ -32,14 +32,14 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
const std::vector<std::vector<bool>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
TanhOp op_obj = op->dyn_cast<TanhOp>(); TanhOp op_obj = op->dyn_cast<TanhOp>();
Tensor out( Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out())); std::make_shared<primitive::experimental::StaticTensor>(op_obj.out()));
Tensor grad_out( Tensor grad_out(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
...@@ -57,14 +57,14 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp( ...@@ -57,14 +57,14 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
// Support inplace in the future. // Support inplace in the future.
Tanh_Op op_obj = op->dyn_cast<Tanh_Op>(); Tanh_Op op_obj = op->dyn_cast<Tanh_Op>();
Tensor out( Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out())); std::make_shared<primitive::experimental::StaticTensor>(op_obj.out()));
Tensor grad_out( Tensor grad_out(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
...@@ -77,9 +77,9 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp( ...@@ -77,9 +77,9 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
MeanOp op_obj = op->dyn_cast<MeanOp>(); MeanOp op_obj = op->dyn_cast<MeanOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x())); Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor out_grad( Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
IntArray axis = op->attribute("axis") IntArray axis = op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>() .dyn_cast<paddle::dialect::IntArrayAttribute>()
...@@ -91,7 +91,7 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp( ...@@ -91,7 +91,7 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients); x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (tensor_res[0][0].defined()) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::StaticTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
...@@ -104,10 +104,10 @@ std::vector<std::vector<ir::OpResult>> AddOp::Vjp( ...@@ -104,10 +104,10 @@ std::vector<std::vector<ir::OpResult>> AddOp::Vjp(
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
AddOp op_obj = op->dyn_cast<AddOp>(); AddOp op_obj = op->dyn_cast<AddOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x())); Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y())); Tensor y(std::make_shared<primitive::experimental::StaticTensor>(op_obj.y()));
Tensor out_grad( Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
int axis = -1; int axis = -1;
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
...@@ -115,10 +115,11 @@ std::vector<std::vector<ir::OpResult>> AddOp::Vjp( ...@@ -115,10 +115,11 @@ std::vector<std::vector<ir::OpResult>> AddOp::Vjp(
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) { if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[i][0] =
tensor_res[i][0].impl()) std::static_pointer_cast<primitive::experimental::StaticTensor>(
->getValue() tensor_res[i][0].impl())
.dyn_cast<ir::OpResult>(); ->getValue()
.dyn_cast<ir::OpResult>();
} }
} }
return res; return res;
...@@ -129,10 +130,10 @@ std::vector<std::vector<ir::OpResult>> Add_Op::Vjp( ...@@ -129,10 +130,10 @@ std::vector<std::vector<ir::OpResult>> Add_Op::Vjp(
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
Add_Op op_obj = op->dyn_cast<Add_Op>(); Add_Op op_obj = op->dyn_cast<Add_Op>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x())); Tensor x(std::make_shared<primitive::experimental::StaticTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y())); Tensor y(std::make_shared<primitive::experimental::StaticTensor>(op_obj.y()));
Tensor out_grad( Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::StaticTensor>(out_grads[0][0]));
int axis = -1; int axis = -1;
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
...@@ -140,10 +141,11 @@ std::vector<std::vector<ir::OpResult>> Add_Op::Vjp( ...@@ -140,10 +141,11 @@ std::vector<std::vector<ir::OpResult>> Add_Op::Vjp(
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) { if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[i][0] =
tensor_res[i][0].impl()) std::static_pointer_cast<primitive::experimental::StaticTensor>(
->getValue() tensor_res[i][0].impl())
.dyn_cast<ir::OpResult>(); ->getValue()
.dyn_cast<ir::OpResult>();
} }
} }
return res; return res;
......
...@@ -15,63 +15,65 @@ ...@@ -15,63 +15,65 @@
#include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/ir/dialect/pd_api.h" #include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/static_tensor.h"
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
namespace backend { namespace backend {
namespace experimental { namespace experimental {
using DescTensor = paddle::primitive::experimental::DescTensor; using StaticTensor = paddle::primitive::experimental::StaticTensor;
template <> template <>
Tensor tanh_grad<DescTensor>(const Tensor& out, const Tensor& grad_out) { Tensor tanh_grad<StaticTensor>(const Tensor& out, const Tensor& grad_out) {
ir::OpResult out_res = std::static_pointer_cast<DescTensor>(out.impl()) ir::OpResult out_res = std::static_pointer_cast<StaticTensor>(out.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult grad_out_res = ir::OpResult grad_out_res =
std::static_pointer_cast<DescTensor>(grad_out.impl()) std::static_pointer_cast<StaticTensor>(grad_out.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res); ir::OpResult op_res = paddle::dialect::tanh_grad(out_res, grad_out_res);
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res)); return Tensor(
std::make_shared<primitive::experimental::StaticTensor>(op_res));
} }
template <> template <>
Tensor mean_grad<DescTensor>(const Tensor& x, Tensor mean_grad<StaticTensor>(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
const IntArray& axis, const IntArray& axis,
bool keepdim, bool keepdim,
bool reduce_all) { bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl()) ir::OpResult x_res = std::static_pointer_cast<StaticTensor>(x.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res = ir::OpResult out_grad_res =
std::static_pointer_cast<DescTensor>(out_grad.impl()) std::static_pointer_cast<StaticTensor>(out_grad.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::mean_grad( ir::OpResult op_res = paddle::dialect::mean_grad(
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res)); return Tensor(
std::make_shared<primitive::experimental::StaticTensor>(op_res));
} }
template <> template <>
std::tuple<Tensor, Tensor> add_grad<DescTensor>(const Tensor& x, std::tuple<Tensor, Tensor> add_grad<StaticTensor>(const Tensor& x,
const Tensor& y, const Tensor& y,
const Tensor& out_grad, const Tensor& out_grad,
int axis) { int axis) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl()) ir::OpResult x_res = std::static_pointer_cast<StaticTensor>(x.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<DescTensor>(y.impl()) ir::OpResult y_res = std::static_pointer_cast<StaticTensor>(y.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res = ir::OpResult out_grad_res =
std::static_pointer_cast<DescTensor>(out_grad.impl()) std::static_pointer_cast<StaticTensor>(out_grad.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
...@@ -79,9 +81,9 @@ std::tuple<Tensor, Tensor> add_grad<DescTensor>(const Tensor& x, ...@@ -79,9 +81,9 @@ std::tuple<Tensor, Tensor> add_grad<DescTensor>(const Tensor& x,
paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis); paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis);
return std::make_tuple( return std::make_tuple(
Tensor(std::make_shared<primitive::experimental::DescTensor>( Tensor(std::make_shared<primitive::experimental::StaticTensor>(
std::get<0>(op_res))), std::get<0>(op_res))),
Tensor(std::make_shared<primitive::experimental::DescTensor>( Tensor(std::make_shared<primitive::experimental::StaticTensor>(
std::get<1>(op_res)))); std::get<1>(op_res))));
} }
} // namespace experimental } // namespace experimental
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/ir/dialect/pd_api.h" #include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/static_tensor.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
// TODO(wanghao107): // TODO(wanghao107):
// op's vjp will be auto generated. // op's vjp will be auto generated.
...@@ -32,14 +32,14 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp( ...@@ -32,14 +32,14 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp(
1, std::vector<paddle::Tensor>(1)); 1, std::vector<paddle::Tensor>(1));
// get tanh_grad res. // get tanh_grad res.
Tensor op_res = Tensor op_res =
backend::experimental::tanh_grad<primitive::experimental::DescTensor>( backend::experimental::tanh_grad<primitive::experimental::StaticTensor>(
out, grad_out); out, grad_out);
// set op stop_gradient info // set op stop_gradient info
// TODO(wanghao107): Replace with more generic code. // TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops. // Support set stop_gradients for all ops.
ir::Operation* grad_op = ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::DescTensor>( std::static_pointer_cast<primitive::experimental::StaticTensor>(
op_res.impl()) op_res.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>() .dyn_cast<ir::OpResult>()
...@@ -77,14 +77,14 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp( ...@@ -77,14 +77,14 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
1, std::vector<paddle::Tensor>(1)); 1, std::vector<paddle::Tensor>(1));
// get mean_grad res. // get mean_grad res.
Tensor op_res = Tensor op_res =
backend::experimental::mean_grad<primitive::experimental::DescTensor>( backend::experimental::mean_grad<primitive::experimental::StaticTensor>(
x, out_grad, axis, keepdim, reduce_all); x, out_grad, axis, keepdim, reduce_all);
// set op stop_gradient info // set op stop_gradient info
// TODO(wanghao107): Replace with more generic code. // TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops. // Support set stop_gradients for all ops.
ir::Operation* grad_op = ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::DescTensor>( std::static_pointer_cast<primitive::experimental::StaticTensor>(
op_res.impl()) op_res.impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>() .dyn_cast<ir::OpResult>()
...@@ -121,14 +121,14 @@ std::vector<std::vector<paddle::Tensor>> add_vjp( ...@@ -121,14 +121,14 @@ std::vector<std::vector<paddle::Tensor>> add_vjp(
2, std::vector<paddle::Tensor>(1)); 2, std::vector<paddle::Tensor>(1));
// get mean_grad res. // get mean_grad res.
std::tuple<Tensor, Tensor> op_res = std::tuple<Tensor, Tensor> op_res =
backend::experimental::add_grad<primitive::experimental::DescTensor>( backend::experimental::add_grad<primitive::experimental::StaticTensor>(
x, y, out_grad, axis); x, y, out_grad, axis);
// set op stop_gradient info // set op stop_gradient info
// TODO(wanghao107): Replace with more generic code. // TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops. // Support set stop_gradients for all ops.
ir::Operation* grad_op = ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::DescTensor>( std::static_pointer_cast<primitive::experimental::StaticTensor>(
std::get<0>(op_res).impl()) std::get<0>(op_res).impl())
->getValue() ->getValue()
.dyn_cast<ir::OpResult>() .dyn_cast<ir::OpResult>()
......
...@@ -24,14 +24,14 @@ namespace paddle { ...@@ -24,14 +24,14 @@ namespace paddle {
namespace primitive { namespace primitive {
namespace experimental { namespace experimental {
class DescTensor : public phi::ExtendedTensor, class StaticTensor : public phi::ExtendedTensor,
public phi::TypeInfoTraits<phi::TensorBase, DescTensor> { public phi::TypeInfoTraits<phi::TensorBase, StaticTensor> {
public: public:
explicit DescTensor(ir::Value value) explicit StaticTensor(ir::Value value)
: value_(value), : value_(value),
dims_(value.type().dyn_cast<dialect::DenseTensorType>().dims()) {} dims_(value.type().dyn_cast<dialect::DenseTensorType>().dims()) {}
static const char* name() { return "DescTensor"; } static const char* name() { return "StaticTensor"; }
const phi::DDim& dims() const override { return dims_; } const phi::DDim& dims() const override { return dims_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册