提交 b7b956f0 编写于 作者: G Guo Sheng 提交者: GitHub

Merge branch 'develop' into add-ClipLayer

...@@ -17,10 +17,14 @@ ...@@ -17,10 +17,14 @@
- id: detect-private-key - id: detect-private-key
files: (?!.*third_party)^.*$ | (?!.*book)^.*$ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - repo: local
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
hooks: hooks:
- id: clang-formater - id: clang-format
name: clang-format
description: Format files with ClangFormat.
entry: clang-format -i
language: system
files: \.(c|cc|cxx|cpp|h|hpp|hxx)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0 sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
......
...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) ...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif() endif()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
......
...@@ -104,6 +104,11 @@ cross_channel_norm ...@@ -104,6 +104,11 @@ cross_channel_norm
------------------ ------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm .. autoclass:: paddle.v2.layer.cross_channel_norm
:noindex: :noindex:
row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers Recurrent Layers
================ ================
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
namespace paddle { namespace paddle {
...@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size)); boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#ifndef PADDLE_ONLY_CPU #else
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size)); boost::get<platform::GPUPlace>(place), size));
} }
......
...@@ -20,16 +20,16 @@ namespace paddle { ...@@ -20,16 +20,16 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext { ...@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext {
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
/**
* A layer for L2 normalization in each row,
* \f[
* out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
* \f]
* where the size of \f$in\f$ is (batchSize x dataDim),
* and the size of \f$out\f$ is (batchSize x dataDim).
*/
class RowL2NormLayer : public Layer {
protected:
MatrixPtr inSquare_;
MatrixPtr l2NormReciprocal_;
MatrixPtr dotSum_;
public:
explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(row_l2_norm, RowL2NormLayer);
bool RowL2NormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1U);
return true;
}
void RowL2NormLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
/* malloc memory for the output_ if necessary */
size_t batchSize = inV->getHeight();
size_t dataDim = getSize();
CHECK_EQ(dataDim, inV->getWidth());
resetOutput(batchSize, dataDim);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_);
inV->square2(*inSquare_);
Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_);
inSquare_->rowSum(*l2NormReciprocal_);
l2NormReciprocal_->sqrt2(*l2NormReciprocal_);
l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0);
outV->rowScale(0, *inV, *l2NormReciprocal_);
}
void RowL2NormLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV = getInputValue(0);
MatrixPtr inG = getInputGrad(0);
MatrixPtr outV = getOutputValue();
MatrixPtr outG = getOutputGrad();
size_t batchSize = inV->getHeight();
// inG[ij] += outG[ij] / l2NormReciprocal
// inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i],
// inV[i])
if (inG) {
Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_);
dotSum_->zeroMem();
dotSum_->rowDotMul(0, *outG, *outV);
dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
inSquare_->rowScale(0, *inV, *dotSum_);
inG->sub(*inSquare_);
inG->addRowScale(0, *outG, *l2NormReciprocal_);
}
}
} // namespace paddle
...@@ -1916,6 +1916,19 @@ TEST(Layer, ClipLayer) { ...@@ -1916,6 +1916,19 @@ TEST(Layer, ClipLayer) {
} }
} }
TEST(Layer, RowL2NormLayer) {
const size_t batchSize = 128;
const size_t size = 512;
TestConfig config;
config.layerConfig.set_type("row_l2_norm");
config.layerConfig.set_size(size);
config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
......
...@@ -28,8 +28,7 @@ public: ...@@ -28,8 +28,7 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(*input0) + framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1); framework::EigenVector<T>::Flatten(*input1);
} }
......
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenScalar<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*input).mean(); EigenVector<T>::Flatten(*input).mean();
} }
}; };
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenMatrix<T>::From(*context.Input<Tensor>("X")) EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")), .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair); dim_pair);
......
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add, REGISTER_OP_GPU_KERNEL(rowwise_add,
......
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size()); Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size); Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) = output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
} }
}; };
......
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*param_out).device(ctx.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad); EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
} }
}; };
......
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
...@@ -27,8 +27,7 @@ public: ...@@ -27,8 +27,7 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
} }
}; };
......
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
......
...@@ -46,9 +46,9 @@ public: ...@@ -46,9 +46,9 @@ public:
.reshape(batch_by_one) .reshape(batch_by_one)
.broadcast(one_by_class)); .broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp(); softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) = softmax.device(context.GetEigenDevice<Place>()) =
(softmax * (softmax *
softmax.sum(along_class) softmax.sum(along_class)
.inverse() .inverse()
......
...@@ -144,12 +144,12 @@ inline void throw_on_error(T e) { ...@@ -144,12 +144,12 @@ inline void throw_on_error(T e) {
throw_on_error(e, ""); throw_on_error(e, "");
} }
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw ::paddle::platform::EnforceNotMet( \ throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \ std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h" #include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() { ...@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1); return generator.fetch_add(1);
} }
bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU
return false;
#else
return true;
#endif
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) { ...@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim)); self.Resize(pd::make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<float>(paddle::platform::CPUPlace()); self.mutable_data<float>(place);
})
.def("alloc_float",
[](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<float>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("alloc_int",
.def("set", paddle::pybind::PyTensorSetFromArray<int>) [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
#endif
.def("shape", .def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); }); [](pd::Tensor& self) { return pd::vectorize(self.dims()); });
...@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle.
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
// clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* { .def_static("create",
return new paddle::platform::CPUDeviceContext(); [](paddle::platform::CPUPlace& place)
}); -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
})
.def_static("create",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("GPUPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
});
// clang-format on
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base( py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
...@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
return m.ptr(); return m.ptr();
} }
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/tensor.h> #include <string>
#include <pybind11/numpy.h> #include "paddle/framework/tensor.h"
#include <pybind11/pybind11.h> #include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS> ...@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
...@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod; strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info( return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()), (size_t)framework::arity(dst_tensor.dims()),
dims_outside, dims_outside,
strides); strides);
} else { } else {
...@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { ...@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
} }
template <typename T> template <typename T>
void PyTensorSetFromArray( void PyCPUTensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array) { py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int> dims; std::vector<int> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (size_t i = 0; i < array.ndim(); ++i) {
...@@ -87,9 +92,28 @@ void PyTensorSetFromArray( ...@@ -87,9 +92,28 @@ void PyTensorSetFromArray(
} }
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
#ifndef PADDLE_ONLY_CPU
template <typename T>
void PyCUDATensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
paddle::platform::GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
}
#endif
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <<EOF ...@@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <<EOF
ADD *.deb / ADD *.deb /
# run paddle version to install python packages first # run paddle version to install python packages first
RUN apt-get update &&\ RUN apt-get update &&\
apt-get install -y python-pip && pip install -U pip && \ apt-get install -y wget python-pip && pip install -U pip && \
dpkg -i /*.deb ; apt-get install -f -y && \ dpkg -i /*.deb ; apt-get install -f -y && \
apt-get clean -y && \ apt-get clean -y && \
rm -f /*.deb && \ rm -f /*.deb && \
......
...@@ -2768,6 +2768,16 @@ class SumToOneNormLayer(LayerBase): ...@@ -2768,6 +2768,16 @@ class SumToOneNormLayer(LayerBase):
self.set_layer_size(input_layer0.size) self.set_layer_size(input_layer0.size)
@config_layer('row_l2_norm')
class RowL2NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(RowL2NormLayer, self).__init__(
name, 'row_l2_norm', 0, inputs=inputs, **xargs)
config_assert(len(self.inputs) == 1, 'RowL2NormLayer must have 1 input')
input_layer = self.get_input_layer(0)
self.set_layer_size(input_layer.size)
@config_layer('cos_vm') @config_layer('cos_vm')
class CosSimVecMatLayer(LayerBase): class CosSimVecMatLayer(LayerBase):
def __init__(self, name, size, inputs, cos_scale=1.0, device=None): def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
......
...@@ -76,6 +76,7 @@ __all__ = [ ...@@ -76,6 +76,7 @@ __all__ = [
'trans_layer', 'trans_layer',
'rotate_layer', 'rotate_layer',
'sum_to_one_norm_layer', 'sum_to_one_norm_layer',
'row_l2_norm_layer',
'get_output_layer', 'get_output_layer',
'LayerType', 'LayerType',
'context_projection', 'context_projection',
...@@ -161,6 +162,7 @@ class LayerType(object): ...@@ -161,6 +162,7 @@ class LayerType(object):
BATCH_NORM_LAYER = 'batch_norm' BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm' NORM_LAYER = 'norm'
SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
ROW_L2_NORM_LAYER = 'row_l2_norm'
ADDTO_LAYER = 'addto' ADDTO_LAYER = 'addto'
CONCAT_LAYER = 'concat' CONCAT_LAYER = 'concat'
...@@ -2891,6 +2893,42 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None): ...@@ -2891,6 +2893,42 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size) name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def row_l2_norm_layer(input, name=None, layer_attr=None):
"""
A layer for L2-normalization in each row.
.. math::
out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
where the size of :math:`in` is (batchSize x dataDim) ,
and the size of :math:`out` is a (batchSize x dataDim) .
The example usage is:
.. code-block:: python
row_l2_norm_layer = row_l2_norm_layer(input=layer)
:param input: Input layer.
:type input: LayerOutput
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
type=LayerType.ROW_L2_NORM_LAYER,
inputs=[input.name],
**ExtraAttr.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.ROW_L2_NORM_LAYER, parents=[input], size=input.size)
@wrap_name_default("addto") @wrap_name_default("addto")
@wrap_act_default(act=LinearActivation()) @wrap_act_default(act=LinearActivation())
@wrap_bias_attr_default(has_bias=False) @wrap_bias_attr_default(has_bias=False)
......
...@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight ...@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer) test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer)
export whole_configs=(test_split_datasource) export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "__row_l2_norm_layer_0__"
type: "row_l2_norm"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
}
}
input_layer_names: "input"
output_layer_names: "__row_l2_norm_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__row_l2_norm_layer_0__"
input_layer_names: "input"
output_layer_names: "__row_l2_norm_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
row_l2_norm = row_l2_norm_layer(input=data)
outputs(row_l2_norm)
...@@ -8,7 +8,6 @@ add_python_test(test_framework ...@@ -8,7 +8,6 @@ add_python_test(test_framework
test_fc_op.py test_fc_op.py
test_add_two_op.py test_add_two_op.py
test_sgd_op.py test_sgd_op.py
test_cross_entropy_op.py
test_mul_op.py test_mul_op.py
test_mean_op.py test_mean_op.py
test_sigmoid_op.py test_sigmoid_op.py
......
...@@ -26,40 +26,45 @@ class OpTestMeta(type): ...@@ -26,40 +26,45 @@ class OpTestMeta(type):
scope = core.Scope() scope = core.Scope()
kwargs = dict() kwargs = dict()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for in_name in func.all_input_args: for place in places:
if hasattr(self, in_name): for in_name in func.all_input_args:
kwargs[in_name] = in_name if hasattr(self, in_name):
var = scope.new_var(in_name).get_tensor() kwargs[in_name] = in_name
arr = getattr(self, in_name) var = scope.new_var(in_name).get_tensor()
var.set_dims(arr.shape) arr = getattr(self, in_name)
var.set(arr) var.set_dims(arr.shape)
else: var.set(arr, place)
kwargs[in_name] = "@EMPTY@" else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args: for out_name in func.all_output_args:
if hasattr(self, out_name): if hasattr(self, out_name):
kwargs[out_name] = out_name kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor() scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in func.all_attr_args:
if hasattr(self, attr_name): if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name) kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs) op = func(**kwargs)
op.infer_shape(scope) op.infer_shape(scope)
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here. # has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future. # And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3) numpy.testing.assert_almost_equal(actual, expect, decimal=3)
obj.test_all = test_all obj.test_all = test_all
return obj return obj
...@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "add_two" self.type = "add_two"
self.X = numpy.random.random((342, 345)).astype("float32") self.X = numpy.random.random((102, 105)).astype("float32")
self.Y = numpy.random.random((342, 345)).astype("float32") self.Y = numpy.random.random((102, 105)).astype("float32")
self.Out = self.X + self.Y self.Out = self.X + self.Y
......
...@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation ...@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase): class TestFc(unittest.TestCase):
def test_fc(self): def test_fc(self):
scope = core.Scope() scope = core.Scope()
place = core.CPUPlace()
x = scope.new_var("X") x = scope.new_var("X")
x_tensor = x.get_tensor() x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784]) x_tensor.set_dims([1000, 784])
x_tensor.alloc_float() x_tensor.alloc_float(place)
w = scope.new_var("W") w = scope.new_var("W")
w_tensor = w.get_tensor() w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100]) w_tensor.set_dims([784, 100])
w_tensor.alloc_float() w_tensor.alloc_float(place)
w_tensor.set(numpy.random.random((784, 100)).astype("float32")) w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
# Set a real numpy array here. # Set a real numpy array here.
# x_tensor.set(numpy.array([])) # x_tensor.set(numpy.array([]))
...@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase): ...@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
op.infer_shape(scope) op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape()) self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
......
...@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "mul" self.type = "mul"
self.X = np.random.random((32, 784)).astype("float32") self.X = np.random.random((32, 84)).astype("float32")
self.Y = np.random.random((784, 100)).astype("float32") self.Y = np.random.random((84, 100)).astype("float32")
self.Out = np.dot(self.X, self.Y) self.Out = np.dot(self.X, self.Y)
......
...@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "rowwise_add" self.type = "rowwise_add"
self.X = np.random.random((32, 784)).astype("float32") self.X = np.random.random((32, 84)).astype("float32")
self.b = np.random.random(784).astype("float32") self.b = np.random.random(84).astype("float32")
self.Out = np.add(self.X, self.b) self.Out = np.add(self.X, self.b)
......
...@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "sgd" self.type = "sgd"
self.param = numpy.random.random((342, 345)).astype("float32") self.param = numpy.random.random((102, 105)).astype("float32")
self.grad = numpy.random.random((342, 345)).astype("float32") self.grad = numpy.random.random((102, 105)).astype("float32")
self.learning_rate = 0.1 self.learning_rate = 0.1
self.param_out = self.param - self.learning_rate * self.grad self.param_out = self.param - self.learning_rate * self.grad
......
...@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase): ...@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.new_var("test_tensor") var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_int() tensor.alloc_int(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1 tensor_array[3, 9] = 1
tensor_array[19, 11] = 2 tensor_array[19, 11] = 2
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertEqual(1.0, tensor_array_2[3, 9]) self.assertEqual(1.0, tensor_array_2[3, 9])
...@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase): ...@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
def test_float_tensor(self): def test_float_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.new_var("test_tensor") var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_float() tensor.alloc_float(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1.0 tensor_array[3, 9] = 1.0
tensor_array[19, 11] = 2.0 tensor_array[19, 11] = 2.0
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册