未验证 提交 fc66b5d7 编写于 作者: H hong 提交者: GitHub

Support selected rows new ir (#54987)

* refine program translator

* fix warning: not override

* fix bug

* merge new modifications

* modify by reviews

* resolve conflicts

* resolve conflicts

* fix

* fix

* update

* support selected rows

* update

* add selectrows

* fix bug

* add ut

* refine code

* refien code

* update

* update

* support selected rows

* support selected rows

* support dense tensor

* remove useless code

* polish code

* remote standalone executor test

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 cfa513f7
......@@ -41,6 +41,7 @@ PaddleKernelDialect::PaddleKernelDialect(ir::IrContext *context)
void PaddleKernelDialect::initialize() {
RegisterTypes<paddle::dialect::AllocatedDenseTensorType>();
RegisterTypes<paddle::dialect::AllocatedSelectedRowsType>();
RegisterOps<dialect::PhiKernelOp>();
RegisterAttributes<paddle::dialect::KernelAttribute>();
......
......@@ -41,7 +41,32 @@ const size_t& AllocatedDenseTensorType::offset() const {
return storage()->dense_tensor_type_.offset();
}
const phi::Place& AllocatedSelectedRowsType::place() const {
return storage()->place_;
}
const ir::Type& AllocatedSelectedRowsType::dtype() const {
return storage()->selected_rows_type_.dtype();
}
const phi::DDim& AllocatedSelectedRowsType::dims() const {
return storage()->selected_rows_type_.dims();
}
const phi::DataLayout& AllocatedSelectedRowsType::data_layout() const {
return storage()->selected_rows_type_.data_layout();
}
const phi::LoD& AllocatedSelectedRowsType::lod() const {
return storage()->selected_rows_type_.lod();
}
const size_t& AllocatedSelectedRowsType::offset() const {
return storage()->selected_rows_type_.offset();
}
} // namespace dialect
} // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AllocatedDenseTensorType)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AllocatedSelectedRowsType)
......@@ -20,9 +20,7 @@
namespace paddle {
namespace dialect {
///
/// \brief Define built-in parametric types.
///
class AllocatedDenseTensorType : public ir::Type {
public:
using Type::Type;
......@@ -64,7 +62,49 @@ class AllocatedDenseTensorType : public ir::Type {
const size_t &offset() const;
};
class AllocatedSelectedRowsType : public ir::Type {
public:
using Type::Type;
DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedSelectedRowsType,
AllocatedSelectedRowsTypeStorage);
static AllocatedSelectedRowsType get(ir::IrContext *ctx,
const phi::Place &place,
dialect::SelectedRowsType type) {
return ir::TypeManager::template get<AllocatedSelectedRowsType>(
ctx, place, type);
}
static AllocatedSelectedRowsType get(ir::IrContext *ctx,
const phi::Place &place,
const ir::Type &dtype,
const phi::DDim &dims,
const phi::DataLayout &layout,
const phi::LoD &lod,
size_t offset) {
dialect::SelectedRowsType dense_tensor_type =
dialect::SelectedRowsType::get(ctx, dtype, dims, layout, lod, offset);
return ir::TypeManager::template get<AllocatedSelectedRowsType>(
ctx, place, dense_tensor_type);
}
const phi::Place &place() const;
const ir::Type &dtype() const;
const phi::DDim &dims() const;
const phi::DataLayout &data_layout() const;
const phi::LoD &lod() const;
const size_t &offset() const;
};
} // namespace dialect
} // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AllocatedDenseTensorType)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AllocatedSelectedRowsType)
......@@ -88,5 +88,67 @@ struct AllocatedDenseTensorTypeStorage : public ir::TypeStorage {
dialect::DenseTensorType dense_tensor_type_;
};
///
/// \brief Define Parametric TypeStorage for AllocatedSelectedRowsTypeStorage.
///
///
struct AllocatedSelectedRowsTypeStorage : public ir::TypeStorage {
using Place = phi::Place;
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<phi::Place, dialect::SelectedRowsType>;
AllocatedSelectedRowsTypeStorage(const phi::Place& place,
const dialect::SelectedRowsType& type)
: place_(place), selected_rows_type_(type) {}
///
/// \brief Each derived TypeStorage must define a Construct method, which
/// StorageManager uses to construct a derived TypeStorage.
///
static AllocatedSelectedRowsTypeStorage* Construct(const ParamKey& key) {
return new AllocatedSelectedRowsTypeStorage(std::get<0>(key),
std::get<1>(key));
}
///
/// \brief Each derived TypeStorage must provide a HashValue method.
///
static std::size_t HashValue(const ParamKey& key) {
std::size_t hash_value = 791;
// hash place
hash_value = ir::hash_combine(hash_value, std::get<0>(key).HashValue());
// hash dtype
auto selected_rows_type = std::get<1>(key);
hash_value = ir::hash_combine(hash_value,
dialect::DenseTensorTypeStorage::HashValue(
dialect::DenseTensorTypeStorage::ParamKey(
selected_rows_type.dtype(),
selected_rows_type.dims(),
selected_rows_type.data_layout(),
selected_rows_type.lod(),
selected_rows_type.offset())));
return hash_value;
}
///
/// \brief Each derived TypeStorage needs to overload operator==.
///
bool operator==(const ParamKey& key) const {
return ParamKey(place_, selected_rows_type_) == key;
}
ParamKey GetAsKey() const { return ParamKey(place_, selected_rows_type_); }
///
/// \brief AllocatedSelectedRowsTypeStorage include five parameters: place,
/// SelectedRowsType
///
phi::Place place_;
dialect::SelectedRowsType selected_rows_type_;
};
} // namespace dialect
} // namespace paddle
......@@ -111,15 +111,23 @@ void PaddleDialect::initialize() {
}
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<";
for (auto d : phi::vectorize(tensor_type.dims())) {
os << d;
os << "x";
if (auto tensor_type = type.dyn_cast<DenseTensorType>()) {
os << "tensor<";
for (auto d : phi::vectorize(tensor_type.dims())) {
os << d;
os << "x";
}
tensor_type.dtype().Print(os);
os << ">";
} else if (auto selected_rows_type = type.dyn_cast<SelectedRowsType>()) {
os << "selectedrows<";
for (auto d : phi::vectorize(selected_rows_type.dims())) {
os << d;
os << "x";
}
selected_rows_type.dtype().Print(os);
os << ">";
}
tensor_type.dtype().Print(os);
os << ">";
}
void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
......
......@@ -187,10 +187,10 @@
no_need_buffer: null
data_transform: null
infer_meta:
func: UnchangedInferMeta
func: EmbeddingGradSparseInferMeta
param: [weight]
kernel:
func: [embedding_grad_sparse]
func: [embedding_sparse_grad]
param: [x, weight, out_grad, padding_idx, sparse]
backend: null
layout: null
......@@ -198,7 +198,7 @@
ordered: false
candidates: [weight]
to_complex_flag: [false]
dispatch: {embedding_grad_sparse: null}
dispatch: {embedding_sparse_grad: null}
force_backend: null
inplace: null
view: null
......
......@@ -104,6 +104,8 @@ void BuildValue(ir::Value value,
var->GetMutable<phi::DenseTensor>();
} else if (value.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
var->GetMutable<phi::DenseTensor>();
} else if (value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
var->GetMutable<phi::SelectedRows>();
} else if (value.type().isa<ir::VectorType>()) {
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < value.type().dyn_cast<ir::VectorType>().size();
......
......@@ -258,6 +258,23 @@ void BuildPhiContext(
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<ir::Int64Attribute>(),
true,
phi::errors::PreconditionNotMet(
"Element in array list MUST be ir::Int64Attribute "));
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(
array_list[i].dyn_cast<ir::Int64Attribute>().data());
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
......@@ -300,13 +317,19 @@ void BuildPhiContext(
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
if (out_ptr.type()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else {
auto out_type = out_ptr.type();
if (!out_type) {
phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr);
ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(scope->Var(name)->Get<phi::SelectedRows>()))));
} else {
PADDLE_THROW("not support type");
}
if (output_map != nullptr) {
......
......@@ -252,6 +252,13 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
op_output_types.push_back(t1);
} else if (result_type.isa<dialect::SelectedRowsType>()) {
auto allocated_selected_rows_dtype =
paddle::dialect::AllocatedSelectedRowsType::get(
ctx,
phi::TransToPhiPlace(kernel_key.backend()),
result_type.dyn_cast<dialect::SelectedRowsType>());
op_output_types.push_back(allocated_selected_rows_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Result type only support DenseTensorType and VectorType"));
......@@ -322,6 +329,8 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
}
} else if (new_in_type.isa<ir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
} else if (new_in_type.isa<dialect::AllocatedSelectedRowsType>()) {
// do nothing here
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"only support allocated dense tensor type for now"));
......
......@@ -894,6 +894,41 @@ struct RnnOpTranscriber : public OpTranscriber {
};
};
struct EmbeddingGradOpTranscriber : public OpTranscriber {
void HandleNonexistentAttribute(ir::IrContext* ctx,
ir::AttributeMap* attribute_map,
const OpAttributeInfo& info) override {
if (info.name == "padding_idx") {
(*attribute_map)[info.name] = ir::Int64Attribute::get(ctx, -1);
} else if (info.name == "sparse") {
(*attribute_map)[info.name] = ir::BoolAttribute::get(ctx, false);
}
}
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
std::string target_op_name =
kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type());
bool is_sparse = paddle::get<bool>(op_desc.GetAttr("is_sparse"));
if (is_sparse) {
target_op_name = "pd.embedding_grad_sparse";
} else {
target_op_name = "pd.embedding_grad_dense";
}
VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
<< target_op_name;
auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW("Op %d should have corresponding OpInfo %d",
op_desc.Type(),
target_op_name);
}
return op_info;
}
};
struct FeedOpTranscriber : public OpTranscriber {
ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
......@@ -960,6 +995,7 @@ OpTranslator::OpTranslator() {
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber();
......
......@@ -59,7 +59,6 @@ void ProgramTranslator::Translate() {
platform::errors::PreconditionNotMet(
"Not support multi block ProgramDesc translated, now has %d blocks",
legacy_program_->Size()));
for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
GetParameterForSingleBlock(block);
......
......@@ -32,7 +32,6 @@ std::unique_ptr<Program> TranslateLegacyProgramToProgram(
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<dialect::PaddleDialect>();
auto program = std::make_unique<Program>(ctx);
translator::ProgramTranslator program_translator(&legacy_program,
program.get());
program_translator.Translate();
......
......@@ -28,6 +28,8 @@ using VarDesc = paddle::framework::VarDesc;
using VarType = paddle::framework::proto::VarType;
using DenseTensorType = paddle::dialect::DenseTensorType;
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
using SelectedRowsType = paddle::dialect::SelectedRowsType;
using SelectedRowsTypeStorage = paddle::dialect::SelectedRowsTypeStorage;
TypeTranslator::TypeTranslator() {
handlers = {
......@@ -105,7 +107,17 @@ TypeTranslator::TypeTranslator() {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";
return this->operator[](VarType::LOD_TENSOR)(ctx, var_desc);
ir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
SelectedRowsTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape());
SelectedRowsTypeStorage::DataLayout layout =
SelectedRowsTypeStorage::DataLayout::UNDEFINED;
SelectedRowsTypeStorage::LoD lod = {};
size_t offset = 0;
ir::Type SelectedRows =
SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset);
return SelectedRows;
}},
};
}
......
......@@ -253,6 +253,16 @@
data_type : weight
backward : embedding_grad
- op : embedding_grad_dense
args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false)
output : Tensor(weight_grad)
infer_meta :
func : UnchangedInferMeta
param : [weight]
kernel :
func : embedding_grad
data_type : weight
- op : empty
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
......
......@@ -882,6 +882,8 @@
{x : Ids, weight : W}
outputs :
out : Out
attrs :
sparse : is_sparse
manual_signature : [embedding_grad]
extra :
attrs : [bool is_sparse = false, bool is_distributed = false, bool remote_prefetch = false,
......
......@@ -884,6 +884,11 @@ void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
out_v->set_dtype(out_dtype);
}
void EmbeddingGradSparseInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
}
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
......
......@@ -149,6 +149,8 @@ void DistBroadcastInferMeta(const MetaTensor& x, MetaTensor* out);
void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out);
void EmbeddingGradSparseInferMeta(const MetaTensor& x, MetaTensor* out);
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v);
void EighInferMeta(const MetaTensor& x,
......
......@@ -227,14 +227,6 @@ if(TARGET standalone_executor_test)
endif()
endif()
if(TARGET standalone_executor_new_ir_test)
if(NOT WIN32)
set_tests_properties(
standalone_executor_new_ir_test
PROPERTIES ENVIRONMENT "FLAGS_enable_new_ir_in_executor=true")
endif()
endif()
if(TARGET layer_test)
add_dependencies(layer_test jit_download_program)
add_dependencies(layer_test_new jit_download_program)
......
if(NOT WIN32)
cc_test(
standalone_executor_new_ir_test
SRCS standalone_executor_new_ir_test.cc
DEPS phi_kernel_adaptor pd_dialect ir)
endif()
# skip win32 since wget is not installed by default on windows machine.
set(OPS
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include <gtest/gtest.h>
#include <chrono>
#include <iostream>
#include <string>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/platform/init_phi.h"
DECLARE_FILE_SYMBOLS(kernel_dialect);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
namespace paddle {
namespace framework {
TEST(StandaloneExecutor, run) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp op1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
kernel_program->Print(std::cout);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_2")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
TEST(StandaloneExecutor, run_2) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder(ctx, program.block());
ir::Block* block = program.block();
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
paddle::dialect::ScaleOp scale =
builder.Build<paddle::dialect::ScaleOp>(add->result(0), 1.0, 0.0, true);
EXPECT_EQ(scale->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_10")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_10")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 1.56764);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 1.85063);
std::cerr << out_tensor.data<float>()[0] << "\t"
<< out_tensor.data<float>()[1] << "\t"
<< out_tensor.data<float>()[2] << "\t"
<< out_tensor.data<float>()[3] << std::endl;
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
#ifdef PADDLE_WITH_CUDA
TEST(StandaloneExecutor, data_transfer) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder(ctx, program.block());
ir::Block* block = program.block();
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{1},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{100, 100},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->result(0), uniform2->result(0));
EXPECT_EQ(add->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
program.Print(std::cout);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
kernel_program->Print(std::cout);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_9")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_9")
->Get<phi::DenseTensor>();
auto& pool = phi::DeviceContextPool::Instance();
phi::DenseTensor out;
phi::DeviceContext* dev_ctx = pool.Get(out_tensor.place());
phi::Copy(*dev_ctx, out_tensor, place, true, &out);
bool res0 = simple_cmp(out.data<float>()[0], 0.903649);
bool res1 = simple_cmp(out.data<float>()[1], 1.07367);
bool res2 = simple_cmp(out.data<float>()[2], 1.10631);
bool res3 = simple_cmp(out.data<float>()[3], 1.68683);
std::cerr << out.data<float>()[0] << "\t" << out.data<float>()[1] << "\t"
<< out.data<float>()[2] << "\t" << out.data<float>()[3]
<< std::endl;
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
#endif
TEST(StandaloneExecutor, run_inplace_sqrt) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::Sqrt_Op>(full->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
kernel_program->Print(std::cout);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.Run({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_0")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(scope.kids().size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 1u);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework
} // namespace paddle
......@@ -89,6 +89,28 @@ class TestFeedOp(unittest.TestCase):
np.testing.assert_array_equal(out[0], gold_res)
class TestSelectedRows(unittest.TestCase):
def test_with_new_ir(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
main_program = paddle.static.Program()
new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
w = paddle.uniform([10, 10], dtype="float32")
w.stop_gradient = False
id = paddle.ones([2], dtype="int32")
t = paddle.nn.functional.embedding(id, w, sparse=True)
loss = paddle.mean(t)
paddle.static.gradients(loss, w)
out = exe.run(
main_program,
fetch_list=[loss.name],
)
class TestAddGradOp(unittest.TestCase):
def test_with_new_ir(self):
place = paddle.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册