diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 574adc1f9d4d4733f16dfcd54bddaf7502c31d52..a8e47953f65dd733ce94b59b1b192d73c3c54ab0 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -19,12 +19,16 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/value.h" +PHI_DECLARE_bool(enable_new_ir_in_executor); + namespace details { using Tensor = paddle::Tensor; @@ -367,16 +371,32 @@ inline void RunProgramAPI( details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScope(params, global_inner_scope); // Step 2. create new interpretercore - interpreter_core = - paddle::framework::CreateInterpreterCoreInfoToCache(*forward_program, - place, - /*is_grad=*/false, - program_id, - global_inner_scope); + + if (FLAGS_enable_new_ir_in_executor) { + // build new ir program + auto ir_program = paddle::framework::ConstructFowardIrProgram( + forward_global_block, backward_global_block, output_names, x); + interpreter_core = + paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + std::move(ir_program), + place, + /*is_grad=*/false, + program_id, + global_inner_scope); + } else { + interpreter_core = + paddle::framework::CreateProgramInterpreterCoreInfoToCache( + *forward_program, + place, + /*is_grad=*/false, + program_id, + global_inner_scope); + } // Step 3. get all eager gc vars std::set skip_eager_delete_vars = paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet( *backward_program); + // all out_vars are skip_eager_var skip_eager_delete_vars.insert(output_names.begin(), output_names.end()); skip_eager_delete_vars.insert(dout_names.begin(), dout_names.end()); @@ -504,12 +524,27 @@ inline void RunProgramGradAPI( 1); VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; details::ShareTensorsIntoScope(out_grad, global_inner_scope); - interpreter_core = - paddle::framework::CreateInterpreterCoreInfoToCache(*backward_program, - place, - /*is_grad=*/true, - program_id, - global_inner_scope); + + if (FLAGS_enable_new_ir_in_executor) { + auto res = paddle::framework::ConstructBackwardIrProgram( + backward_global_block, out_grad, x_grad, params_grad); + + interpreter_core = + paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + std::move(res), + place, + /*is_grad=*/true, + program_id, + global_inner_scope); + } else { + interpreter_core = + paddle::framework::CreateProgramInterpreterCoreInfoToCache( + *backward_program, + place, + /*is_grad=*/true, + program_id, + global_inner_scope); + } // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4137518cf69d42e039e4c737a9d6d49394c98005..41b681afb540055782ebcef5a80167140c097ee2 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1033,7 +1033,8 @@ cc_library( cc_library( executor_cache SRCS executor_cache.cc - DEPS parallel_executor standalone_executor) + DEPS parallel_executor standalone_executor phi_kernel_adaptor + pd_op_to_kernel_pass ir) if(WITH_PSCORE) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) if(WITH_HETERPS) diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 9e8f4a25873d1881246aa7673b0828c8f02a5190..506ce36e47242dbc5b9ff9ceebbc510d902cb853 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/value.h" @@ -288,7 +290,7 @@ InterpreterCoreInfoCache &InterpreterCoreInfoCache::Instance() { return g_info_cache; } -std::shared_ptr CreateInterpreterCoreInfoToCache( +std::shared_ptr CreateProgramInterpreterCoreInfoToCache( const ProgramDesc &program_desc, const platform::Place &place, bool is_grad, @@ -304,13 +306,172 @@ std::shared_ptr CreateInterpreterCoreInfoToCache( interpreter::ExecutionConfig execution_config; execution_config.create_local_scope = false; execution_config.used_for_jit = true; - auto core = std::make_shared( - place, program_desc.Block(0), scope, execution_config); + + std::shared_ptr core = nullptr; + + core.reset(new InterpreterCore( + place, program_desc.Block(0), scope, execution_config)); + + auto &cached_value = + interpretercore_info_cache.GetMutable(program_id, is_grad); + cached_value.core_ = core; + return core; +} + +std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( + std::unique_ptr<::ir::Program> ir_program, + const platform::Place &place, + bool is_grad, + int64_t program_id, + framework::Scope *scope) { + auto &interpretercore_info_cache = + framework::InterpreterCoreInfoCache::Instance(); + if (interpretercore_info_cache.Size() > 10u /* max_cached_size*/) { + VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear " + "all cache!"; + interpretercore_info_cache.Finalize(); + } + interpreter::ExecutionConfig execution_config; + execution_config.create_local_scope = false; + execution_config.used_for_jit = true; + + std::shared_ptr core = nullptr; + + core.reset(new InterpreterCore( + place, std::move(ir_program), scope, execution_config)); + auto &cached_value = interpretercore_info_cache.GetMutable(program_id, is_grad); cached_value.core_ = core; return core; } +std::unique_ptr<::ir::Program> ConstructFowardIrProgram( + const paddle::framework::BlockDesc *forward_global_block, + const paddle::framework::BlockDesc *backward_global_block, + const std::vector output_names, + const std::vector &x) { + auto ir_ctx = ::ir::IrContext::Instance(); + auto program = std::make_unique<::ir::Program>(ir_ctx); + + std::set set_output_names; + auto local_program = + paddle::framework::ProgramDesc(*(forward_global_block->Program())); + + for (auto op_desc : local_program.Block(0).AllOps()) { + for (const auto &n : op_desc->Outputs()) { + const auto &input_var_names = n.second; + for (const auto &var_name : input_var_names) { + set_output_names.insert(var_name); + } + } + } + + // add fetch with place op to program + for (auto &in_t : x) { + auto name = in_t.name(); + auto place = in_t.place().GetType(); + + auto op_desc = local_program.MutableBlock(0)->PrependOp(); + op_desc->SetType("feed_with_place"); + op_desc->SetAttr("index", 0); + // TODO(phlrain) : using tensor dtype + op_desc->SetAttr("dtype", 0); + op_desc->SetAttr("place", static_cast(place)); + op_desc->SetAttr("name", name); + op_desc->SetOutput("out", {name}); + } + + std::set set_parameter_names; + for (auto op_desc : backward_global_block->Program()->Block(0).AllOps()) { + for (const auto &n : op_desc->Inputs()) { + const auto &input_var_names = n.second; + for (const auto &var_name : input_var_names) { + set_parameter_names.insert(var_name); + } + } + } + + for (auto &t : output_names) { + set_parameter_names.insert(t); + } + + for (auto &name : set_parameter_names) { + if (!set_output_names.count(name)) { + continue; + } + + auto op_desc = local_program.MutableBlock(0)->AppendOp(); + op_desc->SetType("shaddow_output"); + op_desc->SetAttr("name", name); + op_desc->SetInput("x", {name}); + op_desc->SetOutput("out", {"@EMPTY@"}); + } + + paddle::translator::ProgramTranslator program_translator(&local_program, + program.get()); + + program_translator.Translate(); + + auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get()); + + return ir_res; +} + +std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( + const paddle::framework::BlockDesc *backward_global_block, + const std::vector &out_grad, + const std::vector &x_grad, + const std::vector ¶ms_grad) { + auto ir_ctx = ::ir::IrContext::Instance(); + auto program = std::make_unique<::ir::Program>(ir_ctx); + + auto local_program = + paddle::framework::ProgramDesc(*(backward_global_block->Program())); + // add feed kernel + for (auto &out_grad_t : out_grad) { + auto name = out_grad_t.name(); + auto place = out_grad_t.place().GetType(); + if (name == "@EMPTY@") { + continue; + } + auto op_desc = local_program.MutableBlock(0)->PrependOp(); + op_desc->SetType("feed_with_place"); + op_desc->SetAttr("index", 0); + // TODO(phlrain) : using tensor dtype + op_desc->SetAttr("dtype", 0); + op_desc->SetAttr("place", static_cast(place)); + op_desc->SetAttr("name", name); + op_desc->SetOutput("out", {name}); + } + + std::vector param_grad_names; + for (auto &p_g : params_grad) { + param_grad_names.push_back(p_g->name()); + } + + for (auto &t : x_grad) { + param_grad_names.push_back(t->name()); + } + for (auto &name : param_grad_names) { + if (name == "@EMPTY@") { + continue; + } + auto op_desc = local_program.MutableBlock(0)->AppendOp(); + op_desc->SetType("shaddow_output"); + op_desc->SetAttr("name", name); + op_desc->SetInput("x", {name}); + op_desc->SetOutput("out", {"@EMPTY@"}); + } + + paddle::translator::ProgramTranslator program_translator(&local_program, + program.get()); + program_translator.Translate(); + + auto res = paddle::dialect::PdOpLowerToKernelPass(program.get()); + + return res; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index f4d926d74c146678f9a28f7b27831644d9fdf997..c639b966286cb37be5f6b628ea00168f537d0f64 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -29,6 +29,11 @@ #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/ir_adaptor/translator/program_translator.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/program.h" + namespace paddle { namespace framework { namespace ir { @@ -218,12 +223,31 @@ class InterpreterCoreInfoCache { std::unordered_map info_map_; }; -std::shared_ptr CreateInterpreterCoreInfoToCache( +std::shared_ptr CreateProgramInterpreterCoreInfoToCache( const ProgramDesc& program_desc, const platform::Place& place, bool is_grad, int64_t program_id, framework::Scope* scope); +std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( + std::unique_ptr<::ir::Program> ir_prog, + const platform::Place& place, + bool is_grad, + int64_t program_id, + framework::Scope* scope); + +std::unique_ptr<::ir::Program> ConstructFowardIrProgram( + const paddle::framework::BlockDesc* forward_global_block, + const paddle::framework::BlockDesc* backward_global_block, + const std::vector output_names, + const std::vector& x); + +std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( + const paddle::framework::BlockDesc* backward_global_block, + const std::vector& out_grad, + const std::vector& x_grad, + const std::vector& params_grad); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 70be3b9dd035a3c398b43c1cb92a3307b2786488..035f4cd4f16d9bb526ca60af58fa9b470a12e16a 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -958,7 +958,8 @@ void BuildOpFuncList( if (op_name == "builtin.combine" || op_name == "pd.feed" || op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter" || op_name == "builtin.slice") { + op_name == "builtin.get_parameter" || op_name == "builtin.slice" || + op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") { VLOG(6) << "skip process " << op_name; continue; } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 6552a14a03fcc815c8cd5aca1af1463ed96a93c0..d4421ed7ab009c7f137e6f2108a3462ca1f8dabc 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -984,7 +984,7 @@ std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) { do { \ if (paddle::framework::TransToProtoVarType(tensor.dtype()) == \ proto_type) { \ - os << " - dtype: " << proto_type << "\n"; \ + os << " - dtype: " << tensor.dtype() << "\n"; \ paddle::framework::print_tensor(os, tensor); \ return os; \ } \ diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 1a880210afbe10c39c663778e9c8398f58deab8f..95702ac672113474eba679ac9b8cd2ababafc397 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -66,8 +66,10 @@ paddle::framework::Variable* CreateVar( } paddle::framework::Variable* var = nullptr; + std::string name = var_name_prefix + "_inner_var_" + std::to_string(variable_2_var_name->size()); + if (force_persisable || is_persisable) { VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root(); var = const_cast(inner_scope->root())->Var(name); @@ -202,6 +204,15 @@ void HandleForSpecialOp( value_2_var_name->emplace(value, feed_var_name); } + if (op_name == "pd.feed_with_place") { + VLOG(6) << "Handle for pd.feed_with_place"; + auto var_name = + op->attributes().at("name").dyn_cast().AsString(); + + auto value = op->result(0); + value_2_var_name->emplace(value, var_name); + } + if (op_name == "builtin.combine") { auto out_value = op->result(0); @@ -252,6 +263,22 @@ void HandleForSpecialOp( (*value_2_var_name)[value] = param_name; } + if (op_name == "pd.shaddow_output") { + VLOG(6) << "Handle for pd.shaddow_ouptut"; + auto var_name = + op->attributes().at("name").dyn_cast().AsString(); + + auto value = op->operand(0); + // change opreand name to param_name + auto orig_name = value_2_var_name->at(value); + + if (inner_scope->root()->FindVar(var_name) == nullptr) { + const_cast(inner_scope->root()) + ->Rename(orig_name, var_name); + } + (*value_2_var_name)[value] = var_name; + } + if (op_name == "builtin.get_parameter") { VLOG(6) << "Handle for builtin.get_parameter:"; auto param_name = op->attributes() @@ -362,7 +389,8 @@ void BuildScope(const ir::Block& block, if (op_name == "pd.feed" || op_name == "pd.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter" || op_name == "builtin.slice") { + op_name == "builtin.get_parameter" || op_name == "builtin.slice" || + op_name == "pd.feed_with_place" || op_name == "pd.shaddow_output") { HandleForSpecialOp(op, inner_scope, var_name_prefix, diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc index d55ce6b24f9cf2f9e4c1cee3afdad10fa2cae20f..beb4635bebba498955b70feda258f8d151fecfca 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc @@ -62,6 +62,20 @@ phi::KernelKey GetKernelKey( TransToPhiDataType( op->result(0).type().dyn_cast().dtype())}; } + + if (op->name() == "pd.feed_with_place") { + // NOTE, for now feed op don't need a kernel, so the data type from Op + // Result the next op use base program datatype + auto t = + op->attributes().at("place").dyn_cast().data(); + + auto backend = paddle::experimental::ParseBackend(t); + return {backend, + phi::DataLayout::ANY, + TransToPhiDataType( + op->result(0).type().dyn_cast().dtype())}; + } + phi::Backend kernel_backend = phi::Backend::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataType kernel_data_type = phi::DataType::UNDEFINED; diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 0aab57af7998a465e1367c558f25d6393bb56512..ee2f66692eda8657142dd35a94e9c5e8a21b323a 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -954,6 +954,39 @@ struct FeedOpTranscriber : public OpTranscriber { } }; +struct FeedWithPlaceOpTranscriber : public OpTranscriber { + ir::AttributeMap TranslateOpAttribute( + ir::IrContext* ctx, + const std::string& normalized_op_name, + const OpAttributeInfoList& op_attr_infos, + const OpDesc& op_desc) override { + int allocate_type = paddle::get(op_desc.GetAttr("place")); + ir::AttributeMap attribute_map = { + {"name", + ir::StrAttribute::get(ctx, + op_desc.GetAttrIfExists("name"))}, + {"index", ir::Int64Attribute::get(ctx, 0)}, + {"dtype", + paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)}, + {"place", + paddle::dialect::PlaceAttribute::get( + ctx, phi::Place(static_cast(allocate_type)))}, + }; + + return attribute_map; + } + + std::vector GenerateOperationInput( + ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfoList& input_infos, + ir::Program* program) override { + return {}; + } +}; + struct SplitOpTranscriber : public OpTranscriber { std::vector GenerateOperationInput( ir::IrContext* ctx, @@ -1087,6 +1120,32 @@ struct FetchOpTranscriber : public OpTranscriber { } }; +struct ShaddowOutputOpTranscriber : public OpTranscriber { + ir::Operation* operator()(ir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + ir::Program* program) override { + std::vector op_inputs; + auto legacy_input_vars = op_desc.Input("x", true); + + auto defining_info = (*param_map)[legacy_input_vars[0]]; + op_inputs.push_back(defining_info.value); + + ir::AttributeMap attribute_map = { + {"parameter_name", + ir::StrAttribute::get(ctx, + op_desc.GetAttrIfExists("name"))}, + }; + + auto create_op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name()); + ir::Operation* operation = + ir::Operation::Create(op_inputs, attribute_map, {}, create_op_info); + program->block()->push_back(operation); + + return operation; + } +}; + // NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now struct AddNOpTranscriber : public OpTranscriber { ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { @@ -1159,6 +1218,7 @@ struct OneHotTranscriber : public OpTranscriber { OpTranslator::OpTranslator() { general_handler = OpTranscriber(); special_handlers["feed"] = FeedOpTranscriber(); + special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber(); special_handlers["cast"] = CastOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); @@ -1167,8 +1227,10 @@ OpTranslator::OpTranslator() { special_handlers["assign_value"] = AssignValueOpTranscriber(); special_handlers["increment"] = IncrementOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber(); + special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber(); special_handlers["one_hot_v2"] = OneHotTranscriber(); special_handlers["add_n"] = AddNOpTranscriber(); + special_handlers["sum"] = AddNOpTranscriber(); } } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index b162e8198b9937425289a9491340fbe53b43e494..202cfc61dd304a754e9a7cfe9e24aa9c4c887c46 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -217,7 +217,15 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( continue; } ir::OpResult value = value_info.value; + if (!value) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Value of [%s] can not ber None", var_name)); + } auto* defining_op = value.owner(); + PADDLE_ENFORCE_NOT_NULL( + defining_op, + phi::errors::PreconditionNotMet( + "Defining operator of [%s] can not be nullptr", var_name)); VLOG(8) << "[op translated][stop gradient]" << var_name << " from: " << defining_op->name(); std::vector stop_gradients; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index add6520493e1f51c3e461f021e884e7392c44c81..ed7c9d42373968a0eb31d8cd45f1d8594e4f94af 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1029,6 +1029,9 @@ - op : feed outputs: {out: Out} +- op : feed_with_place + outputs: {out: out} + - op : fft_c2c inputs: {x: X} outputs: {out: Out} @@ -2461,6 +2464,10 @@ extra : attrs : [bool use_mkldnn=false] +- op : shaddow_output + inputs: {x: x} + outputs: {out: out} + - op : shape inputs : input : Input diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 661de64990ee6b8c45b4ac8279814cdd59ebc572..8368184b2839da33446272991c74013fcae0cc1d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -826,6 +826,18 @@ inplace: (x -> out) backward : expm1_grad +- op : feed_with_place + args : (int64_t index, DataType dtype, str name, Place place) + output : Tensor(out) + infer_meta : + func : FeedWithPlaceInferMeta + param : [index, dtype] + kernel: + func : feed_with_place + param : [index, dtype] + data_type : dtype + backend : place + - op : fft_c2c args : (Tensor x, int64_t[] axes, str normalization, bool forward) output : Tensor @@ -2212,6 +2224,16 @@ optional : master_param, master_param_out inplace : (param -> param_out), (master_param -> master_param_out) +- op : shaddow_output + args : (Tensor x, str name) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel: + func : shaddow_output + param : [x] + - op : shape args : (Tensor input) output : Tensor(out) diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 5ac156ff5714d2c33587287fbbfe51fb12b4d1ec..216fca178fde78adc3462019b730f1c850568c34 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -244,18 +244,6 @@ param : [num_rows, num_columns, dtype] data_type : dtype -- op : feed_with_place - args : (int64_t index, DataType dtype, Place place) - output : Tensor(out) - infer_meta : - func : FeedWithPlaceInferMeta - param : [index, dtype] - kernel: - func : feed_with_place - param : [index, dtype] - data_type : dtype - backend : place - - op : floor_divide args : (Tensor x, Tensor y, int axis = -1) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/feed_with_place_kernel.cc b/paddle/phi/kernels/cpu/feed_with_place_kernel.cc index 342ad6a334cc303e6de3a2f60227bb295bdfff87..5044bceda26bd09624b4ad6cb566c0a1b8a9c2b7 100644 --- a/paddle/phi/kernels/cpu/feed_with_place_kernel.cc +++ b/paddle/phi/kernels/cpu/feed_with_place_kernel.cc @@ -26,6 +26,11 @@ void FeedWithPlaceKernel(const Context& ctx, phi::DataType data_type, DenseTensor* out) {} +template +void ShaddowOutputKernel(const Context& ctx, + const DenseTensor& x, + DenseTensor* out) {} + } // namespace phi PD_REGISTER_KERNEL( @@ -44,3 +49,6 @@ PD_REGISTER_KERNEL(shaddow_feed, phi::bfloat16, phi::complex64, phi::complex128) {} + +PD_REGISTER_KERNEL( + shaddow_output, CPU, ALL_LAYOUT, phi::ShaddowOutputKernel, float) {} diff --git a/paddle/phi/kernels/feed_with_place_kernel.h b/paddle/phi/kernels/feed_with_place_kernel.h index 4e8e9063c8d9b927fd7b14a634e2756db8f7471b..725ec0c508af1df4e979fc827265eb7187a86f02 100644 --- a/paddle/phi/kernels/feed_with_place_kernel.h +++ b/paddle/phi/kernels/feed_with_place_kernel.h @@ -22,6 +22,12 @@ template void FeedWithPlaceKernel(const Context& ctx, int64_t index, phi::DataType data_type, + // std::string name, + DenseTensor* out); + +template +void ShaddowOutputKernel(const Context& ctx, + const DenseTensor& x, DenseTensor* out); template diff --git a/test/ir/new_ir/test_feed_with_place.py b/test/ir/new_ir/test_feed_with_place.py index 5843fe227b1bfe9391885c4e38403b8488b7a383..222a5a86460b8213cb8291be3b5350d71beeb928 100644 --- a/test/ir/new_ir/test_feed_with_place.py +++ b/test/ir/new_ir/test_feed_with_place.py @@ -30,6 +30,7 @@ def feed_with_place(): 'index': 0, 'dtype': 0, 'place': 0, + 'name': "x", }, ) return out diff --git a/test/ir/new_ir/test_standalone_new_ir.py b/test/ir/new_ir/test_standalone_new_ir.py index c67370b2e0a2fcf57ff4c590c51d2cdf3acb3dd8..4a00c2960c286bf90e23ffeaed66b77ae7e621f8 100644 --- a/test/ir/new_ir/test_standalone_new_ir.py +++ b/test/ir/new_ir/test_standalone_new_ir.py @@ -19,11 +19,10 @@ import numpy as np import paddle -paddle.enable_static() - class TestNewIr(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -48,6 +47,7 @@ class TestNewIr(unittest.TestCase): class TestCombineOp(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -72,6 +72,7 @@ class TestCombineOp(unittest.TestCase): class TestFeedOp(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -103,6 +104,7 @@ class TestFeedOp(unittest.TestCase): class TestSelectedRows(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() # TODO(phlrain): support selected rows in GPU # place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() place = paddle.CPUPlace() @@ -127,6 +129,7 @@ class TestSelectedRows(unittest.TestCase): class TestAddGradOp(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -141,11 +144,9 @@ class TestAddGradOp(unittest.TestCase): x = paddle.static.data("x", [2, 2], dtype="float32") y = paddle.static.data("y", [2, 2], dtype="float32") x.stop_gradient = False - z = x * y paddle.static.gradients(z, x) - np_a = np.random.rand(2, 2).astype("float32") np_b = np.random.rand(2, 2).astype("float32") out = exe.run( @@ -159,8 +160,63 @@ class TestAddGradOp(unittest.TestCase): np.testing.assert_array_equal(out[0], gold_res) +class TestNewIrDygraph(unittest.TestCase): + def test_with_new_ir(self): + paddle.disable_static() + # paddle.device.set_device("cpu") + + @paddle.jit.to_static + def func(x, y): + return x + y + + x = paddle.ones([2, 2], dtype='float32') + y = paddle.ones([2, 2], dtype='float32') + z = func(x, y) + + gold_res = np.ones([2, 2], dtype="float32") * 2 + self.assertEqual( + np.array_equal( + z.numpy(), + gold_res, + ), + True, + ) + + +class TestNewIrBackwardDygraph(unittest.TestCase): + def test_with_new_ir(self): + paddle.disable_static() + build_strategy = paddle.static.BuildStrategy() + build_strategy.enable_inplace = False + + @paddle.jit.to_static(build_strategy=build_strategy) + def func(x, y): + return x * y + + x = paddle.ones([2, 2], dtype='float32') + y = paddle.ones([2, 2], dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + z = func(x, y) + loss = z.mean() + loss.backward() + gold_res = np.ones([2, 2], dtype="float32") + self.assertEqual( + np.array_equal( + z.numpy(), + gold_res, + ), + True, + ) + + gold_res = np.ones([2, 2], dtype="float32") * 0.25 + np.testing.assert_array_equal(x.gradient(), gold_res) + np.testing.assert_array_equal(y.gradient(), gold_res) + + class TestSplitOp(unittest.TestCase): def test_with_new_ir(self): + paddle.enable_static() place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -186,4 +242,5 @@ class TestSplitOp(unittest.TestCase): if __name__ == "__main__": + paddle.enable_static() unittest.main()