提交 0079fa32 编写于 作者: Y Yan Chunwei 提交者: GitHub

Rnn make stepnet member (#3469)

* make stepnet member

* add pybind support

* fix Inputs Outputs

* remove unique_ptr
上级 80de7e5e
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
......@@ -241,6 +242,11 @@ All parameter, weight, gradient are variables in Paddle.
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
......@@ -248,6 +254,29 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(net);
// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");
rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
......
......@@ -66,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
......@@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
......@@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
......@@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
// Now all variables in scope must be created outside of op.
auto net_var = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
for (auto& input : net_op->Inputs()) {
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
......@@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for (const auto& output : net_op->Outputs()) {
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
......@@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}
class RecurrentAlgorithmProtoAndCheckerMaker
......@@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
......@@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run(
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
......@@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}
} // namespace operators
......
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
......@@ -33,7 +34,11 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
}
/**
* InferShape must be called before Run.
......@@ -58,7 +63,8 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private:
std::unique_ptr<rnn::Argument> arg_;
std::shared_ptr<NetOp>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};
......@@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
......@@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm {
}
private:
std::unique_ptr<rnn::Argument> arg_;
rnn::Argument* arg_;
mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_;
};
class RecurrentOp final : public framework::OperatorBase {
......@@ -115,10 +126,15 @@ class RecurrentOp final : public framework::OperatorBase {
alg_.Run(scope, dev_ctx);
}
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
static const rnn::ArgumentName kArgName;
private:
RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_;
};
class RecurrentGradientOp final : public framework::OperatorBase {
......@@ -141,8 +157,13 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }
private:
RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_;
rnn::Argument arg_;
};
} // namespace operators
......
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using namespace paddle::framework;
class RecurrentGradientAlgorithmTest : public ::testing::Test {
protected:
virtual void SetUp() override {
CreateGlobalVariables();
CreateStepScopes();
CreateStepNet();
CreateRNNGradientAlgorithm();
// segment inputs
SegmentInputs();
// link forward memories
LinkeMemories();
}
virtual void TearDown() override {}
void CreateGlobalVariables() {
// inputs: x
LOG(INFO) << "create global variable x";
Variable* x = scope_.NewVar("x");
DDim dims =
make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
// inputs: h_boot
LOG(INFO) << "create global variable h_boot";
Variable* h_boot = scope_.NewVar("h_boot");
h_boot->GetMutable<Tensor>()->mutable_data<float>(
make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace());
// inputs: w
LOG(INFO) << "create global variable w";
Variable* w = scope_.NewVar("rnn/w");
w->GetMutable<Tensor>()->mutable_data<float>(make_ddim({30, 30}),
platform::CPUPlace());
// inputs: h_grad
LOG(INFO) << "create variable h_grad";
Variable* dh = scope_.NewVar("h_grad");
dh->GetMutable<Tensor>()->mutable_data<float>(make_ddim({10, 20, 30}),
platform::CPUPlace());
// inputs: step_scopes
LOG(INFO) << "create variable step_scopes";
scope_.NewVar("step_scopes");
// inputs: step_net
LOG(INFO) << "create variable step_net";
scope_.NewVar("step_net");
// outputs: w_grad
LOG(INFO) << "create global variable w_grad";
scope_.NewVar("rnn/w_grad");
// outputs: x_grad
LOG(INFO) << "create global variable x_grad";
scope_.NewVar("x_grad");
// outputs: h_boot_grad
LOG(INFO) << "create global variable h_boot_grad";
scope_.NewVar("h_boot_grad");
}
void CreateStepScopes() {
auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
for (int i = 0; i < 10; ++i) {
auto& scope = scope_.NewScope();
auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable<Tensor>();
pre_t->mutable_data<float>({20, 30}, platform::CPUPlace());
auto tensor = scope.NewVar("rnn/h")->GetMutable<Tensor>();
tensor->mutable_data<float>({20, 30}, platform::CPUPlace());
// for unit test of ConcatOutputs
auto xg = scope.NewVar("rnn/x_grad")->GetMutable<Tensor>();
xg->mutable_data<float>({20, 30}, platform::CPUPlace());
step_scopes->emplace_back(&scope);
}
// last time step
auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable<Tensor>();
g->mutable_data<float>({20, 30}, platform::CPUPlace());
}
void CreateRNNGradientAlgorithm() {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
arg->step_net = "step_net";
arg->step_scopes = "step_scopes";
rnn::Link inlink;
inlink.external = "h_grad";
inlink.internal = "rnn/h_grad";
arg->inlinks = std::vector<rnn::Link>{inlink};
rnn::Link outlink;
outlink.external = "x_grad";
outlink.internal = "rnn/x_grad";
arg->outlinks = std::vector<rnn::Link>{outlink};
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "rnn/h_pre_grad";
mem_attr.var = "rnn/h_grad";
mem_attr.boot_var = "h_boot_grad";
arg->memories = std::vector<rnn::MemoryAttr>{mem_attr};
rnn_grad_algo_.Init(std::move(arg));
}
void CreateStepNet() {
LOG(INFO) << "create variable step_net";
Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>();
// TODO(qingqing) modify backward op create for RNNOp unit test
// and the unit test will be removed to Python.
// net->AddOp(OpRegistry::CreateOp("mul", {"X", {"rnn/h_pre", "rnn/w",
// "rnn/s_grad"}}, {"Y", {"rnn/h_pre_grad", "rnn/w_grad"}}, {}));
// net->AddOp(OpRegistry::CreateOp("add_two", {"X", {"rnn/h_grad"}},
// {"Y", {"rnn/x_grad"}}, {"Out", "rnn/s_grad"}}, {}));
net->CompleteAddOp();
}
void SegmentInputs() {
LOG(INFO) << "segment inputs";
std::vector<std::string> inlinks = {"x"};
std::vector<std::string> inlinks_alias = {"rnn/x"};
rnn::Link inlink;
inlink.external = "x";
inlink.internal = "rnn/x";
auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10,
true /*infer_shape_mode*/);
}
void LinkeMemories() {
LOG(INFO) << "link memories";
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "rnn/h_pre";
mem_attr.var = "rnn/h";
mem_attr.boot_var = "boot_h";
std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr);
auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1,
true /*infer_shape_mode*/);
}
}
Scope scope_;
RecurrentGradientAlgorithm rnn_grad_algo_;
};
// TEST_F(RecurrentGradientAlgorithmTest, Run) {
// platform::CPUDeviceContext ctx;
// rnn_grad_algo_.Run(scope_, ctx);
// }
} // namespace operators
} // namespace paddle
TEST(RecurrentOp, LinkMemories) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators;
// create and init step scopes
size_t len = 10;
std::vector<Scope*> step_scopes;
for (size_t i = 0; i < len; ++i) {
auto scope = new Scope();
scope->NewVar("pre_h");
auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
for (size_t j = 0; j < 15 * 20; ++j) {
data[j] = rand() * (1. / (double)RAND_MAX);
}
step_scopes.push_back(scope);
}
// create MemoryAttr
rnn::MemoryAttr mem_attr;
mem_attr.pre_var = "pre_h";
mem_attr.var = "h";
mem_attr.boot_var = "boot_h";
std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr);
for (size_t i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1, false
/*infer_shape_mode*/);
}
// check
for (size_t i = 0; i < len - 1; ++i) {
const float* a =
step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
const float* b = step_scopes[i + 1]
->FindVar("pre_h")
->GetMutable<Tensor>()
->data<float>();
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1, false
/*infer_shape_mode*/);
}
// check
for (int i = len - 2; i >= 0; --i) {
const float* a =
step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
const float* b =
step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
for (auto s : step_scopes) {
delete s;
}
}
USE_OP(add_two);
USE_OP(mul);
USE_OP_ITSELF(recurrent_op);
......@@ -106,7 +106,6 @@ void LinkMemories(const std::vector<Scope*>& scopes,
void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks);
......
......@@ -23,7 +23,7 @@ class OpDescCreationMethod(object):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
......@@ -177,4 +177,26 @@ class OperatorFactory(object):
return self.get_op_info(type).attrs
class __RecurrentOp__(object):
__proto__ = None
type = 'recurrent_op'
def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
if self.type not in args and 'type' not in kwargs:
kwargs['type'] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.RecurrentOp.create(proto.SerializeToString())
Operator = OperatorFactory() # Default global factory
RecurrentOp = __RecurrentOp__()
......@@ -2,7 +2,7 @@ import logging
import paddle.v2.framework.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator
from paddle.v2.framework.op import Operator, RecurrentOp
def py_sigmoid(x):
......@@ -98,11 +98,11 @@ class TestRecurrentOp(unittest.TestCase):
def forward(self):
self.scope = core.Scope()
self.create_global_variables()
self.create_rnn_op()
self.create_step_net()
rnn_op = self.create_rnn_op()
ctx = core.DeviceContext.create(core.CPUPlace())
rnn_op.infer_shape(self.scope)
rnn_op.run(self.scope, ctx)
self.rnnop.infer_shape(self.scope)
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor())
def create_global_variables(self):
......@@ -128,8 +128,7 @@ class TestRecurrentOp(unittest.TestCase):
def create_rnn_op(self):
# create RNNOp
rnnop = Operator(
"recurrent_op",
self.rnnop = RecurrentOp(
# inputs
inlinks=["x"],
boot_memories=["h_boot"],
......@@ -142,14 +141,9 @@ class TestRecurrentOp(unittest.TestCase):
outlink_alias=["h@alias"],
pre_memories=["h@pre"],
memories=["h@alias"])
return rnnop
def create_step_net(self):
var = self.scope.new_var("stepnet")
stepnet = var.get_net()
# x_fc_op = Operator("fc", X="x@alias", W="W", Y="Wx")
# h_fc_op = Operator("fc", X="h@pre", W="U", Y="Uh")
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum")
......@@ -158,6 +152,7 @@ class TestRecurrentOp(unittest.TestCase):
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.add_op(op)
stepnet.complete_add_op(True)
self.rnnop.set_stepnet(stepnet)
def test_forward(self):
print 'test recurrent op forward'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册