未验证 提交 fb9bec5d 编写于 作者: H hong 提交者: GitHub

[NewIR]new ir dygraph to static supoort gpu (#55620)

* add kernel dialect

* change DenseTensorTypeStorage to DenseTensorType

* add test case`

* add first pd_op to kernel dialect

* lower pd op to kernel dialect

* update

* update

* remove useless code

* add attrite print test

* fix bug

* update

* update

* update

* update

* polish code

* fix bug

* polish  code  and add python test

* add test

* fix test error

* relax constraint when inserting get_parameter

* add env flag

* fix bug

* dygraph2static support new ir

* fix bug

* revert test env

* change cc_test_old to cc_test

* update

* fix build_static bug

* update test

* fix type test error

* udpate cmake

* disable test in windows

* fix inference compile

* fix program translator error

* only run on cpu, not support gpu yet

* fix conflict

* polish code

* fix bug

* add feed with place op

* update

* remove useless unitest

* udpate mkldnn

* update

* update

* align mkldnn version

* new ir support builtin slice op

* fix bug

* fix phi kernel adaptor bug

* add enable static

* add enable_static

* remove useless test case

* change feed list to single variable

* update

* add feed with place and shaddow output op

* fix bug

* remove usless code

* support gpu

* fix bug

* fix bug

* remove template

* add more data type

* fix cimpile bug

* udpate

* remove useless code

* revert dygraph2st test

* remove usless code

* revert op

* fix bug

* new ir dygraph2static support gpu

* remove usless code

* code polish

* add const

* revert code and remove useless code

* revert code

* revert legacy op yaml

* remove useless code

* delete std::move

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 05720257
......@@ -19,12 +19,16 @@
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
PHI_DECLARE_bool(enable_new_ir_in_executor);
namespace details {
using Tensor = paddle::Tensor;
......@@ -367,16 +371,32 @@ inline void RunProgramAPI(
details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope);
// Step 2. create new interpretercore
interpreter_core =
paddle::framework::CreateInterpreterCoreInfoToCache(*forward_program,
place,
/*is_grad=*/false,
program_id,
global_inner_scope);
if (FLAGS_enable_new_ir_in_executor) {
// build new ir program
auto ir_program = paddle::framework::ConstructFowardIrProgram(
forward_global_block, backward_global_block, output_names, x);
interpreter_core =
paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
std::move(ir_program),
place,
/*is_grad=*/false,
program_id,
global_inner_scope);
} else {
interpreter_core =
paddle::framework::CreateProgramInterpreterCoreInfoToCache(
*forward_program,
place,
/*is_grad=*/false,
program_id,
global_inner_scope);
}
// Step 3. get all eager gc vars
std::set<std::string> skip_eager_delete_vars =
paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet(
*backward_program);
// all out_vars are skip_eager_var
skip_eager_delete_vars.insert(output_names.begin(), output_names.end());
skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end());
......@@ -504,12 +524,27 @@ inline void RunProgramGradAPI(
1);
VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
details::ShareTensorsIntoScope(out_grad, global_inner_scope);
interpreter_core =
paddle::framework::CreateInterpreterCoreInfoToCache(*backward_program,
place,
/*is_grad=*/true,
program_id,
global_inner_scope);
if (FLAGS_enable_new_ir_in_executor) {
auto res = paddle::framework::ConstructBackwardIrProgram(
backward_global_block, out_grad, x_grad, params_grad);
interpreter_core =
paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
std::move(res),
place,
/*is_grad=*/true,
program_id,
global_inner_scope);
} else {
interpreter_core =
paddle::framework::CreateProgramInterpreterCoreInfoToCache(
*backward_program,
place,
/*is_grad=*/true,
program_id,
global_inner_scope);
}
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
......
......@@ -1033,7 +1033,8 @@ cc_library(
cc_library(
executor_cache
SRCS executor_cache.cc
DEPS parallel_executor standalone_executor)
DEPS parallel_executor standalone_executor phi_kernel_adaptor
pd_op_to_kernel_pass ir)
if(WITH_PSCORE)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
if(WITH_HETERPS)
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
......@@ -288,7 +290,7 @@ InterpreterCoreInfoCache &InterpreterCoreInfoCache::Instance() {
return g_info_cache;
}
std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
const ProgramDesc &program_desc,
const platform::Place &place,
bool is_grad,
......@@ -304,13 +306,172 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
auto core = std::make_shared<InterpreterCore>(
place, program_desc.Block(0), scope, execution_config);
std::shared_ptr<InterpreterCore> core = nullptr;
core.reset(new InterpreterCore(
place, program_desc.Block(0), scope, execution_config));
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
cached_value.core_ = core;
return core;
}
std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
std::unique_ptr<::ir::Program> ir_program,
const platform::Place &place,
bool is_grad,
int64_t program_id,
framework::Scope *scope) {
auto &interpretercore_info_cache =
framework::InterpreterCoreInfoCache::Instance();
if (interpretercore_info_cache.Size() > 10u /* max_cached_size*/) {
VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear "
"all cache!";
interpretercore_info_cache.Finalize();
}
interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false;
execution_config.used_for_jit = true;
std::shared_ptr<InterpreterCore> core = nullptr;
core.reset(new InterpreterCore(
place, std::move(ir_program), scope, execution_config));
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
cached_value.core_ = core;
return core;
}
std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
const paddle::framework::BlockDesc *forward_global_block,
const paddle::framework::BlockDesc *backward_global_block,
const std::vector<std::string> output_names,
const std::vector<paddle::Tensor> &x) {
auto ir_ctx = ::ir::IrContext::Instance();
auto program = std::make_unique<::ir::Program>(ir_ctx);
std::set<std::string> set_output_names;
auto local_program =
paddle::framework::ProgramDesc(*(forward_global_block->Program()));
for (auto op_desc : local_program.Block(0).AllOps()) {
for (const auto &n : op_desc->Outputs()) {
const auto &input_var_names = n.second;
for (const auto &var_name : input_var_names) {
set_output_names.insert(var_name);
}
}
}
// add fetch with place op to program
for (auto &in_t : x) {
auto name = in_t.name();
auto place = in_t.place().GetType();
auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", name);
op_desc->SetOutput("out", {name});
}
std::set<std::string> set_parameter_names;
for (auto op_desc : backward_global_block->Program()->Block(0).AllOps()) {
for (const auto &n : op_desc->Inputs()) {
const auto &input_var_names = n.second;
for (const auto &var_name : input_var_names) {
set_parameter_names.insert(var_name);
}
}
}
for (auto &t : output_names) {
set_parameter_names.insert(t);
}
for (auto &name : set_parameter_names) {
if (!set_output_names.count(name)) {
continue;
}
auto op_desc = local_program.MutableBlock(0)->AppendOp();
op_desc->SetType("shaddow_output");
op_desc->SetAttr("name", name);
op_desc->SetInput("x", {name});
op_desc->SetOutput("out", {"@EMPTY@"});
}
paddle::translator::ProgramTranslator program_translator(&local_program,
program.get());
program_translator.Translate();
auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get());
return ir_res;
}
std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
const paddle::framework::BlockDesc *backward_global_block,
const std::vector<paddle::Tensor> &out_grad,
const std::vector<paddle::Tensor *> &x_grad,
const std::vector<paddle::Tensor *> &params_grad) {
auto ir_ctx = ::ir::IrContext::Instance();
auto program = std::make_unique<::ir::Program>(ir_ctx);
auto local_program =
paddle::framework::ProgramDesc(*(backward_global_block->Program()));
// add feed kernel
for (auto &out_grad_t : out_grad) {
auto name = out_grad_t.name();
auto place = out_grad_t.place().GetType();
if (name == "@EMPTY@") {
continue;
}
auto op_desc = local_program.MutableBlock(0)->PrependOp();
op_desc->SetType("feed_with_place");
op_desc->SetAttr("index", 0);
// TODO(phlrain) : using tensor dtype
op_desc->SetAttr("dtype", 0);
op_desc->SetAttr("place", static_cast<int>(place));
op_desc->SetAttr("name", name);
op_desc->SetOutput("out", {name});
}
std::vector<std::string> param_grad_names;
for (auto &p_g : params_grad) {
param_grad_names.push_back(p_g->name());
}
for (auto &t : x_grad) {
param_grad_names.push_back(t->name());
}
for (auto &name : param_grad_names) {
if (name == "@EMPTY@") {
continue;
}
auto op_desc = local_program.MutableBlock(0)->AppendOp();
op_desc->SetType("shaddow_output");
op_desc->SetAttr("name", name);
op_desc->SetInput("x", {name});
op_desc->SetOutput("out", {"@EMPTY@"});
}
paddle::translator::ProgramTranslator program_translator(&local_program,
program.get());
program_translator.Translate();
auto res = paddle::dialect::PdOpLowerToKernelPass(program.get());
return res;
}
} // namespace framework
} // namespace paddle
......@@ -29,6 +29,11 @@
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
namespace paddle {
namespace framework {
namespace ir {
......@@ -218,12 +223,31 @@ class InterpreterCoreInfoCache {
std::unordered_map<int64_t, InterpreterCoreInfo> info_map_;
};
std::shared_ptr<InterpreterCore> CreateInterpreterCoreInfoToCache(
std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
const ProgramDesc& program_desc,
const platform::Place& place,
bool is_grad,
int64_t program_id,
framework::Scope* scope);
std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
std::unique_ptr<::ir::Program> ir_prog,
const platform::Place& place,
bool is_grad,
int64_t program_id,
framework::Scope* scope);
std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
const paddle::framework::BlockDesc* forward_global_block,
const paddle::framework::BlockDesc* backward_global_block,
const std::vector<std::string> output_names,
const std::vector<paddle::Tensor>& x);
std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
const paddle::framework::BlockDesc* backward_global_block,
const std::vector<paddle::Tensor>& out_grad,
const std::vector<paddle::Tensor*>& x_grad,
const std::vector<paddle::Tensor*>& params_grad);
} // namespace framework
} // namespace paddle
......@@ -958,7 +958,8 @@ void BuildOpFuncList(
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") {
VLOG(6) << "skip process " << op_name;
continue;
}
......
......@@ -984,7 +984,7 @@ std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) {
do { \
if (paddle::framework::TransToProtoVarType(tensor.dtype()) == \
proto_type) { \
os << " - dtype: " << proto_type << "\n"; \
os << " - dtype: " << tensor.dtype() << "\n"; \
paddle::framework::print_tensor<cpp_type>(os, tensor); \
return os; \
} \
......
......@@ -66,8 +66,10 @@ paddle::framework::Variable* CreateVar(
}
paddle::framework::Variable* var = nullptr;
std::string name = var_name_prefix + "_inner_var_" +
std::to_string(variable_2_var_name->size());
if (force_persisable || is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root();
var = const_cast<paddle::framework::Scope*>(inner_scope->root())->Var(name);
......@@ -202,6 +204,15 @@ void HandleForSpecialOp(
value_2_var_name->emplace(value, feed_var_name);
}
if (op_name == "pd.feed_with_place") {
VLOG(6) << "Handle for pd.feed_with_place";
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->result(0);
value_2_var_name->emplace(value, var_name);
}
if (op_name == "builtin.combine") {
auto out_value = op->result(0);
......@@ -252,6 +263,22 @@ void HandleForSpecialOp(
(*value_2_var_name)[value] = param_name;
}
if (op_name == "pd.shaddow_output") {
VLOG(6) << "Handle for pd.shaddow_ouptut";
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->operand(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);
if (inner_scope->root()->FindVar(var_name) == nullptr) {
const_cast<paddle::framework::Scope*>(inner_scope->root())
->Rename(orig_name, var_name);
}
(*value_2_var_name)[value] = var_name;
}
if (op_name == "builtin.get_parameter") {
VLOG(6) << "Handle for builtin.get_parameter:";
auto param_name = op->attributes()
......@@ -362,7 +389,8 @@ void BuildScope(const ir::Block& block,
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") {
HandleForSpecialOp(op,
inner_scope,
var_name_prefix,
......
......@@ -62,6 +62,20 @@ phi::KernelKey GetKernelKey(
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}
if (op->name() == "pd.feed_with_place") {
// NOTE, for now feed op don't need a kernel, so the data type from Op
// Result the next op use base program datatype
auto t =
op->attributes().at("place").dyn_cast<dialect::PlaceAttribute>().data();
auto backend = paddle::experimental::ParseBackend(t);
return {backend,
phi::DataLayout::ANY,
TransToPhiDataType(
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
}
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
......
......@@ -954,6 +954,39 @@ struct FeedOpTranscriber : public OpTranscriber {
}
};
struct FeedWithPlaceOpTranscriber : public OpTranscriber {
ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
int allocate_type = paddle::get<int>(op_desc.GetAttr("place"));
ir::AttributeMap attribute_map = {
{"name",
ir::StrAttribute::get(ctx,
op_desc.GetAttrIfExists<std::string>("name"))},
{"index", ir::Int64Attribute::get(ctx, 0)},
{"dtype",
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)},
{"place",
paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(static_cast<phi::AllocationType>(allocate_type)))},
};
return attribute_map;
}
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) override {
return {};
}
};
struct SplitOpTranscriber : public OpTranscriber {
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
......@@ -1087,6 +1120,32 @@ struct FetchOpTranscriber : public OpTranscriber {
}
};
struct ShaddowOutputOpTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
std::vector<ir::OpResult> op_inputs;
auto legacy_input_vars = op_desc.Input("x", true);
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);
ir::AttributeMap attribute_map = {
{"parameter_name",
ir::StrAttribute::get(ctx,
op_desc.GetAttrIfExists<std::string>("name"))},
};
auto create_op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name());
ir::Operation* operation =
ir::Operation::Create(op_inputs, attribute_map, {}, create_op_info);
program->block()->push_back(operation);
return operation;
}
};
// NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now
struct AddNOpTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
......@@ -1159,6 +1218,7 @@ struct OneHotTranscriber : public OpTranscriber {
OpTranslator::OpTranslator() {
general_handler = OpTranscriber();
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
......@@ -1167,8 +1227,10 @@ OpTranslator::OpTranslator() {
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
}
} // namespace translator
......
......@@ -217,7 +217,15 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
continue;
}
ir::OpResult value = value_info.value;
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
}
auto* defining_op = value.owner();
PADDLE_ENFORCE_NOT_NULL(
defining_op,
phi::errors::PreconditionNotMet(
"Defining operator of [%s] can not be nullptr", var_name));
VLOG(8) << "[op translated][stop gradient]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> stop_gradients;
......
......@@ -1029,6 +1029,9 @@
- op : feed
outputs: {out: Out}
- op : feed_with_place
outputs: {out: out}
- op : fft_c2c
inputs: {x: X}
outputs: {out: Out}
......@@ -2461,6 +2464,10 @@
extra :
attrs : [bool use_mkldnn=false]
- op : shaddow_output
inputs: {x: x}
outputs: {out: out}
- op : shape
inputs :
input : Input
......
......@@ -826,6 +826,18 @@
inplace: (x -> out)
backward : expm1_grad
- op : feed_with_place
args : (int64_t index, DataType dtype, str name, Place place)
output : Tensor(out)
infer_meta :
func : FeedWithPlaceInferMeta
param : [index, dtype]
kernel:
func : feed_with_place
param : [index, dtype]
data_type : dtype
backend : place
- op : fft_c2c
args : (Tensor x, int64_t[] axes, str normalization, bool forward)
output : Tensor
......@@ -2212,6 +2224,16 @@
optional : master_param, master_param_out
inplace : (param -> param_out), (master_param -> master_param_out)
- op : shaddow_output
args : (Tensor x, str name)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel:
func : shaddow_output
param : [x]
- op : shape
args : (Tensor input)
output : Tensor(out)
......
......@@ -244,18 +244,6 @@
param : [num_rows, num_columns, dtype]
data_type : dtype
- op : feed_with_place
args : (int64_t index, DataType dtype, Place place)
output : Tensor(out)
infer_meta :
func : FeedWithPlaceInferMeta
param : [index, dtype]
kernel:
func : feed_with_place
param : [index, dtype]
data_type : dtype
backend : place
- op : floor_divide
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor(out)
......
......@@ -26,6 +26,11 @@ void FeedWithPlaceKernel(const Context& ctx,
phi::DataType data_type,
DenseTensor* out) {}
template <typename T, typename Context>
void ShaddowOutputKernel(const Context& ctx,
const DenseTensor& x,
DenseTensor* out) {}
} // namespace phi
PD_REGISTER_KERNEL(
......@@ -44,3 +49,6 @@ PD_REGISTER_KERNEL(shaddow_feed,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
PD_REGISTER_KERNEL(
shaddow_output, CPU, ALL_LAYOUT, phi::ShaddowOutputKernel, float) {}
......@@ -22,6 +22,12 @@ template <typename T, typename Context>
void FeedWithPlaceKernel(const Context& ctx,
int64_t index,
phi::DataType data_type,
// std::string name,
DenseTensor* out);
template <typename T, typename Context>
void ShaddowOutputKernel(const Context& ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context>
......
......@@ -30,6 +30,7 @@ def feed_with_place():
'index': 0,
'dtype': 0,
'place': 0,
'name': "x",
},
)
return out
......
......@@ -19,11 +19,10 @@ import numpy as np
import paddle
paddle.enable_static()
class TestNewIr(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -48,6 +47,7 @@ class TestNewIr(unittest.TestCase):
class TestCombineOp(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -72,6 +72,7 @@ class TestCombineOp(unittest.TestCase):
class TestFeedOp(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -103,6 +104,7 @@ class TestFeedOp(unittest.TestCase):
class TestSelectedRows(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
# TODO(phlrain): support selected rows in GPU
# place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
place = paddle.CPUPlace()
......@@ -127,6 +129,7 @@ class TestSelectedRows(unittest.TestCase):
class TestAddGradOp(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -141,11 +144,9 @@ class TestAddGradOp(unittest.TestCase):
x = paddle.static.data("x", [2, 2], dtype="float32")
y = paddle.static.data("y", [2, 2], dtype="float32")
x.stop_gradient = False
z = x * y
paddle.static.gradients(z, x)
np_a = np.random.rand(2, 2).astype("float32")
np_b = np.random.rand(2, 2).astype("float32")
out = exe.run(
......@@ -159,8 +160,63 @@ class TestAddGradOp(unittest.TestCase):
np.testing.assert_array_equal(out[0], gold_res)
class TestNewIrDygraph(unittest.TestCase):
def test_with_new_ir(self):
paddle.disable_static()
# paddle.device.set_device("cpu")
@paddle.jit.to_static
def func(x, y):
return x + y
x = paddle.ones([2, 2], dtype='float32')
y = paddle.ones([2, 2], dtype='float32')
z = func(x, y)
gold_res = np.ones([2, 2], dtype="float32") * 2
self.assertEqual(
np.array_equal(
z.numpy(),
gold_res,
),
True,
)
class TestNewIrBackwardDygraph(unittest.TestCase):
def test_with_new_ir(self):
paddle.disable_static()
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_inplace = False
@paddle.jit.to_static(build_strategy=build_strategy)
def func(x, y):
return x * y
x = paddle.ones([2, 2], dtype='float32')
y = paddle.ones([2, 2], dtype='float32')
x.stop_gradient = False
y.stop_gradient = False
z = func(x, y)
loss = z.mean()
loss.backward()
gold_res = np.ones([2, 2], dtype="float32")
self.assertEqual(
np.array_equal(
z.numpy(),
gold_res,
),
True,
)
gold_res = np.ones([2, 2], dtype="float32") * 0.25
np.testing.assert_array_equal(x.gradient(), gold_res)
np.testing.assert_array_equal(y.gradient(), gold_res)
class TestSplitOp(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -186,4 +242,5 @@ class TestSplitOp(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册