diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 25c545d3f9bb9dad81845c42a3b35bfaa649f0e5..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0970797e02b061787bc1c124554e301822ca9370..4600b06009bcef7d0774d25b816aac4733f30795 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -139,9 +139,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ - return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new cls(*this)); \ } // Macro for define a default constructor for Operator. @@ -331,6 +331,21 @@ class InferShapeContext { const Scope& scope_; }; +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -338,8 +353,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = + typename EigenDeviceConverter::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_->GetPlace(); } diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index e713b5a21151891a7e83382aa811120cfb611ded..ffa5c26da3b8a28bc01598d8607ee7ad241e1d30 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -14,26 +14,6 @@ #include "paddle/operators/activation_op.h" -// #define FILL_ACTIVATION_OP \ -// public: \ -// using framework::OperatorWithKernel::OperatorWithKernel; \ -// \ -// protected: \ -// void InferShape(const framework::InferShapeContext &ctx) const override { \ -// ctx.Output("Y")->Resize( \ -// ctx.Input("X")->dims()); \ -// } - -// #define FILL_ACTIVATION_GRAD_OP \ -// public: \ -// using framework::OperatorWithKernel::OperatorWithKernel; \ -// \ -// protected: \ -// void InferShape(const framework::InferShapeContext &ctx) const override { \ -// ctx.Output(framework::GradVarName("X")) \ -// ->Resize(ctx.Input("Y")->dims()); \ -// } - namespace paddle { namespace operators { @@ -59,10 +39,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } }; -// class SigmoidOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: SigmoidOpMaker(framework::OpProto *proto, @@ -74,14 +50,6 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class SigmoidOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - -// class ExpOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class ExpOpMaker : public framework::OpProtoAndCheckerMaker { public: ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -92,14 +60,6 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class ExpOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - -// class ReluOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class ReluOpMaker : public framework::OpProtoAndCheckerMaker { public: ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -110,36 +70,33 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class ReluOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(sigmoid, + ops::ActivationKernel); REGISTER_OP_CPU_KERNEL( - sigmoid, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL(sigmoid_grad, - ops::ActivationGradKernel); + sigmoid_grad, ops::ActivationGradKernel); REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); REGISTER_OP_CPU_KERNEL( - exp, ops::ActivationKernel); + exp, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(exp_grad, + ops::ActivationGradKernel); + +REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(relu, + ops::ActivationKernel>); REGISTER_OP_CPU_KERNEL( - exp_grad, - ops::ActivationGradKernel); - -// REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, -// ops::ActivationOpGrad); -// REGISTER_OP_CPU_KERNEL(relu, -// ops::ReluKernel); -// REGISTER_OP_CPU_KERNEL(relu_grad, -// ops::ReluGradKernel); + relu_grad, ops::ActivationGradKernel>); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 55d9f52124d3f705b41112b3ad5301aa7f365956..3b2c147f466c391ef84547ef65353c06861ef68c 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -18,15 +18,21 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sigmoid, - ops::SigmoidKernel); + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); + sigmoid_grad, ops::ActivationGradKernel); -REGISTER_OP_GPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_GPU_KERNEL( + exp, + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL(exp_grad, - ops::ExpGradKernel); - + ops::ActivationGradKernel); REGISTER_OP_GPU_KERNEL(relu, - ops::ReluKernel); -REGISTER_OP_GPU_KERNEL(relu_grad, - ops::ReluGradKernel); + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + relu_grad, ops::ActivationGradKernel>); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 7d5c5bb26f4c5d0fa5f0ee3835be7ca17bcf46f7..0b7e171e722e62d987675033b7c48f762048a61d 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -15,42 +15,6 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -// #include "paddle/operators/math/activation_functor.h" - -// #define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel - -// #define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ -// template \ -// class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ -// public: \ -// void Compute(const framework::ExecutionContext& context) const override { \ -// auto* X = context.Input("X"); \ -// auto* Y = context.Output("Y"); \ -// Y->mutable_data(context.GetPlace()); \ -// math::ACTIVATION_NAME functor; \ -// auto* device_context = context.device_context(); \ -// functor(*device_context, *X, Y); \ -// } \ -// }; - -// #define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ -// template \ -// class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ -// : public framework::OpKernel { \ -// public: \ -// void Compute(const framework::ExecutionContext& context) const override { \ -// auto* X = context.Input("X"); \ -// auto* Y = context.Input("Y"); \ -// auto* dY = \ -// context.Input(framework::GradVarName("Y")); \ -// auto* dX = \ -// context.Output(framework::GradVarName("X")); \ -// dX->mutable_data(context.GetPlace()); \ -// math::ACTIVATION_GRAD_NAME functor; \ -// auto* device_context = context.device_context(); \ -// functor(*device_context, *X, *Y, *dY, dX); \ -// } \ -// }; namespace paddle { namespace operators { @@ -91,59 +55,49 @@ class ActivationGradKernel : public framework::OpKernel { } }; -struct Sigmoid { +struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = 1. / (1. + (-x).exp()); } }; -struct SigmoidGrad { +struct SigmoidGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = dy * y * (1. - y); } }; -struct Exp { +struct ExpFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = x.exp(); } }; -struct ExpGrad { +struct ExpGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = y; } }; -// template -// struct Relu { -// void operator()(Device d, X x, Y y) { -// y.device(d) = x.cwiseMax(static_cast(0)); -// } -// }; - -// template -// struct ReluGrad { -// void operator()(Device d, X x, Y y, dY dy, dX dx) { -// dx.device(d) = dy * (x > static_cast(0)).template cast(); -// } -// }; - -// DEFINE_ACTIVATION_KERNEL(Sigmoid); - -// DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); - -// DEFINE_ACTIVATION_KERNEL(Exp); - -// DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); - -// DEFINE_ACTIVATION_KERNEL(Relu); +template +struct ReluFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.cwiseMax(static_cast(0)); + } +}; -// DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); +template +struct ReluGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (x > static_cast(0)).template cast(); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/activation_functor.h b/paddle/operators/math/activation_functor.h deleted file mode 100644 index 1e9bdd142ee2dabf9113aed26105aab575b730f8..0000000000000000000000000000000000000000 --- a/paddle/operators/math/activation_functor.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/tensor.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct Sigmoid { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, framework::Tensor* Y) { - auto x = framework::EigenVector::Flatten(X); - auto y = framework::EigenVector::Flatten(*Y); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = 1. / (1. + (-x).exp()); - } -}; - -template -struct SigmoidGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto y = framework::EigenVector::Flatten(Y); - auto dy = framework::EigenVector::Flatten(dY); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = dy * y * (1. - y); - } -}; - -template -struct Exp { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(input); - auto y = framework::EigenVector::Flatten(*output); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = x.exp(); - } -}; - -template -struct ExpGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto y = framework::EigenVector::Flatten(Y); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = y; - } -}; - -template -struct Relu { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(input); - auto y = framework::EigenVector::Flatten(*output); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = x.cwiseMax(static_cast(0)); - } -}; - -template -struct ReluGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto dy = framework::EigenVector::Flatten(dY); - auto x = framework::EigenVector::Flatten(X); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = dy * (x > static_cast(0)).template cast(); - } -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index cf5c3eec8123f955fc91780a903c5dd17b99efc2..ad212c5b2c47312743362db4926c80bf056e100d 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* -DeviceContext::get_eigen_device() const { +Eigen::DefaultDevice* DeviceContext::get_eigen_device() + const { return reinterpret_cast(this)->eigen_device(); } @@ -91,8 +91,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { }; template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() - const { +Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index a46ba4c7032c91d2bd37f20e4f0df566007c96f0..11528e1194e4516891034fa8febdac3ba6eed204 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,29 +27,12 @@ limitations under the License. */ namespace paddle { namespace platform { -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template ::EigenDeviceType> + template DeviceType* get_eigen_device() const; }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index d71e0aae58b18e2e18535a9731d46cb4dad173d2..5883a55272f0f24c94d48bc43c62ddb7bef15465 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -24,7 +24,7 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); Eigen::GpuDevice* gpu_device = - device_context->template get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device); delete device_context; } diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index bed35d78221f7bf9de69be8b147870375b5bbcd9..bd964c5d0797f03a173e6d869603c3a0a2616af0 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,7 +56,7 @@ USE_OP(sum); USE_OP(reshape); USE_OP(sigmoid); USE_OP(exp); -// USE_OP(relu); +USE_OP(relu); namespace paddle { namespace framework {