未验证 提交 75bd3507 编写于 作者: C Chen Weihang 提交者: GitHub

Implement StaticModelRunner to support dygraph fine-tune static graph pre-training model (#23171)

* static model runner basic implement, test=develop

* add run program op to execute loaded program, test=develop

* refactor static model runner & run program op, test=develop

* reset engine.cc to resolve conflict

* adapt the change of dygraph double grad, test=develop

* refactor impl to solve control flow error, test=develop

* clear debug code, test=develop

* fix ci str compatible error & checkout dygraph grad maker & add example, test=develop

* hide api & add op test, test=develop

* fix run program op test places error, test=develop

* fix program by review comment, test=develop

* delete change var desc name, test=develop

* fix other program by review comment, test=develop

* remove _static_graph_guard, test=develop

* add selectedrows test, test=develop

* remove desc parser, test=develop

* fix detail program, test=develop

* change socpe create & add test, test=develop
上级 9297f49e
...@@ -70,11 +70,6 @@ void ExecutorPrepareContext::PrepareUnusedVars( ...@@ -70,11 +70,6 @@ void ExecutorPrepareContext::PrepareUnusedVars(
force_disable_gc = true; force_disable_gc = true;
} }
#endif #endif
force_disable_gc_ = force_disable_gc;
if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
return;
}
// If gc is enabled and block size > 1 // If gc is enabled and block size > 1
if (prog_.Size() > 1) { if (prog_.Size() > 1) {
operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
...@@ -84,6 +79,12 @@ void ExecutorPrepareContext::PrepareUnusedVars( ...@@ -84,6 +79,12 @@ void ExecutorPrepareContext::PrepareUnusedVars(
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
prog_, block_id_, ops_); prog_, block_id_, ops_);
} }
force_disable_gc_ = force_disable_gc;
if (GetEagerDeletionThreshold() < 0 || force_disable_gc_) {
return;
}
unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars); unused_vars_ = GetUnusedVars(prog_.Block(block_id_), ops_, keep_vars);
} }
...@@ -412,9 +413,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -412,9 +413,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
return result; return result;
} }
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
bool create_local_scope, bool create_vars, Scope* scope, int64_t start_op_index,
bool keep_kids) { int64_t end_op_index,
bool create_local_scope,
bool create_vars, bool keep_kids) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
PADDLE_ENFORCE_NOT_NULL(scope); PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope; Scope* local_scope = scope;
...@@ -446,7 +449,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -446,7 +449,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#endif #endif
} }
for (auto& op : ctx->ops_) { for (int64_t i = start_op_index; i < end_op_index; ++i) {
auto& op = ctx->ops_[i];
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get()); DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
...@@ -471,6 +475,15 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -471,6 +475,15 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars,
bool keep_kids) {
int64_t start_op_index = 0;
int64_t end_op_index = ctx->ops_.size();
RunPartialPreparedContext(ctx, scope, start_op_index, end_op_index,
create_local_scope, create_vars, keep_kids);
}
void Executor::RunPreparedContext( void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope, ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
......
...@@ -115,6 +115,12 @@ class Executor { ...@@ -115,6 +115,12 @@ class Executor {
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
void RunPartialPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t start_op_index, int64_t end_op_index,
bool create_local_scope = true,
bool create_vars = true,
bool keep_kids = false);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true, bool keep_kids = false); bool create_vars = true, bool keep_kids = false);
......
...@@ -64,6 +64,9 @@ constexpr char kZeroVarSuffix[] = "@ZERO"; ...@@ -64,6 +64,9 @@ constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient. /// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@"; constexpr char kNewGradSuffix[] = "@NEWGRAD@";
/// Variables with this suffix are the loaded from pre-train model.
constexpr char kLoadedVarSuffix[] = "@LOADED";
/// RuntimeContext is used to relate input/output names of Operator with /// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope. /// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same /// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
......
...@@ -200,11 +200,12 @@ void BasicEngine::Execute() { ...@@ -200,11 +200,12 @@ void BasicEngine::Execute() {
iter != accumulators_.end(), true, iter != accumulators_.end(), true,
platform::errors::NotFound("Cannot find gradient of variable %s", platform::errors::NotFound("Cannot find gradient of variable %s",
var->Name())); var->Name()));
if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) { if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) {
continue; continue;
} }
var = std::make_shared<VariableWrapper>("Gtmp@"); var = std::make_shared<VariableWrapper>(var->Name());
need_accu_var_list_.emplace_back(iter->second.get(), var); need_accu_var_list_.emplace_back(iter->second.get(), var);
} }
} }
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h"
#include <string>
namespace paddle {
namespace operators {
class RunProgramOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::NotFound(
"Input(X) of RunProgramOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInputs("Params"), true,
platform::errors::NotFound(
"Input(Params) of RunProgramOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Out"), true,
platform::errors::NotFound(
"Output(Out) of RunProgramOp should not be null."));
}
protected:
/* [Why use single type kernel]:
*
* This op is similar to a control flow op, it doses not need
* a op kernel, but in order to make it execute under dynamic
* graph mode, implement it with op kernel.
*
* So whether the kernel data type is int, float or other type,
* which has no effect on its execution logic, so directly
* specified a data type here.
*
* Of course, the data type here is also not important.
*/
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
};
class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(vector<LoDTensor>)"
"The input tensors of RunProgram operator, also the feed targets "
"of loaded program.")
.AsDuplicable();
AddInput("Params",
"(vector<LoDTensor or SelecetedRows>)"
"The input parameter of RunProgram operator, also the parameters "
"of the loaded program.")
.AsDuplicable();
AddOutput("Out",
"(vector<LoDTensor>)"
"The output tensors of RunProgram operator, also the fetch "
"targets of the loaded program.")
.AsDuplicable();
AddOutput("OutScope",
"(StepScopeVar)"
"A vector of execution scope in RunProgram operator, which "
"contains at most one scope."
"NOTE: Do not use Scope directly because Scope output is not "
"currently supported.");
AddAttr<BlockDesc*>("global_block",
"(BlockDesc *)"
"The global block of executed program desc.");
AddAttr<int64_t>("start_op_index",
"(int64_t)"
"The index of the op to start execution");
AddAttr<int64_t>("end_op_index",
"(int64_t)"
"The index of the op to stop execution");
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddComment(R"DOC(
RunProgram operator.
The RunProgram operator receives a program's feed targets, fetch targets,
and parameters, and receives the forward and backward program desc
as attributes, and then executes the program by executor.
NOTE: This operator is added so that the inference model stored by
`fluid.io.save_inference_model` under the static graph mode can be loaded
under the dynamic graph mode for fine-tuning or inferencing.
)DOC");
}
};
class RunProgramGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
platform::errors::NotFound(
"Input(X) of RunProgramGradOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInputs("Params"), true,
platform::errors::NotFound(
"Input(Params) of RunProgramGradOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInputs(framework::GradVarName("Out")), true,
platform::errors::NotFound(
"Input(Out@GRAD) of RunProgramGradOp should not be null."));
// NOTE: The X@GRAD and Params@GRAD may not exist,
// because they can be set stop_gradient = True
}
protected:
/* see [Why use single type kernel] */
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
};
template <typename T>
class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("run_program_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Params", this->Input("Params"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"),
this->InputGrad("Params"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(run_program, ops::RunProgramOp, ops::RunProgramOpMaker,
ops::RunProgramGradOpMaker<paddle::framework::OpDesc>,
ops::RunProgramGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(run_program_grad, ops::RunProgramGradOp);
/* see [Why use single type kernel] */
REGISTER_OP_CPU_KERNEL(
run_program,
ops::RunProgramOpKernel<paddle::platform::CPUDeviceContext, float>)
REGISTER_OP_CPU_KERNEL(
run_program_grad,
ops::RunProgramGradOpKernel<paddle::platform::CPUDeviceContext, float>)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(
run_program,
ops::RunProgramOpKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
run_program_grad,
ops::RunProgramGradOpKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
using StepScopeVar = std::vector<framework::Scope *>;
using BlockDesc = framework::BlockDesc;
using Variable = framework::Variable;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
namespace details {
// all input vars should be LoDTensor & is initialized
static void CheckInputVarStatus(const Variable &var,
const std::string &var_name) {
PADDLE_ENFORCE_EQ(
var.IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The input variable %s of "
"RunProgram(Grad)Op(StaticModelRunner) holds "
"wrong type. Expect type is LoDTensor, but receive type is %s.",
var_name, platform::demangle(framework::ToTypeName(var.Type()))));
PADDLE_ENFORCE_EQ(
var.Get<LoDTensor>().IsInitialized(), true,
platform::errors::InvalidArgument("The tensor in input variable %s of "
"RunProgram(Grad)Op(StaticModelRunner) "
"is not initialized.",
var_name));
}
static void CheckOutputVarStatus(const Variable &src_var,
const Variable &dst_var,
const std::string &var_name) {
if (dst_var.IsType<LoDTensor>()) {
PADDLE_ENFORCE_EQ(
src_var.IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds "
"wrong type. Expect type is LoDTensor, but receive type is %s.",
var_name,
platform::demangle(framework::ToTypeName(src_var.Type()))));
PADDLE_ENFORCE_EQ(src_var.Get<LoDTensor>().IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor in output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal "
"scope is not initialized.",
var_name));
} else if (dst_var.IsType<SelectedRows>()) {
PADDLE_ENFORCE_EQ(
src_var.IsType<SelectedRows>(), true,
platform::errors::InvalidArgument(
"The output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s internal scope holds "
"wrong type. Expect type is SelectedRows, but receive type is %s.",
var_name,
platform::demangle(framework::ToTypeName(src_var.Type()))));
PADDLE_ENFORCE_EQ(src_var.Get<SelectedRows>().value().IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor in output variable %s get from "
"RunProgram(Grad)Op(StaticModelRunner)'s "
"internal scope is not initialized.",
var_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The RunProgram(Grad)Op(StaticModelRunner) only support output "
"variable of type LoDTensor or SelectedRows, "
"but received variable %s's type is %s",
var_name, platform::demangle(framework::ToTypeName(dst_var.Type()))));
}
}
static void VariableShare(const Variable &src_var, Variable *dst_var) {
// The previous check ensures that the variable type can only be LoDTensor
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
lod_tensor->ShareDataWith(src_var.Get<LoDTensor>());
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
}
static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
auto *var = scope->Var(var_names[i]);
CheckInputVarStatus(*vars[i], var_names[i]);
VariableShare(*vars[i], var);
}
}
static void VariableCopy(const Variable &src_var,
const platform::Place &dst_place, Variable *dst_var) {
// The previous check ensures that the variable type can only be LoDTensor or
// SelectedRows
if (src_var.IsType<LoDTensor>()) {
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
TensorCopySync(src_var.Get<LoDTensor>(), dst_place, lod_tensor);
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
} else if (src_var.IsType<SelectedRows>()) {
auto *selected_rows = dst_var->GetMutable<SelectedRows>();
TensorCopySync(src_var.Get<SelectedRows>().value(), dst_place,
selected_rows->mutable_value());
selected_rows->set_rows(src_var.Get<SelectedRows>().rows());
selected_rows->set_height(src_var.Get<SelectedRows>().height());
}
}
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
auto *var = scope->FindVar(var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("The output variable %s is not in "
"RunProgram(Grad)Op(StaticModelRunner)'"
"s internal scope.",
var_names[i]));
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
VariableShare(*var, vars[i]);
}
}
static void CopyVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
const platform::Place &dst_place,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == framework::kEmptyVarName) {
VLOG(2) << "find variable name is " << framework::kEmptyVarName
<< ", skip it!";
continue;
}
auto *var = scope->FindVar(var_names[i]);
// NOTE: Here skip not found var is dangerous, if a bug is caused here,
// the result is grad calculation error, which will be very hidden!
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("The output variable %s is not in "
"RunProgram(Grad)Op(StaticModelRunner)'"
"s internal scope.",
var_names[i]));
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
VariableCopy(*var, dst_place, vars[i]);
}
}
static void AppendSkipDeletionVars(
std::vector<std::string> *all_vars,
const std::vector<std::string> &append_vars) {
for (auto &var : append_vars) {
all_vars->emplace_back(var);
}
}
} // namespace details
template <typename DeviceContext, typename T>
class RunProgramOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
VLOG(2) << "RunProgramOpKernel Compute";
// Step 1. prepare inputs, outputs, attrs
auto &input_vars = ctx.MultiInputVar("X");
auto &param_vars = ctx.MultiInputVar("Params");
auto output_vars = ctx.MultiOutputVar("Out");
auto input_var_names = ctx.InputNames("X");
auto param_names = ctx.InputNames("Params");
auto output_var_names = ctx.OutputNames("Out");
auto *block = ctx.Attr<BlockDesc *>("global_block");
auto *program = block->Program();
auto start_op_index = ctx.Attr<int64_t>("start_op_index");
auto end_op_index = ctx.Attr<int64_t>("end_op_index");
auto is_test = ctx.Attr<bool>("is_test");
// NOTE(chenweihang): In order not to add new variable type, use vector
// here. Originally, here can use scope directly.
auto *out_scope_vec = ctx.Output<StepScopeVar>("OutScope");
PADDLE_ENFORCE_EQ(
out_scope_vec->size(), 1,
platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should only hold one scope."));
// Step 2. prepare executor and init persistable variables
framework::Executor exe(ctx.GetPlace());
// skip delete vars
std::vector<std::string> skip_vars;
details::AppendSkipDeletionVars(&skip_vars, output_var_names);
VLOG(2) << "Prepare to skip " << skip_vars.size()
<< " var(s): " << string::join_strings(skip_vars, ' ');
auto exe_ctx = exe.Prepare(*program, 0, skip_vars);
framework::Scope &scope = *(out_scope_vec->front());
// share input_vars & parameters into scope
details::ShareVarsIntoScope(input_vars, input_var_names, &scope);
details::ShareVarsIntoScope(param_vars, param_names, &scope);
// Step 3. run ops
exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index,
end_op_index, /*create_local_scope=*/false,
/*create_vars=*/true, /*keep_kids=*/!is_test);
// Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope);
// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
}
};
template <typename DeviceContext, typename T>
class RunProgramGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
VLOG(2) << "RunProgramGradOpKernel Compute";
// Step 1. prepare inputs and outputs
auto &output_grad_vars = ctx.MultiInputVar(framework::GradVarName("Out"));
auto input_grad_vars = ctx.MultiOutputVar(framework::GradVarName("X"));
auto param_grad_vars = ctx.MultiOutputVar(framework::GradVarName("Params"));
// if all output vars are set to stop_gradient, grad op no need to executed
if (input_grad_vars.empty() && param_grad_vars.empty()) return;
auto output_grad_var_names = ctx.InputNames(framework::GradVarName("Out"));
// NOTE: after PR22939 [Add double grad] merged, the grad op maker's
// SetOutput will set to None if the input var stop_gradient=True,
// it will cause an NotFound error when ctx.OutputNames() is called
std::vector<std::string> input_grad_var_names;
std::vector<std::string> param_grad_names;
if (!input_grad_vars.empty()) {
input_grad_var_names = ctx.OutputNames(framework::GradVarName("X"));
}
if (!param_grad_vars.empty()) {
param_grad_names = ctx.OutputNames(framework::GradVarName("Params"));
}
auto *block = ctx.Attr<BlockDesc *>("global_block");
auto *program = block->Program();
auto orig_end_op_index = ctx.Attr<int64_t>("end_op_index");
// NOTE: skip `shape` and `fill_constant` op created by
// fluid.backward.gradients,
// one forward output will generate one `shape` and `fill_constant`
int64_t start_op_index = orig_end_op_index + (output_grad_vars.size() * 2);
int64_t end_op_index = block->OpSize();
auto *out_scope_vec = ctx.Input<StepScopeVar>("OutScope");
PADDLE_ENFORCE_EQ(
out_scope_vec->size(), 1,
platform::errors::InvalidArgument(
"The OutScope of RunProgramGradOp should only hold one scope."));
// Step 2. prepare executor and scope
framework::Executor exe(ctx.GetPlace());
// skip delete vars
std::vector<std::string> skip_vars;
details::AppendSkipDeletionVars(&skip_vars, input_grad_var_names);
details::AppendSkipDeletionVars(&skip_vars, param_grad_names);
VLOG(2) << "Prepare to skip " << skip_vars.size()
<< " var(s): " << string::join_strings(skip_vars, ' ');
auto exe_ctx = exe.Prepare(*program, 0, skip_vars);
auto &scope = *(out_scope_vec->front());
details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names,
&scope);
// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
// Step 3. run ops
exe.RunPartialPreparedContext(exe_ctx.get(), &scope, start_op_index,
end_op_index, /*create_local_scope=*/false,
/*create_vars=*/true, /*keep_kids=*/false);
// Step 4. copy outputs
details::CopyVarsFromScope(input_grad_vars, input_grad_var_names,
ctx.GetPlace(), &scope);
details::CopyVarsFromScope(param_grad_vars, param_grad_names,
ctx.GetPlace(), &scope);
}
};
} // namespace operators
} // namespace paddle
...@@ -621,6 +621,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -621,6 +621,10 @@ void BindImperative(py::module *m_ptr) {
return self.MutableGradVar()->Get<framework::LoDTensor>(); return self.MutableGradVar()->Get<framework::LoDTensor>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("_set_grad_type",
[](imperative::VarBase &self, framework::proto::VarType::Type type) {
self.MutableGradVarBase()->SetType(type);
})
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { [](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase(); auto &grad_var = self.GradVarBase();
......
...@@ -989,7 +989,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -989,7 +989,11 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE_EQ(self.IsType<framework::ReaderHolder>(), true); PADDLE_ENFORCE_EQ(self.IsType<framework::ReaderHolder>(), true);
return self.GetMutable<framework::ReaderHolder>(); return self.GetMutable<framework::ReaderHolder>();
}, },
py::return_value_policy::reference); py::return_value_policy::reference)
.def("set_scope", [](Variable &self, Scope &scope) {
auto scope_vec = self.GetMutable<std::vector<framework::Scope *>>();
scope_vec->emplace_back(&scope);
});
BindReader(&m); BindReader(&m);
...@@ -1180,6 +1184,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1180,6 +1184,8 @@ All parameter, weight, gradient are variables in Paddle.
[]() { return std::string(framework::kEmptyVarName); }); []() { return std::string(framework::kEmptyVarName); });
m.def("grad_var_suffix", m.def("grad_var_suffix",
[]() { return std::string(framework::kGradVarSuffix); }); []() { return std::string(framework::kGradVarSuffix); });
m.def("loaded_var_suffix",
[]() { return std::string(framework::kLoadedVarSuffix); });
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
......
...@@ -44,6 +44,9 @@ from .backward_strategy import * ...@@ -44,6 +44,9 @@ from .backward_strategy import *
from . import jit from . import jit
from .jit import * from .jit import *
from . import static_runner
from .static_runner import StaticModelRunner
__all__ = [] __all__ = []
__all__ += layers.__all__ __all__ += layers.__all__
__all__ += base.__all__ __all__ += base.__all__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
import numpy as np
import os
import six
from . import layers
from .. import core
from .. import framework
from .. import backward
from .base import switch_to_static_graph
from ... import compat as cpt
# Set Log level
logging.getLogger().setLevel(logging.WARNING)
# DESIGN IDEA: Add an special operator, execute static program inside operator.
#
# Op's Inputs:
# - the input variable of the user feed
# - the necessary parameters of the network
# Op's Outputs:
# - the output variable of fetch
#
# This op receives a complete program desc, internally creates scope
# and executor, executes this program. Key points:
#
# 1. Data Sharing:
# The varBase of the dynamic graph is not in the scope, so before the op
# executes the program internally, create persistent variables with the
# same name as feed, parameters, and fetch in the scope, and share the
# LoDTensor of the op input.
#
# 2. Forward and Backward Separation:
# Because the dynamic graph op performs the forward and backward separately,
# the forward program is used as the execution object of the forward op,
# and the reverse program is used as the execution object of the grad op.
class StaticModelRunner(layers.Layer):
"""
A Dynamic graph Layer for loading inference program and related parameters,
and then performing fine-tune training or inference.
The loaded program and parameters are saved by `fluid.io.save_inference_model`.
.. note::
**1. Dynamic graph mode do not support LoDTensor.
All original static graph model's feed targets or parametars
that depend on LoD are temporarily unavailable.**
**2. All saved inference model's feed targets need be given.**
**3. The ``stop_gradient`` information is lost and can not be recovered.**
**4. The parameter's ``trainable`` information is lost and can not be recovered.**
**5. Double gradient model is not supported now.**
**6. Now only supports loading models saved by `fluid.io.save_inference_model`.**
Args:
model_dir(str): The directory path where the model is saved.
model_filename(str, optional): The file name of saved inference program.
If set to None, a default filename is
:code:`__model__`.
The default value is None.
params_filename(str, optional): The file name of saved all related parameters.
If set to None, parameters are saved
in separate files.
The default value is None.
Returns:
Layer: A Layer object can run loaded program.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
BATCH_SIZE = 32
BATCH_NUM = 20
SAVE_DIRNAME = "fc.inference.model"
def random_batch_reader():
def _get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def __reader__():
for _ in range(BATCH_NUM):
batch_image, batch_label = _get_random_images_and_labels(
[BATCH_SIZE, 784], [BATCH_SIZE, 1])
yield batch_image, batch_label
return __reader__
def train_and_save_static_model(place):
img = fluid.data(name='img', shape=[None, 784], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
pred = fluid.layers.fc(input=img, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
loader = fluid.io.DataLoader.from_generator(
feed_list=[img, label], capacity=5, iterable=True)
loader.set_batch_generator(random_batch_reader(), places=place)
for data in loader():
exe.run(
fluid.default_main_program(),
feed=data,
fetch_list=[avg_loss])
# save model by fluid.io.save_inference_model
fluid.io.save_inference_model(
SAVE_DIRNAME, ["img"], [pred], exe)
# Step 1. train and save inference model in static graph mode
place = fluid.CPUPlace()
train_and_save_static_model(place)
# Step 2. load inference model in dygraph and fine-tune
with fluid.dygraph.guard(place):
fc = fluid.dygraph.static_runner.StaticModelRunner(SAVE_DIRNAME)
sgd = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=fc.parameters())
train_loader = fluid.io.DataLoader.from_generator(capacity=5)
train_loader.set_batch_generator(
random_batch_reader(), places=place)
for data in train_loader():
img = data[0]
label = data[1]
label.stop_gradient = True
cost = fc(inputs=img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
sgd.minimize(avg_loss)
"""
def __init__(self, model_dir, model_filename=None, params_filename=None):
super(StaticModelRunner, self).__init__()
# Step 0. key variable definitions
self._load_program_desc = None
self._program_desc = None
self._inner_scope = core.Scope()
# the layer outputs var desc
self._output_descs = []
# input, output, params name list
self._input_names = []
self._output_names = []
self._param_names = []
# train or eval flag
self._is_test = False
# Step 1. load program desc from disk
# the saved model hold feed, fetch & scale op, no need, can be remove
self._load_program_desc = self._load_static_model(model_dir,
model_filename)
# Step 2. set all `is_test` attributes to False
self._change_is_test_status(False)
# Step 3. load all parameters
self._load_persisitable_dict(model_dir, params_filename)
# Step 4. generate backwar program desc
self._program_desc = self._append_backward_desc()
# Step 5. recheck parameters stop gradients
self._recheck_stop_gradients()
def train(self):
# TODO: remove global train_mode setting
framework._dygraph_tracer().train_mode()
self._is_test = False
self._change_is_test_status(False)
def eval(self):
# TODO: remove global train_mode setting
framework._dygraph_tracer().eval_mode()
self._is_test = True
self._change_is_test_status(True)
def forward(self, inputs):
"""
Executed forward part of StaticModelRunner Layer.
Generally execute directly using the Layer object.
Args:
inputs(np.ndarray|Variable|list[np.ndarray|Variable]): the inputs of StaticModelRunner
Returns:
Variable|list[Variable]: The forward outputs of StaticModelRunner Layer.
"""
# Step 1. prepare inputs, outputs, attrs
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
input_vars = []
for i, value in enumerate(inputs):
if not isinstance(value, (np.ndarray, core.VarBase)):
raise TypeError(
"The type of inputs.value in StaticModelRunner.forward must be numpy array or Variable(VarBase), but received %s."
% type(value))
# NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=self._input_names[i],
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
else:
var = value
# TODO: here may have important name set by user
var.name = self._input_names[i]
input_vars.append(var)
params = []
for param in self._parameters.values():
params.append(param)
output_vars = []
for var_desc in self._output_descs:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
output_vars.append(var)
# hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(self._inner_scope)
# Step 2. run prorgam by op
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={'X': input_vars,
'Params': params},
outputs={'Out': output_vars,
'OutScope': tmp_scope_vec},
attrs={
'global_block': self._program_desc.block(0),
'start_op_index': 0,
'end_op_index': self._load_program_desc.block(0).op_size(),
'is_test': self._is_test
})
# NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
# If we don't change grad_var type here, RunProgramOp need
# transform SelectedRows to LoDTensor forcely, it may not
# be user wanted result.
for param in params:
grad_name = param.name + core.grad_var_suffix()
grad_var = self._program_desc.block(0).find_var(
cpt.to_bytes(grad_name))
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None:
continue
param._set_grad_type(grad_var.type())
# Step 3. prepare output, keep same form with inputs
outs = output_vars
if len(output_vars) == 1:
outs = output_vars[0]
return outs
def _load_static_model(self, model_dir, model_filename=None):
# Step 1. dir and filename check
load_dirname = os.path.normpath(model_dir)
if not os.path.isdir(load_dirname):
raise ValueError("There is no directory named '%s'" % load_dirname)
if model_filename is not None:
model_filename = os.path.basename(model_filename)
else:
model_filename = "__model__"
model_filename = os.path.join(load_dirname, model_filename)
# Step 2. parse program desc
with open(model_filename, "rb") as f:
program_desc_str = f.read()
program_desc = core.ProgramDesc(program_desc_str)
if not core._is_program_version_supported(program_desc._version()):
raise ValueError("Unsupported program version: %d\n" %
program_desc._version())
# Step 3.
# - remove feed, fetch and useless scale-1 op
# - remove op_callstack attr
ops_to_remove = []
root_block = program_desc.block(0)
for i in six.moves.range(root_block.op_size()):
op = root_block.op(i)
if op.type() == 'feed':
ops_to_remove.append(i)
feed_var_name = cpt.to_bytes(op.input('X')[0])
root_block._remove_var(feed_var_name)
self._input_names.append(cpt.to_bytes(op.output('Out')[0]))
elif op.type() == 'scale' and op.output('Out')[0].startswith(
'save_infer_model/scale_'):
ops_to_remove.append(i)
out_var_name = cpt.to_bytes(op.output('Out')[0])
root_block._remove_var(out_var_name)
self._output_names.append(cpt.to_bytes(op.input('X')[0]))
self._output_descs.append(
root_block.find_var(cpt.to_bytes(op.input('X')[0])))
elif op.type() == 'fetch' and op.input('X')[0].startswith(
'save_infer_model/scale_'):
ops_to_remove.append(i)
fetch_var_name = cpt.to_bytes(op.output('Out')[0])
root_block._remove_var(fetch_var_name)
else:
if op.has_attr("op_callstack"):
op.remove_attr("op_callstack")
for op_idx in reversed(ops_to_remove):
root_block._remove_op(op_idx, op_idx + 1)
return program_desc
@switch_to_static_graph
def _append_backward_desc(self):
assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly."
program_desc_copy = core.ProgramDesc(self._load_program_desc)
# Step 1. prepare program and related var
# NOTE: To reuse backward interfaces, build Program firstly.
# Originally, there is no need to build a program, but need to almost
# rewrite a series of methods for append_backward for program_desc.
# Therefore, in order to reuse the method of backward.py, build the program here.
fwd_op_num = program_desc_copy.block(0).op_size()
program = self._build_program_by_desc(program_desc_copy)
# TODO: could the targets be in sub block?
targets = []
for out in self._output_descs:
targets.append(program.global_block().var(out.name()))
# Step 2. append backward
backward.gradients(targets=targets, inputs=[])
return program.desc
def _load_persisitable_dict(self, model_dir, params_filename=None):
load_dirname = os.path.normpath(model_dir)
assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly."
persis_vars = self._get_persis_vars(self._load_program_desc)
load_var_map = {}
for each_var in persis_vars:
orig_each_name = each_var.name()
# append suffix
self._append_loaded_suffix_to_param(each_var)
# create output varbase
new_var = framework.ParamBase(
shape=each_var.shape(),
dtype=each_var.dtype(),
name=each_var.name(),
type=each_var.type(),
persistable=True)
if params_filename is None:
if not self._is_parameter(each_var):
continue
# logging.info("persis var name %s" % each_var.name())
framework._dygraph_tracer().trace_op(
type='load',
inputs={},
outputs={'Out': new_var},
attrs={
'file_path': os.path.join(load_dirname, orig_each_name)
})
new_var.stop_gradient = False
self.add_parameter(name=new_var.name, parameter=new_var)
self._param_names.append(new_var.name)
else:
load_var_map[each_var.name()] = new_var
if params_filename is not None:
load_var_list = []
for name in sorted(load_var_map.keys()):
load_var_list.append(load_var_map[name])
framework._dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={
'file_path': os.path.join(load_dirname, params_filename)
})
for each_var in persis_vars:
if not self._is_parameter(each_var):
continue
param = load_var_map[each_var.name()]
param.stop_gradient = False
self.add_parameter(name=param.name, parameter=param)
self._param_names.append(param.name)
def _recheck_stop_gradients(self):
assert self._program_desc is not None, "The StaticModelRunner not initialized properly."
# NOTE: After loading the model, the stop_gradient information
# of the original variable is lost, but if a parameter does not
# have a corresponding @GRAD variable in the backward program,
# it can be said that it is also stop_gradient
all_var_names = self._get_all_var_names(self._program_desc)
for param_name in self._parameters:
param_grad_name = param_name + core.grad_var_suffix()
if param_grad_name not in all_var_names:
logging.info("set %s stop gradient = True" % param_grad_name)
self._parameters[param_name].stop_gradient = True
def _get_all_var_names(self, program_desc):
all_var_names = set()
for i in six.moves.range(program_desc.num_blocks()):
block = program_desc.block(i)
for var in block.all_vars():
logging.info(var.name())
all_var_names.add(var.name())
return all_var_names
def _get_persis_vars(self, program_desc):
persis_vars = []
for i in six.moves.range(program_desc.num_blocks()):
block = program_desc.block(i)
persis_vars.extend(
list(filter(self._is_persistable, block.all_vars())))
return persis_vars
@switch_to_static_graph
def _build_program_by_desc(self, program_desc):
prog = framework.Program()
prog.desc = program_desc
prog.blocks = [
framework.Block(prog, i)
for i in six.moves.range(prog.desc.num_blocks())
]
prog._sync_with_cpp()
return prog
def _is_persistable(self, var_desc):
if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var_desc.type() == core.VarDesc.VarType.READER or \
var_desc.type() == core.VarDesc.VarType.RAW:
return False
return var_desc.persistable()
def _is_parameter(self, persis_var_desc):
assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly."
# 1. firstly, param should be input of op
input_ops = [] # op can be repeated
for block_idx in six.moves.range(self._load_program_desc.num_blocks()):
block = self._load_program_desc.block(block_idx)
for op_idx in six.moves.range(block.op_size()):
op = block.op(op_idx)
# NOTE: parameter is the input of a certain op
if persis_var_desc.name() in op.input_arg_names():
input_ops.append(op)
# 2. secondly, param should not be output of op or be same op's output
for block_idx in six.moves.range(self._load_program_desc.num_blocks()):
block = self._load_program_desc.block(block_idx)
for op_idx in six.moves.range(block.op_size()):
op = block.op(op_idx)
if persis_var_desc.name() in op.output_arg_names():
# such as batch_norm_op
if op in input_ops:
continue
else:
return False
return True
def _change_is_test_status(self, is_test):
# change all `is_test` attributes
assert self._load_program_desc is not None, "The StaticModelRunner not initialized properly."
for i in six.moves.range(self._load_program_desc.num_blocks()):
block = self._load_program_desc.block(i)
for j in six.moves.range(block.op_size()):
op = block.op(j)
if op.has_attr('is_test'):
op._set_attr('is_test', is_test)
def _append_loaded_suffix(self, name):
"""
Append grad suffix to the given variable name
e.g. x ==> x@LOADED
"""
suffix = core.loaded_var_suffix()
name = cpt.to_text(name)
if suffix not in name:
name = name + suffix
return name
def _append_loaded_suffix_to_param(self, param_desc):
old_name = param_desc.name()
new_name = self._append_loaded_suffix(param_desc.name())
param_desc.set_name(new_name)
for block_idx in six.moves.range(self._load_program_desc.num_blocks()):
block = self._load_program_desc.block(block_idx)
for op_idx in six.moves.range(block.op_size()):
op = block.op(op_idx)
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)
...@@ -329,12 +329,12 @@ def _fetch_var(name, scope=None, return_numpy=True): ...@@ -329,12 +329,12 @@ def _fetch_var(name, scope=None, return_numpy=True):
Returns: Returns:
LodTensor|numpy.ndarray LodTensor|numpy.ndarray
""" """
assert isinstance(name, str) assert isinstance(name, six.string_types)
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
assert isinstance(scope, core._Scope) assert isinstance(scope, core._Scope)
var = scope.find_var(name) var = scope.find_var(_to_name_str(name))
assert var is not None, ( assert var is not None, (
"Cannot find " + name + " in scope. Perhaps you need to make the" "Cannot find " + name + " in scope. Perhaps you need to make the"
" variable persistable by using var.persistable = True in your" " variable persistable by using var.persistable = True in your"
......
...@@ -124,11 +124,6 @@ class OpDescCreationMethod(object): ...@@ -124,11 +124,6 @@ class OpDescCreationMethod(object):
new_attr.bools.extend(user_defined_attr) new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.LONGS: elif attr.type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr) new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr:
pair = new_attr.int_pairs.add()
pair.first = p[0]
pair.second = p[1]
else: else:
raise NotImplementedError( raise NotImplementedError(
"A not supported attribute type: %s." % ( "A not supported attribute type: %s." % (
......
...@@ -193,6 +193,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) ...@@ -193,6 +193,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_api)
list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
list(REMOVE_ITEM TEST_OPS test_imperative_debug_string) list(REMOVE_ITEM TEST_OPS test_imperative_debug_string)
list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist)
list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while)
if (APPLE OR WIN32) if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset) list(REMOVE_ITEM TEST_OPS test_dataset)
...@@ -269,6 +271,10 @@ py_test_modules(test_install_check MODULES test_install_check ENVS ...@@ -269,6 +271,10 @@ py_test_modules(test_install_check MODULES test_install_check ENVS
FLAGS_cudnn_deterministic=1 SERIAL) FLAGS_cudnn_deterministic=1 SERIAL)
set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST") set_tests_properties(test_install_check PROPERTIES LABELS "RUN_TYPE=DIST")
py_test_modules(test_imperative_debug_string MODULES test_imperative_debug_string ENVS FLAGS_dygraph_debug=1) py_test_modules(test_imperative_debug_string MODULES test_imperative_debug_string ENVS FLAGS_dygraph_debug=1)
py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_static_runner_mnist ENVS
FLAGS_cudnn_deterministic=1)
py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS
FLAGS_cudnn_deterministic=1)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add these tests back # FIXME(typhoonzero): add these tests back
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import contextlib
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from test_imperative_base import new_program_scope
def convolutional_neural_network(img):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return prediction
def static_train_net(img, label):
prediction = convolutional_neural_network(img)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
return prediction, avg_loss
class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
def setUp(self):
self.seed = 90
self.epoch_num = 1
self.batch_size = 128
self.batch_num = 50
def reader_decorator(self, reader):
def _reader_impl():
for item in reader():
image = np.array(item[0]).reshape(1, 28, 28)
label = np.array(item[1]).astype('int64').reshape(1)
yield image, label
return _reader_impl
def train_and_save_model(self):
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
img = fluid.data(name='img', shape=[None, 1, 28, 28], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
prediction, avg_loss = static_train_net(img, label)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(startup_program)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=100),
batch_size=self.batch_size)
for _ in range(0, self.epoch_num):
for batch_id, data in enumerate(train_reader()):
exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss])
if batch_id > self.batch_num:
break
fluid.io.save_inference_model(
self.save_dirname, ["img"], [prediction],
exe,
model_filename=self.model_filename,
params_filename=self.params_filename)
def load_and_train_dygraph(self):
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
mnist = fluid.dygraph.static_runner.StaticModelRunner(
model_dir=self.save_dirname,
model_filename=self.model_filename,
params_filename=self.params_filename)
dy_param_init_value = {}
for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()
sgd = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=mnist.parameters())
train_reader = paddle.batch(
self.reader_decorator(paddle.dataset.mnist.train()),
batch_size=self.batch_size,
drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_sample_list_generator(train_reader, places=place)
mnist.train()
for epoch in range(self.epoch_num):
for batch_id, data in enumerate(train_loader()):
img = data[0]
label = data[1]
label.stop_gradient = True
cost = mnist(inputs=img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward(backward_strategy)
sgd.minimize(avg_loss)
mnist.clear_gradients()
if batch_id >= self.batch_num:
break
dy_x_data = img.numpy()
dy_out = avg_loss.numpy()
dy_param_value = {}
for param in mnist.parameters():
dy_param_value[param.name] = param.numpy()
return dy_x_data, dy_out, dy_param_init_value, dy_param_value
def load_and_train_static(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
img = fluid.data(
name='img', shape=[None, 1, 28, 28], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
prediction, avg_loss = static_train_net(img, label)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_params(
exe,
self.save_dirname,
main_program=fluid.default_main_program(),
filename=self.params_filename)
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_main_program().all_parameters():
static_param_name_list.append(param.name)
static_param_init_value[param.name] = fluid.executor._fetch_var(
param.name)
train_reader = paddle.batch(
self.reader_decorator(paddle.dataset.mnist.train()),
batch_size=self.batch_size,
drop_last=True)
for epoch in range(self.epoch_num):
for batch_id, data in enumerate(train_reader()):
static_x_data = np.array([x[0] for x in data])
y_data = np.array([x[1] for x in data]).reshape(
[self.batch_size, 1])
fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed={"img": static_x_data,
"label": y_data},
fetch_list=fetch_list)
if batch_id >= self.batch_num:
break
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i]
return static_x_data, static_out, static_param_init_value, static_param_value
def test_mnist_no_params_filename(self):
self.save_dirname = "mnist.inference.model.noname"
self.model_filename = None
self.params_filename = None
# Phase 1. run and save static model
self.train_and_save_model()
# Phase 2. load model & train dygraph
dy_x_data, dy_out, dy_param_init_value, dy_param_value = \
self.load_and_train_dygraph()
static_x_data, static_out, static_param_init_value, static_param_value = \
self.load_and_train_static()
# Phase 3. compare
self.assertTrue(np.array_equal(static_x_data, dy_x_data))
for key, value in six.iteritems(static_param_init_value):
key += core.loaded_var_suffix()
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
key += core.loaded_var_suffix()
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
def test_mnist_with_params_filename(self):
self.save_dirname = "mnist.inference.model"
self.model_filename = "mnist.model"
self.params_filename = "mnist.params"
# Phase 1. run and save static model
self.train_and_save_model()
# Phase 2. load model & train dygraph
dy_x_data, dy_out, dy_param_init_value, dy_param_value = \
self.load_and_train_dygraph()
static_x_data, static_out, static_param_init_value, static_param_value = \
self.load_and_train_static()
# Phase 3. compare
self.assertTrue(np.array_equal(static_x_data, dy_x_data))
for key, value in six.iteritems(static_param_init_value):
key += core.loaded_var_suffix()
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
key += core.loaded_var_suffix()
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import contextlib
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from test_imperative_base import new_program_scope
import paddle.fluid.transpiler.details.program_utils as pu
def while_softmax_regression(img):
def cond(i, times, pred):
return i < times
def body(i, times, pred):
pred = fluid.layers.fc(input=pred, size=10, act='softmax')
i = i + 1
return [i, times, pred]
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
times = fluid.layers.fill_constant(shape=[1], dtype='int64', value=5)
pred = fluid.layers.fc(input=img, size=10, act='softmax')
i, times, pred = fluid.layers.while_loop(
cond=cond, body=body, loop_vars=[i, times, pred])
return pred
class TestImperativeStaticModelRunnerWhile(unittest.TestCase):
def setUp(self):
self.seed = 90
self.batch_size = 32
self.batch_num = 50
self.save_dirname = "while.inference.model"
self.model_filename = None
self.params_filename = None
def _random_batch_reader(self):
def _get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def __reader__():
for _ in range(self.batch_num):
batch_image, batch_label = _get_random_images_and_labels(
[self.batch_size, 784], [self.batch_size, 1])
yield batch_image, batch_label
return __reader__
def train_and_save_model(self):
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
img = fluid.data(name='img', shape=[None, 784], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
pred = while_softmax_regression(img)
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
# pu.program_to_code(main_program, skip_op_callstack=True)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
loader = fluid.io.DataLoader.from_generator(
feed_list=[img, label], capacity=5, iterable=True)
loader.set_batch_generator(self._random_batch_reader(), places=place)
for data in loader():
exe.run(main_program, feed=data, fetch_list=[avg_loss])
fluid.io.save_inference_model(
self.save_dirname, ["img"], [pred],
exe,
model_filename=self.model_filename,
params_filename=self.params_filename)
def load_and_train_dygraph(self):
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.dygraph.guard(place):
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
np.random.seed(self.seed)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
while_net = fluid.dygraph.static_runner.StaticModelRunner(
self.save_dirname)
dy_param_init_value = {}
for param in while_net.parameters():
dy_param_init_value[param.name] = param.numpy()
sgd = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=while_net.parameters())
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(
self._random_batch_reader(), places=place)
while_net.train()
for data in train_loader():
img = data[0]
label = data[1]
label.stop_gradient = True
cost = while_net(inputs=img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward(backward_strategy)
sgd.minimize(avg_loss)
while_net.clear_gradients()
dy_out = avg_loss.numpy()
dy_param_value = {}
for param in while_net.parameters():
dy_param_value[param.name] = param.numpy()
return dy_out, dy_param_init_value, dy_param_value
def load_and_train_static(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
np.random.seed(self.seed)
img = fluid.data(name='img', shape=[None, 784], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
pred = while_softmax_regression(img)
loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_params(
exe,
self.save_dirname,
main_program=fluid.default_main_program(),
filename=self.params_filename)
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_main_program().all_parameters():
static_param_name_list.append(param.name)
static_param_init_value[param.name] = fluid.executor._fetch_var(
param.name)
loader = fluid.io.DataLoader.from_generator(
feed_list=[img, label], capacity=5, iterable=True)
loader.set_batch_generator(
self._random_batch_reader(), places=place)
for data in loader():
fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed=data,
fetch_list=[avg_loss])
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i]
return static_out, static_param_init_value, static_param_value
def test_while_no_params_filename(self):
# Phase 1. run and save static model
self.train_and_save_model()
# # Phase 2. load model & train dygraph
dy_out, dy_param_init_value, dy_param_value = \
self.load_and_train_dygraph()
static_out, static_param_init_value, static_param_value = \
self.load_and_train_static()
# Phase 3. compare
for key, value in six.iteritems(static_param_init_value):
key += core.loaded_var_suffix()
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
key += core.loaded_var_suffix()
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle.fluid as fluid
from paddle import compat as cpt
from paddle.fluid import core, framework, executor
@contextlib.contextmanager
def program_scope_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
with fluid.unique_name.guard():
yield
# NOTE: Because RunProgramOp has a special output of type std::vector<Scope *>,
# the OpTest cannot be used in RunProgramOp. The variable type cannot be specified
# when creating output variables in OpTest, default type is LoDTensor
# NOTE: the gradient test method in OpTest also cannot be used for RunProgramOp,
# because it hold BlockDesc type attr, OperatorFactory can't parse this attr type
# when create Operator, so here compare gradients with static graph
# NOTE: Here rewrite a simple unittest framework for RunProgramOp
class RunProgramOpTest(unittest.TestCase):
def build_model(self):
raise NotImplementedError(
"RunProgramOp test should implement build_model")
def check_output(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
# TODO: RunProgramOp is not recommended for use in static mode now
self.expect_outs = self.run_static_model(place, is_test=True)
self.check_output_with_place(place)
def check_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
# TODO: RunProgramOp is not recommended for use in static mode now
self.expect_grads = self.run_static_model(place, is_test=False)
self.check_grad_with_place(place)
def run_static_model(self, place, is_test=True):
with program_scope_guard():
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
self.build_model()
exe = fluid.Executor(place)
exe.run(startup_program)
if is_test:
fetch_list = self.output_names['Out']
else:
fetch_list = self.get_param_grad_names()
outs = exe.run(main_program,
feed=self.inputs['X'],
fetch_list=fetch_list)
return outs
def get_program_desc(self):
with program_scope_guard():
fwd_op_num = self.build_model()
return fluid.default_main_program().desc, fwd_op_num
def prepare_attrs(self):
return {
'global_block': self.program_desc.block(0),
'start_op_index': 0,
'end_op_index': self.fwd_op_num
}
def get_param_grad_names(self):
grad_names = []
for var_name in self.inputs['Params']:
grad_names.append(var_name + core.grad_var_suffix())
return grad_names
def check_output_with_place(self, place):
# Step 1. run op
actual_outs = self.calc_dygraph_output(place)
# Step 2. compare output
for expect_v, actual_v in six.moves.zip(self.expect_outs, actual_outs):
self.assertTrue(np.allclose(expect_v, actual_v.numpy(), atol=1e-5))
def check_grad_with_place(self, place):
# Step 1. calc grads
actual_grads = self.calc_dygraph_grad(place)
# Step 2. compare grads
for expect_v, actual_v in six.moves.zip(self.expect_grads,
actual_grads):
np.testing.assert_array_almost_equal(expect_v, actual_v)
self.assertTrue(np.allclose(expect_v, actual_v, atol=1e-5))
def prepare_dygraph_input(self, place, return_param_list=False):
def create_var_base(is_input, name, np_value, stop_gradient):
var = core.VarBase(
value=np_value, name=name, place=place, zero_copy=True)
var.stop_gradient = stop_gradient
return var
# build inputs
inputs = {}
param_list = []
inputs['X'] = []
for name, np_value in self.inputs['X'].items():
var = create_var_base(True, name, np_value, True)
inputs['X'].append(var)
inputs['Params'] = []
for name, np_value in self.inputs['Params'].items():
var = create_var_base(True, name, np_value, False)
inputs['Params'].append(var)
if return_param_list:
param_list.append(var)
if return_param_list:
return inputs, param_list
return inputs
def prepare_dygraph_output(self):
def create_var_base(is_input, name):
var = framework._varbase_creator(dtype=None, shape=None, name=name)
var.stop_gradient = False
return var
# build outputs
outputs = {}
outputs['Out'] = []
for name in self.output_names['Out']:
outputs['Out'].append(create_var_base(False, name))
outputs['OutScope'] = framework._varbase_creator(
type=core.VarDesc.VarType.STEP_SCOPES,
name="program_out_scope",
persistable=True)
inner_scope = core.Scope()
outputs['OutScope'].value().set_scope(inner_scope)
return outputs
def calc_dygraph_output(self, place):
with fluid.dygraph.guard(place):
inputs = self.prepare_dygraph_input(place)
outputs = self.prepare_dygraph_output()
framework._dygraph_tracer().trace_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=self.attrs)
return outputs['Out']
def calc_dygraph_grad(self, place):
with fluid.dygraph.guard(place):
# Step 1. run forward
inputs, input_param_list = self.prepare_dygraph_input(place, True)
outputs = self.prepare_dygraph_output()
framework._dygraph_tracer().trace_op(
type=self.op_type,
inputs=inputs,
outputs=outputs,
attrs=self.attrs)
for param in input_param_list:
var_type = self._get_grad_vartype(param.name)
if var_type is None:
continue
param._set_grad_type(var_type)
# Step 2. run backward
# NOTE: in unittest, only support single output now
actual_outs = outputs['Out']
assert len(actual_outs) == 1
actual_outs[0].backward()
# Step 3. prepare grads
grads = []
for param in input_param_list:
grad = param.gradient()
grads.append(grad)
return grads
def _get_grad_vartype(self, name):
assert self.program_desc is not None
grad_name = name + core.grad_var_suffix()
for i in six.moves.range(self.program_desc.num_blocks()):
block = self.program_desc.block(i)
var_desc = block.find_var_recursive(cpt.to_bytes(grad_name))
return var_desc.type() if var_desc is not None else None
class TestRunProgramOpWithFC(RunProgramOpTest):
def setUp(self):
self.op_type = "run_program"
self.dtype = np.float32
self.input_names = {
'X': ['img'],
'Params': ['weight_param', 'bias_param']
}
self.output_names = {'Out': ['fc_0.tmp_2']}
self.inputs = {
'X': {
self.input_names['X'][0]: np.random.random((32, 1, 28, 28))
.astype(self.dtype)
},
'Params': {
self.input_names['Params'][0]: np.random.random(
(784, 10)).astype(self.dtype),
self.input_names['Params'][1]: np.random.random(
(32, 10)).astype(self.dtype)
}
}
self.program_desc, self.fwd_op_num = self.get_program_desc()
self.attrs = self.prepare_attrs()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad()
def build_model(self):
# 1. simple model
img = fluid.data(
name=self.input_names['X'][0],
shape=[None, 1, 28, 28],
dtype='float32')
weight_attr = fluid.ParamAttr(
name=self.input_names['Params'][0],
learning_rate=0.5,
initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[
'Params'][self.input_names['Params'][0]]),
trainable=True)
bias_attr = fluid.ParamAttr(
name=self.input_names['Params'][1],
learning_rate=0.5,
initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[
'Params'][self.input_names['Params'][1]]),
trainable=True)
pred = fluid.layers.fc(input=img,
size=10,
param_attr=weight_attr,
bias_attr=bias_attr,
act='relu')
# 2. get forward op num
fwd_op_num = fluid.default_main_program().global_block().desc.op_size()
# 3. append backward
grads = fluid.backward.gradients(targets=[pred], inputs=[img])
return fwd_op_num
class TestRunProgramOpWithEmbedding(RunProgramOpTest):
def setUp(self):
self.op_type = "run_program"
self.dtype = np.float32
self.input_names = {'X': ['x'], 'Params': ['emb_weight']}
self.output_names = {'Out': ['reduce_sum_0.tmp_0']}
self.inputs = {
'X': {
'x': np.array([[1, 3, 0, 4, 7]]).astype("int64")
},
'Params': {
'emb_weight': np.random.random(size=(10, 16)).astype("float32")
}
}
self.program_desc, self.fwd_op_num = self.get_program_desc()
self.attrs = self.prepare_attrs()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
# NOTE: fecth not support SelectedRows, catnot compare
# sparse gradients with staic mode, only run dygraph
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
# TODO: RunProgramOp is not recommended for use in static mode now
self.calc_dygraph_grad(place)
def build_model(self):
# 1. simple model
x = fluid.layers.data(
name=self.input_names['X'][0], shape=[5], dtype='int64')
emb = fluid.input.embedding(
input=x,
size=[10, 16],
param_attr=fluid.ParamAttr(
name="emb_weight",
learning_rate=10,
initializer=fluid.initializer.NumpyArrayInitializer(self.inputs[
'Params'][self.input_names['Params'][0]])),
is_sparse=True)
y = fluid.layers.reduce_sum(emb, dim=-1)
# 2. get forward op num
fwd_op_num = fluid.default_main_program().global_block().desc.op_size()
# 3. append backward
grads = fluid.backward.gradients(targets=[y], inputs=[x])
return fwd_op_num
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册