提交 6cd94cc7 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_refactor_tensor

...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License # limitations under the License
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR}) set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
......
...@@ -290,8 +290,22 @@ function(go_library TARGET_NAME) ...@@ -290,8 +290,22 @@ function(go_library TARGET_NAME)
set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif() endif()
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
# This custom command will always run since it depends on a not
# existing file.
add_custom_command(
OUTPUT dummy_rebulid_${TARGET_NAME}
COMMAND cmake -E touch ${dummyfile}
)
# Create a custom target that depends on the custom command output
# file, so the custom command can be referenced as a dependency by
# `add_dependencies`.
add_custom_target(rebuild_${TARGET_NAME}
DEPENDS dummy_rebulid_${TARGET_NAME}
)
# Add dummy code to support `make target_name` under Terminal Command
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared) if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile}) add_library(${TARGET_NAME} SHARED ${dummyfile})
...@@ -302,6 +316,12 @@ function(go_library TARGET_NAME) ...@@ -302,6 +316,12 @@ function(go_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${go_library_DEPS}) add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS) endif(go_library_DEPS)
# The "source file" of the library is `${dummyfile}` which never
# change, so the target will never rebuild. Make the target depends
# on the custom command that touches the library "source file", so
# rebuild will always happen.
add_dependencies(${TARGET_NAME} rebuild_${TARGET_NAME})
set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}") set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection package connection
import ( import (
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(master_test) go_test(master_test)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer) go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test package master_test
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "sync" import "sync"
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "testing" import "testing"
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer) go_test(pserver_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer) go_test(pserver_client_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
add_style_check_target(test_cclient test_cclient.c) add_style_check_target(test_cclient test_cclient.c)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test package client_test
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
...@@ -66,10 +80,10 @@ func (p *EtcdClient) List() []Server { ...@@ -66,10 +80,10 @@ func (p *EtcdClient) List() []Server {
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := p.client.Get(ctx, psKey)
cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(p.timeout)
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test package pserver_test
import ( import (
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(network_helper_test) go_test(network_helper_test)
endif() endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import "testing" import "testing"
......
# ddim lib # ddim lib
cc_library(enforce SRCS enforce.cc DEPS glog)
cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc) cc_test(scope_test SRCS scope_test.cc)
proto_library(attr_type SRCS attr_type.proto) proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type) proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,6 +42,35 @@ class DefaultValueSetter { ...@@ -41,6 +42,35 @@ class DefaultValueSetter {
T default_value_; T default_value_;
}; };
template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val,
ContainerDebugString());
}
private:
std::string ContainerDebugString() const {
std::ostringstream sout;
sout << "[";
size_t cnt = 0;
for (auto& v : container_) {
sout << v;
++cnt;
if (cnt != container_.size()) {
sout << " ,";
}
}
sout << "]";
return sout.str();
}
std::unordered_set<T> container_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -50,6 +80,11 @@ class TypedAttrChecker { ...@@ -50,6 +80,11 @@ class TypedAttrChecker {
public: public:
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
value_checkers_.push_back(EnumInContainer<T>(range));
return *this;
}
TypedAttrChecker& LargerThan(const T& lower_bound) { TypedAttrChecker& LargerThan(const T& lower_bound) {
value_checkers_.push_back(LargerThanChecker<T>(lower_bound)); value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
return *this; return *this;
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/framework/dim.h" #include "paddle/framework/dim.h"
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
...@@ -119,17 +119,6 @@ int arity(const DDim& ddim); ...@@ -119,17 +119,6 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&); std::ostream& operator<<(std::ostream&, const DDim&);
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d];
}
return dsizes;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
template <int D>
struct EigenDim {
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret;
for (int d = 0; d < arity(dims); d++) {
ret[d] = dims[d];
}
return ret;
}
};
// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor.
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenTensor {
// TODO(qijun) Now, default type in unaligned, and we will make a benchmark on
// the speed of aligned and unaligned version in future.
using Type = Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, IndexType>>;
using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, D, MajorType, IndexType>>;
static Type From(Tensor& tensor, DDim dims) {
return Type(tensor.data<T>(), EigenDim<D>::From(dims));
}
static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); }
static ConstType From(const Tensor& tensor, DDim dims) {
return ConstType(tensor.data<T>(), EigenDim<D>::From(dims));
}
static ConstType From(const Tensor& tensor) {
return From(tensor, tensor.dims_);
}
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten is to reshape a Tensor into a one dimension EigenVector
using Parent = EigenTensor<T, 1, MajorType, IndexType>;
static typename Parent::Type Flatten(Tensor& tensor) {
return Parent::From(tensor,
make_ddim({static_cast<int>(product(tensor.dims_))}));
}
static typename Parent::ConstType Flatten(const Tensor& tensor) {
return Parent::From(tensor,
make_ddim({static_cast<int>(product(tensor.dims_))}));
}
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = EigenTensor<T, 2, MajorType, IndexType>;
} // namespace framework
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/framework/eigen.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
TEST(EigenDim, From) {
EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
ASSERT_EQ(1, ed[0]);
ASSERT_EQ(2, ed[1]);
ASSERT_EQ(3, ed[2]);
}
TEST(Eigen, Tensor) {
Tensor t;
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
for (int i = 0; i < 1 * 2 * 3; i++) {
p[i] = static_cast<float>(i);
}
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
ASSERT_EQ(1, et.dimension(0));
ASSERT_EQ(2, et.dimension(1));
ASSERT_EQ(3, et.dimension(2));
for (int i = 0; i < 1; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 3; k++) {
ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f);
}
}
}
}
TEST(Eigen, VectorFrom) {
Tensor t;
float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
for (int i = 0; i < 6; i++) {
p[i] = static_cast<float>(i);
}
EigenVector<float>::Type ev = EigenVector<float>::From(t);
ASSERT_EQ(6, ev.dimension(0));
for (int i = 0; i < 6; i++) {
ASSERT_NEAR(i, ev(i), 1e-6f);
}
}
TEST(Eigen, VectorFlatten) {
Tensor t;
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
for (int i = 0; i < 1 * 2 * 3; i++) {
p[i] = static_cast<float>(i);
}
EigenVector<float>::Type ev = EigenVector<float>::Flatten(t);
ASSERT_EQ(1 * 2 * 3, ev.dimension(0));
for (int i = 0; i < 1 * 2 * 3; i++) {
ASSERT_NEAR(i, ev(i), 1e-6f);
}
}
TEST(Eigen, Matrix) {
Tensor t;
float* p = t.mutable_data<float>(make_ddim({2, 3}), platform::CPUPlace());
for (int i = 0; i < 2 * 3; i++) {
p[i] = static_cast<float>(i);
}
EigenMatrix<float>::Type em = EigenMatrix<float>::From(t);
ASSERT_EQ(2, em.dimension(0));
ASSERT_EQ(3, em.dimension(1));
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 3; j++) {
ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f);
}
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/enforce.h"
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
namespace paddle {
namespace framework {
/**
* @brief Enforce exception. Inherits std::exception
*
* All enforce condition not met, will throw an EnforceNotMet exception.
*/
class EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
std::ostringstream sout;
sout << msg << " at [" << file << ":" << fileline << "];";
all_msg_ = sout.str();
}
const char* what() const noexcept override { return all_msg_.c_str(); }
private:
std::string all_msg_;
};
// From https://stackoverflow.com/questions/30130930/
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
// should be true in most situation, it will make the compiler generate faster
// code by adding `UNLIKELY` macro.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
/**
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
* __LINE__
*
* This macro take __VA_ARGS__, user can pass any type if that type can
* serialize to std::ostream
*/
#define PADDLE_THROW(...) \
do { \
throw ::paddle::framework::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#ifdef NDEBUG
#define PADDLE_ENFORCE(condition, ...) \
do { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
} \
} while (0)
#else
#define PADDLE_ENFORCE(condition, ...) \
CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
#endif
} // namespace framework
} // namespace paddle
...@@ -19,7 +19,10 @@ ...@@ -19,7 +19,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void PlainNet::CompleteAddOp() { void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true;
if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set; std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output; std::unordered_set<std::string> temp_output;
...@@ -52,7 +55,6 @@ void PlainNet::CompleteAddOp() { ...@@ -52,7 +55,6 @@ void PlainNet::CompleteAddOp() {
} }
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
add_op_done_ = true;
} }
std::string PlainNet::DebugString() const { std::string PlainNet::DebugString() const {
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <paddle/framework/op_desc.pb.h> #include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
#include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -41,7 +40,7 @@ namespace framework { ...@@ -41,7 +40,7 @@ namespace framework {
class Net : public OperatorBase { class Net : public OperatorBase {
public: public:
virtual void AddOp(const OperatorPtr& op) = 0; virtual void AddOp(const OperatorPtr& op) = 0;
virtual void CompleteAddOp() = 0; virtual void CompleteAddOp(bool calc) = 0;
}; };
using NetPtr = std::shared_ptr<Net>; using NetPtr = std::shared_ptr<Net>;
...@@ -86,7 +85,7 @@ class PlainNet : public Net { ...@@ -86,7 +85,7 @@ class PlainNet : public Net {
ops_.push_back(op); ops_.push_back(op);
} }
void CompleteAddOp() override; void CompleteAddOp(bool calculate = true) override;
std::string DebugString() const override; std::string DebugString() const override;
......
...@@ -63,5 +63,5 @@ TEST(OpKernel, all) { ...@@ -63,5 +63,5 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), std::runtime_error);
} }
...@@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) {
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) {
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) {
try { try {
paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what(); const char* err_msg = err.what();
...@@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
} }
class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
...@@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto; pd::OpProto op_proto;
pd::OpAttrChecker op_checker; pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet); ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
} }
...@@ -19,9 +19,8 @@ limitations under the License. */ ...@@ -19,9 +19,8 @@ limitations under the License. */
#include <memory> #include <memory>
#include <typeindex> #include <typeindex>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/framework/tensor_types.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -35,6 +34,15 @@ struct CastToPyBufferImpl; ...@@ -35,6 +34,15 @@ struct CastToPyBufferImpl;
namespace framework { namespace framework {
class Tensor { class Tensor {
template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
template <typename T, int MajorType, typename IndexType>
friend struct EigenVector;
public: public:
Tensor() : offset_(0) {} Tensor() : offset_(0) {}
...@@ -46,7 +54,7 @@ class Tensor { ...@@ -46,7 +54,7 @@ class Tensor {
} }
template <typename T> template <typename T>
T* raw_data() const { T* data() {
EnforceSufficientMemory<T>(); EnforceSufficientMemory<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
...@@ -86,66 +94,6 @@ class Tensor { ...@@ -86,66 +94,6 @@ class Tensor {
offset_); offset_);
} }
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(
raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
// flat to rank = 1
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}
// to TensorType Vec
template <typename T>
typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}
// to TensorType Matrix
template <typename T>
typename TTypes<T>::Matrix matrix() {
return tensor<T, 2>();
}
// const versions of all the methods above.
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}
template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}
template <typename T>
typename TTypes<T>::ConstVec vec() const {
return tensor<T, 1>();
}
template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, 2>();
}
template <typename T> template <typename T>
void ShareDataWith(const Tensor& src) { void ShareDataWith(const Tensor& src) {
src.EnforceSufficientMemory<T>(); src.EnforceSufficientMemory<T>();
...@@ -232,8 +180,6 @@ class Tensor { ...@@ -232,8 +180,6 @@ class Tensor {
// byte offset between PlaceHolder::ptr_ and where tensor's data really // byte offset between PlaceHolder::ptr_ and where tensor's data really
// begins. // begins.
size_t offset_; size_t offset_;
template <bool less, size_t i, typename... args>
friend struct paddle::pybind::details::CastToPyBufferImpl;
}; };
} // namespace framework } // namespace framework
......
...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) { ...@@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false; bool caught = false;
try { try {
src_tensor.data<double>(); src_tensor.data<double>();
} catch (paddle::framework::EnforceNotMet err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) { ...@@ -107,7 +107,7 @@ TEST(Tensor, ShareDataWith) {
bool caught = false; bool caught = false;
try { try {
dst_tensor.ShareDataWith<float>(src_tensor); dst_tensor.ShareDataWith<float>(src_tensor);
} catch (EnforceNotMet err) { } catch (std::runtime_error& err) {
caught = true; caught = true;
std::string msg = std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first."; "Tenosr holds no memory. Call Tensor::mutable_data first.";
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
// Helper to define Tensor types given that the scalar is of type T.
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstScalar;
// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;
// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
};
} // namespace framework
} // namespace paddle
...@@ -36,6 +36,7 @@ if(WITH_GPU) ...@@ -36,6 +36,7 @@ if(WITH_GPU)
add_simple_unittest(MulOpTest) add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest) add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest) add_simple_unittest(RowConvOpTest)
add_simple_unittest(CropOpTest)
endif() endif()
add_simple_unittest(ConvOpTest) add_simple_unittest(ConvOpTest)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CropOp.h"
#include "paddle/function/TensorShape.h"
#include "paddle/math/Vector.h"
namespace paddle {
template <>
void Crop<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
int cCrop = crop_corner[1];
int hCrop = crop_corner[2];
int wCrop = crop_corner[3];
int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
for (int n = 0; n < num; n++) {
for (int c = 0; c < outC; c++) {
for (int h = 0; h < outH; h++) {
int outoff = ((n * outC + c) * outH + h) * outW;
int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
}
}
}
}
template <>
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
real* outGrad,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
int cCrop = crop_corner[1];
int hCrop = crop_corner[2];
int wCrop = crop_corner[3];
int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
for (int n = 0; n < num; n++) {
for (int c = 0; c < inC; c++) {
for (int h = 0; h < inH; h++) {
int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
int inoff = ((n * inC + c) * inH + h) * inW;
CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
CpuVector outG = CpuVector(inW, outGrad + outoff);
outG += inG;
}
}
}
}
/**
* \brief Crop input according to the specify corner and shape.
* The input and output is a 4D tensor. In CropFunc, we only
* crop the 2nd to 4th dimension.
*
* Argument in this Function:
* \param pad_ A struct object contains the cropping corner and shape.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after cropping.
*
* For example,
* Input(2,2,2,3) = [
* [ [[1,2,3], [3,4,5]],
* [[2,3,5], [1,6,7]] ],
* [ [[4,3,1], [1,8,7]],
* [[3,8,9], [2,3,5]] ]
* ] # the input shape is (2,2,2,3)
*
* pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
* Output(2,2,1,2) = [
* [ [[4,5]],
* [[6,7]] ],
* [ [[8,7]],
* [[3,5]] ]
* ] # the input shape is (2,2,2,3)
*/
template <DeviceType Device>
class CropFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
TensorShape inShape = inputs[0].shape();
TensorShape outShape = outputs[0].shape();
Crop<Device>(outputs[0].data<real>(),
inputs[0].data<real>(),
inShape,
outShape,
conf_);
}
private:
FuncConfig conf_;
};
/**
* \brief The backward propagation of cropping Function.
*
* Argument in this Function:
* \param crop_ The same meaning as it in CropFunc.
* \param inputs The gradient with respect to the output value of CropFunc.
* \param outputs The gradient with respect to the input value of CropFunc.
*/
template <DeviceType Device>
class CropGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
TensorShape outShape = outputs[0].shape();
TensorShape inShape = inputs[0].shape();
CropGrad<Device>(inputs[0].data<real>(),
outputs[0].data<real>(),
inShape,
outShape,
conf_);
}
private:
FuncConfig conf_;
};
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
/**
* \brief This funtion crops inputs according to the specify start point and
*shape.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] inShape the shape of input tensor.
* \param[in] conf the cropping config
*/
template <DeviceType Device>
void Crop(real* outputs,
const real* inputs,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf);
/**
* \brief Cropping operation backward.
*
* \param[out] inGrad gradients of previous layer
* \param[in] outGrad output gradient
* \param[in] inShape the shape of input tensor.
* \param[in] conf the cropping config
*/
template <DeviceType Device>
void CropGrad(const real* inGrad,
real* outGrad,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "CropOp.h"
namespace paddle {
__global__ void KeCrop(real* outputs, const real* inputs,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % outW;
const int h = (idx / outW) % outH;
const int c = (idx / outW / outH) % outC;
const int n = idx / outW / outH / outC;
const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
outputs[idx] = inputs[off];
}
}
template <>
void Crop<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
int cropC = crop_corner[1];
int cropH = crop_corner[2];
int cropW = crop_corner[3];
int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
size_t nth = num * outC * outH * outW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("Crop");
}
__global__ void KeCropDiff(const real* inGrad, real* outGrad,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;
const int off =
((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;
outGrad[off] += inGrad[idx];
}
}
template <>
void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
real* outGrad,
const TensorShape inShape,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
int cropC = crop_corner[1];
int cropH = crop_corner[2];
int cropW = crop_corner[3];
int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;
KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("CropGrad");
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "FunctionTest.h"
namespace paddle {
TEST(Crop, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {5, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {false, true}) {
CpuGpuFuncCompare compare(
test_grad ? "CropGrad" : "Crop",
FuncConfig()
.set<std::vector<uint32_t>>("crop_corner", {0, 1, 1, 1})
.set<std::vector<uint32_t>>("crop_shape", {0, 2, 3, 3}));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{numSamples, 2, 3, 3};
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT,
test_grad ? inDims : outDims,
test_grad ? ADD_TO : ASSIGN_TO),
test_grad ? ADD_TO : ASSIGN_TO);
compare.run();
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "CropLayer.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(crop, CropLayer);
bool CropLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_LE(static_cast<int>(inputLayers_.size()), 2);
CHECK_GE(static_cast<int>(inputLayers_.size()), 1);
crop_axis_ = config_.axis();
for (int i = 0; i < config_.offset_size(); i++) {
crop_offsets_.push_back(config_.offset(i));
}
// 1. get input_0 shape
auto& input0_img_conf = config_.inputs(0).image_conf();
inDims_ = TensorShape({0,
input0_img_conf.channels(),
input0_img_conf.has_img_size_y()
? input0_img_conf.img_size_y()
: input0_img_conf.img_size(),
input0_img_conf.img_size()});
// 2. get target dims from config
if (config_.inputs_size() == 1) {
targetDims_ = TensorShape({config_.shape(0),
config_.shape(1),
config_.shape(2),
config_.shape(3)});
} else {
// 2. get input_1 shape
auto& input1_img_conf = config_.inputs(1).image_conf();
targetDims_ = TensorShape({0,
input1_img_conf.channels(),
input1_img_conf.has_img_size_y()
? input1_img_conf.img_size_y()
: input1_img_conf.img_size(),
input1_img_conf.img_size()});
}
// 3. get final crop corner
int dimSize = 4;
crop_corner_ = {0, 0, 0, 0};
for (int i = 0; i < dimSize; i++) {
if (i >= crop_axis_) {
if (crop_offsets_.size() > 1) {
crop_corner_[i] = crop_offsets_[i - crop_axis_];
} else {
crop_corner_[i] = crop_offsets_[0];
}
}
}
outDims_ = TensorShape(4);
createFunction(
forward_, "Crop", FuncConfig().set("crop_corner", crop_corner_));
createFunction(
backward_, "CropGrad", FuncConfig().set("crop_corner", crop_corner_));
return true;
}
void CropLayer::setOutDims() {
MatrixPtr input = inputLayers_[1]->getOutputValue();
size_t batchSize = input->getHeight();
// get target dims from input_1
if (config_.inputs_size() == 2) {
targetDims_.setDim(0, batchSize);
int ch = config_.inputs(0).image_conf().channels();
if (ch != 0) targetDims_.setDim(1, ch);
int h = inputLayers_[1]->getOutput().getFrameHeight();
if (h != 0) targetDims_.setDim(2, h);
int w = inputLayers_[1]->getOutput().getFrameWidth();
if (w != 0) targetDims_.setDim(3, w);
}
// get final crop shape from target dims and crop axis
std::vector<uint32_t> crop_shape;
int dimSize = 4;
for (int i = 0; i < dimSize; i++) {
if (i >= crop_axis_) {
crop_shape.push_back(targetDims_[i]);
} else {
crop_shape.push_back(inDims_[i]);
}
}
outDims_.reshape(
{crop_shape[0], crop_shape[1], crop_shape[2], crop_shape[3]});
output_.setFrameHeight(crop_shape[2]);
output_.setFrameWidth(crop_shape[3]);
}
void CropLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
}
void CropLayer::forward(PassType passType) {
Layer::forward(passType);
setInDims();
setOutDims();
int size = outDims_[1] * outDims_[2] * outDims_[3];
resetOutput(outDims_[0], size);
MatrixPtr outV = getOutputValue();
REGISTER_TIMER_INFO("CropForward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), inDims_);
outputs.addArg(*getOutputValue(), outDims_, ASSIGN_TO);
forward_[0]->calc(inputs, outputs);
}
void CropLayer::backward(const UpdateCallback& callback) {
(void)callback;
REGISTER_TIMER_INFO("CropBackward", getName().c_str());
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getOutputGrad(), outDims_);
outputs.addArg(*getInputGrad(0), inDims_, ADD_TO);
backward_[0]->calc(inputs, outputs);
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
namespace paddle {
/**
* \brief This layer crop input according to the specify conf.
* input_0: input to be cropped
* input_1: optional reference input
* axis: start dimension to be croped
* offset: offset of cropping in each dimension
* shape: if reference input layer was not setted,
* crop input as this shape conf
*/
class CropLayer : public Layer {
public:
explicit CropLayer(const LayerConfig& config) : Layer(config) {}
~CropLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
protected:
void setOutDims();
void setInDims();
int32_t crop_axis_;
std::vector<uint32_t> crop_offsets_;
std::vector<uint32_t> crop_corner_;
TensorShape inDims_;
TensorShape targetDims_;
TensorShape outDims_;
};
} // namespace paddle
...@@ -1802,6 +1802,34 @@ TEST(Layer, RowConvLayer) { ...@@ -1802,6 +1802,34 @@ TEST(Layer, RowConvLayer) {
} }
} }
TEST(Layer, CropLayer) {
TestConfig config;
// config input_0
config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
ImageConfig* img = input->mutable_image_conf();
img->set_channels(4);
img->set_img_size(16);
config.layerConfig.set_axis(2);
config.layerConfig.add_offset(0);
config.layerConfig.add_offset(0);
// config input_1
config.inputDefs.push_back({INPUT_DATA, "layer_1", 128, 0});
input = config.layerConfig.add_inputs();
img = input->mutable_image_conf();
img->set_channels(2);
img->set_img_size(8);
// config crop layer
config.layerConfig.set_type("crop");
config.layerConfig.set_name("cropLayer");
for (auto useGpu : {false, true}) {
testLayerGrad(config, "crop", 100, false, useGpu, false);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
## Design # Region-based Heterogeneous Memory Management
### Usage Please check out the [design documentation](http://gangliao.me) to find out more details about
buddy memory allocator for both CPU and GPU.
To allocate 4KB CPU memory:
```cpp
p = memory::Alloc(platform::CPUPlace(), 4*1024);
```
To allocate 4KB memory on the 3rd GPU:
```cpp
p = memory::Alloc(platform::GPUPlace(2), 4*1024);
```
To free memory and check the so-far used amount of memory on a place:
```cpp
auto pl = platform::GPUPlace(0);
p = memory::Alloc(pl, 4*1024);
cout << memory::Used(pl);
memory::Free(pl, p);
```
### API
In `paddle/memory/memory.h` we have:
```cpp
namespace memory {
template <typename Place> void* Alloc(Place, size_t);
template <typename Place> void Free(Place, void*);
template <typename Place> size_t Used(Place);
} // namespace memory
```
These function templates have specializations on either `platform::CPUPlace` or `platform::GPUPlace`:
```cpp
template<>
void* Alloc<CPUPlace>(CPUPlace p, size_t size) {
return GetCPUBuddyAllocator()->Alloc(size);
}
```
and
```cpp
template<>
void Alloc<GPUPlace>(GPUPlace p, size_t size) {
return GetGPUBuddyAllocator(p.id)->Alloc(size);
}
```
Similar specializations exist for `Free` and `Used`.
### Implementation
`GetCPUBuddyAllocator` and `GetGPUBuddyAllocator` are singletions.
```cpp
BuddyAllocator* GetCPUBuddyAllocator() {
static BuddyAllocator* a = NULL;
if (a == NULL) {
a = new BuddyAllocator(new CPUAllocator /*backup allocator*/, ...);
}
return a;
}
BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
static BuddyAllocator* as = NULL;
if (as == NULL) {
as = new BuddyAllocator*[platform::NumGPUs()];
for (int gpu = 0; gpu < platform::NumGPUs(); gpu++) {
as[gpu] = new BuddyAllocator(new GPUAllocator(gpu) /* backup allocator */, ...);
}
}
return as[gpu_id);
```
#### `BuddyAllocator`
`BuddyAllocator` implements the buddy allocation algorithm. Its constructor takes parameters only related with the algorithm:
```cpp
BuddyAllocator::BuddyAllocator(initial_pool_size, max_pool_size) {
...
}
```
Please be aware that **`BuddyAllocator` always allocate aligned memory**, aligned on 32-bytes, which can hold a `BuddyAllocator::Block` object:
```cpp
class BuddyAllocator {
private:
struct Block {
size_t size;
Block* left, right;
size_t index; // allocator id
};
...
};
```
Because BuddyAllocator has the meta-data of each block, it can trace the used memory -- record the amount returned by `Alloc` freed in `Free`. Instead, `CPUAllocator` and `GPUAllocator` doesn't know the size of freed memory block and cannot do the trace.
#### System Allocators
The `GPUAllocator` and `CPUAllocator` are calls *system allocators*. They work as the fallback allocators of `BuddyAllocator`.
## Justification
I got inspiration from Majel and Caffe2, though above design look different from both.
### Caffe2
In Caffe2, `Tensor<Context>::mutable_data()` allocates the memroy. In particular, [`Tensor<Context>::mutable_data`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L523) calls [`Tensor<Context>::raw_mutable_data`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L459), which in turn calls [`Context::New`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/tensor.h#L479).
There are two implementations of `Context`:
1. [`CPUContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L105), whose [`New` method](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L131) calls [`g_cpu_allocator.get()->New(size_t)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.cc#L15) to allocate the memory.
1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::GPUPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory.
### Majel
In Majel, there are basically two allocator types:
1. `cpu::SystemAllocator`, which has similar functionality to `caffe2::CPUContext::New/Delete`.
1. `gpu::SystemAllocator`, which has similar functionality to `caffe2::CUDAContext::New/Delete`.
However, memory allocation is not via these two allocators. Instead, these two allocators are defined in hidden namespaces.
In Majel there are hidden global variables like:
1. `cpu::SystemAllocator g_cpu_allocator`, and
1. `vector<gpu::SystemAllocator*> g_gpu_allocators(NUM_GPUS)`.
Programs allocate memory via a BuddyAllocator, which can take the `g_cpu_allocator` or a `g_gpu_allocators[gpu_id]` as its *fallback allocator*, so that if BuddyAllocator cannot find a block in its memory pool, it extends its memory pool by calling the fallback allocator's `New(size_t)`.
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/memory/detail/system_allocator.h" #include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/platform/error.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#include <stdlib.h> // for malloc and free #include <stdlib.h> // for malloc and free
...@@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) { ...@@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
// process is terminating, in which case we don't care if // process is terminating, in which case we don't care if
// cudaFree succeeds. // cudaFree succeeds.
if (err != cudaErrorCudartUnloading) { if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
"cudaFree{Host} failed in GPUAllocator::Free.");
} }
} }
......
...@@ -27,7 +27,8 @@ function(op_library TARGET) ...@@ -27,7 +27,8 @@ function(op_library TARGET)
endif() endif()
list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_srcs cu_srcs_len)
if (${cu_srcs_len} EQUAL 0) list(LENGTH op_library_DEPS dep_len)
if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0)
message(WARNING "The op library ${TARGET} not support GPU!") message(WARNING "The op library ${TARGET} not support GPU!")
endif() endif()
...@@ -47,3 +48,9 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) ...@@ -47,3 +48,9 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -29,8 +30,10 @@ public: ...@@ -29,8 +30,10 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
output->flat<T>().device(*(context.GetEigenDevice<Place>())) = framework::EigenVector<T>::Flatten(*output).device(
input0.flat<T>() + input1.flat<T>(); *(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
"Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(outputs[0] != nullptr,
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->set_dims(framework::make_ddim({inputs[0]->dims()[0]}));
}
};
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
AddOutput("Y", "The output of OnehotCrossEntropyOp");
AddComment(R"DOC(
OnehotCrossEntropy Operator.
Y[i] = -log(X[i][j])
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOp,
paddle::operators::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<::paddle::platform::CPUPlace,
float>);
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
paddle::operators::OnehotCrossEntropyOpKernel<
::paddle::platform::GPUPlace, float>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const framework::KernelContext& context) const override {
auto X = context.Input(0)->Get<framework::Tensor>();
const T* X_data = X.data<T>();
const int* label_data =
context.Input(1)->Get<framework::Tensor>().data<int>();
auto* Y = context.Output(0)->GetMutable<framework::Tensor>();
Y->mutable_data<T>(context.GetPlace());
T* Y_data = Y->data<T>();
int batch_size = X.dims()[0];
int class_num = X.dims()[1];
// Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) {
Y_data[i] = -std::log(
std::max(X_data[i * class_num + label_data[i]], LOG_THRESHOLD()));
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
class FullyConnectedOp : public framework::PlainNet {
public:
void Init() override {
AddOp(framework::OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")},
{}));
auto b = Input("b");
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")},
{}));
}
auto activation = GetAttr<std::string>("activation");
AddOp(framework::OpRegistry::CreateOp(
activation, {Output("before_act")}, {Output("Y")}, {}));
CompleteAddOp(false);
}
};
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator");
AddOutput(
"before_act", "the before activation output of fc operator", true);
AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"});
//! TODO(yuyang18): Complete comment;
AddComment("FullyConnected Operator");
}
};
} // namespace operators
} // namespace paddle
USE_OP(mul);
USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
REGISTER_OP(fc,
paddle::operators::FullyConnectedOp,
paddle::operators::FullyConnectedOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of SGD Op's dimension must be same.");
outputs[0]->set_dims(inputs[0]->dims());
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("grad", "input gradient");
AddOutput("param_out", "output parameter");
AddAttr<float>("learning_rate", "learning rate of sgd");
AddComment(R"DOC(
Simplest sgd algorithm.
param_out = param - learning_rate * grad;
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
SGDOpKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
#include "paddle/operators/sgd_op.h"
#include "paddle/framework/op_registry.h"
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& ctx) const override {
auto param = ctx.Input("param")->Get<framework::Tensor>();
auto grad = ctx.Input("grad")->Get<framework::Tensor>();
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
framework::EigenVector<T>::Flatten(*param_out)
.device(*(ctx.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(param) -
lr * framework::EigenVector<T>::Flatten(grad);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_registry.h>
USE_OP(sgd);
TEST(SGDOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("sgd");
ASSERT_NE(it, protos.end());
}
...@@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) ...@@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
add_subdirectory(dynload) add_subdirectory(dynload)
cc_test(enforce_test SRCS enforce_test.cc)
IF(WITH_GPU) IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader) set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE() ELSE()
......
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#endif #endif
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/platform/error.h"
DEFINE_double(fraction_of_cpu_memory_to_use, 1, DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle," "Default use 100% of CPU memory for PaddlePaddle,"
......
...@@ -11,12 +11,13 @@ limitations under the License. */ ...@@ -11,12 +11,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h" #include "paddle/platform/dynload/curand.h"
#include "paddle/platform/error.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
...@@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) { explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_), PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
"cudaStreamCreate failed");
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_)); eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
} }
...@@ -83,7 +83,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -83,7 +83,7 @@ class CUDADeviceContext : public DeviceContext {
} }
void Wait() { void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_), PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed"); "cudaStreamSynchronize failed");
} }
...@@ -94,11 +94,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -94,11 +94,10 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle() { cublasHandle_t cublas_handle() {
if (!blas_handle_) { if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) == PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
CUBLAS_STATUS_SUCCESS,
"cublasCreate failed"); "cublasCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream( PADDLE_ENFORCE(
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS, paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed"); "cublasSetStream failed");
} }
return blas_handle_; return blas_handle_;
...@@ -107,11 +106,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -107,11 +106,10 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle() { cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) { if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) == PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
CUDNN_STATUS_SUCCESS,
"cudnnCreate failed"); "cudnnCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream( PADDLE_ENFORCE(
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS, paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed"); "cudnnSetStream failed");
} }
return dnn_handle_; return dnn_handle_;
...@@ -121,15 +119,14 @@ class CUDADeviceContext : public DeviceContext { ...@@ -121,15 +119,14 @@ class CUDADeviceContext : public DeviceContext {
if (!rand_generator_) { if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_); GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator( PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) == &rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
CURAND_STATUS_SUCCESS,
"curandCreateGenerator failed"); "curandCreateGenerator failed");
PADDLE_ENFORCE( PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed( paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS, rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed"); "curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream( PADDLE_ENFORCE(
rand_generator_, stream_) == CURAND_STATUS_SUCCESS, paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed"); "curandSetStream failed");
} }
return rand_generator_; return rand_generator_;
...@@ -138,26 +135,23 @@ class CUDADeviceContext : public DeviceContext { ...@@ -138,26 +135,23 @@ class CUDADeviceContext : public DeviceContext {
~CUDADeviceContext() { ~CUDADeviceContext() {
Wait(); Wait();
if (blas_handle_) { if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) == PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
CUBLAS_STATUS_SUCCESS,
"cublasDestroy failed"); "cublasDestroy failed");
} }
if (dnn_handle_) { if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) == PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
CUDNN_STATUS_SUCCESS,
"cudnnDestroy failed"); "cudnnDestroy failed");
} }
if (rand_generator_) { if (rand_generator_) {
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator( PADDLE_ENFORCE(
rand_generator_) == CURAND_STATUS_SUCCESS, paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed"); "curandDestroyGenerator failed");
} }
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
paddle::platform::throw_on_error(cudaStreamDestroy(stream_), PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
"cudaStreamDestroy failed");
} }
private: private:
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/framework/enforce.h" #include "paddle/platform/enforce.h"
DEFINE_string(cudnn_dir, "", DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, " "Specify path for loading libcudnn.so. For instance, "
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/string/printf.h>
#include <sstream>
#include <stdexcept>
#include <string>
#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include <cublas_v2.h>
#include <cudnn.h>
#include <curand.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
namespace paddle {
namespace platform {
// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
// This generates faster binary code. __builtin_expect is since C++11.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#ifndef PADDLE_ONLY_CPU
template <typename... Args>
inline void throw_on_error(cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
// clang-format off
throw thrust::system_error(
e, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
}
}
template <typename... Args>
inline void throw_on_error(curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off
throw thrust::system_error(
cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
}
}
template <typename... Args>
inline void throw_on_error(cudnnStatus_t stat, const Args&... args) {
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
// clang-format off
throw std::runtime_error(
platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
// clang-format on
}
}
template <typename... Args>
inline void throw_on_error(cublasStatus_t stat, const Args&... args) {
std::string err;
if (stat == CUBLAS_STATUS_SUCCESS) {
return;
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
err = "CUBLAS: not initialized, ";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
err = "CUBLAS: alloc failed, ";
} else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
err = "CUBLAS: invalid value, ";
} else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
err = "CUBLAS: arch mismatch, ";
} else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
err = "CUBLAS: mapping error, ";
} else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
err = "CUBLAS: execution failed, ";
} else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
err = "CUBLAS: internal error, ";
} else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
err = "CUBLAS: not supported, ";
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
err = "CUBLAS: license error, ";
}
throw std::runtime_error(err + string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
}
#endif // PADDLE_ONLY_CPU
template <typename... Args>
inline void throw_on_error(int stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
throw std::runtime_error(
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
}
}
#define PADDLE_THROW(...) \
do { \
throw std::runtime_error( \
string::Sprintf(__VA_ARGS__) + \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
} while (0)
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#define PADDLE_ENFORCE(condition, ...) \
do { \
::paddle::platform::throw_on_error(condition, __VA_ARGS__); \
} while (0)
} // namespace platform
} // namespace paddle
...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gtest/gtest.h> #include "paddle/platform/enforce.h"
#include <paddle/framework/enforce.h> #include "gtest/gtest.h"
TEST(ENFORCE, OK) { TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
...@@ -23,10 +23,11 @@ TEST(ENFORCE, FAILED) { ...@@ -23,10 +23,11 @@ TEST(ENFORCE, FAILED) {
bool in_catch = false; bool in_catch = false;
try { try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::framework::EnforceNotMet err) { } catch (const std::runtime_error& error) {
// your error handling code here
in_catch = true; in_catch = true;
std::string msg = "Enforce is not ok 123 at all"; std::string msg = "Enforce is not ok 123 at all";
const char* what = err.what(); const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) { for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]); ASSERT_EQ(what[i], msg[i]);
} }
......
#pragma once
#include <sstream>
#include <stdexcept>
#include <string>
#ifndef PADDLE_ONLY_CPU
#include <cublas_v2.h>
#include <cudnn.h>
#include <curand.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif // PADDLE_ONLY_CPU
namespace paddle {
namespace platform {
#ifndef PADDLE_ONLY_CPU
inline void throw_on_error(cudaError_t e, const char* message) {
if (e) {
throw thrust::system_error(e, thrust::cuda_category(), message);
}
}
inline void throw_on_error(curandStatus_t stat, const char* message) {
if (stat != CURAND_STATUS_SUCCESS) {
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
message);
}
}
inline void throw_on_error(cudnnStatus_t stat, const char* message) {
std::stringstream ss;
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
ss << cudnnGetErrorString(stat);
ss << ", " << message;
throw std::runtime_error(ss.str());
}
}
inline void throw_on_error(cublasStatus_t stat, const char* message) {
std::stringstream ss;
if (stat == CUBLAS_STATUS_SUCCESS) {
return;
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
ss << "CUBLAS: not initialized";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
ss << "CUBLAS: alloc failed";
} else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
ss << "CUBLAS: invalid value";
} else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
ss << "CUBLAS: arch mismatch";
} else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
ss << "CUBLAS: mapping error";
} else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
ss << "CUBLAS: execution failed";
} else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
ss << "CUBLAS: internal error";
} else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
ss << "CUBLAS: not supported";
} else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
ss << "CUBLAS: license error";
}
ss << ", " << message;
throw std::runtime_error(ss.str());
}
inline void throw_on_error(cublasStatus_t stat) {
const char* message = "";
throw_on_error(stat, message);
}
#endif // PADDLE_ONLY_CPU
inline void throw_on_error(int stat, const char* message) {
if (stat) {
throw std::runtime_error(message + (", stat = " + std::to_string(stat)));
}
}
} // namespace platform
} // namespace paddle
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/platform/error.h" #include "paddle/platform/enforce.h"
DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, DEFINE_double(fraction_of_gpu_memory_to_use, 0.95,
"Default use 95% of GPU memory for PaddlePaddle," "Default use 95% of GPU memory for PaddlePaddle,"
...@@ -25,7 +25,7 @@ namespace platform { ...@@ -25,7 +25,7 @@ namespace platform {
int GetDeviceCount() { int GetDeviceCount() {
int count; int count;
throw_on_error( PADDLE_ENFORCE(
cudaGetDeviceCount(&count), cudaGetDeviceCount(&count),
"cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount");
return count; return count;
...@@ -33,19 +33,19 @@ int GetDeviceCount() { ...@@ -33,19 +33,19 @@ int GetDeviceCount() {
int GetCurrentDeviceId() { int GetCurrentDeviceId() {
int device_id; int device_id;
throw_on_error( PADDLE_ENFORCE(
cudaGetDevice(&device_id), cudaGetDevice(&device_id),
"cudaGetDevice failed in paddle::platform::GetCurrentDeviceId"); "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId");
return device_id; return device_id;
} }
void SetDeviceId(int id) { void SetDeviceId(int id) {
throw_on_error(cudaSetDevice(id), PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId"); "cudaSetDevice failed in paddle::platform::SetDeviceId");
} }
void GpuMemoryUsage(size_t& available, size_t& total) { void GpuMemoryUsage(size_t& available, size_t& total) {
throw_on_error(cudaMemGetInfo(&available, &total), PADDLE_ENFORCE(cudaMemGetInfo(&available, &total),
"cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage");
} }
......
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
add_op mul_op rowwise_add_op sigmoid_op softmax_op) add_op fc_op sgd_op cross_entropy_op)
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <Python.h> #include <Python.h>
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/scope.h> #include <paddle/framework/scope.h>
#include <paddle/pybind/tensor_bind.h> #include <paddle/pybind/tensor_bind.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
...@@ -26,10 +27,9 @@ namespace py = pybind11; ...@@ -26,10 +27,9 @@ namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
USE_OP(add_two); USE_OP(add_two);
USE_OP(softmax); USE_OP(onehot_cross_entropy);
USE_OP(mul); USE_OP_WITHOUT_KERNEL(fc);
USE_OP(rowwise_add); USE_OP(sgd);
USE_OP(sigmoid);
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of Paddle Paddle"); py::module m("core", "C++ core of Paddle Paddle");
...@@ -53,7 +53,9 @@ PYBIND11_PLUGIN(core) { ...@@ -53,7 +53,9 @@ PYBIND11_PLUGIN(core) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(paddle::platform::CPUPlace());
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>); .def("set", paddle::pybind::PyTensorSetFromArray<int>)
.def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); });
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class. py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class.
...@@ -83,15 +85,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -83,15 +85,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<std::string> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto& protos = pd::OpRegistry::protos(); auto& protos = pd::OpRegistry::protos();
std::vector<std::string> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = protos.begin(); it != protos.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), PADDLE_ENFORCE(it->second.IsInitialized(),
"OpProto must all be initialized"); "OpProto must all be initialized");
ret_values.emplace_back(); std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), PADDLE_ENFORCE(it->second.SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str));
} }
return ret_values; return ret_values;
}); });
...@@ -101,9 +104,15 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -101,9 +104,15 @@ All parameter, weight, gradient are variables in Paddle.
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
});
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator") py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
.def("__str__", &pd::OperatorBase::DebugString) .def("__str__", &pd::OperatorBase::DebugString)
.def_static("create", [](const std::string& protobin) { .def_static("create",
[](py::bytes protobin) {
pd::OpDesc desc; pd::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
...@@ -111,7 +120,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -111,7 +120,10 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return pd::OpRegistry::CreateOp(desc); return pd::OpRegistry::CreateOp(desc);
}); })
.def("infer_shape", &pd::OperatorBase::InferShape)
.def("run", &pd::OperatorBase::Run)
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
return m.ptr(); return m.ptr();
} }
...@@ -476,6 +476,12 @@ message LayerConfig { ...@@ -476,6 +476,12 @@ message LayerConfig {
// controls the scope of pooling operation. can be set > 0. // controls the scope of pooling operation. can be set > 0.
// leave empty or set to -1 to disable this stride pooling. // leave empty or set to -1 to disable this stride pooling.
optional int32 seq_pool_stride = 53 [default = -1]; optional int32 seq_pool_stride = 53 [default = -1];
// for crop layer
optional int32 axis = 54 [default = 2];
repeated uint32 offset = 55;
repeated uint32 shape = 56;
} }
message EvaluatorConfig { message EvaluatorConfig {
......
...@@ -1998,6 +1998,23 @@ class PadLayer(LayerBase): ...@@ -1998,6 +1998,23 @@ class PadLayer(LayerBase):
self.config.size = out_ch * out_h * out_w self.config.size = out_ch * out_h * out_w
@config_layer('crop')
class CropLayer(LayerBase):
def __init__(self, name, inputs, axis, offset, shape, **xargs):
super(CropLayer, self).__init__(name, 'crop', 0, inputs=inputs, **xargs)
self.config.axis = axis
self.config.offset.extend(offset)
self.config.shape.extend(shape)
# get channel, width and height from input_0 layer
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
image_conf.img_size = input_layer.width
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
@config_layer('batch_norm') @config_layer('batch_norm')
class BatchNormLayer(LayerBase): class BatchNormLayer(LayerBase):
layer_type = 'batch_norm' layer_type = 'batch_norm'
......
...@@ -127,6 +127,7 @@ __all__ = [ ...@@ -127,6 +127,7 @@ __all__ = [
'dropout_layer', 'dropout_layer',
'prelu_layer', 'prelu_layer',
'gated_unit_layer', 'gated_unit_layer',
'crop_layer',
] ]
...@@ -218,6 +219,7 @@ class LayerType(object): ...@@ -218,6 +219,7 @@ class LayerType(object):
SMOOTH_L1 = 'smooth_l1' SMOOTH_L1 = 'smooth_l1'
PRELU = 'prelu' PRELU = 'prelu'
CROP_LAYER = 'crop'
@staticmethod @staticmethod
def is_layer_type(type_name): def is_layer_type(type_name):
...@@ -3171,11 +3173,11 @@ def memory(name, ...@@ -3171,11 +3173,11 @@ def memory(name,
@wrap_bias_attr_default() @wrap_bias_attr_default()
@wrap_act_default( @wrap_act_default(param_names=['gate_act'], act=SigmoidActivation())
param_names=['gate_act', 'state_act'], act=SigmoidActivation()) @wrap_act_default(param_names=['state_act'], act=TanhActivation())
@wrap_act_default(act=TanhActivation()) @wrap_act_default(act=TanhActivation())
@wrap_name_default('lstm_step') @wrap_name_default('lstm_step')
@layer_support() @layer_support(ERROR_CLIPPING, DROPOUT)
def lstm_step_layer(input, def lstm_step_layer(input,
state, state,
size=None, size=None,
...@@ -3529,12 +3531,7 @@ def SubsequenceInput(input): ...@@ -3529,12 +3531,7 @@ def SubsequenceInput(input):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -3600,21 +3597,12 @@ def recurrent_group(step, ...@@ -3600,21 +3597,12 @@ def recurrent_group(step,
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
:type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
model_type('recurrent_nn') model_type('recurrent_nn')
def is_single_input(x): if isinstance(input, LayerOutput) or isinstance(input, StaticInput):
return isinstance(x, LayerOutput) or isinstance(x, StaticInput)
if is_single_input(input):
input = [input] input = [input]
assert isinstance(input, collections.Sequence) assert isinstance(input, collections.Sequence)
...@@ -3628,13 +3616,8 @@ def recurrent_group(step, ...@@ -3628,13 +3616,8 @@ def recurrent_group(step,
in_links=map(lambda x: x.name, in_links), in_links=map(lambda x: x.name, in_links),
seq_reversed=reverse) seq_reversed=reverse)
in_args = [] in_args = []
has_LayerOutput = False
for each_input in input: for each_input in input:
assert is_single_input(each_input) if isinstance(each_input, StaticInput): # StaticInput
if isinstance(each_input, LayerOutput):
in_args.append(each_input)
has_LayerOutput = True
else: # StaticInput
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=None, name=None,
...@@ -3642,24 +3625,26 @@ def recurrent_group(step, ...@@ -3642,24 +3625,26 @@ def recurrent_group(step,
boot_layer=each_input.input) boot_layer=each_input.input)
mem.set_input(mem) mem.set_input(mem)
in_args.append(mem) in_args.append(mem)
else:
assert (is_generating != has_LayerOutput) in_args.append(each_input)
layer_outs = step(*in_args) layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
layer_outs = [layer_outs] layer_outs = [layer_outs]
for ot in layer_outs: for layer_out in layer_outs:
assert isinstance(ot, LayerOutput) assert isinstance(
ot.reverse = reverse layer_out, LayerOutput
RecurrentLayerGroupSetOutLink(ot.name) ), "Type of step function's return value must be LayerOutput."
layer_out.reverse = reverse
RecurrentLayerGroupSetOutLink(layer_out.name)
RecurrentLayerGroupEnd(name=name) RecurrentLayerGroupEnd(name=name)
for layer_out in layer_outs: for layer_out in layer_outs:
# Thee previous full_name is the name is the rnn group # The previous full_name is the name inside the recurrent group.
# We need a full_name outside the rnn group # We need a full_name outside the recurrent group.
layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name) layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name)
if len(layer_outs) == 1: if len(layer_outs) == 1:
...@@ -3682,7 +3667,20 @@ class BaseGeneratedInput(object): ...@@ -3682,7 +3667,20 @@ class BaseGeneratedInput(object):
class GeneratedInput(BaseGeneratedInput): class GeneratedInput(BaseGeneratedInput):
def after_real_step(self, input): def after_real_step(self, input):
return maxid_layer(input=input, name='__beam_search_predict__') if isinstance(input, LayerOutput):
input = [input]
elif isinstance(input, collections.Sequence):
input = list(input)
if len(input) > 1:
logger.info(
("More than one layers inside the recurrent_group "
"are returned as outputs of the entire recurrent_group "
"PLEASE garantee the first output is probability of "
"the predicted next word."))
return [maxid_layer(
input=input[0], name='__beam_search_predict__')] + (
input[1:] if len(input) > 1 else [])
def before_real_step(self): def before_real_step(self):
predict_id = memory( predict_id = memory(
...@@ -3869,6 +3867,7 @@ def beam_search(step, ...@@ -3869,6 +3867,7 @@ def beam_search(step,
:type step: callable :type step: callable
:param input: Input data for the recurrent unit, which should include the :param input: Input data for the recurrent unit, which should include the
previously generated words as a GeneratedInput object. previously generated words as a GeneratedInput object.
In beam_search, none of the input's type should be LayerOutput.
:type input: list :type input: list
:param bos_id: Index of the start symbol in the dictionary. The start symbol :param bos_id: Index of the start symbol in the dictionary. The start symbol
is a special token for NLP task, which indicates the is a special token for NLP task, which indicates the
...@@ -3910,15 +3909,18 @@ def beam_search(step, ...@@ -3910,15 +3909,18 @@ def beam_search(step,
real_input = [] real_input = []
for i, each_input in enumerate(input): for i, each_input in enumerate(input):
assert isinstance(each_input, StaticInput) or isinstance( assert not isinstance(each_input, LayerOutput), (
each_input, BaseGeneratedInput) "in beam_search, "
"none of the input should has a type of LayerOutput.")
if isinstance(each_input, BaseGeneratedInput): if isinstance(each_input, BaseGeneratedInput):
assert generated_input_index == -1 assert generated_input_index == -1, ("recurrent_group accepts "
"only one GeneratedInput.")
generated_input_index = i generated_input_index = i
else: else:
real_input.append(each_input) real_input.append(each_input)
assert generated_input_index != -1 assert generated_input_index != -1, "No GeneratedInput is given."
gipt = input[generated_input_index] gipt = input[generated_input_index]
...@@ -3939,17 +3941,11 @@ def beam_search(step, ...@@ -3939,17 +3941,11 @@ def beam_search(step,
predict = gipt.after_real_step(step(*args)) predict = gipt.after_real_step(step(*args))
eos_layer(input=predict, eos_id=eos_id, name=eos_name) eos_layer(input=predict[0], eos_id=eos_id, name=eos_name)
return predict return predict
tmp = recurrent_group( return recurrent_group(
step=__real_step__, step=__real_step__, input=real_input, reverse=False, name=name)
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp
def __cost_input__(input, label, weight=None): def __cost_input__(input, label, weight=None):
...@@ -5970,3 +5966,52 @@ def gated_unit_layer(input, ...@@ -5970,3 +5966,52 @@ def gated_unit_layer(input,
name="%s_gated_act" % name, name="%s_gated_act" % name,
input=dotmul_operator(input_proj, gate), input=dotmul_operator(input_proj, gate),
layer_attr=layer_attr) layer_attr=layer_attr)
@wrap_name_default()
@layer_support()
def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
"""
The crop layer crops images by offset and shape. User can set crop shape by
args 'shape' explicitly or by reference input layer.
The example usage is:
.. code-block:: python
crop = crop_layer(input=[image_input, reference_input], axis=2, offset=[2, 3])
:param input: The input layer.If two inputs were setted,
the second input will be regarded as reference input
:type input: LayerOutput or Sequence
:param offset: The crop offset
:type offset: Sequence
:param axis: start axis to be cropped. To image input layer:
- 0: batch size
- 1: channels
- 2: height
- 3: width
:type partial_sum: int
:param shape: The shape to be cropped. Default is None.
:type shape: Sequence | None
:param name: Name of this layer.
:type name: basestring
:return: LayerOutput object.
:rtype: LayerOutput
"""
if isinstance(input, LayerOutput):
input = [input]
else:
assert isinstance(input, collections.Sequence)
l = Layer(
inputs=[x.name for x in input],
axis=axis,
offset=offset,
shape=shape,
name=name,
type=LayerType.CROP_LAYER,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.CROP_LAYER,
parents=input,
size=l.config.size)
...@@ -614,18 +614,17 @@ def simple_lstm(input, ...@@ -614,18 +614,17 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit') @wrap_name_default('lstm_unit')
def lstmemory_unit(input, def lstmemory_unit(input,
memory_boot=None, out_memory=None,
name=None, name=None,
size=None, size=None,
param_attr=None, param_attr=None,
act=None, act=None,
gate_act=None, gate_act=None,
state_act=None, state_act=None,
mixed_bias_attr=None, input_proj_bias_attr=None,
input_proj_layer_attr=None,
lstm_bias_attr=None, lstm_bias_attr=None,
mixed_layer_attr=None, lstm_layer_attr=None):
lstm_layer_attr=None,
get_output_layer_attr=None):
""" """
Define calculations that a LSTM unit performs during a single time step. Define calculations that a LSTM unit performs during a single time step.
This function itself is not a recurrent layer, so it can not be This function itself is not a recurrent layer, so it can not be
...@@ -662,8 +661,8 @@ def lstmemory_unit(input, ...@@ -662,8 +661,8 @@ def lstmemory_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell. :param out_memory: output of previous time step
:type memory_boot: LayerOutput | None :type out_memory: LayerOutput | None
:param name: lstmemory unit name. :param name: lstmemory unit name.
:type name: basestring :type name: basestring
:param size: lstmemory unit size. :param size: lstmemory unit size.
...@@ -676,33 +675,35 @@ def lstmemory_unit(input, ...@@ -676,33 +675,35 @@ def lstmemory_unit(input,
:type gate_act: BaseActivation :type gate_act: BaseActivation
:param state_act: lstm state activiation type. :param state_act: lstm state activiation type.
:type state_act: BaseActivation :type state_act: BaseActivation
:param mixed_bias_attr: bias parameter attribute of mixed layer. :param input_proj_bias_attr: bias attribute for input-to-hidden projection.
False means no bias, None means default bias. False means no bias, None means default bias.
:type mixed_bias_attr: ParameterAttribute|False :type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_bias_attr: bias parameter attribute of lstm layer. :param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias. False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False :type lstm_bias_attr: ParameterAttribute|False
:param mixed_layer_attr: mixed layer's extra attribute.
:type mixed_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute. :param lstm_layer_attr: lstm layer's extra attribute.
:type lstm_layer_attr: ExtraLayerAttribute :type lstm_layer_attr: ExtraLayerAttribute
:param get_output_layer_attr: get output layer's extra attribute.
:type get_output_layer_attr: ExtraLayerAttribute
:return: lstmemory unit name. :return: lstmemory unit name.
:rtype: LayerOutput :rtype: LayerOutput
""" """
if size is None: if size is None:
assert input.size % 4 == 0 assert input.size % 4 == 0
size = input.size / 4 size = input.size / 4
if out_memory is None:
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size)
state_mem = memory( else:
name="%s_state" % name, size=size, boot_layer=memory_boot) out_mem = out_memory
state_mem = memory(name="%s_state" % name, size=size)
with mixed_layer( with mixed_layer(
name="%s_input_recurrent" % name, name="%s_input_recurrent" % name,
size=size * 4, size=size * 4,
bias_attr=mixed_bias_attr, bias_attr=input_proj_bias_attr,
layer_attr=mixed_layer_attr, layer_attr=input_proj_layer_attr,
act=IdentityActivation()) as m: act=IdentityActivation()) as m:
m += identity_projection(input=input) m += identity_projection(input=input)
m += full_matrix_projection(input=out_mem, param_attr=param_attr) m += full_matrix_projection(input=out_mem, param_attr=param_attr)
...@@ -717,11 +718,7 @@ def lstmemory_unit(input, ...@@ -717,11 +718,7 @@ def lstmemory_unit(input,
gate_act=gate_act, gate_act=gate_act,
state_act=state_act, state_act=state_act,
layer_attr=lstm_layer_attr) layer_attr=lstm_layer_attr)
get_output_layer( get_output_layer(name='%s_state' % name, input=lstm_out, arg_name='state')
name='%s_state' % name,
input=lstm_out,
arg_name='state',
layer_attr=get_output_layer_attr)
return lstm_out return lstm_out
...@@ -730,17 +727,16 @@ def lstmemory_unit(input, ...@@ -730,17 +727,16 @@ def lstmemory_unit(input,
def lstmemory_group(input, def lstmemory_group(input,
size=None, size=None,
name=None, name=None,
memory_boot=None, out_memory=None,
reverse=False, reverse=False,
param_attr=None, param_attr=None,
act=None, act=None,
gate_act=None, gate_act=None,
state_act=None, state_act=None,
mixed_bias_attr=None, input_proj_bias_attr=None,
input_proj_layer_attr=None,
lstm_bias_attr=None, lstm_bias_attr=None,
mixed_layer_attr=None, lstm_layer_attr=None):
lstm_layer_attr=None,
get_output_layer_attr=None):
""" """
lstm_group is a recurrent_group version of Long Short Term Memory. It lstm_group is a recurrent_group version of Long Short Term Memory. It
does exactly the same calculation as the lstmemory layer (see lstmemory in does exactly the same calculation as the lstmemory layer (see lstmemory in
...@@ -774,8 +770,8 @@ def lstmemory_group(input, ...@@ -774,8 +770,8 @@ def lstmemory_group(input,
:type size: int :type size: int
:param name: name of the lstmemory group. :param name: name of the lstmemory group.
:type name: basestring :type name: basestring
:param memory_boot: the initialization state of LSTM cell. :param out_memory: output of previous time step
:type memory_boot: LayerOutput | None :type out_memory: LayerOutput | None
:param reverse: is lstm reversed :param reverse: is lstm reversed
:type reverse: bool :type reverse: bool
:param param_attr: Parameter config, None if use default. :param param_attr: Parameter config, None if use default.
...@@ -786,18 +782,17 @@ def lstmemory_group(input, ...@@ -786,18 +782,17 @@ def lstmemory_group(input,
:type gate_act: BaseActivation :type gate_act: BaseActivation
:param state_act: lstm state activiation type. :param state_act: lstm state activiation type.
:type state_act: BaseActivation :type state_act: BaseActivation
:param mixed_bias_attr: bias parameter attribute of mixed layer.
False means no bias, None means default bias.
:type mixed_bias_attr: ParameterAttribute|False
:param lstm_bias_attr: bias parameter attribute of lstm layer. :param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias. False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False :type lstm_bias_attr: ParameterAttribute|False
:param mixed_layer_attr: mixed layer's extra attribute. :param input_proj_bias_attr: bias attribute for input-to-hidden projection.
:type mixed_layer_attr: ExtraLayerAttribute False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute. :param lstm_layer_attr: lstm layer's extra attribute.
:type lstm_layer_attr: ExtraLayerAttribute :type lstm_layer_attr: ExtraLayerAttribute
:param get_output_layer_attr: get output layer's extra attribute.
:type get_output_layer_attr: ExtraLayerAttribute
:return: the lstmemory group. :return: the lstmemory group.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -805,18 +800,17 @@ def lstmemory_group(input, ...@@ -805,18 +800,17 @@ def lstmemory_group(input,
def __lstm_step__(ipt): def __lstm_step__(ipt):
return lstmemory_unit( return lstmemory_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
mixed_bias_attr=mixed_bias_attr,
mixed_layer_attr=mixed_layer_attr,
param_attr=param_attr,
lstm_bias_attr=lstm_bias_attr,
act=act, act=act,
gate_act=gate_act, gate_act=gate_act,
state_act=state_act, state_act=state_act,
out_memory=out_memory,
input_proj_bias_attr=input_proj_bias_attr,
input_proj_layer_attr=input_proj_layer_attr,
param_attr=param_attr,
lstm_layer_attr=lstm_layer_attr, lstm_layer_attr=lstm_layer_attr,
get_output_layer_attr=get_output_layer_attr) lstm_bias_attr=lstm_bias_attr)
return recurrent_group( return recurrent_group(
name='%s_recurrent_group' % name, name='%s_recurrent_group' % name,
......
...@@ -104,7 +104,7 @@ layers { ...@@ -104,7 +104,7 @@ layers {
} }
bias_parameter_name: "lstm_bias" bias_parameter_name: "lstm_bias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
...@@ -183,7 +183,7 @@ layers { ...@@ -183,7 +183,7 @@ layers {
} }
bias_parameter_name: "lstm_bias" bias_parameter_name: "lstm_bias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_1___state@__lstm_group_1___recurrent_group" name: "__lstm_group_1___state@__lstm_group_1___recurrent_group"
......
...@@ -258,7 +258,7 @@ layers { ...@@ -258,7 +258,7 @@ layers {
} }
bias_parameter_name: "___lstm_group_0__@__lstm_group_0___recurrent_group.wbias" bias_parameter_name: "___lstm_group_0__@__lstm_group_0___recurrent_group.wbias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
......
...@@ -20,12 +20,13 @@ lstm1 = lstmemory_group( ...@@ -20,12 +20,13 @@ lstm1 = lstmemory_group(
input=m1, input=m1,
param_attr=lstm_param, param_attr=lstm_param,
lstm_bias_attr=lstm_bias, lstm_bias_attr=lstm_bias,
mixed_bias_attr=False) input_proj_bias_attr=False)
lstm2 = lstmemory_group( lstm2 = lstmemory_group(
input=m2, input=m2,
param_attr=lstm_param, param_attr=lstm_param,
lstm_bias_attr=lstm_bias, lstm_bias_attr=lstm_bias,
mixed_bias_attr=False) input_proj_bias_attr=False)
softmax_param = ParamAttr(name='softmax_param') softmax_param = ParamAttr(name='softmax_param')
......
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
data = data_layer(name='data', size=2016, height=48, width=42)
refernce_data = data_layer(name='data', size=768, height=16, width=16)
conv = img_conv_layer(
input=data,
filter_size=3,
num_channels=1,
num_filters=16,
padding=1,
act=LinearActivation(),
bias_attr=True)
pool = img_pool_layer(input=conv, pool_size=2, stride=2, pool_type=MaxPooling())
crop = crop_layer(input=[pool, refernce_data], axis=2)
outputs(pad)
...@@ -116,7 +116,7 @@ def reader_creator(data_file, ...@@ -116,7 +116,7 @@ def reader_creator(data_file,
data = batch['data'] data = batch['data']
labels = batch['label'] labels = batch['label']
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label) - 1
if use_xmap: if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
......
...@@ -217,6 +217,10 @@ def create_op_creation_method(op_proto): ...@@ -217,6 +217,10 @@ def create_op_creation_method(op_proto):
return core.Operator.create(opdesc.SerializeToString()) return core.Operator.create(opdesc.SerializeToString())
__impl__.__doc__ = get_docstring_from_op_proto(op_proto) __impl__.__doc__ = get_docstring_from_op_proto(op_proto)
__impl__.all_input_args = [var.name for var in op_proto.inputs]
__impl__.all_output_args = [var.name for var in op_proto.outputs]
__impl__.all_attr_args = [attr.name for attr in op_proto.attrs]
return __impl__ return __impl__
......
add_python_test(test_framework test_protobuf.py test_scope.py add_python_test(test_framework test_protobuf.py test_scope.py
test_default_scope_funcs.py test_op_creation_methods.py test_default_scope_funcs.py test_op_creation_methods.py
test_tensor.py) test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py test_cross_entropy_op.py)
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.create_op_creation_methods as creation
class OpTestMeta(type):
"""
Operator Test ClassMeta.
It injects `test_all` method into user's OperatorTest class, to make Python
unittest module run that method.
The `test_all` read what value is stored in `self`. It use self's values to
create and run a operator, and check whether that op is OK or not.
See `test_add_two_op` for example usage.
"""
def __new__(cls, name, bases, attrs):
obj = super(OpTestMeta, cls).__new__(cls, name, bases, attrs)
def test_all(self):
func = getattr(creation.op_creations, self.type, None)
self.assertIsNotNone(func)
scope = core.Scope(None)
kwargs = dict()
for in_name in func.all_input_args:
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.create_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr)
else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args:
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.create_var(out_name).get_tensor()
for attr_name in func.all_attr_args:
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs)
op.infer_shape(scope)
ctx = core.DeviceContext.cpu_context()
op.run(scope, ctx)
for out_name in func.all_output_args:
actual = numpy.array(scope.get_var(out_name).get_tensor())
expect = getattr(self, out_name)
numpy.testing.assert_almost_equal(actual, expect)
obj.test_all = test_all
return obj
import unittest
from op_test_util import OpTestMeta
import numpy
class TestAddOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "add_two"
self.X = numpy.random.random((342, 345)).astype("float32")
self.Y = numpy.random.random((342, 345)).astype("float32")
self.Out = self.X + self.Y
if __name__ == '__main__':
unittest.main()
import unittest
import numpy
from op_test_util import OpTestMeta
class TestSGD(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "onehot_cross_entropy"
batch_size = 100
class_num = 10
self.X = numpy.random.random((batch_size, class_num)).astype("float32")
self.label = 5 * numpy.ones(batch_size).astype("int32")
Y = []
for i in range(0, batch_size):
Y.append(-numpy.log(self.X[i][self.label[i]]))
self.Y = numpy.array(Y).astype("float32")
if __name__ == "__main__":
unittest.main()
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase):
def test_fc(self):
scope = core.Scope(None)
x = scope.create_var("X")
x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784])
x_tensor.alloc_float()
w = scope.create_var("W")
w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100])
w_tensor.alloc_float()
w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
op = creation.op_creations.fc(X="X", Y="Y", W="W")
for out in op.outputs():
if scope.get_var(out) is None:
scope.create_var(out).get_tensor()
tensor = scope.get_var("Y").get_tensor()
op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context()
op.run(scope, ctx)
# After complete all ops, check Y is expect or not.
if __name__ == '__main__':
unittest.main()
import unittest
import numpy
from op_test_util import OpTestMeta
class TestSGD(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "sgd"
self.param = numpy.random.random((342, 345)).astype("float32")
self.grad = numpy.random.random((342, 345)).astype("float32")
self.learning_rate = 0.1
self.param_out = self.param - self.learning_rate * self.grad
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册