diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 41cf1dc385aa8ab3160e5e15213d28494f5da69d..b646ad5f3b67e22b933a56d88e4a1bf6e74a124e 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/framework/op_desc.h" #include "paddle/framework/var_desc.h" +#include "paddle/platform/macros.h" namespace paddle { namespace framework { @@ -34,9 +35,6 @@ class BlockDescBind { BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) : prog_(prog), desc_(desc), need_update_(false) {} - BlockDescBind(const BlockDescBind &o) = delete; - BlockDescBind &operator=(const BlockDescBind &o) = delete; - int32_t ID() const { return desc_->idx(); } int32_t Parent() const { return desc_->parent_idx(); } @@ -68,6 +66,8 @@ class BlockDescBind { std::deque> ops_; std::unordered_map> vars_; + + DISABLE_COPY_AND_ASSIGN(BlockDescBind); }; } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 470336d367566b342af6a7dac9c9cdf465101ea2..9672e540c8dd7bc3ee7993200e606320087d9ba7 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -20,6 +20,7 @@ #include "paddle/framework/attribute.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/type_defs.h" +#include "paddle/platform/macros.h" namespace paddle { namespace framework { @@ -67,11 +68,6 @@ class OpInfoMap { public: static OpInfoMap& Instance(); - OpInfoMap(const OpInfoMap& o) = delete; - OpInfoMap(OpInfoMap&& o) = delete; - OpInfoMap& operator=(const OpInfoMap& o) = delete; - OpInfoMap& operator=(OpInfoMap&& o) = delete; - bool Has(const std::string& op_type) const { return map_.find(op_type) != map_.end(); } @@ -107,6 +103,8 @@ class OpInfoMap { private: OpInfoMap() = default; std::unordered_map map_; + + DISABLE_COPY_AND_ASSIGN(OpInfoMap); }; } // namespace framework diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 06ffcd4b15078f62ea8b7a3714e73de799530785..9b34a06aeff94e6fa855f6f287a73889e2a4faee 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/framework/framework.pb.h" +#include "paddle/platform/macros.h" namespace paddle { namespace framework { @@ -26,9 +27,6 @@ class ProgramDescBind { public: static ProgramDescBind &Instance(ProgramDesc *prog); - ProgramDescBind(const ProgramDescBind &o) = delete; - ProgramDescBind &operator=(const ProgramDescBind &o) = delete; - BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } @@ -46,6 +44,8 @@ class ProgramDescBind { ProgramDesc *prog_; std::vector> blocks_; + + DISABLE_COPY_AND_ASSIGN(ProgramDescBind); }; } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index c93b03e48130afe9568089b6a7586c4185d1d5b4..7047f0d55e9844aec19892631fe4b5b387bf89ca 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/framework/variable.h" +#include "paddle/platform/macros.h" namespace paddle { namespace framework { @@ -38,11 +39,6 @@ class Scope { Scope() {} ~Scope(); - // Disable Copy, Assign, Move. - Scope(const Scope& other) = delete; - Scope& operator=(const Scope& other) = delete; - Scope(Scope&& other) = delete; - /// Create a sub-scope. Returns a reference other than a pointer so /// to prevent from manual deletion. /// Mark it to const because that new kid scope cannot change parent scope. @@ -73,6 +69,8 @@ class Scope { std::unordered_map vars_; mutable std::list kids_; Scope const* parent_{nullptr}; + + DISABLE_COPY_AND_ASSIGN(Scope); }; } // namespace framework diff --git a/paddle/framework/tensor_array.h b/paddle/framework/tensor_array.h index e76f33d2c0818aa6091d6d5edb1cdfd1e9122476..22ae6a966f90c47fe8b4bbaaf5eb227c39d84173 100644 --- a/paddle/framework/tensor_array.h +++ b/paddle/framework/tensor_array.h @@ -47,13 +47,6 @@ class TensorArray { // max number of values allowed to store. const size_t MAX_SIZE{100000}; - /* - * Inputs: - * - value_shared: share memory between tensors. - */ - explicit TensorArray(bool values_shared = true) - : values_shared_(values_shared) {} - /* * Read the value at location `index` in the `TensorArray`. */ @@ -111,7 +104,6 @@ class TensorArray { private: mutable std::vector values_; - bool values_shared_; }; // class TensorArray } // namespace framework diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 80de229c333f645fb3098b97fa076c6b77bb7ca9..04c4c24951f5db572486ded5edfc26948a821682 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -30,36 +30,39 @@ using LoDTensor = framework::LoDTensor; void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { - auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, - false /*infer_shape_mode*/); - InitMemories(step_scopes[0], false /*infer_shape_mode*/); + auto* input0 = scope.FindVar(arg_->inlinks[0]); + PADDLE_ENFORCE_NOT_NULL(input0); + size_t seq_len = input0->GetMutable()->dims()[0]; + PADDLE_ENFORCE_GT(seq_len, 0); - for (size_t step_id = 0; step_id < seq_len_; step_id++) { - // create output alias variables + CreateScopes(scope, seq_len); + auto& step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len); + InitMemories(step_scopes[0]); + + for (size_t step_id = 0; step_id < seq_len; step_id++) { if (step_id > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, - false /*infer_shape_mode*/); + rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); } (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, - false /*infer_shape_mode*/); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); } -void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { +void RecurrentAlgorithm::CreateScopes(const Scope& scope, + size_t seq_len) const { // TODO(superjom) Only two scopes are needed for inference, this case will be // supported later. - auto step_scopes_var = scope.FindVar(arg_->step_scopes); + auto* step_scopes_var = scope.FindVar(arg_->step_scopes); PADDLE_ENFORCE(step_scopes_var != nullptr, ""); - auto step_scopes = step_scopes_var->GetMutable>(); + auto* step_scopes = step_scopes_var->GetMutable>(); // Now all variables in scope must be created outside of op. PADDLE_ENFORCE_NOT_NULL(stepnet_); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); - if (seq_len_ > step_scopes->size()) { - for (size_t i = step_scopes->size(); i < seq_len_; ++i) { + if (seq_len > step_scopes->size()) { + for (size_t i = step_scopes->size(); i < seq_len; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs @@ -82,8 +85,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { } } -void RecurrentAlgorithm::InitMemories(Scope* step_scope, - bool infer_shape_mode) const { +void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { for (auto& attr : arg_->memories) { auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, @@ -91,12 +93,9 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope, attr.boot_var); auto* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable(); - if (infer_shape_mode) { - pre_mem->Resize(boot_mem->dims()); - PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2); - } else { - pre_mem->ShareDataWith(*boot_mem); - } + pre_mem->Resize(boot_mem->dims()); + PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2); + pre_mem->ShareDataWith(*boot_mem); } } @@ -146,23 +145,23 @@ class RecurrentAlgorithmProtoAndCheckerMaker void RecurrentGradientAlgorithm::Run( const Scope& scope, const platform::DeviceContext& dev_ctx) const { - auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, - false /*infer_shape_mode*/); - for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { - if (static_cast(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, - false /*infer_shape_mode*/); + auto* input0 = scope.FindVar(arg_->inlinks[0]); + PADDLE_ENFORCE_NOT_NULL(input0); + size_t seq_len = input0->GetMutable()->dims()[0]; + auto& step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len); + for (int step_id = seq_len - 1; step_id >= 0; --step_id) { + if (step_id != seq_len - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); } (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - LinkBootMemoryGradients(step_scopes[0], false); - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, - false /*infer_shape_mode*/); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); + LinkBootMemoryGradients(step_scopes[0]); } void RecurrentGradientAlgorithm::LinkBootMemoryGradients( - Scope* step_scope, bool infer_shape_mode) const { + Scope* step_scope) const { for (auto& attr : arg_->memories) { PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, "memory variable [%s] does not exists", attr.var); @@ -171,11 +170,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); auto* boot_mem_grad = step_scope->NewVar(attr.boot_var)->GetMutable(); - if (infer_shape_mode) { - boot_mem_grad->Resize(mem_grad->dims()); - } else { - boot_mem_grad->ShareDataWith(*mem_grad); - } + boot_mem_grad->Resize(mem_grad->dims()); + boot_mem_grad->ShareDataWith(*mem_grad); } } diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index c6b9a5533eece9057449b5c875ddcb3cefe716f0..253d7e3284360ceaddce9ef5f8f9a3ea4793d740 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -48,7 +48,7 @@ class RecurrentAlgorithm { * NOTE the scopes are reused in both the forward and backward, so just * create once and expand its size if more steps need. */ - void CreateScopes(const framework::Scope& scope) const; + void CreateScopes(const framework::Scope& scope, size_t seq_len) const; const std::vector& GetStepScopes( const framework::Scope& scope) const { @@ -56,12 +56,11 @@ class RecurrentAlgorithm { ->GetMutable>(); } - void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; + void InitMemories(framework::Scope* step_scopes) const; private: std::unique_ptr* stepnet_; rnn::Argument* arg_; - mutable size_t seq_len_; }; class RecurrentGradientAlgorithm { @@ -86,8 +85,7 @@ class RecurrentGradientAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void LinkBootMemoryGradients(framework::Scope* step_scopes, - bool infer_shape_mode) const; + void LinkBootMemoryGradients(framework::Scope* step_scopes) const; protected: inline const std::vector& GetStepScopes( @@ -98,7 +96,6 @@ class RecurrentGradientAlgorithm { private: rnn::Argument* arg_; - mutable size_t seq_len_; std::unique_ptr* stepnet_; }; @@ -123,6 +120,7 @@ class RecurrentOp : public framework::OperatorBase { void set_stepnet(std::unique_ptr net) { stepnet_ = std::move(net); } + const OperatorBase& stepnet() const { return *stepnet_; } static const rnn::ArgumentName kArgName; diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index a767009d2366e20d2ebd35f562b8df7d408f2d4e..ef317a71f12c6de974bd8715bb08122b761fae37 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -25,7 +25,7 @@ using LoDTensor = framework::LoDTensor; void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len, bool infer_shape_mode) { + const size_t seq_len) { PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); for (size_t i = 0; i < inlinks.size(); ++i) { // global inputs @@ -41,11 +41,9 @@ void SegmentInputs(const std::vector& step_scopes, for (size_t j = 0; j < seq_len; j++) { Tensor* step_input = step_scopes[j]->NewVar(inlinks[i])->GetMutable(); - if (!infer_shape_mode) { - // The input of operators of each step is Tensor here. - // Maybe need to modify Slice function. - *step_input = input->Slice(j, j + 1); - } + // The input of operators of each step is Tensor here. + // Maybe need to modify Slice function. + *step_input = input->Slice(j, j + 1); step_input->Resize(step_dims); } } @@ -53,39 +51,35 @@ void SegmentInputs(const std::vector& step_scopes, void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len, bool infer_shape_mode) { + const size_t seq_len) { for (size_t i = 0; i < outlinks.size(); i++) { - auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]); + auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]); PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.", outlinks[i]); LoDTensor* output = output_var->GetMutable(); - if (infer_shape_mode) { - auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]); - PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); - f::DDim step_dims = - step_scope_var->template GetMutable()->dims(); - std::vector dims_vec = vectorize(step_dims); - dims_vec.insert(dims_vec.begin(), seq_len); - output->Resize(f::make_ddim(dims_vec)); - } else { - output->mutable_data(platform::CPUPlace()); - for (size_t j = 0; j < seq_len; j++) { - LoDTensor* step_output = - step_scopes[j]->FindVar(outlinks[i])->GetMutable(); - // TODO(luotao02) data type and platform::DeviceContext() should set - // correctly - (output->Slice(j, j + 1)) - .CopyFrom(*step_output, platform::CPUPlace()); - } + auto* step_scope_var = step_scopes[0]->FindVar(outlinks[i]); + PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]); + f::DDim step_dims = + step_scope_var->template GetMutable()->dims(); + std::vector dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->Resize(f::make_ddim(dims_vec)); + output->mutable_data(platform::CPUPlace()); + for (size_t j = 0; j < seq_len; j++) { + LoDTensor* step_output = + step_scopes[j]->FindVar(outlinks[i])->GetMutable(); + // TODO(luotao02) data type and platform::DeviceContext() should set + // correctly + (output->Slice(j, j + 1)) + .CopyFrom(*step_output, platform::CPUPlace()); } } } void LinkMemories(const std::vector& scopes, const std::vector& memories, - const size_t step_id, const int offset, - bool infer_shape_mode) { + const size_t step_id, const int offset) { PADDLE_ENFORCE_LT(step_id, scopes.size(), "step [%d] is out of range of step scopes' size [%d]", step_id, scopes.size()); @@ -95,16 +89,13 @@ void LinkMemories(const std::vector& scopes, step_id + offset, scopes.size(), "offset [%d] is out of range, it must be less than (%d - %d)", offset, scopes.size(), step_id); - auto scope = scopes[step_id]; - auto linked_scope = scopes[step_id + offset]; + auto* scope = scopes[step_id]; + auto* linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->FindVar(attr.pre_var)->GetMutable(); - auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); - if (infer_shape_mode) { - mem->Resize(linked_mem->dims()); - } else { - mem->ShareDataWith(*linked_mem); - } + auto* mem = scope->FindVar(attr.pre_var)->GetMutable(); + auto* linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); + mem->Resize(linked_mem->dims()); + mem->ShareDataWith(*linked_mem); } } @@ -115,11 +106,11 @@ void InitArgument(const ArgumentName& name, Argument* arg, arg->inlinks = op.Inputs(name.inlinks); arg->outlinks = op.Outputs(name.outlinks); - auto boot_memories = + auto& boot_memories = is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories); // attributes - auto memories = op.Attr>(name.memories); - auto pre_memories = op.Attr>(name.pre_memories); + auto& memories = op.Attr>(name.memories); + auto& pre_memories = op.Attr>(name.pre_memories); PADDLE_ENFORCE(memories.size() == boot_memories.size(), "the size of memories, boot_memories don't match:%d,%d", diff --git a/paddle/operators/rnn/recurrent_op_utils.h b/paddle/operators/rnn/recurrent_op_utils.h index 9c777f1e9067a3e2ceb9d23f7bf7d3c73343c91f..fd17b9b88915cf458ff2836b5c5d8f84cd9b65b5 100644 --- a/paddle/operators/rnn/recurrent_op_utils.h +++ b/paddle/operators/rnn/recurrent_op_utils.h @@ -64,18 +64,18 @@ struct ArgumentName { */ void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len, bool infer_shape_mode); + const size_t seq_len); /** * Process outputs of step nets and merge to variables. */ void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len, bool infer_shape_mode); + const size_t seq_len); void LinkMemories(const std::vector& step_scopes, const std::vector& memories, const size_t step_id, - const int offset, bool infer_shape_mode); + const int offset); void InitArgument(const ArgumentName& name, Argument* arg, const framework::OperatorBase& op, bool is_grad = false); diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 5d76313aeb96c0c8204f64aee1057f753ec85d6b..c54843faa698654dafac786979045bebf0ebc95d 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -22,14 +22,15 @@ class SumOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null"); auto x_dims = ctx->GetInputsDim("X"); - PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SumOp should not be null."); - auto in_dim = x_dims[0]; size_t N = x_dims.size(); PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1."); + + auto in_dim = x_dims[0]; for (size_t i = 1; i < N; i++) { auto dim = x_dims[i]; PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape"); diff --git a/paddle/platform/macros.h b/paddle/platform/macros.h index 4a04a38c0c6f905639004dea2f4416ecc57c8620..feae7bdd77e3a0d02f33fb33991648408f542d0e 100644 --- a/paddle/platform/macros.h +++ b/paddle/platform/macros.h @@ -16,8 +16,10 @@ limitations under the License. */ // Disable the copy and assignment operator for a class. #ifndef DISABLE_COPY_AND_ASSIGN -#define DISABLE_COPY_AND_ASSIGN(classname) \ - private: \ - classname(const classname&) = delete; \ - classname& operator=(const classname&) = delete +#define DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname(const classname&&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname& operator=(const classname&&) = delete #endif diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 92161ae5dd93d34d898a2027435cc5e55611bcd0..1f114432c09f29fab6cd56de00dff341785ae0e4 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -16,14 +16,17 @@ class PySimpleRNN(object): ''' def __init__(self, input_dim=30, batch_size=50, weight_dim=15, sent_len=11): - self.x = np.random.normal(size=(sent_len, batch_size, input_dim)) - self.W = np.random.normal(size=(input_dim, input_dim)) - self.U = np.random.normal(size=(input_dim, input_dim)) - self.h_boot = np.random.normal(size=(batch_size, input_dim)) + self.x = np.random.normal(size=(sent_len, batch_size, + input_dim)).astype("float32") + self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32") + self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32") + self.h_boot = np.random.normal(size=(batch_size, + input_dim)).astype("float32") # memories self.mems = [ - np.zeros(shape=(batch_size, input_dim)) for i in range(sent_len) + np.zeros(shape=(batch_size, input_dim)).astype("float32") + for i in range(sent_len) ] def forward(self): @@ -36,7 +39,7 @@ class PySimpleRNN(object): return [self.x[i] for i in range(self.x.shape[0])] def concat_outputs(self): - return np.array(self.mems) + return np.array(self.mems).astype("float32") def step(self, step_id, x): ''' @@ -47,8 +50,8 @@ class PySimpleRNN(object): pre_mem = self.mems[step_id - 1] else: pre_mem = self.h_boot - xW = np.matmul(x, self.W) - hU = np.matmul(pre_mem, self.U) + xW = np.matmul(x, self.W).astype("float32") + hU = np.matmul(pre_mem, self.U).astype("float32") sum = xW + hU self.mems[step_id] = py_sigmoid(sum) @@ -102,7 +105,8 @@ class RecurrentOpTest(unittest.TestCase): self.create_step_net() ctx = core.DeviceContext.create(core.CPUPlace()) self.rnnop.run(self.scope, ctx) - return np.array(self.scope.find_var("h@mem").get_tensor()) + return np.array(self.scope.find_var("h@mem").get_tensor()).astype( + "float32") def create_global_variables(self): # create inlink @@ -142,7 +146,7 @@ class RecurrentOpTest(unittest.TestCase): stepnet = core.Net.create() x_fc_op = Operator("mul", X="x", Y="W", Out="Wx") h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") - sum_op = Operator("add", X="Wx", Y="Uh", Out="sum") + sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum") sig_op = Operator("sigmoid", X="sum", Y="h@mem") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: @@ -179,7 +183,7 @@ class RecurrentGradientOpTest(unittest.TestCase): stepnet = core.Net.create() x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx") h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") - sum_op = Operator("add", X="Wx", Y="Uh", Out="sum") + sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum") sig_op = Operator("sigmoid", X="sum", Y="h@alias") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: @@ -197,7 +201,4 @@ class RecurrentGradientOpTest(unittest.TestCase): if __name__ == '__main__': - exit( - 0 - ) # FIXME(yuyang18): InferShape has been removed, this unittest may error unittest.main()