diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7e061c460682824625b8c6202fdce0f833a5cc11..ebf0aeb35471c2602bf011bcc400a19f318498e2 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -70,11 +70,6 @@ void ExecutorPrepareContext::PrepareUnusedVars( force_disable_gc = true; } #endif - force_disable_gc_ = force_disable_gc; - if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) { - return; - } - // If gc is enabled and block size > 1 if (prog_.Size() > 1) { operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( @@ -84,6 +79,12 @@ void ExecutorPrepareContext::PrepareUnusedVars( operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( prog_, block_id_, ops_); } + + force_disable_gc_ = force_disable_gc; + if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) { + return; + } + unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); } @@ -412,9 +413,11 @@ std::vector> Executor::Prepare( return result; } -void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, - bool create_local_scope, bool create_vars, - bool keep_kids) { +void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, + Scope* scope, int64_t start_op_index, + int64_t end_op_index, + bool create_local_scope, + bool create_vars, bool keep_kids) { platform::RecordBlock b(kProgramId); PADDLE_ENFORCE_NOT_NULL(scope); Scope* local_scope = scope; @@ -446,7 +449,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, #endif } - for (auto& op : ctx->ops_) { + for (int64_t i = start_op_index; i < end_op_index; ++i) { + auto& op = ctx->ops_[i]; op->Run(*local_scope, place_); if (gc) { DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get()); @@ -471,6 +475,15 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } +void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + bool create_local_scope, bool create_vars, + bool keep_kids) { + int64_t start_op_index = 0; + int64_t end_op_index = ctx->ops_.size(); + RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index, + create_local_scope, create_vars, keep_kids); +} + void Executor::RunPreparedContext( ExecutorPrepareContext* ctx, Scope* scope, std::map* feed_targets, diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index cc663f220955540411eacff2c2b0704784cf0427..aa70bb2d81e7c0a08ca4d35b41fdc70ca3362de6 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -115,6 +115,12 @@ class Executor { void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); + void RunPartialPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, + int64_t start_op_index, int64_t end_op_index, + bool create_local_scope = true, + bool create_vars = true, + bool keep_kids = false); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, bool create_vars = true, bool keep_kids = false); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 8e2dc860a4cfc0fa3f560008d4944dbecee55098..470a1070011da44599ec95c2b6406bf7cdd7e3d2 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -64,6 +64,9 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; /// Variables with this suffix are the new Gradient. constexpr char kNewGradSuffix[] = "@NEWGRAD@"; +/// Variables with this suffix are the loaded from pre-train model. +constexpr char kLoadedVarSuffix[] = "@LOADED"; + /// RuntimeContext is used to relate input/output names of Operator with /// the corresponding variables in name scope. /// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index a991ce689af5b0029612a31c551133b6f7dd4e63..60bc88ca7237c44dc63aa98e0064ab59addd707c 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -200,11 +200,12 @@ void BasicEngine::Execute() { iter != accumulators_.end(), true, platform::errors::NotFound("Cannot find gradient of variable %s", var->Name())); + if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) { continue; } - var = std::make_shared("Gtmp@"); + var = std::make_shared(var->Name()); need_accu_var_list_.emplace_back(iter->second.get(), var); } } diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..04559a93c866c72f2d0b309a5005557134355666 --- /dev/null +++ b/paddle/fluid/operators/run_program_op.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/run_program_op.h" + +#include + +namespace paddle { +namespace operators { + +class RunProgramOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, + platform::errors::NotFound( + "Input(X) of RunProgramOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInputs("Params"), true, + platform::errors::NotFound( + "Input(Params) of RunProgramOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), true, + platform::errors::NotFound( + "Output(Out) of RunProgramOp should not be null.")); + } + + protected: + /* [Why use single type kernel]: + * + * This op is similar to a control flow op, it doses not need + * a op kernel, but in order to make it execute under dynamic + * graph mode, implement it with op kernel. + * + * So whether the kernel data type is int, float or other type, + * which has no effect on its execution logic, so directly + * specified a data type here. + * + * Of course, the data type here is also not important. + */ + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } +}; + +class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(vector)" + "The input tensors of RunProgram operator, also the feed targets " + "of loaded program.") + .AsDuplicable(); + AddInput("Params", + "(vector)" + "The input parameter of RunProgram operator, also the parameters " + "of the loaded program.") + .AsDuplicable(); + AddOutput("Out", + "(vector)" + "The output tensors of RunProgram operator, also the fetch " + "targets of the loaded program.") + .AsDuplicable(); + AddOutput("OutScope", + "(StepScopeVar)" + "A vector of execution scope in RunProgram operator, which " + "contains at most one scope." + "NOTE: Do not use Scope directly because Scope output is not " + "currently supported."); + AddAttr("global_block", + "(BlockDesc *)" + "The global block of executed program desc."); + AddAttr("start_op_index", + "(int64_t)" + "The index of the op to start execution"); + AddAttr("end_op_index", + "(int64_t)" + "The index of the op to stop execution"); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training.") + .SetDefault(false); + AddComment(R"DOC( +RunProgram operator. + +The RunProgram operator receives a program's feed targets, fetch targets, +and parameters, and receives the forward and backward program desc +as attributes, and then executes the program by executor. + +NOTE: This operator is added so that the inference model stored by +`fluid.io.save_inference_model` under the static graph mode can be loaded +under the dynamic graph mode for fine-tuning or inferencing. + +)DOC"); + } +}; + +class RunProgramGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, + platform::errors::NotFound( + "Input(X) of RunProgramGradOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInputs("Params"), true, + platform::errors::NotFound( + "Input(Params) of RunProgramGradOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInputs(framework::GradVarName("Out")), true, + platform::errors::NotFound( + "Input(Out@GRAD) of RunProgramGradOp should not be null.")); + // NOTE: The X@GRAD and Params@GRAD may not exist, + // because they can be set stop_gradient = True + } + + protected: + /* see [Why use single type kernel] */ + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } +}; + +template +class RunProgramGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("run_program_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput("Params", this->Input("Params")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetInput("OutScope", this->Output("OutScope")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Params"), + this->InputGrad("Params")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(run_program, ops::RunProgramOp, ops::RunProgramOpMaker, + ops::RunProgramGradOpMaker, + ops::RunProgramGradOpMaker); +REGISTER_OPERATOR(run_program_grad, ops::RunProgramGradOp); + +/* see [Why use single type kernel] */ +REGISTER_OP_CPU_KERNEL( + run_program, + ops::RunProgramOpKernel) +REGISTER_OP_CPU_KERNEL( + run_program_grad, + ops::RunProgramGradOpKernel) diff --git a/paddle/fluid/operators/run_program_op.cu.cc b/paddle/fluid/operators/run_program_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..19cd354c18f3a01863d0ea6269d9fed343e78d1d --- /dev/null +++ b/paddle/fluid/operators/run_program_op.cu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/run_program_op.h" + +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +/* see [Why use single type kernel] */ +REGISTER_OP_CUDA_KERNEL( + run_program, + ops::RunProgramOpKernel); +REGISTER_OP_CUDA_KERNEL( + run_program_grad, + ops::RunProgramGradOpKernel); diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9e099b2e96e793ff7d1cccc4cb1419d80cb92dee --- /dev/null +++ b/paddle/fluid/operators/run_program_op.h @@ -0,0 +1,318 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/framework/variable.h" + +namespace paddle { +namespace operators { + +using StepScopeVar = std::vector; +using BlockDesc = framework::BlockDesc; + +using Variable = framework::Variable; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; + +namespace details { + +// all input vars should be LoDTensor & is initialized +static void CheckInputVarStatus(const Variable &var, + const std::string &var_name) { + PADDLE_ENFORCE_EQ( + var.IsType(), true, + platform::errors::InvalidArgument( + "The input variable %s of " + "RunProgram(Grad)Op(StaticModelRunner) holds " + "wrong type. Expect type is LoDTensor, but receive type is %s.", + var_name, platform::demangle(framework::ToTypeName(var.Type())))); + PADDLE_ENFORCE_EQ( + var.Get().IsInitialized(), true, + platform::errors::InvalidArgument("The tensor in input variable %s of " + "RunProgram(Grad)Op(StaticModelRunner) " + "is not initialized.", + var_name)); +} + +static void CheckOutputVarStatus(const Variable &src_var, + const Variable &dst_var, + const std::string &var_name) { + if (dst_var.IsType()) { + PADDLE_ENFORCE_EQ( + src_var.IsType(), true, + platform::errors::InvalidArgument( + "The output variable %s get from " + "RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds " + "wrong type. Expect type is LoDTensor, but receive type is %s.", + var_name, + platform::demangle(framework::ToTypeName(src_var.Type())))); + PADDLE_ENFORCE_EQ(src_var.Get().IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor in output variable %s get from " + "RunProgram(Grad)Op(StaticModelRunner)'s internal " + "scope is not initialized.", + var_name)); + } else if (dst_var.IsType()) { + PADDLE_ENFORCE_EQ( + src_var.IsType(), true, + platform::errors::InvalidArgument( + "The output variable %s get from " + "RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds " + "wrong type. Expect type is SelectedRows, but receive type is %s.", + var_name, + platform::demangle(framework::ToTypeName(src_var.Type())))); + PADDLE_ENFORCE_EQ(src_var.Get().value().IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor in output variable %s get from " + "RunProgram(Grad)Op(StaticModelRunner)'s " + "internal scope is not initialized.", + var_name)); + + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The RunProgram(Grad)Op(StaticModelRunner) only support output " + "variable of type LoDTensor or SelectedRows, " + "but received variable %s's type is %s", + var_name, platform::demangle(framework::ToTypeName(dst_var.Type())))); + } +} + +static void VariableShare(const Variable &src_var, Variable *dst_var) { + // The previous check ensures that the variable type can only be LoDTensor + auto *lod_tensor = dst_var->GetMutable(); + lod_tensor->ShareDataWith(src_var.Get()); + lod_tensor->set_lod(src_var.Get().lod()); +} + +static void ShareVarsIntoScope(const std::vector &vars, + const std::vector &var_names, + framework::Scope *scope) { + for (size_t i = 0; i < vars.size(); ++i) { + auto *var = scope->Var(var_names[i]); + CheckInputVarStatus(*vars[i], var_names[i]); + VariableShare(*vars[i], var); + } +} + +static void VariableCopy(const Variable &src_var, + const platform::Place &dst_place, Variable *dst_var) { + // The previous check ensures that the variable type can only be LoDTensor or + // SelectedRows + if (src_var.IsType()) { + auto *lod_tensor = dst_var->GetMutable(); + TensorCopySync(src_var.Get(), dst_place, lod_tensor); + lod_tensor->set_lod(src_var.Get().lod()); + } else if (src_var.IsType()) { + auto *selected_rows = dst_var->GetMutable(); + TensorCopySync(src_var.Get().value(), dst_place, + selected_rows->mutable_value()); + selected_rows->set_rows(src_var.Get().rows()); + selected_rows->set_height(src_var.Get().height()); + } +} + +static void ShareVarsFromScope(const std::vector &vars, + const std::vector &var_names, + framework::Scope *scope) { + for (size_t i = 0; i < vars.size(); ++i) { + auto *var = scope->FindVar(var_names[i]); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("The output variable %s is not in " + "RunProgram(Grad)Op(StaticModelRunner)'" + "s internal scope.", + var_names[i])); + CheckOutputVarStatus(*var, *vars[i], var_names[i]); + VariableShare(*var, vars[i]); + } +} + +static void CopyVarsFromScope(const std::vector &vars, + const std::vector &var_names, + const platform::Place &dst_place, + framework::Scope *scope) { + for (size_t i = 0; i < vars.size(); ++i) { + if (var_names[i] == framework::kEmptyVarName) { + VLOG(2) << "find variable name is " << framework::kEmptyVarName + << ", skip it!"; + continue; + } + auto *var = scope->FindVar(var_names[i]); + // NOTE: Here skip not found var is dangerous, if a bug is caused here, + // the result is grad calculation error, which will be very hidden! + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("The output variable %s is not in " + "RunProgram(Grad)Op(StaticModelRunner)'" + "s internal scope.", + var_names[i])); + CheckOutputVarStatus(*var, *vars[i], var_names[i]); + VariableCopy(*var, dst_place, vars[i]); + } +} + +static void AppendSkipDeletionVars( + std::vector *all_vars, + const std::vector &append_vars) { + for (auto &var : append_vars) { + all_vars->emplace_back(var); + } +} + +} // namespace details + +template +class RunProgramOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + VLOG(2) << "RunProgramOpKernel Compute"; + // Step 1. prepare inputs, outputs, attrs + auto &input_vars = ctx.MultiInputVar("X"); + auto ¶m_vars = ctx.MultiInputVar("Params"); + auto output_vars = ctx.MultiOutputVar("Out"); + + auto input_var_names = ctx.InputNames("X"); + auto param_names = ctx.InputNames("Params"); + auto output_var_names = ctx.OutputNames("Out"); + + auto *block = ctx.Attr("global_block"); + auto *program = block->Program(); + auto start_op_index = ctx.Attr("start_op_index"); + auto end_op_index = ctx.Attr("end_op_index"); + auto is_test = ctx.Attr("is_test"); + + // NOTE(chenweihang): In order not to add new variable type, use vector + // here. Originally, here can use scope directly. + auto *out_scope_vec = ctx.Output("OutScope"); + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), 1, + platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + + // Step 2. prepare executor and init persistable variables + framework::Executor exe(ctx.GetPlace()); + + // skip delete vars + std::vector skip_vars; + details::AppendSkipDeletionVars(&skip_vars, output_var_names); + VLOG(2) << "Prepare to skip " << skip_vars.size() + << " var(s): " << string::join_strings(skip_vars, ' '); + + auto exe_ctx = exe.Prepare(*program, 0, skip_vars); + + framework::Scope &scope = *(out_scope_vec->front()); + // share input_vars & parameters into scope + details::ShareVarsIntoScope(input_vars, input_var_names, &scope); + details::ShareVarsIntoScope(param_vars, param_names, &scope); + + // Step 3. run ops + exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index, + end_op_index, /*create_local_scope=*/false, + /*create_vars=*/true, /*keep_kids=*/!is_test); + + // Step 4. Get Output + details::ShareVarsFromScope(output_vars, output_var_names, &scope); + + // Debug info: scope info when run end + VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + } +}; + +template +class RunProgramGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + VLOG(2) << "RunProgramGradOpKernel Compute"; + // Step 1. prepare inputs and outputs + auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out")); + auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X")); + auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params")); + + // if all output vars are set to stop_gradient, grad op no need to executed + if (input_grad_vars.empty() && param_grad_vars.empty()) return; + + auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out")); + // NOTE: after PR22939 [Add double grad] merged, the grad op maker's + // SetOutput will set to None if the input var stop_gradient=True, + // it will cause an NotFound error when ctx.OutputNames() is called + std::vector input_grad_var_names; + std::vector param_grad_names; + if (!input_grad_vars.empty()) { + input_grad_var_names = ctx.OutputNames(framework::GradVarName("X")); + } + if (!param_grad_vars.empty()) { + param_grad_names = ctx.OutputNames(framework::GradVarName("Params")); + } + + auto *block = ctx.Attr("global_block"); + auto *program = block->Program(); + + auto orig_end_op_index = ctx.Attr("end_op_index"); + // NOTE: skip `shape` and `fill_constant` op created by + // fluid.backward.gradients, + // one forward output will generate one `shape` and `fill_constant` + int64_t start_op_index = orig_end_op_index + (output_grad_vars.size() * 2); + int64_t end_op_index = block->OpSize(); + + auto *out_scope_vec = ctx.Input("OutScope"); + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), 1, + platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + + // Step 2. prepare executor and scope + framework::Executor exe(ctx.GetPlace()); + + // skip delete vars + std::vector skip_vars; + details::AppendSkipDeletionVars(&skip_vars, input_grad_var_names); + details::AppendSkipDeletionVars(&skip_vars, param_grad_names); + VLOG(2) << "Prepare to skip " << skip_vars.size() + << " var(s): " << string::join_strings(skip_vars, ' '); + + auto exe_ctx = exe.Prepare(*program, 0, skip_vars); + + auto &scope = *(out_scope_vec->front()); + details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names, + &scope); + + // Debug info: scope info when run end + VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + + // Step 3. run ops + exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index, + end_op_index, /*create_local_scope=*/false, + /*create_vars=*/true, /*keep_kids=*/false); + + // Step 4. copy outputs + details::CopyVarsFromScope(input_grad_vars, input_grad_var_names, + ctx.GetPlace(), &scope); + details::CopyVarsFromScope(param_grad_vars, param_grad_names, + ctx.GetPlace(), &scope); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 6b466c2639eab5b05ea4b717a6857c7b88c74416..d786031111b18e9b81fa654557a5ac2c9fc6af3d 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -621,6 +621,10 @@ void BindImperative(py::module *m_ptr) { return self.MutableGradVar()->Get(); }, py::return_value_policy::reference) + .def("_set_grad_type", + [](imperative::VarBase &self, framework::proto::VarType::Type type) { + self.MutableGradVarBase()->SetType(type); + }) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 62866a492cb61df9bd11df474b206e75ef167426..0d165dc07f92ac56b8d5a53ff75caffbfb7ad61b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -989,7 +989,11 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE_EQ(self.IsType(), true); return self.GetMutable(); }, - py::return_value_policy::reference); + py::return_value_policy::reference) + .def("set_scope", [](Variable &self, Scope &scope) { + auto scope_vec = self.GetMutable>(); + scope_vec->emplace_back(&scope); + }); BindReader(&m); @@ -1180,6 +1184,8 @@ All parameter, weight, gradient are variables in Paddle. []() { return std::string(framework::kEmptyVarName); }); m.def("grad_var_suffix", []() { return std::string(framework::kGradVarSuffix); }); + m.def("loaded_var_suffix", + []() { return std::string(framework::kLoadedVarSuffix); }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") diff --git a/python/paddle/fluid/dygraph/__init__.py b/python/paddle/fluid/dygraph/__init__.py index 7ce43d900a7fe7302a5dfc72a7565fde933c1f3b..cd539cce90b8ba0413e7cb53846a5c650ec671cc 100644 --- a/python/paddle/fluid/dygraph/__init__.py +++ b/python/paddle/fluid/dygraph/__init__.py @@ -44,6 +44,9 @@ from .backward_strategy import * from . import jit from .jit import * +from . import static_runner +from .static_runner import StaticModelRunner + __all__ = [] __all__ += layers.__all__ __all__ += base.__all__ diff --git a/python/paddle/fluid/dygraph/static_runner.py b/python/paddle/fluid/dygraph/static_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9b8d66fd76394ddd8becd4346a7cf4137dded1 --- /dev/null +++ b/python/paddle/fluid/dygraph/static_runner.py @@ -0,0 +1,538 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging +import numpy as np +import os +import six + +from . import layers +from .. import core +from .. import framework +from .. import backward + +from .base import switch_to_static_graph +from ... import compat as cpt + +# Set Log level +logging.getLogger().setLevel(logging.WARNING) + +# DESIGN IDEA: Add an special operator, execute static program inside operator. +# +# Op's Inputs: +# - the input variable of the user feed +# - the necessary parameters of the network +# Op's Outputs: +# - the output variable of fetch +# +# This op receives a complete program desc, internally creates scope +# and executor, executes this program. Key points: +# +# 1. Data Sharing: +# The varBase of the dynamic graph is not in the scope, so before the op +# executes the program internally, create persistent variables with the +# same name as feed, parameters, and fetch in the scope, and share the +# LoDTensor of the op input. +# +# 2. Forward and Backward Separation: +# Because the dynamic graph op performs the forward and backward separately, +# the forward program is used as the execution object of the forward op, +# and the reverse program is used as the execution object of the grad op. + + +class StaticModelRunner(layers.Layer): + """ + A Dynamic graph Layer for loading inference program and related parameters, + and then performing fine-tune training or inference. + + The loaded program and parameters are saved by `fluid.io.save_inference_model`. + + .. note:: + **1. Dynamic graph mode do not support LoDTensor. + All original static graph model's feed targets or parametars + that depend on LoD are temporarily unavailable.** + **2. All saved inference model's feed targets need be given.** + **3. The ``stop_gradient`` information is lost and can not be recovered.** + **4. The parameter's ``trainable`` information is lost and can not be recovered.** + **5. Double gradient model is not supported now.** + **6. Now only supports loading models saved by `fluid.io.save_inference_model`.** + + Args: + model_dir(str): The directory path where the model is saved. + model_filename(str, optional): The file name of saved inference program. + If set to None, a default filename is + :code:`__model__`. + The default value is None. + params_filename(str, optional): The file name of saved all related parameters. + If set to None, parameters are saved + in separate files. + The default value is None. + + Returns: + Layer: A Layer object can run loaded program. + + Examples: + .. code-block:: python + + import numpy as np + import paddle.fluid as fluid + + BATCH_SIZE = 32 + BATCH_NUM = 20 + SAVE_DIRNAME = "fc.inference.model" + + def random_batch_reader(): + def _get_random_images_and_labels(image_shape, label_shape): + image = np.random.random(size=image_shape).astype('float32') + label = np.random.random(size=label_shape).astype('int64') + return image, label + + def __reader__(): + for _ in range(BATCH_NUM): + batch_image, batch_label = _get_random_images_and_labels( + [BATCH_SIZE, 784], [BATCH_SIZE, 1]) + yield batch_image, batch_label + + return __reader__ + + def train_and_save_static_model(place): + img = fluid.data(name='img', shape=[None, 784], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + pred = fluid.layers.fc(input=img, size=10, act='softmax') + + loss = fluid.layers.cross_entropy(input=pred, label=label) + avg_loss = fluid.layers.mean(loss) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001) + optimizer.minimize(avg_loss) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + loader = fluid.io.DataLoader.from_generator( + feed_list=[img, label], capacity=5, iterable=True) + loader.set_batch_generator(random_batch_reader(), places=place) + + for data in loader(): + exe.run( + fluid.default_main_program(), + feed=data, + fetch_list=[avg_loss]) + + # save model by fluid.io.save_inference_model + fluid.io.save_inference_model( + SAVE_DIRNAME, ["img"], [pred], exe) + + + # Step 1. train and save inference model in static graph mode + place = fluid.CPUPlace() + train_and_save_static_model(place) + + # Step 2. load inference model in dygraph and fine-tune + with fluid.dygraph.guard(place): + fc = fluid.dygraph.static_runner.StaticModelRunner(SAVE_DIRNAME) + + sgd = fluid.optimizer.SGD(learning_rate=0.001, + parameter_list=fc.parameters()) + + train_loader = fluid.io.DataLoader.from_generator(capacity=5) + train_loader.set_batch_generator( + random_batch_reader(), places=place) + + for data in train_loader(): + img = data[0] + label = data[1] + label.stop_gradient = True + + cost = fc(inputs=img) + + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + + avg_loss.backward() + sgd.minimize(avg_loss) + """ + + def __init__(self, model_dir, model_filename=None, params_filename=None): + super(StaticModelRunner, self).__init__() + + # Step 0. key variable definitions + self._load_program_desc = None + self._program_desc = None + self._inner_scope = core.Scope() + # the layer outputs var desc + self._output_descs = [] + # input, output, params name list + self._input_names = [] + self._output_names = [] + self._param_names = [] + # train or eval flag + self._is_test = False + + # Step 1. load program desc from disk + # the saved model hold feed, fetch & scale op, no need, can be remove + self._load_program_desc = self._load_static_model(model_dir, + model_filename) + + # Step 2. set all `is_test` attributes to False + self._change_is_test_status(False) + + # Step 3. load all parameters + self._load_persisitable_dict(model_dir, params_filename) + + # Step 4. generate backwar program desc + self._program_desc = self._append_backward_desc() + + # Step 5. recheck parameters stop gradients + self._recheck_stop_gradients() + + def train(self): + # TODO: remove global train_mode setting + framework._dygraph_tracer().train_mode() + self._is_test = False + self._change_is_test_status(False) + + def eval(self): + # TODO: remove global train_mode setting + framework._dygraph_tracer().eval_mode() + self._is_test = True + self._change_is_test_status(True) + + def forward(self, inputs): + """ + Executed forward part of StaticModelRunner Layer. + Generally execute directly using the Layer object. + + Args: + inputs(np.ndarray|Variable|list[np.ndarray|Variable]): the inputs of StaticModelRunner + + Returns: + Variable|list[Variable]: The forward outputs of StaticModelRunner Layer. + """ + # Step 1. prepare inputs, outputs, attrs + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + + input_vars = [] + for i, value in enumerate(inputs): + if not isinstance(value, (np.ndarray, core.VarBase)): + raise TypeError( + "The type of inputs.value in StaticModelRunner.forward must be numpy array or Variable(VarBase), but received %s." + % type(value)) + # NOTE: In order to unify the API, firstly convert the input to VarBase + if isinstance(value, np.ndarray): + var = core.VarBase( + value=value, + name=self._input_names[i], + persistable=False, + place=framework._current_expected_place(), + zero_copy=True) + else: + var = value + # TODO: here may have important name set by user + var.name = self._input_names[i] + input_vars.append(var) + + params = [] + for param in self._parameters.values(): + params.append(param) + + output_vars = [] + for var_desc in self._output_descs: + var = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + output_vars.append(var) + + # hold forward variables + tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], + "program_out_scope", + core.VarDesc.VarType.STEP_SCOPES, True) + tmp_scope_vec.value().set_scope(self._inner_scope) + + # Step 2. run prorgam by op + framework._dygraph_tracer().trace_op( + type='run_program', + inputs={'X': input_vars, + 'Params': params}, + outputs={'Out': output_vars, + 'OutScope': tmp_scope_vec}, + attrs={ + 'global_block': self._program_desc.block(0), + 'start_op_index': 0, + 'end_op_index': self._load_program_desc.block(0).op_size(), + 'is_test': self._is_test + }) + + # NOTE: [ why need set param's gradient type here ] + # if user set sparse gradient mode, the param's gradient + # will be SelectedRows, not LoDTensor. But tracer will just + # set param grad VarBase by forward VarBase(LoDTensor) + # If we don't change grad_var type here, RunProgramOp need + # transform SelectedRows to LoDTensor forcely, it may not + # be user wanted result. + for param in params: + grad_name = param.name + core.grad_var_suffix() + grad_var = self._program_desc.block(0).find_var( + cpt.to_bytes(grad_name)) + # NOTE: cannot find var desc maybe no problem, such as in batch_norm + if grad_var is None: + continue + param._set_grad_type(grad_var.type()) + + # Step 3. prepare output, keep same form with inputs + outs = output_vars + if len(output_vars) == 1: + outs = output_vars[0] + return outs + + def _load_static_model(self, model_dir, model_filename=None): + # Step 1. dir and filename check + load_dirname = os.path.normpath(model_dir) + if not os.path.isdir(load_dirname): + raise ValueError("There is no directory named '%s'" % load_dirname) + + if model_filename is not None: + model_filename = os.path.basename(model_filename) + else: + model_filename = "__model__" + model_filename = os.path.join(load_dirname, model_filename) + + # Step 2. parse program desc + with open(model_filename, "rb") as f: + program_desc_str = f.read() + + program_desc = core.ProgramDesc(program_desc_str) + if not core._is_program_version_supported(program_desc._version()): + raise ValueError("Unsupported program version: %d\n" % + program_desc._version()) + + # Step 3. + # - remove feed, fetch and useless scale-1 op + # - remove op_callstack attr + ops_to_remove = [] + root_block = program_desc.block(0) + for i in six.moves.range(root_block.op_size()): + op = root_block.op(i) + if op.type() == 'feed': + ops_to_remove.append(i) + feed_var_name = cpt.to_bytes(op.input('X')[0]) + root_block._remove_var(feed_var_name) + self._input_names.append(cpt.to_bytes(op.output('Out')[0])) + elif op.type() == 'scale' and op.output('Out')[0].startswith( + 'save_infer_model/scale_'): + ops_to_remove.append(i) + out_var_name = cpt.to_bytes(op.output('Out')[0]) + root_block._remove_var(out_var_name) + self._output_names.append(cpt.to_bytes(op.input('X')[0])) + self._output_descs.append( + root_block.find_var(cpt.to_bytes(op.input('X')[0]))) + elif op.type() == 'fetch' and op.input('X')[0].startswith( + 'save_infer_model/scale_'): + ops_to_remove.append(i) + fetch_var_name = cpt.to_bytes(op.output('Out')[0]) + root_block._remove_var(fetch_var_name) + else: + if op.has_attr("op_callstack"): + op.remove_attr("op_callstack") + + for op_idx in reversed(ops_to_remove): + root_block._remove_op(op_idx, op_idx + 1) + + return program_desc + + @switch_to_static_graph + def _append_backward_desc(self): + assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly." + program_desc_copy = core.ProgramDesc(self._load_program_desc) + + # Step 1. prepare program and related var + # NOTE: To reuse backward interfaces, build Program firstly. + # Originally, there is no need to build a program, but need to almost + # rewrite a series of methods for append_backward for program_desc. + # Therefore, in order to reuse the method of backward.py, build the program here. + fwd_op_num = program_desc_copy.block(0).op_size() + program = self._build_program_by_desc(program_desc_copy) + + # TODO: could the targets be in sub block? + targets = [] + for out in self._output_descs: + targets.append(program.global_block().var(out.name())) + + # Step 2. append backward + backward.gradients(targets=targets, inputs=[]) + return program.desc + + def _load_persisitable_dict(self, model_dir, params_filename=None): + load_dirname = os.path.normpath(model_dir) + assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly." + + persis_vars = self._get_persis_vars(self._load_program_desc) + load_var_map = {} + for each_var in persis_vars: + orig_each_name = each_var.name() + # append suffix + self._append_loaded_suffix_to_param(each_var) + # create output varbase + new_var = framework.ParamBase( + shape=each_var.shape(), + dtype=each_var.dtype(), + name=each_var.name(), + type=each_var.type(), + persistable=True) + if params_filename is None: + if not self._is_parameter(each_var): + continue + # logging.info("persis var name %s" % each_var.name()) + framework._dygraph_tracer().trace_op( + type='load', + inputs={}, + outputs={'Out': new_var}, + attrs={ + 'file_path': os.path.join(load_dirname, orig_each_name) + }) + new_var.stop_gradient = False + self.add_parameter(name=new_var.name, parameter=new_var) + self._param_names.append(new_var.name) + else: + load_var_map[each_var.name()] = new_var + + if params_filename is not None: + load_var_list = [] + for name in sorted(load_var_map.keys()): + load_var_list.append(load_var_map[name]) + + framework._dygraph_tracer().trace_op( + type='load_combine', + inputs={}, + outputs={'Out': load_var_list}, + attrs={ + 'file_path': os.path.join(load_dirname, params_filename) + }) + + for each_var in persis_vars: + if not self._is_parameter(each_var): + continue + param = load_var_map[each_var.name()] + param.stop_gradient = False + self.add_parameter(name=param.name, parameter=param) + self._param_names.append(param.name) + + def _recheck_stop_gradients(self): + assert self._program_desc is not None, "The StaticModelRunner not initialized properly." + # NOTE: After loading the model, the stop_gradient information + # of the original variable is lost, but if a parameter does not + # have a corresponding @GRAD variable in the backward program, + # it can be said that it is also stop_gradient + all_var_names = self._get_all_var_names(self._program_desc) + for param_name in self._parameters: + param_grad_name = param_name + core.grad_var_suffix() + if param_grad_name not in all_var_names: + logging.info("set %s stop gradient = True" % param_grad_name) + self._parameters[param_name].stop_gradient = True + + def _get_all_var_names(self, program_desc): + all_var_names = set() + for i in six.moves.range(program_desc.num_blocks()): + block = program_desc.block(i) + for var in block.all_vars(): + logging.info(var.name()) + all_var_names.add(var.name()) + return all_var_names + + def _get_persis_vars(self, program_desc): + persis_vars = [] + for i in six.moves.range(program_desc.num_blocks()): + block = program_desc.block(i) + persis_vars.extend( + list(filter(self._is_persistable, block.all_vars()))) + return persis_vars + + @switch_to_static_graph + def _build_program_by_desc(self, program_desc): + prog = framework.Program() + prog.desc = program_desc + prog.blocks = [ + framework.Block(prog, i) + for i in six.moves.range(prog.desc.num_blocks()) + ] + prog._sync_with_cpp() + return prog + + def _is_persistable(self, var_desc): + if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var_desc.type() == core.VarDesc.VarType.READER or \ + var_desc.type() == core.VarDesc.VarType.RAW: + return False + return var_desc.persistable() + + def _is_parameter(self, persis_var_desc): + assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly." + # 1. firstly, param should be input of op + input_ops = [] # op can be repeated + for block_idx in six.moves.range(self._load_program_desc.num_blocks()): + block = self._load_program_desc.block(block_idx) + for op_idx in six.moves.range(block.op_size()): + op = block.op(op_idx) + # NOTE: parameter is the input of a certain op + if persis_var_desc.name() in op.input_arg_names(): + input_ops.append(op) + # 2. secondly, param should not be output of op or be same op's output + for block_idx in six.moves.range(self._load_program_desc.num_blocks()): + block = self._load_program_desc.block(block_idx) + for op_idx in six.moves.range(block.op_size()): + op = block.op(op_idx) + if persis_var_desc.name() in op.output_arg_names(): + # such as batch_norm_op + if op in input_ops: + continue + else: + return False + return True + + def _change_is_test_status(self, is_test): + # change all `is_test` attributes + assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly." + for i in six.moves.range(self._load_program_desc.num_blocks()): + block = self._load_program_desc.block(i) + for j in six.moves.range(block.op_size()): + op = block.op(j) + if op.has_attr('is_test'): + op._set_attr('is_test', is_test) + + def _append_loaded_suffix(self, name): + """ + Append grad suffix to the given variable name + e.g. x ==> x@LOADED + """ + suffix = core.loaded_var_suffix() + name = cpt.to_text(name) + if suffix not in name: + name = name + suffix + return name + + def _append_loaded_suffix_to_param(self, param_desc): + old_name = param_desc.name() + new_name = self._append_loaded_suffix(param_desc.name()) + param_desc.set_name(new_name) + for block_idx in six.moves.range(self._load_program_desc.num_blocks()): + block = self._load_program_desc.block(block_idx) + for op_idx in six.moves.range(block.op_size()): + op = block.op(op_idx) + op._rename_input(old_name, new_name) + op._rename_output(old_name, new_name) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d8578a56ceaddc27c56426288ebb8dc15df5049f..04554522684d239738ca4cd7db3797f093534de0 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -329,12 +329,12 @@ def _fetch_var(name, scope=None, return_numpy=True): Returns: LodTensor|numpy.ndarray """ - assert isinstance(name, str) + assert isinstance(name, six.string_types) if scope is None: scope = global_scope() assert isinstance(scope, core._Scope) - var = scope.find_var(name) + var = scope.find_var(_to_name_str(name)) assert var is not None, ( "Cannot find " + name + " in scope. Perhaps you need to make the" " variable persistable by using var.persistable = True in your" diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index b8bb3db1eedcf25c9b6a02ad3b4f261e8be8efce..ee61ec1c3da3f0343b643b15b719e94dd35a3841 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -124,11 +124,6 @@ class OpDescCreationMethod(object): new_attr.bools.extend(user_defined_attr) elif attr.type == framework_pb2.LONGS: new_attr.longs.extend(user_defined_attr) - elif attr.type == framework_pb2.INT_PAIRS: - for p in user_defined_attr: - pair = new_attr.int_pairs.add() - pair.first = p[0] - pair.second = p[1] else: raise NotImplementedError( "A not supported attribute type: %s." % ( diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1fc1f87887f6994ae60dac3f0bb272c504567366..6e50108383bfe9a21aff6295162e332f0506e4b1 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -193,6 +193,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_imperative_debug_string) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) +list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist) +list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while) if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_dataset) @@ -269,6 +271,10 @@ py_test_modules(test_install_check MODULES test_install_check ENVS FLAGS_cudnn_deterministic=1 SERIAL) set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST") py_test_modules(test_imperative_debug_string MODULES test_imperative_debug_string ENVS FLAGS_dygraph_debug=1) +py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_static_runner_mnist ENVS + FLAGS_cudnn_deterministic=1) +py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS + FLAGS_cudnn_deterministic=1) if(WITH_DISTRIBUTE) # FIXME(typhoonzero): add these tests back list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..4907591d4fb80bd90fa5c7f5cf6bae34b874f56e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py @@ -0,0 +1,286 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import contextlib +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from test_imperative_base import new_program_scope + + +def convolutional_neural_network(img): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + return prediction + + +def static_train_net(img, label): + prediction = convolutional_neural_network(img) + + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001) + optimizer.minimize(avg_loss) + + return prediction, avg_loss + + +class TestImperativeStaticModelRunnerMnist(unittest.TestCase): + def setUp(self): + self.seed = 90 + self.epoch_num = 1 + self.batch_size = 128 + self.batch_num = 50 + + def reader_decorator(self, reader): + def _reader_impl(): + for item in reader(): + image = np.array(item[0]).reshape(1, 28, 28) + label = np.array(item[1]).astype('int64').reshape(1) + yield image, label + + return _reader_impl + + def train_and_save_model(self): + startup_program = fluid.default_startup_program() + main_program = fluid.default_main_program() + + img = fluid.data(name='img', shape=[None, 1, 28, 28], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + prediction, avg_loss = static_train_net(img, label) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + exe.run(startup_program) + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=100), + batch_size=self.batch_size) + + for _ in range(0, self.epoch_num): + for batch_id, data in enumerate(train_reader()): + exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[avg_loss]) + + if batch_id > self.batch_num: + break + + fluid.io.save_inference_model( + self.save_dirname, ["img"], [prediction], + exe, + model_filename=self.model_filename, + params_filename=self.params_filename) + + def load_and_train_dygraph(self): + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = self.seed + fluid.default_main_program().random_seed = self.seed + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + + mnist = fluid.dygraph.static_runner.StaticModelRunner( + model_dir=self.save_dirname, + model_filename=self.model_filename, + params_filename=self.params_filename) + + dy_param_init_value = {} + for param in mnist.parameters(): + dy_param_init_value[param.name] = param.numpy() + + sgd = fluid.optimizer.SGD(learning_rate=0.001, + parameter_list=mnist.parameters()) + + train_reader = paddle.batch( + self.reader_decorator(paddle.dataset.mnist.train()), + batch_size=self.batch_size, + drop_last=True) + train_loader = fluid.io.DataLoader.from_generator(capacity=10) + train_loader.set_sample_list_generator(train_reader, places=place) + + mnist.train() + + for epoch in range(self.epoch_num): + for batch_id, data in enumerate(train_loader()): + img = data[0] + label = data[1] + label.stop_gradient = True + + cost = mnist(inputs=img) + + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + + avg_loss.backward(backward_strategy) + sgd.minimize(avg_loss) + mnist.clear_gradients() + + if batch_id >= self.batch_num: + break + + dy_x_data = img.numpy() + dy_out = avg_loss.numpy() + + dy_param_value = {} + for param in mnist.parameters(): + dy_param_value[param.name] = param.numpy() + + return dy_x_data, dy_out, dy_param_init_value, dy_param_value + + def load_and_train_static(self): + with new_program_scope(): + fluid.default_startup_program().random_seed = self.seed + fluid.default_main_program().random_seed = self.seed + + img = fluid.data( + name='img', shape=[None, 1, 28, 28], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + prediction, avg_loss = static_train_net(img, label) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + fluid.io.load_params( + exe, + self.save_dirname, + main_program=fluid.default_main_program(), + filename=self.params_filename) + + static_param_init_value = {} + static_param_name_list = [] + for param in fluid.default_main_program().all_parameters(): + static_param_name_list.append(param.name) + static_param_init_value[param.name] = fluid.executor._fetch_var( + param.name) + + train_reader = paddle.batch( + self.reader_decorator(paddle.dataset.mnist.train()), + batch_size=self.batch_size, + drop_last=True) + + for epoch in range(self.epoch_num): + for batch_id, data in enumerate(train_reader()): + static_x_data = np.array([x[0] for x in data]) + y_data = np.array([x[1] for x in data]).reshape( + [self.batch_size, 1]) + + fetch_list = [avg_loss.name] + fetch_list.extend(static_param_name_list) + + out = exe.run(fluid.default_main_program(), + feed={"img": static_x_data, + "label": y_data}, + fetch_list=fetch_list) + + if batch_id >= self.batch_num: + break + + static_param_value = {} + static_out = out[0] + for i in range(1, len(out)): + static_param_value[static_param_name_list[i - 1]] = out[i] + + return static_x_data, static_out, static_param_init_value, static_param_value + + def test_mnist_no_params_filename(self): + self.save_dirname = "mnist.inference.model.noname" + self.model_filename = None + self.params_filename = None + # Phase 1. run and save static model + self.train_and_save_model() + + # Phase 2. load model & train dygraph + dy_x_data, dy_out, dy_param_init_value, dy_param_value = \ + self.load_and_train_dygraph() + + static_x_data, static_out, static_param_init_value, static_param_value = \ + self.load_and_train_static() + + # Phase 3. compare + self.assertTrue(np.array_equal(static_x_data, dy_x_data)) + + for key, value in six.iteritems(static_param_init_value): + key += core.loaded_var_suffix() + self.assertTrue(np.array_equal(value, dy_param_init_value[key])) + + self.assertTrue(np.allclose(static_out, dy_out)) + + for key, value in six.iteritems(static_param_value): + key += core.loaded_var_suffix() + self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) + + def test_mnist_with_params_filename(self): + self.save_dirname = "mnist.inference.model" + self.model_filename = "mnist.model" + self.params_filename = "mnist.params" + # Phase 1. run and save static model + self.train_and_save_model() + + # Phase 2. load model & train dygraph + dy_x_data, dy_out, dy_param_init_value, dy_param_value = \ + self.load_and_train_dygraph() + + static_x_data, static_out, static_param_init_value, static_param_value = \ + self.load_and_train_static() + + # Phase 3. compare + self.assertTrue(np.array_equal(static_x_data, dy_x_data)) + + for key, value in six.iteritems(static_param_init_value): + key += core.loaded_var_suffix() + self.assertTrue(np.array_equal(value, dy_param_init_value[key])) + + self.assertTrue(np.allclose(static_out, dy_out)) + + for key, value in six.iteritems(static_param_value): + key += core.loaded_var_suffix() + self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py new file mode 100644 index 0000000000000000000000000000000000000000..f15fe74d1ab5ae18478ae0bf69088344cedea424 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py @@ -0,0 +1,235 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import contextlib +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from test_imperative_base import new_program_scope + +import paddle.fluid.transpiler.details.program_utils as pu + + +def while_softmax_regression(img): + def cond(i, times, pred): + return i < times + + def body(i, times, pred): + pred = fluid.layers.fc(input=pred, size=10, act='softmax') + i = i + 1 + return [i, times, pred] + + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + times = fluid.layers.fill_constant(shape=[1], dtype='int64', value=5) + pred = fluid.layers.fc(input=img, size=10, act='softmax') + i, times, pred = fluid.layers.while_loop( + cond=cond, body=body, loop_vars=[i, times, pred]) + return pred + + +class TestImperativeStaticModelRunnerWhile(unittest.TestCase): + def setUp(self): + self.seed = 90 + self.batch_size = 32 + self.batch_num = 50 + self.save_dirname = "while.inference.model" + self.model_filename = None + self.params_filename = None + + def _random_batch_reader(self): + def _get_random_images_and_labels(image_shape, label_shape): + image = np.random.random(size=image_shape).astype('float32') + label = np.random.random(size=label_shape).astype('int64') + return image, label + + def __reader__(): + for _ in range(self.batch_num): + batch_image, batch_label = _get_random_images_and_labels( + [self.batch_size, 784], [self.batch_size, 1]) + yield batch_image, batch_label + + return __reader__ + + def train_and_save_model(self): + startup_program = fluid.default_startup_program() + main_program = fluid.default_main_program() + + img = fluid.data(name='img', shape=[None, 784], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + pred = while_softmax_regression(img) + + loss = fluid.layers.cross_entropy(input=pred, label=label) + avg_loss = fluid.layers.mean(loss) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001) + optimizer.minimize(avg_loss) + + # pu.program_to_code(main_program, skip_op_callstack=True) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(startup_program) + + loader = fluid.io.DataLoader.from_generator( + feed_list=[img, label], capacity=5, iterable=True) + loader.set_batch_generator(self._random_batch_reader(), places=place) + + for data in loader(): + exe.run(main_program, feed=data, fetch_list=[avg_loss]) + + fluid.io.save_inference_model( + self.save_dirname, ["img"], [pred], + exe, + model_filename=self.model_filename, + params_filename=self.params_filename) + + def load_and_train_dygraph(self): + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = self.seed + fluid.default_main_program().random_seed = self.seed + np.random.seed(self.seed) + + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + + while_net = fluid.dygraph.static_runner.StaticModelRunner( + self.save_dirname) + + dy_param_init_value = {} + for param in while_net.parameters(): + dy_param_init_value[param.name] = param.numpy() + + sgd = fluid.optimizer.SGD(learning_rate=0.001, + parameter_list=while_net.parameters()) + + train_loader = fluid.io.DataLoader.from_generator(capacity=10) + train_loader.set_batch_generator( + self._random_batch_reader(), places=place) + + while_net.train() + + for data in train_loader(): + img = data[0] + label = data[1] + label.stop_gradient = True + + cost = while_net(inputs=img) + + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + + avg_loss.backward(backward_strategy) + sgd.minimize(avg_loss) + while_net.clear_gradients() + + dy_out = avg_loss.numpy() + dy_param_value = {} + for param in while_net.parameters(): + dy_param_value[param.name] = param.numpy() + + return dy_out, dy_param_init_value, dy_param_value + + def load_and_train_static(self): + with new_program_scope(): + fluid.default_startup_program().random_seed = self.seed + fluid.default_main_program().random_seed = self.seed + np.random.seed(self.seed) + + img = fluid.data(name='img', shape=[None, 784], dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + + pred = while_softmax_regression(img) + + loss = fluid.layers.cross_entropy(input=pred, label=label) + avg_loss = fluid.layers.mean(loss) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001) + optimizer.minimize(avg_loss) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + fluid.io.load_params( + exe, + self.save_dirname, + main_program=fluid.default_main_program(), + filename=self.params_filename) + + static_param_init_value = {} + static_param_name_list = [] + for param in fluid.default_main_program().all_parameters(): + static_param_name_list.append(param.name) + static_param_init_value[param.name] = fluid.executor._fetch_var( + param.name) + + loader = fluid.io.DataLoader.from_generator( + feed_list=[img, label], capacity=5, iterable=True) + loader.set_batch_generator( + self._random_batch_reader(), places=place) + + for data in loader(): + fetch_list = [avg_loss.name] + fetch_list.extend(static_param_name_list) + + out = exe.run(fluid.default_main_program(), + feed=data, + fetch_list=[avg_loss]) + + static_param_value = {} + static_out = out[0] + for i in range(1, len(out)): + static_param_value[static_param_name_list[i - 1]] = out[i] + + return static_out, static_param_init_value, static_param_value + + def test_while_no_params_filename(self): + # Phase 1. run and save static model + self.train_and_save_model() + + # # Phase 2. load model & train dygraph + dy_out, dy_param_init_value, dy_param_value = \ + self.load_and_train_dygraph() + + static_out, static_param_init_value, static_param_value = \ + self.load_and_train_static() + + # Phase 3. compare + for key, value in six.iteritems(static_param_init_value): + key += core.loaded_var_suffix() + self.assertTrue(np.array_equal(value, dy_param_init_value[key])) + + self.assertTrue(np.allclose(static_out, dy_out)) + + for key, value in six.iteritems(static_param_value): + key += core.loaded_var_suffix() + self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py new file mode 100644 index 0000000000000000000000000000000000000000..55810faff13e272a03df91269de28f9643068316 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -0,0 +1,341 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six + +import paddle.fluid as fluid +from paddle import compat as cpt +from paddle.fluid import core, framework, executor + + +@contextlib.contextmanager +def program_scope_guard(): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + with fluid.unique_name.guard(): + yield + + +# NOTE: Because RunProgramOp has a special output of type std::vector, +# the OpTest cannot be used in RunProgramOp. The variable type cannot be specified +# when creating output variables in OpTest, default type is LoDTensor +# NOTE: the gradient test method in OpTest also cannot be used for RunProgramOp, +# because it hold BlockDesc type attr, OperatorFactory can't parse this attr type +# when create Operator, so here compare gradients with static graph +# NOTE: Here rewrite a simple unittest framework for RunProgramOp +class RunProgramOpTest(unittest.TestCase): + def build_model(self): + raise NotImplementedError( + "RunProgramOp test should implement build_model") + + def check_output(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + # TODO: RunProgramOp is not recommended for use in static mode now + self.expect_outs = self.run_static_model(place, is_test=True) + self.check_output_with_place(place) + + def check_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + # TODO: RunProgramOp is not recommended for use in static mode now + self.expect_grads = self.run_static_model(place, is_test=False) + self.check_grad_with_place(place) + + def run_static_model(self, place, is_test=True): + with program_scope_guard(): + startup_program = fluid.default_startup_program() + main_program = fluid.default_main_program() + + self.build_model() + + exe = fluid.Executor(place) + exe.run(startup_program) + + if is_test: + fetch_list = self.output_names['Out'] + else: + fetch_list = self.get_param_grad_names() + + outs = exe.run(main_program, + feed=self.inputs['X'], + fetch_list=fetch_list) + return outs + + def get_program_desc(self): + with program_scope_guard(): + fwd_op_num = self.build_model() + return fluid.default_main_program().desc, fwd_op_num + + def prepare_attrs(self): + return { + 'global_block': self.program_desc.block(0), + 'start_op_index': 0, + 'end_op_index': self.fwd_op_num + } + + def get_param_grad_names(self): + grad_names = [] + for var_name in self.inputs['Params']: + grad_names.append(var_name + core.grad_var_suffix()) + return grad_names + + def check_output_with_place(self, place): + # Step 1. run op + actual_outs = self.calc_dygraph_output(place) + + # Step 2. compare output + for expect_v, actual_v in six.moves.zip(self.expect_outs, actual_outs): + self.assertTrue(np.allclose(expect_v, actual_v.numpy(), atol=1e-5)) + + def check_grad_with_place(self, place): + # Step 1. calc grads + actual_grads = self.calc_dygraph_grad(place) + + # Step 2. compare grads + for expect_v, actual_v in six.moves.zip(self.expect_grads, + actual_grads): + np.testing.assert_array_almost_equal(expect_v, actual_v) + self.assertTrue(np.allclose(expect_v, actual_v, atol=1e-5)) + + def prepare_dygraph_input(self, place, return_param_list=False): + def create_var_base(is_input, name, np_value, stop_gradient): + var = core.VarBase( + value=np_value, name=name, place=place, zero_copy=True) + var.stop_gradient = stop_gradient + return var + + # build inputs + inputs = {} + param_list = [] + inputs['X'] = [] + for name, np_value in self.inputs['X'].items(): + var = create_var_base(True, name, np_value, True) + inputs['X'].append(var) + inputs['Params'] = [] + for name, np_value in self.inputs['Params'].items(): + var = create_var_base(True, name, np_value, False) + inputs['Params'].append(var) + if return_param_list: + param_list.append(var) + + if return_param_list: + return inputs, param_list + return inputs + + def prepare_dygraph_output(self): + def create_var_base(is_input, name): + var = framework._varbase_creator(dtype=None, shape=None, name=name) + var.stop_gradient = False + return var + + # build outputs + outputs = {} + outputs['Out'] = [] + for name in self.output_names['Out']: + outputs['Out'].append(create_var_base(False, name)) + + outputs['OutScope'] = framework._varbase_creator( + type=core.VarDesc.VarType.STEP_SCOPES, + name="program_out_scope", + persistable=True) + inner_scope = core.Scope() + outputs['OutScope'].value().set_scope(inner_scope) + return outputs + + def calc_dygraph_output(self, place): + with fluid.dygraph.guard(place): + inputs = self.prepare_dygraph_input(place) + outputs = self.prepare_dygraph_output() + + framework._dygraph_tracer().trace_op( + type=self.op_type, + inputs=inputs, + outputs=outputs, + attrs=self.attrs) + return outputs['Out'] + + def calc_dygraph_grad(self, place): + with fluid.dygraph.guard(place): + # Step 1. run forward + inputs, input_param_list = self.prepare_dygraph_input(place, True) + outputs = self.prepare_dygraph_output() + + framework._dygraph_tracer().trace_op( + type=self.op_type, + inputs=inputs, + outputs=outputs, + attrs=self.attrs) + + for param in input_param_list: + var_type = self._get_grad_vartype(param.name) + if var_type is None: + continue + param._set_grad_type(var_type) + + # Step 2. run backward + # NOTE: in unittest, only support single output now + actual_outs = outputs['Out'] + assert len(actual_outs) == 1 + actual_outs[0].backward() + + # Step 3. prepare grads + grads = [] + for param in input_param_list: + grad = param.gradient() + grads.append(grad) + return grads + + def _get_grad_vartype(self, name): + assert self.program_desc is not None + grad_name = name + core.grad_var_suffix() + for i in six.moves.range(self.program_desc.num_blocks()): + block = self.program_desc.block(i) + var_desc = block.find_var_recursive(cpt.to_bytes(grad_name)) + return var_desc.type() if var_desc is not None else None + + +class TestRunProgramOpWithFC(RunProgramOpTest): + def setUp(self): + self.op_type = "run_program" + self.dtype = np.float32 + self.input_names = { + 'X': ['img'], + 'Params': ['weight_param', 'bias_param'] + } + self.output_names = {'Out': ['fc_0.tmp_2']} + + self.inputs = { + 'X': { + self.input_names['X'][0]: np.random.random((32, 1, 28, 28)) + .astype(self.dtype) + }, + 'Params': { + self.input_names['Params'][0]: np.random.random( + (784, 10)).astype(self.dtype), + self.input_names['Params'][1]: np.random.random( + (32, 10)).astype(self.dtype) + } + } + + self.program_desc, self.fwd_op_num = self.get_program_desc() + + self.attrs = self.prepare_attrs() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad() + + def build_model(self): + # 1. simple model + img = fluid.data( + name=self.input_names['X'][0], + shape=[None, 1, 28, 28], + dtype='float32') + weight_attr = fluid.ParamAttr( + name=self.input_names['Params'][0], + learning_rate=0.5, + initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[ + 'Params'][self.input_names['Params'][0]]), + trainable=True) + bias_attr = fluid.ParamAttr( + name=self.input_names['Params'][1], + learning_rate=0.5, + initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[ + 'Params'][self.input_names['Params'][1]]), + trainable=True) + pred = fluid.layers.fc(input=img, + size=10, + param_attr=weight_attr, + bias_attr=bias_attr, + act='relu') + # 2. get forward op num + fwd_op_num = fluid.default_main_program().global_block().desc.op_size() + # 3. append backward + grads = fluid.backward.gradients(targets=[pred], inputs=[img]) + + return fwd_op_num + + +class TestRunProgramOpWithEmbedding(RunProgramOpTest): + def setUp(self): + self.op_type = "run_program" + self.dtype = np.float32 + self.input_names = {'X': ['x'], 'Params': ['emb_weight']} + self.output_names = {'Out': ['reduce_sum_0.tmp_0']} + + self.inputs = { + 'X': { + 'x': np.array([[1, 3, 0, 4, 7]]).astype("int64") + }, + 'Params': { + 'emb_weight': np.random.random(size=(10, 16)).astype("float32") + } + } + + self.program_desc, self.fwd_op_num = self.get_program_desc() + + self.attrs = self.prepare_attrs() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + # NOTE: fecth not support SelectedRows, catnot compare + # sparse gradients with staic mode, only run dygraph + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + # TODO: RunProgramOp is not recommended for use in static mode now + self.calc_dygraph_grad(place) + + def build_model(self): + # 1. simple model + x = fluid.layers.data( + name=self.input_names['X'][0], shape=[5], dtype='int64') + emb = fluid.input.embedding( + input=x, + size=[10, 16], + param_attr=fluid.ParamAttr( + name="emb_weight", + learning_rate=10, + initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[ + 'Params'][self.input_names['Params'][0]])), + is_sparse=True) + y = fluid.layers.reduce_sum(emb, dim=-1) + # 2. get forward op num + fwd_op_num = fluid.default_main_program().global_block().desc.op_size() + # 3. append backward + grads = fluid.backward.gradients(targets=[y], inputs=[x]) + + return fwd_op_num + + +if __name__ == "__main__": + unittest.main()