未验证 提交 4bd5b695 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support static build function for op builder (#54197)

* add build

* add build

* refine code

* refine code

* refine code

* refine code

* refine interface

* fix bug

* fix bug

* fix bug

* refine yaml
上级 4f25604e
此差异已折叠。
...@@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { ...@@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
std::make_shared<paddle::framework::Variable>(); std::make_shared<paddle::framework::Variable>();
phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Init DenseTensor // Init DenseTensor
auto dim = parameter->type().dyn_cast<DenseTensorType>().dim(); auto dim = parameter->type().dyn_cast<DenseTensorType>().dims();
phi::DenseTensorMeta meta( phi::DenseTensorMeta meta(
TransToPhiDataType( TransToPhiDataType(
parameter->type().dyn_cast<DenseTensorType>().dtype()), parameter->type().dyn_cast<DenseTensorType>().dtype()),
phi::DDim(dim.data(), dim.size()), dim,
TransToPhiDataLayout(
parameter->type().dyn_cast<DenseTensorType>().data_layout()), parameter->type().dyn_cast<DenseTensorType>().data_layout(),
parameter->type().dyn_cast<DenseTensorType>().lod(), parameter->type().dyn_cast<DenseTensorType>().lod(),
parameter->type().dyn_cast<DenseTensorType>().offset()); parameter->type().dyn_cast<DenseTensorType>().offset());
tensor->set_meta(meta); tensor->set_meta(meta);
...@@ -67,17 +67,13 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter( ...@@ -67,17 +67,13 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
// Get Meta // Get Meta
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
DenseTensorTypeStorage::Dim dims(tensor->dims().size());
std::copy(tensor->dims().Get(),
tensor->dims().Get() + tensor->dims().size(),
dims.data());
DenseTensorTypeStorage::DataLayout data_layout =
TransToIrDataLayout(tensor->layout());
DenseTensorTypeStorage::LoD lod = tensor->lod();
size_t offset = tensor->meta().offset;
void *data = tensor->data(); void *data = tensor->data();
ir::Type dense_tensor_type = ir::Type dense_tensor_type = DenseTensorType::get(ctx,
DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset); data_type,
tensor->dims(),
tensor->layout(),
tensor->lod(),
tensor->meta().offset);
return std::make_unique<ir::Parameter>( return std::make_unique<ir::Parameter>(
data, data,
tensor->numel() * phi::SizeOf(tensor->dtype()), tensor->numel() * phi::SizeOf(tensor->dtype()),
...@@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { ...@@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>(); DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<"; os << "tensor<";
auto &dims = tensor_type.dim(); for (auto d : phi::vectorize(tensor_type.dims())) {
for (auto d : dims) {
os << d; os << d;
os << "x"; os << "x";
} }
......
...@@ -19,25 +19,22 @@ ...@@ -19,25 +19,22 @@
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>, using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>, std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>>; std::vector<paddle::dialect::OpOutputInfo>,
paddle::dialect::OpRunTimeInfo>;
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
public: public:
struct Concept { struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) explicit Concept(OpInfoTuple (*get_op_info)())
: get_op_info_(get_op_info) {} : get_op_info_(get_op_info) {}
OpInfoTuple (*get_op_info_)(ir::Operation *); OpInfoTuple (*get_op_info_)();
}; };
template <class ConcreteOp> template <class ConcreteOp>
struct Model : public Concept { struct Model : public Concept {
static OpInfoTuple GetOpInfo(ir::Operation *op) { static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); }
ConcreteOp concret_op = op->dyn_cast<ConcreteOp>();
if (concret_op == nullptr) throw("concret_op is nullptr");
return concret_op.GetOpInfo();
}
Model() : Concept(GetOpInfo) {} Model() : Concept(GetOpInfo) {}
}; };
...@@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { ...@@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
GetOpInfoInterface(ir::Operation *op, Concept *impl) GetOpInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {} : ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); }
private: private:
Concept *impl_; Concept *impl_;
......
...@@ -11,17 +11,6 @@ ...@@ -11,17 +11,6 @@
- {typename: Tensor, name: out, optional: false, intermediate: false} - {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null inplace: null
backward: null backward: null
- name: fetch - name: fetch
...@@ -37,16 +26,5 @@ ...@@ -37,16 +26,5 @@
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false} - {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null inplace: null
backward: null backward: null
...@@ -18,20 +18,13 @@ namespace paddle { ...@@ -18,20 +18,13 @@ namespace paddle {
namespace dialect { namespace dialect {
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const paddle::dialect::DenseTensorTypeStorage::Dim& DenseTensorType::dim() const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; }
const {
return storage()->dims_;
}
const paddle::dialect::DenseTensorTypeStorage::DataLayout& const phi::DataLayout& DenseTensorType::data_layout() const {
DenseTensorType::data_layout() const {
return storage()->layout_; return storage()->layout_;
} }
const paddle::dialect::DenseTensorTypeStorage::LoD& DenseTensorType::lod() const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; }
const {
return storage()->lod_;
}
const size_t& DenseTensorType::offset() const { return storage()->offset_; } const size_t& DenseTensorType::offset() const { return storage()->offset_; }
......
...@@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type { ...@@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type {
const ir::Type &dtype() const; const ir::Type &dtype() const;
const paddle::dialect::DenseTensorTypeStorage::Dim &dim() const; const phi::DDim &dims() const;
const paddle::dialect::DenseTensorTypeStorage::DataLayout &data_layout() const phi::DataLayout &data_layout() const;
const;
const paddle::dialect::DenseTensorTypeStorage::LoD &lod() const; const phi::LoD &lod() const;
const size_t &offset() const; const size_t &offset() const;
}; };
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
#include "paddle/phi/core/tensor_meta.h"
namespace std { namespace std {
/// ///
...@@ -46,46 +47,20 @@ namespace dialect { ...@@ -46,46 +47,20 @@ namespace dialect {
/// (3)define HashValue method, (4)overload operator==. /// (3)define HashValue method, (4)overload operator==.
/// ///
struct DenseTensorTypeStorage : public ir::TypeStorage { struct DenseTensorTypeStorage : public ir::TypeStorage {
/// using DataLayout = phi::DataLayout;
/// \brief It is consistent with the DataLayout defined by Phi operator using Dim = phi::DDim;
/// library. See the file for details: paddle/phi/common/layout.h.
///
enum class DataLayout : unsigned int {
UNDEFINED = 0,
NHWC,
NCHW,
NCDHW,
NDHWC,
ONEDNN,
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
// for compatible with fluid DataLayout, here need prefix `k`
kNHWC = NHWC,
kNCHW = NCHW,
kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally
kNDHWC = NDHWC,
kNCDHW = NCDHW,
};
using Dim = std::vector<int64_t>;
using LoD = std::vector<std::vector<size_t>>; using LoD = std::vector<std::vector<size_t>>;
/// ///
/// \brief Declare ParamKey according to parameter type. /// \brief Declare ParamKey according to parameter type.
/// ///
using ParamKey = std::tuple<ir::Type, Dim, DataLayout, LoD, size_t>; using ParamKey =
std::tuple<ir::Type, phi::DDim, phi::DataLayout, phi::LoD, size_t>;
DenseTensorTypeStorage(
ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset) DenseTensorTypeStorage(ir::Type dtype,
phi::DDim dims,
phi::DataLayout layout,
phi::LoD lod,
size_t offset)
: dtype_(dtype), : dtype_(dtype),
dims_(dims), dims_(dims),
layout_(layout), layout_(layout),
...@@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { ...@@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key))); ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims // hash dims
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key))); ir::hash_combine(hash_value, std::hash<phi::DDim>()(std::get<1>(key)));
// hash layout // hash layout
hash_value = ir::hash_combine( hash_value = ir::hash_combine(
hash_value, hash_value,
std::hash<std::underlying_type<DataLayout>::type>()( std::hash<std::underlying_type<phi::DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>( static_cast<std::underlying_type<phi::DataLayout>::type>(
std::get<2>(key)))); std::get<2>(key))));
// hash lod // hash lod
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key))); ir::hash_combine(hash_value, std::hash<phi::LoD>()(std::get<3>(key)));
// hash offset // hash offset
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key))); ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
...@@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { ...@@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
/// layout, lod, offset. /// layout, lod, offset.
/// ///
ir::Type dtype_; ir::Type dtype_;
Dim dims_; phi::DDim dims_;
DataLayout layout_; phi::DataLayout layout_;
LoD lod_; phi::LoD lod_;
size_t offset_; size_t offset_;
}; };
......
...@@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, ...@@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
} }
} }
inline phi::DataLayout TransToPhiDataLayout( // inline phi::DataLayout TransToPhiDataLayout(
DenseTensorTypeStorage::DataLayout data_layout) { // DenseTensorTypeStorage::DataLayout data_layout) {
switch (data_layout) { // switch (data_layout) {
case DenseTensorTypeStorage::DataLayout::NHWC: // case DenseTensorTypeStorage::DataLayout::NHWC:
return phi::DataLayout::NHWC; // return phi::DataLayout::NHWC;
case DenseTensorTypeStorage::DataLayout::NCHW: // case DenseTensorTypeStorage::DataLayout::NCHW:
return phi::DataLayout::NCHW; // return phi::DataLayout::NCHW;
case DenseTensorTypeStorage::DataLayout::NCDHW: // case DenseTensorTypeStorage::DataLayout::NCDHW:
return phi::DataLayout::NCDHW; // return phi::DataLayout::NCDHW;
case DenseTensorTypeStorage::DataLayout::NDHWC: // case DenseTensorTypeStorage::DataLayout::NDHWC:
return phi::DataLayout::NDHWC; // return phi::DataLayout::NDHWC;
case DenseTensorTypeStorage::DataLayout::ONEDNN: // case DenseTensorTypeStorage::DataLayout::ONEDNN:
return phi::DataLayout::ONEDNN; // return phi::DataLayout::ONEDNN;
case DenseTensorTypeStorage::DataLayout::SPARSE_COO: // case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
return phi::DataLayout::SPARSE_COO; // return phi::DataLayout::SPARSE_COO;
case DenseTensorTypeStorage::DataLayout::SPARSE_CSR: // case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
return phi::DataLayout::SPARSE_CSR; // return phi::DataLayout::SPARSE_CSR;
case DenseTensorTypeStorage::DataLayout::PSTRING_UNION: // case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
return phi::DataLayout::PSTRING_UNION; // return phi::DataLayout::PSTRING_UNION;
case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS: // case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
return phi::DataLayout::NUM_DATA_LAYOUTS; // return phi::DataLayout::NUM_DATA_LAYOUTS;
case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT: // case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
return phi::DataLayout::ALL_LAYOUT; // return phi::DataLayout::ALL_LAYOUT;
default: // default:
PADDLE_THROW(phi::errors::Unimplemented( // PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir data layout `%s` when casting it into " // "Unsupported ir data layout `%s` when casting it into "
"phi data type.", // "phi data type.",
static_cast<int>(data_layout))); // static_cast<int>(data_layout)));
} // }
} // }
inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( // inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
phi::DataLayout data_layout) { // phi::DataLayout data_layout) {
switch (data_layout) { // switch (data_layout) {
case phi::DataLayout::NHWC: // case phi::DataLayout::NHWC:
return DenseTensorTypeStorage::DataLayout::NHWC; // return DenseTensorTypeStorage::DataLayout::NHWC;
case phi::DataLayout::NCHW: // case phi::DataLayout::NCHW:
return DenseTensorTypeStorage::DataLayout::NCHW; // return DenseTensorTypeStorage::DataLayout::NCHW;
case phi::DataLayout::NCDHW: // case phi::DataLayout::NCDHW:
return DenseTensorTypeStorage::DataLayout::NCDHW; // return DenseTensorTypeStorage::DataLayout::NCDHW;
case phi::DataLayout::NDHWC: // case phi::DataLayout::NDHWC:
return DenseTensorTypeStorage::DataLayout::NDHWC; // return DenseTensorTypeStorage::DataLayout::NDHWC;
case phi::DataLayout::ONEDNN: // case phi::DataLayout::ONEDNN:
return DenseTensorTypeStorage::DataLayout::ONEDNN; // return DenseTensorTypeStorage::DataLayout::ONEDNN;
case phi::DataLayout::SPARSE_COO: // case phi::DataLayout::SPARSE_COO:
return DenseTensorTypeStorage::DataLayout::SPARSE_COO; // return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
case phi::DataLayout::SPARSE_CSR: // case phi::DataLayout::SPARSE_CSR:
return DenseTensorTypeStorage::DataLayout::SPARSE_CSR; // return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
case phi::DataLayout::PSTRING_UNION: // case phi::DataLayout::PSTRING_UNION:
return DenseTensorTypeStorage::DataLayout::PSTRING_UNION; // return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
case phi::DataLayout::NUM_DATA_LAYOUTS: // case phi::DataLayout::NUM_DATA_LAYOUTS:
return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS; // return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
case phi::DataLayout::ALL_LAYOUT: // case phi::DataLayout::ALL_LAYOUT:
return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT; // return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
default: // default:
PADDLE_THROW(phi::errors::Unimplemented( // PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data layout `%s` when casting it into " // "Unsupported phi data layout `%s` when casting it into "
"ir data type.", // "ir data type.",
static_cast<int>(data_layout))); // static_cast<int>(data_layout)));
} // }
} // }
// inline phi::DenseTensorMeta TransToDenseTensorMeta(
// paddle::dialect::DenseTensorType type) {
// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()),
// type.dim(),
// type.data_layout(),
// type.lod(),
// type.offset());
// }
struct OpInputInfo { struct OpInputInfo {
std::string name; std::string name;
...@@ -172,5 +181,20 @@ struct OpAttributeInfo { ...@@ -172,5 +181,20 @@ struct OpAttributeInfo {
: name(name), type_name(type_name), data_type(data_type) {} : name(name), type_name(type_name), data_type(data_type) {}
}; };
struct OpRunTimeInfo {
std::string infer_meta_func;
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param) {}
};
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() { ...@@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() {
ir::Type dtype = ir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc); this->operator[](var_desc.GetDataType())(ctx, var_desc);
DenseTensorTypeStorage::Dim dim = var_desc.GetShape(); DenseTensorTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout = DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED; DenseTensorTypeStorage::DataLayout::UNDEFINED;
DenseTensorTypeStorage::LoD lod = {}; DenseTensorTypeStorage::LoD lod = {};
......
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ #define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \
struct concrete_storage : public ir::AttributeStorage { \ struct concrete_storage : public ir::AttributeStorage { \
using ParamKey = bool; \ using ParamKey = base_type; \
\ \
explicit concrete_storage(const ParamKey &key) { data_ = key; } \ explicit concrete_storage(const ParamKey &key) { data_ = key; } \
\ \
......
...@@ -221,7 +221,7 @@ ...@@ -221,7 +221,7 @@
- backward_op : broadcast_tensors_grad - backward_op : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] input) -> Tensor[](out) forward : broadcast_tensors (Tensor[] input) -> Tensor[](out)
args : (Tensor[] input, Tensor[] out_grad) args : (Tensor[] input, Tensor[] out_grad)
output : Tensor[](input_grad) output : Tensor[](input_grad){input.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
param : [input] param : [input]
......
...@@ -235,7 +235,7 @@ ...@@ -235,7 +235,7 @@
- backward_op : einsum_grad - backward_op : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()} output : Tensor[](x_grad){x_shape.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
param : [x_shape] param : [x_shape]
......
...@@ -107,10 +107,9 @@ TEST(program_test, program) { ...@@ -107,10 +107,9 @@ TEST(program_test, program) {
a_interface->ParameterToVariable(program.GetParameter("a")); a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>(); const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4); EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), phi::DDim(dims.data(), dims.size())); EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(), EXPECT_EQ(a_tensor.layout(), data_layout);
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(a_tensor.lod(), lod); EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset); EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) { for (int64_t i = 0; i < a_tensor.numel(); i++) {
...@@ -137,10 +136,9 @@ TEST(program_test, program) { ...@@ -137,10 +136,9 @@ TEST(program_test, program) {
b_interface->ParameterToVariable(program.GetParameter("b")); b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>(); const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4); EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), phi::DDim(dims.data(), dims.size())); EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(), EXPECT_EQ(b_tensor.layout(), data_layout);
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(b_tensor.lod(), lod); EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset); EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) { for (int64_t i = 0; i < b_tensor.numel(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册