提交 18f0c40a 编写于 作者: Y Yang Yang(Tony) 提交者: Yu Yang

feature/while_grad_op (#5554)

* first commit

* Python API for while op

* Python Unittest for simple while_op forward

* fix out to be list

* Fix UT

* VarType

* Fix several bugs

* Fix bug

* Fix bug

* Fix Bug

* Fix bug

* Fix unittest

* Remove debug log

* Add comments

* add PADDLE_ENFORCE

* while_grad_op first commit

* Add `BlockDescBind::FindRecursiveOrCreateVar()` and fix bugs

* not sure how to setdim of while outputs

* push for test

* add executor vlog

* fix bug of while_op cond

* Several enhancement for code

1. Backward always infer shape & infer var type. Since there are RENAME
variables will be created when creating backward operator, but their
shape & var types are not inferenced.
2. Never use SomePtr-> directly, since every pointer could be nullptr if
it is a function return value. Add `detail::Ref` to cast pointer to
reference safely.
3. Enhance error message for backward.
4. Infer data type of variable in `sum` and `tensor_write`

* Fix bugs of while_op gradient

* Fix several bugs of while_op grad

* fix fill zeros like

* fix 3 >= 3

* fix place holder shouldn't be null

* fail on sum op

* Fix SumOp of TensorList

* clean up

* pass while test

* fix test_array_write_read

* pass sum op

* Support int/int64 for fill_constant_batch_size_like

* Fix compile
上级 08bc08d6
...@@ -270,6 +270,19 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -270,6 +270,19 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return false; return false;
} }
} }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "All input {";
for (auto& name : names) {
sout << name << ",";
}
sout << "} is in {";
for (auto& name : set) {
sout << name << ",";
}
sout << "}";
VLOG(10) << sout.str();
}
return true; return true;
} }
...@@ -290,14 +303,12 @@ static void CreateGradVarInBlock( ...@@ -290,14 +303,12 @@ static void CreateGradVarInBlock(
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
bool need_infer_shape = false;
std::unordered_set<std::string> new_vars; std::unordered_set<std::string> new_vars;
ForEachVarName(ops[op_index]->Outputs(), ForEachVarName(ops[op_index]->Outputs(),
[&](const std::string& grad_var_name) { [&](const std::string& grad_var_name) {
if (block_desc->HasVar(grad_var_name)) { if (block_desc->HasVar(grad_var_name)) {
return false; return false;
} }
need_infer_shape = true;
auto var = block_desc->Var(grad_var_name); auto var = block_desc->Var(grad_var_name);
new_vars.insert(var->Name()); new_vars.insert(var->Name());
auto it = param_name_map.find(grad_var_name); auto it = param_name_map.find(grad_var_name);
...@@ -311,23 +322,21 @@ static void CreateGradVarInBlock( ...@@ -311,23 +322,21 @@ static void CreateGradVarInBlock(
grad_record.op_idx_ = static_cast<int>(op_index); grad_record.op_idx_ = static_cast<int>(op_index);
return false; /* not break */ return false; /* not break */
}); });
if (need_infer_shape) { ops[op_index]->InferVarType(block_desc);
ops[op_index]->InferVarType(block_desc); for (auto& arg : ops[op_index]->OutputArgumentNames()) {
for (auto& arg : ops[op_index]->OutputArgumentNames()) { if (new_vars.find(arg) == new_vars.end()) {
if (new_vars.find(arg) == new_vars.end()) { continue;
continue; }
} auto pname = FwdName(arg);
auto pname = FwdName(arg); auto* param = block_desc->FindVarRecursive(pname);
auto* param = block_desc->FindVarRecursive(pname); auto* grad = block_desc->FindVar(arg);
auto* grad = block_desc->FindVar(arg); if (param == nullptr) {
if (param == nullptr) { grad->SetDataType(DataType::FP32);
grad->SetDataType(DataType::FP32); } else {
} else { grad->SetDataType(param->GetDataType());
grad->SetDataType(param->GetDataType());
}
} }
ops[op_index]->InferShape(*block_desc);
} }
ops[op_index]->InferShape(*block_desc);
} }
} }
...@@ -387,6 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -387,6 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
VLOG(5) << "MakeBlockBackward";
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx); BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps(); std::vector<OpDescBind*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
...@@ -394,9 +404,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -394,9 +404,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
VLOG(5) << "Making backward " << (*it)->Type() << " op";
std::vector<std::unique_ptr<OpDescBind>> op_grads; std::vector<std::unique_ptr<OpDescBind>> op_grads;
if ((*it)->Type() == "recurrent") { if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
int step_block_idx = (*it)->GetBlockAttr("step_block"); int step_block_idx = (*it)->GetBlockAttr("step_block");
BlockDescBind* backward_block = CreateStepBlock( BlockDescBind* backward_block = CreateStepBlock(
program_desc, no_grad_vars, grad_to_var, step_block_idx); program_desc, no_grad_vars, grad_to_var, step_block_idx);
...@@ -410,6 +421,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -410,6 +421,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
} }
if (VLOG_IS_ON(10)) {
std::ostringstream sout;
sout << "Made ";
for (auto& op_grad : op_grads) {
sout << op_grad->Type() << " ";
}
VLOG(10) << sout.str();
}
for (const auto& desc : op_grads) { for (const auto& desc : op_grads) {
for (const std::string& out_name : desc->OutputArgumentNames()) { for (const std::string& out_name : desc->OutputArgumentNames()) {
if (out_name.find("@GRAD") == std::string::npos) { if (out_name.find("@GRAD") == std::string::npos) {
...@@ -425,6 +445,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -425,6 +445,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs), op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); }); [](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
} }
VLOG(5) << "Appending Sums";
// Check whether some variables are written more than once // Check whether some variables are written more than once
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops; std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
for (const auto& dup : dup_out_ops) { for (const auto& dup : dup_out_ops) {
...@@ -432,16 +454,22 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -432,16 +454,22 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
const std::vector<size_t> dup_op = dup.second; const std::vector<size_t> dup_op = dup.second;
if (out_name != kEmptyVarName && dup_op.size() > 1) { if (out_name != kEmptyVarName && dup_op.size() > 1) {
std::vector<std::string> sum_op_inputs; std::vector<std::string> sum_op_inputs;
std::string next_g_name = out_name;
for (size_t i = 0; i < dup_op.size(); ++i) { for (size_t i = 0; i < dup_op.size(); ++i) {
VLOG(10) << backward_descs[dup_op[i]]->Type() << " has " << out_name
<< " duplicated";
std::string new_name = out_name + "@RENAME@" + std::to_string(i); std::string new_name = out_name + "@RENAME@" + std::to_string(i);
backward_descs[dup_op[i]]->Rename(out_name, new_name); backward_descs[dup_op[i]]->RenameOutput(out_name, new_name);
backward_descs[dup_op[i]]->RenameInput(out_name, next_g_name);
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
next_g_name = sum_op_inputs.back();
} }
std::unique_ptr<OpDescBind> sum_op(new OpDescBind( std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {})); "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
} }
} }
pending_sum_ops.sort( pending_sum_ops.sort(
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a, [](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) { const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
...@@ -452,6 +480,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -452,6 +480,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::move(p.second)); std::move(p.second));
} }
VLOG(5) << "MakeBlockBackward Finished";
return backward_descs; return backward_descs;
} }
......
...@@ -29,6 +29,8 @@ inline DataType ToDataType(std::type_index type) { ...@@ -29,6 +29,8 @@ inline DataType ToDataType(std::type_index type) {
return DataType::INT32; return DataType::INT32;
} else if (typeid(int64_t).hash_code() == type.hash_code()) { } else if (typeid(int64_t).hash_code() == type.hash_code()) {
return DataType::INT64; return DataType::INT64;
} else if (typeid(bool).hash_code() == type.hash_code()) {
return DataType::BOOL;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
} }
......
...@@ -60,8 +60,7 @@ void make_ddim(DDim& ddim, const int64_t* dims, int n) { ...@@ -60,8 +60,7 @@ void make_ddim(DDim& ddim, const int64_t* dims, int n) {
ddim = make_dim<9>(dims); ddim = make_dim<9>(dims);
break; break;
default: default:
throw std::invalid_argument( PADDLE_THROW("Dynamic dimensions must have between [1, 9] dimensions.");
"Dynamic dimensions must have between [1, 9] dimensions.");
} }
} }
......
...@@ -120,6 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, ...@@ -120,6 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(10) << op->DebugString();
op->Run(*local_scope, *device); op->Run(*local_scope, *device);
} }
if (create_local_scope) { if (create_local_scope) {
......
...@@ -235,6 +235,23 @@ void OpDescBind::Rename(const std::string &old_name, ...@@ -235,6 +235,23 @@ void OpDescBind::Rename(const std::string &old_name,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::RenameOutput(const std::string &old_name,
const std::string &new_name) {
for (auto &output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
need_update_ = true;
}
void OpDescBind::RenameInput(const std::string &old_name,
const std::string &new_name) {
for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
need_update_ = true;
}
struct SetAttrDescVisitor : public boost::static_visitor<void> { struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_; mutable OpDesc::Attr *attr_;
...@@ -448,7 +465,12 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( ...@@ -448,7 +465,12 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
return framework::make_ddim(var->Shape()); try {
return framework::make_ddim(var->Shape());
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
} }
void CompileTimeInferShapeContext::SetDim(const std::string &name, void CompileTimeInferShapeContext::SetDim(const std::string &name,
......
...@@ -73,6 +73,10 @@ class OpDescBind { ...@@ -73,6 +73,10 @@ class OpDescBind {
void Rename(const std::string &old_name, const std::string &new_name); void Rename(const std::string &old_name, const std::string &new_name);
void RenameOutput(const std::string &old_name, const std::string &new_name);
void RenameInput(const std::string &old_name, const std::string &new_name);
// Only be used in C++ // Only be used in C++
const AttributeMap &GetAttrMap() const; const AttributeMap &GetAttrMap() const;
......
...@@ -403,19 +403,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -403,19 +403,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
if (VLOG_IS_ON(1)) {
auto inputs = this->InputVars();
auto outputs = this->OutputVars(true);
std::ostringstream sout;
sout << "Run operator " << this->Type() << " From [";
std::ostream_iterator<std::string> out_it(sout, ",");
std::copy(inputs.begin(), inputs.end(), out_it);
sout << "] to [";
std::copy(outputs.begin(), outputs.end(), out_it);
sout << "]";
VLOG(1) << sout.str();
}
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
......
...@@ -38,11 +38,12 @@ Scope& Scope::NewScope() const { ...@@ -38,11 +38,12 @@ Scope& Scope::NewScope() const {
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
auto iter = vars_.find(name); auto iter = vars_.find(name);
if (iter != vars_.end()) { if (iter != vars_.end()) {
VLOG(3) << "Get existing variable " << name;
return iter->second; return iter->second;
} }
Variable* v = new Variable(); Variable* v = new Variable();
vars_[name] = v; vars_[name] = v;
VLOG(3) << "Create variable " << name << " on scope"; VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
......
...@@ -53,6 +53,10 @@ class InferShapeContext { ...@@ -53,6 +53,10 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims);
protected: protected:
virtual framework::DDim GetDim(const std::string &name) const = 0; virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
...@@ -60,9 +64,6 @@ class InferShapeContext { ...@@ -60,9 +64,6 @@ class InferShapeContext {
std::vector<framework::DDim> GetDims( std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims);
std::vector<VarDesc::VarType> GetVarTypes( std::vector<VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
......
...@@ -42,6 +42,7 @@ class ArrayOp : public framework::OperatorBase { ...@@ -42,6 +42,7 @@ class ArrayOp : public framework::OperatorBase {
} else { } else {
offset = static_cast<size_t>(*i_tensor.data<int64_t>()); offset = static_cast<size_t>(*i_tensor.data<int64_t>());
} }
VLOG(10) << " Offset = " << offset;
return offset; return offset;
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace operators {
namespace detail {
/**
* Get Reference From Pointer with check. The error message is printf format,
* and passed by `args`
*/
template <typename T, typename... ARGS>
inline T &Ref(T *ptr, ARGS &&... args) {
PADDLE_ENFORCE(ptr != nullptr, args...);
return *ptr;
}
} // namespace detail
} // namespace operators
} // namespace paddle
...@@ -101,4 +101,7 @@ REGISTER_OPERATOR(fill_constant_batch_size_like, ...@@ -101,4 +101,7 @@ REGISTER_OPERATOR(fill_constant_batch_size_like,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>); ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace,
int64_t>);
...@@ -19,4 +19,7 @@ namespace ops = paddle::operators; ...@@ -19,4 +19,7 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>); ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace,
int64_t>);
...@@ -54,5 +54,8 @@ namespace ops = paddle::operators; ...@@ -54,5 +54,8 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker); ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, bool>);
...@@ -17,5 +17,8 @@ ...@@ -17,5 +17,8 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
fill_zeros_like, fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, double>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, bool>);
...@@ -250,6 +250,8 @@ void axpy<platform::CPUPlace, double>(const platform::DeviceContext& context, ...@@ -250,6 +250,8 @@ void axpy<platform::CPUPlace, double>(const platform::DeviceContext& context,
template struct SetConstant<platform::CPUPlace, float>; template struct SetConstant<platform::CPUPlace, float>;
template struct SetConstant<platform::CPUPlace, double>; template struct SetConstant<platform::CPUPlace, double>;
template struct SetConstant<platform::CPUPlace, int>; template struct SetConstant<platform::CPUPlace, int>;
template struct SetConstant<platform::CPUPlace, int64_t>;
template struct SetConstant<platform::CPUPlace, bool>;
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUPlace, float, RANK>; \ template struct Transpose<platform::CPUPlace, float, RANK>; \
......
...@@ -256,6 +256,8 @@ void axpy<platform::GPUPlace, double>(const platform::DeviceContext& context, ...@@ -256,6 +256,8 @@ void axpy<platform::GPUPlace, double>(const platform::DeviceContext& context,
template struct SetConstant<platform::GPUPlace, float>; template struct SetConstant<platform::GPUPlace, float>;
template struct SetConstant<platform::GPUPlace, double>; template struct SetConstant<platform::GPUPlace, double>;
template struct SetConstant<platform::GPUPlace, int>; template struct SetConstant<platform::GPUPlace, int>;
template struct SetConstant<platform::GPUPlace, int64_t>;
template struct SetConstant<platform::GPUPlace, bool>;
#define DEFINE_GPU_TRANS(RANK) \ #define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::GPUPlace, float, RANK>; \ template struct Transpose<platform::GPUPlace, float, RANK>; \
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/operators/sum_op.h" #include "paddle/operators/sum_op.h"
#include <vector> #include <vector>
#include "paddle/framework/var_type_inference.h" #include "paddle/framework/var_type_inference.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -59,13 +60,16 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -59,13 +60,16 @@ class SumOp : public framework::OperatorWithKernel {
x_vars[0]->Get<framework::SelectedRows>().value().type()), x_vars[0]->Get<framework::SelectedRows>().value().type()),
ctx.device_context()); ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
auto& array = x_vars[0]->Get<framework::LoDTensorArray>(); for (auto& x_var : x_vars) {
for (auto& each : array) { auto& array = x_var->Get<framework::LoDTensorArray>();
if (each.numel() != 0) { for (auto& each : array) {
return framework::OpKernelType(framework::ToDataType(each.type()), if (each.numel() != 0) {
ctx.device_context()); return framework::OpKernelType(framework::ToDataType(each.type()),
ctx.device_context());
}
} }
} }
PADDLE_THROW("Cannot find the input data type by all input data");
} }
PADDLE_THROW("Unexpected branch. Input type is %s", PADDLE_THROW("Unexpected branch. Input type is %s",
x_vars[0]->Type().name()); x_vars[0]->Type().name());
...@@ -96,6 +100,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -96,6 +100,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::VarDesc::SELECTED_ROWS; auto var_type = framework::VarDesc::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name)->GetType();
}
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() == return block->FindRecursiveOrCreateVar(name)->GetType() ==
...@@ -103,7 +112,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -103,7 +112,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() == return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() ==
framework::VarDesc::LOD_TENSOR_ARRAY; framework::VarDesc::LOD_TENSOR_ARRAY;
}; };
...@@ -113,14 +122,26 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -113,14 +122,26 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
std::all_of(inputs.begin(), inputs.end(), is_tensor_array); std::all_of(inputs.begin(), inputs.end(), is_tensor_array);
if (any_input_is_tensor_array) { if (any_input_is_tensor_array) {
PADDLE_ENFORCE(all_inputs_are_tensor_array); if (!all_inputs_are_tensor_array) {
std::ostringstream os;
for (auto& each : inputs) {
os << " " << each << " type is "
<< detail::Ref(block->FindRecursiveOrCreateVar(each)).GetType()
<< "\n";
}
PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str());
}
var_type = framework::VarDesc::LOD_TENSOR_ARRAY; var_type = framework::VarDesc::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) { } else if (any_input_is_lod_tensor) {
var_type = framework::VarDesc::LOD_TENSOR; var_type = framework::VarDesc::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
block->FindRecursiveOrCreateVar(out_var_name)->SetType(var_type); auto& out_var = detail::Ref(block->FindRecursiveOrCreateVar(out_var_name));
out_var.SetType(var_type);
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType());
} }
}; };
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/array_operator.h" #include "paddle/operators/array_operator.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,6 +33,8 @@ class WriteToArrayOp : public ArrayOp { ...@@ -33,6 +33,8 @@ class WriteToArrayOp : public ArrayOp {
auto *out = auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>(); scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensorArray>();
if (offset >= out->size()) { if (offset >= out->size()) {
VLOG(10) << "Resize " << Output("Out") << " from " << out->size()
<< " to " << offset + 1;
out->resize(offset + 1); out->resize(offset + 1);
} }
auto *out_tensor = &out->at(offset); auto *out_tensor = &out->at(offset);
...@@ -85,11 +87,15 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -85,11 +87,15 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDescBind *block) const override {
for (auto &out_var : op_desc.OutputArgumentNames()) { auto x_name = op_desc.Input("X")[0];
VLOG(10) << "Set Variable " << out_var << " as LOD_TENSOR_ARRAY"; auto out_name = op_desc.Output("Out")[0];
block->FindRecursiveOrCreateVar(out_var)->SetType( VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
framework::VarDesc::LOD_TENSOR_ARRAY); auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name),
} "Cannot found %s", out_name);
out.SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
auto &x =
detail::Ref(block->FindVarRecursive(x_name), "Cannot found %s", x_name);
out.SetDataType(x.GetDataType());
} }
}; };
...@@ -107,11 +113,11 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -107,11 +113,11 @@ class ReadFromArrayOp : public ArrayOp {
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out")); auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set"); PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tesnor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, dev_ctx); size_t offset = GetOffset(scope, dev_ctx);
PADDLE_ENFORCE_LT(offset, x_array.size()); PADDLE_ENFORCE_LT(offset, x_array.size());
out_tesnor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx); out_tensor->CopyFrom(x_array[offset], dev_ctx.GetPlace(), dev_ctx);
out_tesnor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} }
}; };
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include <vector> #include <vector>
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,8 +28,9 @@ using LoDTensor = framework::LoDTensor; ...@@ -26,8 +28,9 @@ using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "step_block"; constexpr char kStepBlock[] = "step_block";
constexpr char kCondition[] = "Condition"; constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes"; constexpr char kStepScopes[] = "StepScopes";
constexpr char kParamGrads[] = "X@Grad";
constexpr char kParameters[] = "X"; constexpr char kParameters[] = "X";
constexpr char kParamGrads[] = "X@GRAD";
constexpr char kOutputs[] = "Out";
class WhileOp : public framework::OperatorBase { class WhileOp : public framework::OperatorBase {
public: public:
...@@ -71,9 +74,9 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,9 +74,9 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
kCondition, kCondition,
"(Bool) An scalar. When it's False, the While Op will be terminated.") "(Bool) An scalar. When it's False, the While Op will be terminated.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", AddOutput(kOutputs,
"A set of variables, which will be assigned with values " "A set of variables, which will be assigned with values "
"generated by perators inside the block of While Op.") "generated by the operators inside the block of While Op.")
.AsDuplicable(); .AsDuplicable();
AddOutput(kStepScopes, AddOutput(kStepScopes,
"(StepScopeVar) A vector of local scope, which size equals the " "(StepScopeVar) A vector of local scope, which size equals the "
...@@ -104,17 +107,64 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -104,17 +107,64 @@ class WhileGradOp : public framework::OperatorBase {
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
auto inside_og_names =
Attr<std::vector<std::string>>("original_output_grad");
PADDLE_ENFORCE_EQ(outside_og_names.size(), inside_og_names.size());
for (auto cur_scope_iter = step_scopes->rbegin(); for (auto cur_scope_iter = step_scopes->rbegin();
cur_scope_iter != step_scopes->rend(); ++cur_scope_iter) { cur_scope_iter != step_scopes->rend(); ++cur_scope_iter) {
VLOG(3) << "Start backward at time_step "
<< cur_scope_iter - step_scopes->rbegin();
framework::Scope &cur_scope = **cur_scope_iter;
// Link OG from outside to inside
for (size_t i = 0; i < outside_og_names.size(); ++i) {
auto outside_og_name = outside_og_names[i];
auto inside_og_name = inside_og_names[i];
VLOG(10) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
auto &og_outside = detail::Ref(scope.FindVar(outside_og_name));
auto &og_inside = detail::Ref(cur_scope.Var(inside_og_name));
if (og_outside.Type().hash_code() ==
typeid(framework::LoDTensor).hash_code()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.Type().hash_code() ==
typeid(framework::LoDTensorArray).hash_code()) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
VLOG(10) << outside_og_name << " size = " << outside_array.size();
inside_array.resize(outside_array.size());
for (size_t j = 0; j < inside_array.size(); ++j) {
VLOG(10) << j << " " << outside_array[j].numel();
if (outside_array[j].numel() != 0) {
inside_array[j].set_lod(outside_array[j].lod());
inside_array[j].ShareDataWith(outside_array[j]);
} else {
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
}
}
}
}
executor.Run(*program, *cur_scope_iter, block->ID(), false); executor.Run(*program, *cur_scope_iter, block->ID(), false);
auto &pg_names = Outputs(kParamGrads); auto &pg_names = Outputs(kParamGrads);
auto &p_names = Inputs(kParameters); auto &p_names = Inputs(kParameters);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
for (size_t prog_id = 0; prog_id < pg_names.size(); ++prog_id) { for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
auto inside_grad_name = framework::GradVarName(p_names[prog_id]); if (pg_names[param_id] == framework::kEmptyVarName) {
continue; // iterator doesn't have gradient
}
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
// // TODO(tonyyang-savil: Not sure we need the following // // TODO(tonyyang-svail): Not sure we need the following
// // If does not compute gradient of that variable inside rnn, // // If does not compute gradient of that variable inside rnn,
// just // just
// // continue // // continue
...@@ -126,7 +176,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -126,7 +176,7 @@ class WhileGradOp : public framework::OperatorBase {
// zero gradient variable in step 0 // zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) { if (cur_scope_iter == step_scopes->rbegin()) {
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
...@@ -135,27 +185,18 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -135,27 +185,18 @@ class WhileGradOp : public framework::OperatorBase {
attrs["value"] = 0.0f; attrs["value"] = 0.0f;
auto zero_op = framework::OpRegistry::CreateOp( auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {pg_names[prog_id]}}}, attrs); "fill_constant", {}, {{"Out", {pg_names[param_id]}}}, attrs);
zero_op->Run(scope, dev_ctx); zero_op->Run(scope, dev_ctx);
} }
} }
// sum gradient // sum gradient
auto *outside_var = scope.FindVar(pg_names[prog_id]); auto new_inside_name = cur_scope.Rename(inside_grad_name);
PADDLE_ENFORCE_NOT_NULL(outside_var);
auto &outside_tensor = *outside_var->GetMutable<framework::LoDTensor>();
std::string result_var_name;
auto *local_result_var = (*cur_scope_iter)->Var(&result_var_name);
auto &local_result_tensor =
*local_result_var->GetMutable<framework::LoDTensor>();
local_result_tensor.ShareDataWith(outside_tensor);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {result_var_name, inside_grad_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {result_var_name}}}, {}); {{"Out", {pg_names[param_id]}}}, {});
sum_op->Run(**cur_scope_iter, dev_ctx); sum_op->Run(cur_scope, dev_ctx);
cur_scope.Rename(new_inside_name, inside_grad_name);
} }
} }
} }
...@@ -169,29 +210,110 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -169,29 +210,110 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
virtual std::unique_ptr<framework::OpDescBind> Apply() const { virtual std::unique_ptr<framework::OpDescBind> Apply() const {
auto *grad = new framework::OpDescBind(); auto *grad = new framework::OpDescBind();
grad->SetType("while_grad"); grad->SetType("while_grad");
for (auto &input_param : this->InputNames()) { grad->SetInput(kParameters, Input(kParameters));
grad->SetInput(input_param, this->Input(input_param)); grad->SetOutput(
grad->SetOutput(framework::GradVarName(input_param), framework::GradVarName(kParameters),
this->InputGrad(input_param)); InputGrad(kParameters, /*do not drop empty gradient*/ false));
grad->SetInput(kOutputs, Output(kOutputs));
// OG should be re-calculated by step blocks, since many outputs of while op
// do not need to calculate gradients.
std::unordered_set<std::string> block_ins;
{
for (auto &p : Input(kParameters)) {
block_ins.insert(p);
}
for (auto &o : Output(kOutputs)) {
block_ins.insert(o);
}
} }
std::unordered_set<std::string> extra_inputs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &input_name : grad_block_[0]->Op(i)->InputArgumentNames()) {
if (block_ins.find(input_name) != block_ins.end()) {
continue;
}
extra_inputs.insert(input_name);
}
for (auto &output_param : this->OutputNames()) { for (auto &output_name : grad_block_[0]->Op(i)->OutputArgumentNames()) {
grad->SetInput(output_param, this->Output(output_param)); block_ins.insert(output_name);
if (output_param != kStepScopes) {
grad->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
} }
} }
std::vector<std::string> extra_inputs_list;
extra_inputs_list.resize(extra_inputs.size());
std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kStepScopes, Output(kStepScopes));
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]); grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
// record the original output gradient names, since the gradient name of
// while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list);
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDescBind>(grad);
} }
}; };
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
auto p_names = op_desc.Input(kParameters);
auto pg_names = op_desc.Output(framework::GradVarName(kParameters));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
auto *g_var = block->FindVarRecursive(pg_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
VLOG(5) << "Setting " << pg_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType();
g_var->SetType(p_var.GetType());
g_var->SetDataType(p_var.GetDataType());
}
}
}
};
class WhileGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
ctx->HasInputs(kParameters);
ctx->HasOutputs(framework::GradVarName(kParameters));
ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kParameters);
auto pg_names = ctx->Outputs(kParamGrads);
auto dims = ctx->GetInputsDim(kParameters);
auto var_types = ctx->GetInputsVarType(kParameters);
std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set;
for (size_t i = 0; i < p_names.size(); ++i) {
if (pg_names[i] == framework::kEmptyVarName) {
continue;
}
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims[i]);
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims[i]);
}
}
ctx->SetDims(names_to_set, dims_to_set);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(while, paddle::operators::WhileOp, REGISTER_OPERATOR(while, paddle::operators::WhileOp,
paddle::operators::WhileOpMaker, paddle::operators::WhileOpMaker,
paddle::operators::WhileGradOpDescMaker); paddle::operators::WhileGradOpDescMaker);
REGISTER_OPERATOR(while_grad, paddle::operators::WhileGradOp,
paddle::operators::WhileGradOpShapeInference,
paddle::operators::WhileGradOpVarTypeInference);
...@@ -12,9 +12,9 @@ def unique_name(prefix): ...@@ -12,9 +12,9 @@ def unique_name(prefix):
return "_".join([prefix, str(uid)]) return "_".join([prefix, str(uid)])
def _debug_string_(proto): def _debug_string_(proto, throw_on_error=True):
error_fields = list() error_fields = list()
if not proto.IsInitialized(error_fields): if not proto.IsInitialized(error_fields) and throw_on_error:
raise ValueError("{0} are not initialized\nThe message is {1}".format( raise ValueError("{0} are not initialized\nThe message is {1}".format(
error_fields, proto)) error_fields, proto))
return proto.__str__() return proto.__str__()
...@@ -101,9 +101,12 @@ class Variable(object): ...@@ -101,9 +101,12 @@ class Variable(object):
self.stop_gradient = stop_gradient self.stop_gradient = stop_gradient
def __str__(self): def __str__(self):
return self.to_string(True)
def to_string(self, throw_on_error):
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr)) proto = framework_pb2.VarDesc.FromString(str(protostr))
return _debug_string_(proto) return _debug_string_(proto, throw_on_error)
__repr__ = __str__ __repr__ = __str__
...@@ -291,10 +294,13 @@ class Operator(object): ...@@ -291,10 +294,13 @@ class Operator(object):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
def __str__(self): def to_string(self, throw_on_error):
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr)) proto = framework_pb2.OpDesc.FromString(str(protostr))
return _debug_string_(proto) return _debug_string_(proto, throw_on_error)
def __str__(self):
return self.to_string(True)
__repr__ = __str__ __repr__ = __str__
...@@ -349,9 +355,12 @@ class Block(object): ...@@ -349,9 +355,12 @@ class Block(object):
self.program = program self.program = program
def __str__(self): def __str__(self):
return self.to_string(True)
def to_string(self, throw_on_error):
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr)) proto = framework_pb2.BlockDesc.FromString(str(protostr))
return _debug_string_(proto) return _debug_string_(proto, throw_on_error)
__repr__ = __str__ __repr__ = __str__
...@@ -454,9 +463,12 @@ class Program(object): ...@@ -454,9 +463,12 @@ class Program(object):
self.current_block_idx = 0 self.current_block_idx = 0
def __str__(self): def __str__(self):
return self.to_string(True)
def to_string(self, throw_on_error):
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr)) proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return _debug_string_(proto) return _debug_string_(proto, throw_on_error)
def clone(self): def clone(self):
p = Program() p = Program()
...@@ -512,7 +524,14 @@ class Program(object): ...@@ -512,7 +524,14 @@ class Program(object):
assert isinstance(target, Variable) assert isinstance(target, Variable)
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
param_to_grad_info = self.desc.append_backward(target.desc, no_grad_set) try:
param_to_grad_info = self.desc.append_backward(target.desc,
no_grad_set)
except Exception as e:
raise core.EnforceNotMet(
str(e) + "\nCurrent protobuf is\n{0}".format(
self.to_string(False)))
self.sync_with_cpp() self.sync_with_cpp()
return param_to_grad_info return param_to_grad_info
......
...@@ -66,10 +66,13 @@ def parse_graph(program, graph, var_dict, **kwargs): ...@@ -66,10 +66,13 @@ def parse_graph(program, graph, var_dict, **kwargs):
if not var_dict.has_key(var): if not var_dict.has_key(var):
var_dict[var] = "Feed" var_dict[var] = "Feed"
temp_id = 0
proto = framework_pb2.ProgramDesc.FromString( proto = framework_pb2.ProgramDesc.FromString(
program.desc.serialize_to_string()) program.desc.serialize_to_string())
for block in proto.blocks: for block in proto.blocks:
for op in block.ops: for op in block.ops:
op.type = op.type + "_" + str(temp_id)
temp_id += 1
graph.node(**draw_node(op)) graph.node(**draw_node(op))
for o in op.outputs: for o in op.outputs:
for arg in o.arguments: for arg in o.arguments:
...@@ -78,6 +81,7 @@ def parse_graph(program, graph, var_dict, **kwargs): ...@@ -78,6 +81,7 @@ def parse_graph(program, graph, var_dict, **kwargs):
for arg in e.arguments: for arg in e.arguments:
if var_dict.has_key(arg): if var_dict.has_key(arg):
graph.edge(**draw_edge(var_dict, op, e, arg)) graph.edge(**draw_edge(var_dict, op, e, arg))
break # only plot the first block
def draw_graph(startup_program, main_program, **kwargs): def draw_graph(startup_program, main_program, **kwargs):
......
...@@ -2,6 +2,7 @@ import unittest ...@@ -2,6 +2,7 @@ import unittest
import paddle.v2.fluid.layers as layers import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.core as core import paddle.v2.fluid.core as core
from paddle.v2.fluid.backward import append_backward_ops
import numpy import numpy
...@@ -16,7 +17,7 @@ class TestWhileOp(unittest.TestCase): ...@@ -16,7 +17,7 @@ class TestWhileOp(unittest.TestCase):
i = layers.zeros(shape=[1], dtype='int64') i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32') init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(init, i=i) mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i) data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i) i = layers.increment(i)
...@@ -29,17 +30,23 @@ class TestWhileOp(unittest.TestCase): ...@@ -29,17 +30,23 @@ class TestWhileOp(unittest.TestCase):
i.stop_gradient = True i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=3) array_len = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len) cond = layers.less_than(x=i, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
with while_op.block(): with while_op.block():
d = layers.array_read(array=data_array, i=i) d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i) prev = layers.array_read(array=mem_array, i=i)
i = layers.increment(x=i, in_place=True)
result = layers.sums(input=[d, prev]) result = layers.sums(input=[d, prev])
i = layers.increment(x=i, in_place=True)
layers.array_write(result, i=i, array=mem_array) layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond) layers.less_than(x=i, y=array_len, cond=cond)
sum_result = layers.array_read(mem_array, i=array_len)
sum_result = layers.array_read(array=mem_array, i=i)
loss = layers.mean(x=sum_result)
append_backward_ops(loss)
cpu = core.CPUPlace() cpu = core.CPUPlace()
exe = Executor(cpu) exe = Executor(cpu)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册