提交 b94584cf 编写于 作者: D dangqingqing

Rename recurrent_network_op recurrent_op.

上级 79e89ef4
...@@ -55,7 +55,5 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op ...@@ -55,7 +55,5 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
tensor op_registry operator net) cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
recurrent_network_op gtest mul_op add_op)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/recurrent_network_op.h" #include "paddle/operators/recurrent_op.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <cstring> #include <cstring>
...@@ -108,8 +108,13 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, ...@@ -108,8 +108,13 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
std::shared_ptr<Scope> scope = scopes[step_id]; std::shared_ptr<Scope> scope = scopes[step_id];
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset]; std::shared_ptr<Scope> linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
PADDLE_ENFORCE(scope->HasVariable(attr.pre_var),
"the pre-memory [%s] is not in scope.",
attr.pre_var);
PADDLE_ENFORCE(linked_scope->HasVariable(attr.var),
"the memory [%s] is not in linked scope.",
attr.var);
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>(); auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better?
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>(); auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
mem->Resize(linked_mem->dims()); mem->Resize(linked_mem->dims());
...@@ -295,12 +300,12 @@ public: ...@@ -295,12 +300,12 @@ public:
const auto& name = RecurrentOp::kArgName; const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInputs(name.inlinks, AddInputs(name.inlinks,
"the input that need to be segmented for each step."); "the inputs that need to be segmented for each step.");
AddInputs(name.boot_memories, "variables to initialize memories."); AddInputs(name.boot_memories, "variables to initialize memories.");
AddInput(name.step_net, "network shared by all steps."); AddInput(name.step_net, "network shared by all steps.");
AddOutputs(name.outlinks, AddOutputs(name.outlinks,
"the output that need to concated for all steps."); "the outputs that need to concated for all steps.");
AddOutput(name.step_scopes, "step scopes"); AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h" #include "paddle/operators/recurrent_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op sgd_op cross_entropy_op recurrent_network_op) add_op fc_op sgd_op cross_entropy_op recurrent_op)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册