提交 b94584cf 编写于 作者: D dangqingqing

Rename recurrent_network_op recurrent_op.

上级 79e89ef4
......@@ -55,7 +55,5 @@ op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc
tensor op_registry operator net)
cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
recurrent_network_op gtest mul_op add_op)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
......@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_network_op.h"
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring>
......@@ -108,8 +108,13 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
std::shared_ptr<Scope> scope = scopes[step_id];
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
PADDLE_ENFORCE(scope->HasVariable(attr.pre_var),
"the pre-memory [%s] is not in scope.",
attr.pre_var);
PADDLE_ENFORCE(linked_scope->HasVariable(attr.var),
"the memory [%s] is not in linked scope.",
attr.var);
auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better?
auto linked_mem = linked_scope->GetVariable(attr.var)->GetMutable<Tensor>();
if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
......@@ -295,12 +300,12 @@ public:
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInputs(name.inlinks,
"the input that need to be segmented for each step.");
"the inputs that need to be segmented for each step.");
AddInputs(name.boot_memories, "variables to initialize memories.");
AddInput(name.step_net, "network shared by all steps.");
AddOutputs(name.outlinks,
"the output that need to concated for all steps.");
"the outputs that need to concated for all steps.");
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
......
......@@ -18,7 +18,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle {
namespace operators {
......
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op fc_op sgd_op cross_entropy_op recurrent_network_op)
add_op fc_op sgd_op cross_entropy_op recurrent_op)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册