提交 7fab7ddd 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/refactorize_framework_proto

...@@ -36,8 +36,8 @@ include(simd) ...@@ -36,8 +36,8 @@ include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF)
option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
......
...@@ -46,5 +46,6 @@ cc_library(paddle_pybind SHARED ...@@ -46,5 +46,6 @@ cc_library(paddle_pybind SHARED
cross_entropy_op cross_entropy_op
recurrent_op recurrent_op
uniform_random_op uniform_random_op
gaussian_random_op
fill_zeros_like_op) fill_zeros_like_op)
endif(WITH_PYTHON) endif(WITH_PYTHON)
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp>
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -23,6 +22,7 @@ limitations under the License. */ ...@@ -23,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp>
#include <initializer_list> #include <initializer_list>
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/framework/dim.h" #include "paddle/framework/dim.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -26,6 +25,7 @@ limitations under the License. */ ...@@ -26,6 +25,7 @@ limitations under the License. */
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/platform/variant.h"
#include "paddle/utils/Error.h" #include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
......
...@@ -40,7 +40,9 @@ USE_OP(softmax); ...@@ -40,7 +40,9 @@ USE_OP(softmax);
USE_OP(rowwise_add); USE_OP(rowwise_add);
USE_OP(fill_zeros_like); USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op); USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random); USE_OP(uniform_random);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -79,11 +79,11 @@ class Tensor { ...@@ -79,11 +79,11 @@ class Tensor {
inline const DDim& dims() const; inline const DDim& dims() const;
/*! Resize the dimensions of the memory block. */ /*! Resize the dimensions of the memory block. */
inline void Resize(const DDim& dims); inline Tensor& Resize(const DDim& dims);
/*! The internal of two tensors share the same memory block. */ /*! The internal of two tensors share the same memory block. */
template <typename T> template <typename T>
inline void ShareDataWith(const Tensor& src); inline Tensor& ShareDataWith(const Tensor& src);
/** /**
* @brief Copy the content of external tensor to a new place. * @brief Copy the content of external tensor to a new place.
......
...@@ -23,9 +23,11 @@ template <typename T> ...@@ -23,9 +23,11 @@ template <typename T>
inline void Tensor::check_memory_size() const { inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_, PADDLE_ENFORCE_GE(
holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data " "Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory."); "first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
} }
template <typename T> template <typename T>
...@@ -78,9 +80,10 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -78,9 +80,10 @@ inline T* Tensor::mutable_data(platform::Place place) {
} }
template <typename T> template <typename T>
inline void Tensor::ShareDataWith(const Tensor& src) { inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size<T>(); src.check_memory_size<T>();
*this = src; *this = src;
return *this;
} }
template <typename T> template <typename T>
...@@ -136,7 +139,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -136,7 +139,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
return dst; return dst;
} }
inline void Tensor::Resize(const DDim& dims) { dims_ = dims; } inline Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
return *this;
}
inline const DDim& Tensor::dims() const { return dims_; } inline const DDim& Tensor::dims() const { return dims_; }
......
...@@ -96,6 +96,11 @@ TEST(Layer, kmaxSeqScoreLayer) { ...@@ -96,6 +96,11 @@ TEST(Layer, kmaxSeqScoreLayer) {
MatrixPtr inValue = MatrixPtr inValue =
Matrix::create(subSeqStartPosition.back(), 1, false, false); Matrix::create(subSeqStartPosition.back(), 1, false, false);
std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU
mode.push_back(true);
#endif
for (auto hasSubseq : {false, true}) { for (auto hasSubseq : {false, true}) {
vector<vector<int>> groundTruth; vector<vector<int>> groundTruth;
inValue->randomizeUniform(); inValue->randomizeUniform();
...@@ -104,7 +109,7 @@ TEST(Layer, kmaxSeqScoreLayer) { ...@@ -104,7 +109,7 @@ TEST(Layer, kmaxSeqScoreLayer) {
hasSubseq ? subSeqStartPosition : seqStartPosition, hasSubseq ? subSeqStartPosition : seqStartPosition,
beamSize); beamSize);
for (auto useGpu : {false, true}) { for (auto useGpu : mode) {
TestConfig config; TestConfig config;
config.layerConfig.set_type("kmax_seq_score"); config.layerConfig.set_type("kmax_seq_score");
config.layerConfig.set_beam_size(beamSize); config.layerConfig.set_beam_size(beamSize);
......
...@@ -41,6 +41,8 @@ function(op_library TARGET) ...@@ -41,6 +41,8 @@ function(op_library TARGET)
endif() endif()
endfunction() endfunction()
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
...@@ -53,6 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -53,6 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory.h>
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
}
}
// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/**
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
auto src_dims = src->dims();
paddle::framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gather.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include <gtest/gtest.h>
#include <iostream>
#include <string>
TEST(Gather, GatherData) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators;
Tensor* src = new Tensor();
Tensor* index = new Tensor();
Tensor* output = new Tensor();
int* p_src = nullptr;
int* p_index = nullptr;
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());
for (size_t i = 0; i < 12; ++i) p_src[i] = i;
p_index[0] = 1;
p_index[1] = 0;
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
Gather<int>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
}
}
};
class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(dims));
}
};
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "output matrix of random op");
AddComment(R"DOC(
GaussianRandom operator.
Use to initialize tensor with gaussian random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
curandGenerator_t g;
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
curandGenerateNormal(g, data, framework::product(tensor->dims()), mean,
std);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp>
#include <iostream> #include <iostream>
#include "paddle/platform/variant.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <boost/config.hpp>
#ifndef PADDLE_ONLY_CPU
// Because boost's variadic templates has bug on nvcc, boost will disable
// variadic template support when GPU enabled on nvcc.
// Define BOOST_NO_CXX11_VARIADIC_TEMPLATES on gcc/clang to generate same
// function symbols.
//
// https://github.com/PaddlePaddle/Paddle/issues/3386
#ifndef BOOST_NO_CXX11_VARIADIC_TEMPLATES
#define BOOST_NO_CXX11_VARIADIC_TEMPLATES
#endif
#endif
#include <boost/variant.hpp>
...@@ -21,5 +21,7 @@ py_test(gradient_checker SRCS gradient_checker.py) ...@@ -21,5 +21,7 @@ py_test(gradient_checker SRCS gradient_checker.py)
py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py) py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import numpy
class GaussianRandomTest(unittest.TestCase):
def test_cpu(self):
self.gaussian_random_test(place=core.CPUPlace())
def test_gpu(self):
if core.is_compile_gpu():
self.gaussian_random_test(place=core.GPUPlace(0))
def gaussian_random_test(self, place):
scope = core.Scope()
scope.new_var("Out").get_tensor()
op = Operator(
"gaussian_random",
Out="Out",
dims=[1000, 784],
mean=.0,
std=1.,
seed=10)
op.infer_shape(scope)
context = core.DeviceContext.create(place)
op.run(scope, context)
tensor = numpy.array(scope.find_var("Out").get_tensor())
self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1)
self.assertAlmostEqual(numpy.std(tensor), 1., delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册