From 0f8c9dbe61762092a701ac035445dbae31b27338 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Sun, 6 Aug 2017 15:37:36 +0800 Subject: [PATCH] device context pointer --- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/gaussian_random_op.cc | 11 +++++------ paddle/operators/gaussian_random_op.cu | 26 +++++++++++++------------- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8887dc6dbd..3b60df0218 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,7 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(guassian_random_op SRCS guassain_random_op.cc guassian_random_op.cu) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 7afc0cd56b..f5fd902c5f 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -12,9 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/operators/gaussian_random_op.h" #include "glog/logging.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/random_op.h" namespace paddle { namespace operators { @@ -22,7 +22,7 @@ namespace operators { template class GaussianRandomOpKernel : public framework::OpKernel { -public: + public: void Compute(const framework::KernelContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); @@ -40,7 +40,7 @@ public: }; class GaussianRandomOp : public framework::OperatorWithKernel { -protected: + protected: void InferShape( const std::vector& inputs, const std::vector& outputs) const override { @@ -54,7 +54,7 @@ protected: }; class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: GaussianRandomOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { @@ -74,8 +74,7 @@ The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std) } // namespace operators } // namespace paddle -REGISTER_OP(gaussian_random, - paddle::operators::GaussianRandomOp, +REGISTER_OP(gaussian_random, paddle::operators::GaussianRandomOp, paddle::operators::GaussianRandomOpMaker); typedef paddle::operators::GaussianRandomOpKernel -class GaussianRandomOpKernel : public framework::OpKernel { -public: + +template +class GaussianRandomOpKernel + : public framework::OpKernel { + public: void Compute(const framework::KernelContext& context) const override { auto mean = context.op_.GetAttr("mean"); auto std = context.op_.GetAttr("std"); auto* output = context.Output(0)->GetMutable(); T* r = output->mutable_data(context.GetPlace()); - auto ctx = static_cast - (context.device_context_); - // generator need to modify context + auto ctx = + static_cast(context.device_context_); + // generator need to modify context auto g = const_cast(ctx)->RandGenerator(); curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); - } }; - + } // namespace operators } // namespace paddle - -typedef paddle::operators::GaussianRandomOpKernel - RandomOpKernel_GPU_float; +typedef paddle::operators::GaussianRandomOpKernel + RandomOpKernel_GPU_float; REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float); \ No newline at end of file -- GitLab