未验证 提交 4e3c535a 编写于 作者: G Guo Sheng 提交者: GitHub

Add support for dynamic_decode(while) training. (#22231) (#22574)

* Add support for dynamic_decode(while) training. test=develop

* Fix assign_op and tensor_array_read_write_op after solving conflict. test=develop

* Fix test_rnn_decode_api.py. test=develop

* Refine docs for apis in rnn.py. test=develop

* Adjust outputs of dynamic_decode. test=develop

* Remove the force_cpu update in assign_op. test=develop

* Remove the force_cpu update in assign_op. test=develop

* Make RNNCell.get_initial_states support batch_dim_idx argument. test=develop

* Rename _create_array_outof_while as _create_array_out_of_while in rnn.py.

test=release/1.7
上级 a8f85f2c
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/array_operator.h" #include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -152,6 +154,21 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -152,6 +154,21 @@ class ReadFromArrayOp : public ArrayOp {
out_tensor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} else { } else {
VLOG(10) << "offset " << offset << " >= " << x_array.size(); VLOG(10) << "offset " << offset << " >= " << x_array.size();
// set grad of the writed tensor to 0 when used as write_to_array_grad
auto *fw_var = scope.FindVar(Input("X_W"));
if (fw_var == nullptr) return;
auto &fw_var_tensor = fw_var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["dtype"] = fw_var_tensor.type();
attrs["shape"] = framework::vectorize<int>(fw_var_tensor.dims());
attrs["value"] = 0.0f;
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {Output("Out")}}}, attrs);
zero_op->Run(scope, place);
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
out_tensor->set_lod(fw_var_tensor.lod());
} }
} }
}; };
...@@ -163,6 +180,10 @@ class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -163,6 +180,10 @@ class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
AddInput("I", AddInput("I",
"(Tensor) the subscript index in tensor array. The number of " "(Tensor) the subscript index in tensor array. The number of "
"element should be 1"); "element should be 1");
AddInput("X_W",
"(Tensor) the writed tensor when used as the grad op of "
"write_to_array. We use this to fill zero gradient.")
.AsDispensable();
AddOutput("Out", "(LoDTensor) the tensor will be read from."); AddOutput("Out", "(LoDTensor) the tensor will be read from.");
AddComment(R"DOC( AddComment(R"DOC(
ReadFromArray Operator. ReadFromArray Operator.
...@@ -199,6 +220,7 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -199,6 +220,7 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetType("read_from_array"); grad_op->SetType("read_from_array");
grad_op->SetInput("I", this->Input("I")); grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("X_W", this->Input("X"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op); return std::unique_ptr<T>(grad_op);
......
此差异已折叠。
...@@ -787,6 +787,7 @@ def argmin(x, axis=0): ...@@ -787,6 +787,7 @@ def argmin(x, axis=0):
inputs={'X': x}, inputs={'X': x},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'axis': axis}) attrs={'axis': axis})
out.stop_gradient = True
return out return out
...@@ -846,6 +847,7 @@ def argmax(x, axis=0): ...@@ -846,6 +847,7 @@ def argmax(x, axis=0):
inputs={'X': x}, inputs={'X': x},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'axis': axis}) attrs={'axis': axis})
out.stop_gradient = True
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册