未验证 提交 61ec0b95 编写于 作者: Q QI JUN 提交者: GitHub

Refine device context (#6433)

There are mainly following fixes:

- take `DeviceContext` as the template parameter of math functors and OpKernel instead of `Place`
- remove `eigen_device` interface in base class  `DeviceContext`
- remove `GetEigenDevice` interface in `ExecutionContext` and base class `DeviceContext`
- remove unused `platform::EigenDeviceConverter`
- rename `REGISTER_OP_GPU_KERNEL` to `REGISTER_OP_CUDA_KERNEL`
- rename `USE_GPU_ONLY_OP` to `USE_CUDA_ONLY_OP`
上级 7902ad65
...@@ -181,8 +181,8 @@ class OpKernelRegistrar : public Registrar { ...@@ -181,8 +181,8 @@ class OpKernelRegistrar : public Registrar {
return 0; \ return 0; \
} }
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
...@@ -217,7 +217,7 @@ class OpKernelRegistrar : public Registrar { ...@@ -217,7 +217,7 @@ class OpKernelRegistrar : public Registrar {
#else #else
#define USE_OP_KERNEL(op_type) \ #define USE_OP_KERNEL(op_type) \
USE_OP_DEVICE_KERNEL(op_type, CPU); \ USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, CUDA)
#endif #endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type); #define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
...@@ -226,9 +226,9 @@ class OpKernelRegistrar : public Registrar { ...@@ -226,9 +226,9 @@ class OpKernelRegistrar : public Registrar {
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_GPU_ONLY_OP(op_type) \ #define USE_CUDA_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, CUDA)
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
......
...@@ -22,20 +22,6 @@ limitations under the License. */ ...@@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.GetEigenDevice<platform::GPUPlace>();
}
#endif
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -429,7 +415,7 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -429,7 +415,7 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
OpKernelType OperatorWithKernel::GetKernelType( OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.device_context()); return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
} }
DataType OperatorWithKernel::IndicateDataType( DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
......
...@@ -276,17 +276,25 @@ class ExecutionContext { ...@@ -276,17 +276,25 @@ class ExecutionContext {
out_tensor->set_lod(in_tensor.lod()); out_tensor->set_lod(in_tensor.lod());
} }
template <typename PlaceType,
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
template <typename DeviceContextType>
const DeviceContextType& device_context() const {
return *reinterpret_cast<const DeviceContextType*>(&device_context_);
}
const platform::DeviceContext& device_context() const { const platform::DeviceContext& device_context() const {
return device_context_; return device_context_;
} }
#ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_);
}
#endif
//! Get actual name vector for this input. //! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name); return op_.Inputs(name);
...@@ -297,14 +305,6 @@ class ExecutionContext { ...@@ -297,14 +305,6 @@ class ExecutionContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
#ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_);
}
#endif
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
......
...@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override { OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.device_context()); return OpKernelType(DataType::FP32, ctx.GetPlace());
} }
}; };
......
...@@ -138,7 +138,7 @@ function(op_library TARGET) ...@@ -138,7 +138,7 @@ function(op_library TARGET)
if ("${TARGET}" STREQUAL "nccl_op") if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1) set(pybind_flag 1)
# It's enough to just adding one operator to pybind # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
endif() endif()
# reduce_op contains several operators # reduce_op contains several operators
......
...@@ -57,7 +57,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context()); ctx.GetPlace());
} }
}; };
......
...@@ -104,5 +104,6 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -104,5 +104,6 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): types of T is for inference data. // FIXME(typhoonzero): types of T is for inference data.
// label data is always int64 // label data is always int64
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(accuracy,
paddle::operators::AccuracyOpCUDAKernel<double>); paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>);
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AccuracyKernel : public framework::OpKernel<T> { class AccuracyKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -611,16 +611,17 @@ REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker, ...@@ -611,16 +611,17 @@ REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad, REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, \ act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>, \ ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CPUPlace, \ ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<double>>); \ ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, \ act_type##_grad, \
ops::grad_functor<float>>, \ ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::ActivationGradKernel<paddle::platform::CPUPlace, \ ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
act_type, \ act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>, \ ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::GPUPlace, \ ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>); \ ops::functor<double>>); \
REGISTER_OP_GPU_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, \ act_type##_grad, \
ops::grad_functor<float>>, \ ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::ActivationGradKernel<paddle::platform::GPUPlace, \ ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationKernel class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
...@@ -32,18 +32,19 @@ class ActivationKernel ...@@ -32,18 +32,19 @@ class ActivationKernel
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y); auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>(); auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs(); auto attrs = functor.GetAttrs();
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first); *attr.second = context.Attr<float>(attr.first);
} }
functor(place, x, y); functor(*place, x, y);
} }
}; };
template <typename Place, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationGradKernel class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
...@@ -59,13 +60,14 @@ class ActivationGradKernel ...@@ -59,13 +60,14 @@ class ActivationGradKernel
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y); auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX); auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>(); auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs(); auto attrs = functor.GetAttrs();
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first); *attr.second = context.Attr<float>(attr.first);
} }
functor(place, x, y, dy, dx); functor(*place, x, y, dy, dx);
} }
}; };
......
...@@ -109,5 +109,5 @@ $$ ...@@ -109,5 +109,5 @@ $$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>, adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUPlace, double>); ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
#include "paddle/operators/adadelta_op.h" #include "paddle/operators/adadelta_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>, adadelta, ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::GPUPlace, double>); ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> { class AdadeltaOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -51,7 +51,7 @@ class AdadeltaOpKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,7 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor); framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
auto avg_squared_update_out = auto avg_squared_update_out =
framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor); framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
avg_squared_grad_out.device(place) = avg_squared_grad_out.device(place) =
rho * avg_squared_grad + (1 - rho) * grad.square(); rho * avg_squared_grad + (1 - rho) * grad.square();
......
...@@ -100,8 +100,8 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) { ...@@ -100,8 +100,8 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
} // namespace } // namespace
template <typename T> template <typename T>
struct SparseAdagradFunctor<platform::CPUPlace, T> { struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::DeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& grad, const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
...@@ -120,7 +120,7 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> { ...@@ -120,7 +120,7 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
{static_cast<int64_t>(merge_rows.size()), grad_width}), {static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::CPUPlace, T> constant_functor; math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0); constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>(); auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
...@@ -144,9 +144,9 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> { ...@@ -144,9 +144,9 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
auto gs = auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value())); framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value()); auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.GetEigenDevice<platform::CPUPlace>()) = gm * gm; gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CPUPlace, T> functor; math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, *grad_square, moment);
// 3. update parameter // 3. update parameter
...@@ -164,13 +164,13 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> { ...@@ -164,13 +164,13 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
} }
}; };
template struct SparseAdagradFunctor<platform::CPUPlace, float>; template struct SparseAdagradFunctor<platform::CPUDeviceContext, float>;
template struct SparseAdagradFunctor<platform::CPUPlace, double>; template struct SparseAdagradFunctor<platform::CPUDeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::CPUPlace, float>, adagrad, ops::AdagradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdagradOpKernel<paddle::platform::CPUPlace, double>); ops::AdagradOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -72,8 +72,8 @@ __global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows, ...@@ -72,8 +72,8 @@ __global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows,
} // namespace } // namespace
template <typename T> template <typename T>
struct SparseAdagradFunctor<platform::GPUPlace, T> { struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::DeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& grad, const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
...@@ -92,7 +92,7 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> { ...@@ -92,7 +92,7 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
{static_cast<int64_t>(merge_rows.size()), grad_width}), {static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::GPUPlace, T> constant_functor; math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0); constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>(); auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
...@@ -119,9 +119,9 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> { ...@@ -119,9 +119,9 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
auto gs = auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value())); framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value()); auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.GetEigenDevice<platform::GPUPlace>()) = gm * gm; gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::GPUPlace, T> functor; math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, *grad_square, moment);
// 3. update parameter // 3. update parameter
...@@ -139,13 +139,13 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> { ...@@ -139,13 +139,13 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
} }
}; };
template struct SparseAdagradFunctor<platform::GPUPlace, float>; template struct SparseAdagradFunctor<platform::CUDADeviceContext, float>;
template struct SparseAdagradFunctor<platform::GPUPlace, double>; template struct SparseAdagradFunctor<platform::CUDADeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::GPUPlace, float>, adagrad, ops::AdagradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdagradOpKernel<paddle::platform::GPUPlace, double>); ops::AdagradOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -19,15 +19,15 @@ limitations under the License. */ ...@@ -19,15 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
struct SparseAdagradFunctor { struct SparseAdagradFunctor {
void operator()(const platform::DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& grad, const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param); framework::Tensor* moment, framework::Tensor* param);
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> { class AdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -52,11 +52,11 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -52,11 +52,11 @@ class AdagradOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>(); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = moment + grad * grad; moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) = param_out.device(*place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param"); auto* param_tensor = ctx.Input<framework::Tensor>("Param");
...@@ -65,8 +65,9 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -65,8 +65,9 @@ class AdagradOpKernel : public framework::OpKernel<T> {
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment"); auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<Place, T> functor; SparseAdagradFunctor<DeviceContext, T> functor;
functor(ctx.device_context(), *ctx.Input<framework::SelectedRows>("Grad"), functor(ctx.template device_context<DeviceContext>(),
*ctx.Input<framework::SelectedRows>("Grad"),
*ctx.Input<framework::Tensor>("LearningRate"), epsilon, *ctx.Input<framework::Tensor>("LearningRate"), epsilon,
moment_out_tensor, param_out_tensor); moment_out_tensor, param_out_tensor);
} else { } else {
......
...@@ -128,6 +128,6 @@ $$ ...@@ -128,6 +128,6 @@ $$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(adam, REGISTER_OP_CPU_KERNEL(
ops::AdamOpKernel<paddle::platform::CPUPlace, float>, adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CPUPlace, double>); ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
#include "paddle/operators/adam_op.h" #include "paddle/operators/adam_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adam, REGISTER_OP_CUDA_KERNEL(
ops::AdamOpKernel<paddle::platform::GPUPlace, float>, adam, ops::AdamOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamOpKernel<paddle::platform::GPUPlace, double>); ops::AdamOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -52,17 +52,17 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -52,17 +52,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor); auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor); auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
auto place = ctx.GetEigenDevice<Place>(); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad; moment1_out.device(*place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square(); moment2_out.device(*place) = beta2 * moment2 + (1 - beta2) * grad.square();
// All of these are tensors of 1 element // All of these are tensors of 1 element
auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow); auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow);
// Eigen does not support automatic broadcast // Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t // Get dimensions of moment vector to broadcast lr_t
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
param_out.device(place) = param_out.device(*place) =
param - param -
lr_t.broadcast(m_dsize) * lr_t.broadcast(m_dsize) *
(moment1_out / (moment2_out.sqrt() + epsilon)); (moment1_out / (moment2_out.sqrt() + epsilon));
......
...@@ -127,6 +127,6 @@ division by 0 error. ...@@ -127,6 +127,6 @@ division by 0 error.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(adamax, REGISTER_OP_CPU_KERNEL(
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>, adamax, ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, double>); ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
#include "paddle/operators/adamax_op.h" #include "paddle/operators/adamax_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adamax, REGISTER_OP_CUDA_KERNEL(
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>, adamax, ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, double>); ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> { class AdamaxOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -51,14 +51,14 @@ class AdamaxOpKernel : public framework::OpKernel<T> { ...@@ -51,14 +51,14 @@ class AdamaxOpKernel : public framework::OpKernel<T> {
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto inf_norm_out = auto inf_norm_out =
framework::EigenVector<T>::Flatten(*inf_norm_out_tensor); framework::EigenVector<T>::Flatten(*inf_norm_out_tensor);
auto place = ctx.GetEigenDevice<Place>(); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = beta1 * moment + (1 - beta1) * grad; moment_out.device(*place) = beta1 * moment + (1 - beta1) * grad;
inf_norm_out.device(place) = inf_norm_out.device(*place) =
grad.abs().cwiseMax((beta2 * inf_norm) + epsilon); grad.abs().cwiseMax((beta2 * inf_norm) + epsilon);
auto lr_t = lr / (1 - beta1_pow); auto lr_t = lr / (1 - beta1_pow);
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) = param_out.device(*place) =
param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out); param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out);
} }
}; };
......
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> { class AucKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -135,7 +135,8 @@ The required data format for this layer is one of the following: ...@@ -135,7 +135,8 @@ The required data format for this layer is one of the following:
}; };
template <typename T> template <typename T>
class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> { class BatchNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
...@@ -318,12 +319,12 @@ class BatchNormGradOp : public framework::OperatorWithKernel { ...@@ -318,12 +319,12 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW("can't find Y@GRAD"); PADDLE_THROW("can't find Y@GRAD");
} }
return framework::OpKernelType(framework::ToDataType(t->type()), return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.device_context()); ctx.GetPlace());
} }
}; };
template <typename T> template <typename T>
class BatchNormGradKernel<platform::CPUPlace, T> class BatchNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -436,8 +437,9 @@ class BatchNormGradKernel<platform::CPUPlace, T> ...@@ -436,8 +437,9 @@ class BatchNormGradKernel<platform::CPUPlace, T>
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
batch_norm_grad, ops::BatchNormGradOp); batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(batch_norm, REGISTER_OP_CPU_KERNEL(
ops::BatchNormKernel<paddle::platform::CPUPlace, float>); batch_norm,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
batch_norm_grad, batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUPlace, float>); ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -47,7 +47,8 @@ void ExtractNCWHD(const framework::DDim &dims, ...@@ -47,7 +47,8 @@ void ExtractNCWHD(const framework::DDim &dims,
} }
template <typename T> template <typename T>
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -121,11 +122,12 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -121,11 +122,12 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
saved_mean->mutable_data<T>(ctx.GetPlace()); saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace()); saved_variance->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::GPUPlace, T> functor; auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
functor(ctx.device_context(), saved_mean, 0); math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(ctx.device_context(), saved_variance, 0); functor(dev_ctx, saved_mean, 0);
functor(dev_ctx, saved_variance, 0);
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
if (is_test) { if (is_test) {
...@@ -171,7 +173,7 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -171,7 +173,7 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class BatchNormGradKernel<platform::GPUPlace, T> class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
...@@ -244,11 +246,12 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -244,11 +246,12 @@ class BatchNormGradKernel<platform::GPUPlace, T>
const void *saved_mean_data = saved_mean->template data<T>(); const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>(); const void *saved_var_data = saved_var->template data<T>();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
ctx.cuda_device_context().cudnn_handle(), mode_, dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_, CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_, data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_, d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()), d_scale->template mutable_data<T>(ctx.GetPlace()),
...@@ -266,8 +269,9 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -266,8 +269,9 @@ class BatchNormGradKernel<platform::GPUPlace, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(batch_norm, REGISTER_OP_CUDA_KERNEL(
ops::BatchNormKernel<paddle::platform::GPUPlace, float>); batch_norm,
REGISTER_OP_GPU_KERNEL( ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>); ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -34,13 +34,13 @@ inline TensorFormat StringToTensorFormat(const std::string& str) { ...@@ -34,13 +34,13 @@ inline TensorFormat StringToTensorFormat(const std::string& str) {
} }
} }
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> { class BatchNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> { class BatchNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext& ctx) const override;
......
...@@ -159,9 +159,12 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, ...@@ -159,9 +159,12 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpGrad); ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product, bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>, ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double>); ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>, ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double>); float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -16,11 +16,15 @@ limitations under the License. */ ...@@ -16,11 +16,15 @@ limitations under the License. */
#include "paddle/operators/bilinear_tensor_product_op.h" #include "paddle/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product, bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>, ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, double>); float>,
REGISTER_OP_GPU_KERNEL( ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>, ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, double>); float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
double>);
...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> { class BilinearTensorProductKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -46,7 +46,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
int out_dim = weight_dims[0]; int out_dim = weight_dims[0];
auto x_dim = weight_dims[1]; auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2]; auto y_dim = weight_dims[2];
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to caculate the result of // Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is: // Input(X) multiplied by Input(Weight_i), the formula is:
...@@ -60,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -60,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto output_col_vec = output_mat.chip(i, 1); auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat = Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(), batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>()); weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) = output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1)); (left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
} }
...@@ -74,7 +75,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -74,7 +75,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class BilinearTensorProductGradKernel : public framework::OpKernel<T> { class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -96,8 +97,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -96,8 +97,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto x_mat = EigenMatrix<T>::From(*x); auto x_mat = EigenMatrix<T>::From(*x);
auto y_mat = EigenMatrix<T>::From(*y); auto y_mat = EigenMatrix<T>::From(*y);
auto d_out_mat = EigenMatrix<T>::From(*d_out); auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to caculate the Output(Y@Grad). // Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale; Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}), x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
...@@ -110,18 +111,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -110,18 +111,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
ctx.GetPlace()); ctx.GetPlace());
auto y_scale_mat = EigenMatrix<T>::From(y_scale); auto y_scale_mat = EigenMatrix<T>::From(y_scale);
math::SetConstant<Place, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
// Set Output(X@Grad) be zero. // Set Output(X@Grad) be zero.
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0)); set_zero(dev_ctx, d_x, static_cast<T>(0));
} }
// Set Output(Y@Grad) be zero. // Set Output(Y@Grad) be zero.
if (d_y) { if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace()); d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0)); set_zero(dev_ctx, d_y, static_cast<T>(0));
} }
// Caculate the Output(X@Grad) and Output(Y@Grad). // Caculate the Output(X@Grad) and Output(Y@Grad).
...@@ -137,18 +138,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -137,18 +138,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) * .broadcast(bcast_for_x) *
y_mat; y_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans, math::gemm<DeviceContext, T>(
batch_size, x_dim, y_dim, 1, y_scale.data<T>(), dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
weight_i.data<T>(), 1, d_x->data<T>()); y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
} }
if (d_y) { if (d_y) {
x_scale_mat.device(place) = x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) * .broadcast(bcast_for_y) *
x_mat; x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, math::gemm<DeviceContext, T>(
batch_size, y_dim, x_dim, 1, x_scale.data<T>(), dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
weight_i.data<T>(), 1, d_y->data<T>()); x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
} }
} }
} }
...@@ -165,9 +166,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -165,9 +166,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) * .broadcast(bcast_for_weight) *
x_mat; x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans, math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim,
x_dim, y_dim, batch_size, 1, x_scale.data<T>(), y_dim, batch_size, 1, x_scale.data<T>(),
y->data<T>(), 0, d_weight_i.data<T>()); y->data<T>(), 0, d_weight_i.data<T>());
} }
} }
......
...@@ -68,7 +68,7 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -68,7 +68,7 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUPlace; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape, REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
ops::CastOpProtoMaker); ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
template <typename T> template <typename T>
using CastOpKernel = using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::GPUPlace, T>; paddle::operators::CastOpKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_GPU_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>, REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>); CastOpKernel<int>, CastOpKernel<int64_t>);
...@@ -27,13 +27,13 @@ struct CastOpTransformFunctor { ...@@ -27,13 +27,13 @@ struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); } HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
}; };
template <typename Place, typename InT> template <typename DeviceContext, typename InT>
struct CastOpFunctor { struct CastOpFunctor {
const framework::Tensor* in_; const framework::Tensor* in_;
framework::Tensor* out_; framework::Tensor* out_;
const platform::DeviceContext& ctx_; const DeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::DeviceContext& ctx) const DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {} : in_(in), out_(out), ctx_(ctx) {}
template <typename OutT> template <typename OutT>
...@@ -42,13 +42,13 @@ struct CastOpFunctor { ...@@ -42,13 +42,13 @@ struct CastOpFunctor {
auto numel = in_->numel(); auto numel = in_->numel();
auto* in_end = in_begin + numel; auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace()); auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::Transform<Place> trans; platform::Transform<DeviceContext> trans;
trans(ctx_, in_begin, in_end, out_begin, trans(ctx_, in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, OutT>()); CastOpTransformFunctor<InT, OutT>());
} }
}; };
template <typename Place, typename InT> template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> { class CastOpKernel : public framework::OpKernel<InT> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -56,7 +56,8 @@ class CastOpKernel : public framework::OpKernel<InT> { ...@@ -56,7 +56,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType( framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_dtype")), static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<Place, InT>(in, out, context.device_context())); CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
} }
}; };
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ChunkEvalKernel : public framework::OpKernel<T> { class ChunkEvalKernel : public framework::OpKernel<T> {
public: public:
struct Segment { struct Segment {
......
...@@ -71,4 +71,5 @@ namespace ops = paddle::operators; ...@@ -71,4 +71,5 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp, REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker); ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>); clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -15,5 +15,6 @@ ...@@ -15,5 +15,6 @@
#include "paddle/operators/clip_by_norm_op.h" #include "paddle/operators/clip_by_norm_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::GPUPlace, float>); clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ClipByNormKernel : public framework::OpKernel<T> { class ClipByNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -38,7 +38,8 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -38,7 +38,8 @@ class ClipByNormKernel : public framework::OpKernel<T> {
auto x = EigenVector<T>::Flatten(*input); auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output); auto out = EigenVector<T>::Flatten(*output);
auto x_norm = x.square().sum().sqrt(); auto x_norm = x.square().sum().sqrt();
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto temp = (x_norm <= max_norm).template cast<T>().eval(); auto temp = (x_norm <= max_norm).template cast<T>().eval();
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm; auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
......
...@@ -83,7 +83,7 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -83,7 +83,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad, REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad); ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip, REGISTER_OP_CPU_KERNEL(
ops::ClipKernel<paddle::platform::CPUPlace, float>); clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(clip_grad, REGISTER_OP_CPU_KERNEL(
ops::ClipGradKernel<paddle::platform::CPUPlace, float>); clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/operators/clip_op.h" #include "paddle/operators/clip_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip, REGISTER_OP_CUDA_KERNEL(
ops::ClipKernel<paddle::platform::GPUPlace, float>); clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_GPU_KERNEL(clip_grad, REGISTER_OP_CUDA_KERNEL(
ops::ClipGradKernel<paddle::platform::GPUPlace, float>); clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -55,7 +55,7 @@ class ClipGradFunctor { ...@@ -55,7 +55,7 @@ class ClipGradFunctor {
T max_; T max_;
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> { class ClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -66,13 +66,13 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -66,13 +66,13 @@ class ClipKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int64_t numel = x->numel(); int64_t numel = x->numel();
Transform<Place> trans; Transform<DeviceContext> trans;
trans(context.device_context(), x_data, x_data + numel, out_data, trans(context.template device_context<DeviceContext>(), x_data,
ClipFunctor<T>(min, max)); x_data + numel, out_data, ClipFunctor<T>(min, max));
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> { class ClipGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -86,9 +86,9 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -86,9 +86,9 @@ class ClipGradKernel : public framework::OpKernel<T> {
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>(); const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
Transform<Place> trans; Transform<DeviceContext> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data, trans(context.template device_context<DeviceContext>(), d_out_data,
d_x_data, ClipGradFunctor<T>(min, max)); d_out_data + numel, x_data, d_x_data, ClipGradFunctor<T>(min, max));
} }
} }
}; };
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#include "paddle/operators/compare_op.h" #include "paddle/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, GPU, REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor); paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, GPU, REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor); paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
...@@ -59,7 +59,7 @@ struct EqualFunctor { ...@@ -59,7 +59,7 @@ struct EqualFunctor {
} }
}; };
template <typename Place, typename Functor> template <typename DeviceContext, typename Functor>
class CompareOpKernel class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> { : public framework::OpKernel<typename Functor::ELEM_TYPE> {
public: public:
...@@ -69,24 +69,23 @@ class CompareOpKernel ...@@ -69,24 +69,23 @@ class CompareOpKernel
auto* y = context.Input<framework::Tensor>("Y"); auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func; Functor binary_func;
platform::Transform<Place> trans; platform::Transform<DeviceContext> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(), trans(context.template device_context<DeviceContext>(), x->data<T>(),
y->data<T>(), out->mutable_data<bool>(context.GetPlace()), x->data<T>() + x->numel(), y->data<T>(),
binary_func); out->mutable_data<bool>(context.GetPlace()), binary_func);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \ #define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \ REGISTER_OP_##dev##_KERNEL( \
op_type, \ op_type, ::paddle::operators::CompareOpKernel< \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ ::paddle::platform::dev##DeviceContext, functor<int>>, \
functor<int>>, \ ::paddle::operators::CompareOpKernel< \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ ::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
functor<int64_t>>, \ ::paddle::operators::CompareOpKernel< \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ ::paddle::platform::dev##DeviceContext, functor<float>>, \
functor<float>>, \ ::paddle::operators::CompareOpKernel< \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ ::paddle::platform::dev##DeviceContext, functor<double>>);
functor<double>>);
...@@ -14,7 +14,8 @@ limitations under the License. */ ...@@ -14,7 +14,8 @@ limitations under the License. */
#include "paddle/operators/concat_op.h" #include "paddle/operators/concat_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(concat, REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::GPUPlace, float>); concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
concat_grad, ops::ConcatGradKernel<paddle::platform::GPUPlace, float>); concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ConcatKernel : public framework::OpKernel<T> { class ConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -43,7 +43,7 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class ConcatKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ConcatGradKernel : public framework::OpKernel<T> { class ConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -57,18 +57,20 @@ REGISTER_OP(conv2d_cudnn, ops::ConvOp, ops::CudnnConv2DOpMaker, ...@@ -57,18 +57,20 @@ REGISTER_OP(conv2d_cudnn, ops::ConvOp, ops::CudnnConv2DOpMaker,
REGISTER_OP(conv3d_cudnn, ops::ConvOp, ops::CudnnConv3DOpMaker, REGISTER_OP(conv3d_cudnn, ops::ConvOp, ops::CudnnConv3DOpMaker,
conv3d_cudnn_grad, ops::ConvOpGrad); conv3d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d_cudnn, REGISTER_OP_CPU_KERNEL(
ops::GemmConvKernel<paddle::platform::CPUPlace, float>, conv2d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>); ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_cudnn_grad, conv2d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(conv3d_cudnn, REGISTER_OP_CPU_KERNEL(
ops::GemmConvKernel<paddle::platform::CPUPlace, float>, conv3d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>); ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_cudnn_grad, conv3d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -118,7 +118,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -118,7 +118,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
} }
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
...@@ -238,7 +239,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -238,7 +239,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
auto handle = ctx.cuda_device_context().cudnn_handle(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
if (input_grad) { if (input_grad) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
...@@ -313,16 +315,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -313,16 +315,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_GPU_KERNEL(conv2d_cudnn, REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>); paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_cudnn_grad, REGISTER_OP_CUDA_KERNEL(conv2d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>, paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>); paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_cudnn, REGISTER_OP_CUDA_KERNEL(conv3d_cudnn,
paddle::operators::CudnnConvOpKernel<float>, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>); paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_cudnn_grad, REGISTER_OP_CUDA_KERNEL(conv3d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>, paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>); paddle::operators::CudnnConvGradOpKernel<double>);
...@@ -235,16 +235,18 @@ namespace ops = paddle::operators; ...@@ -235,16 +235,18 @@ namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>, conv2d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>, conv3d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,16 +16,18 @@ ...@@ -16,16 +16,18 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d, REGISTER_OP_CUDA_KERNEL(
ops::GemmConvKernel<paddle::platform::GPUPlace, float>, conv2d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>); ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>, conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>); ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL(conv3d, REGISTER_OP_CUDA_KERNEL(
ops::GemmConvKernel<paddle::platform::GPUPlace, float>, conv3d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>); ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>, conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>); ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -72,7 +72,7 @@ class ConvOpGrad : public framework::OperatorWithKernel { ...@@ -72,7 +72,7 @@ class ConvOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class GemmConvKernel : public framework::OpKernel<T> { class GemmConvKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -141,9 +141,10 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -141,9 +141,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups; int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups; int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -157,27 +158,26 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -157,27 +158,26 @@ class GemmConvKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) { } else if (data_dim == 2U) {
// im2col // im2col
im2col(context.device_context(), in_slice, dilations, strides, im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col // vol2col
vol2col(context.device_context(), in_slice, dilations, strides, vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
paddings, &col);
} }
// gemm // gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false, math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
col_matrix, false, T(1.0), &out_slice, T(0.0)); false, T(1.0), &out_slice, T(0.0));
} }
} }
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class GemmConvGradKernel : public framework::OpKernel<T> { class GemmConvGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -256,14 +256,15 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -256,14 +256,15 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
math::SetConstant<Place, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(dev_ctx, input_grad, static_cast<T>(0));
math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
...@@ -282,18 +283,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -282,18 +283,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice); col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
math::matmul<Place, T>(context.device_context(), filter_slice, true, math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix, out_grad_slice, false, T(1.0),
T(0.0)); &col_matrix, T(0.0));
if (is_expand && data_dim == 2U) { if (is_expand && data_dim == 2U) {
col2im(context.device_context(), col, dilations, strides, col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&in_grad_slice); &in_grad_slice);
} else if (is_expand && data_dim == 3U) { } else if (is_expand && data_dim == 3U) {
col2vol(context.device_context(), col, dilations, strides, paddings, col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice);
&in_grad_slice);
} }
} }
} }
...@@ -303,9 +303,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -303,9 +303,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0)); set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch = Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape); output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -321,21 +321,20 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -321,21 +321,20 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) { } else if (data_dim == 2U) {
im2col(context.device_context(), in_slice, dilations, strides, im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
vol2col(context.device_context(), in_slice, dilations, strides, vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
paddings, &col);
} }
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step); filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice, math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
false, col_matrix, true, T(1.0), col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0)); &filter_grad_slice, T(1.0));
} }
} }
} }
......
...@@ -111,7 +111,8 @@ __global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width, ...@@ -111,7 +111,8 @@ __global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width,
} // namespace } // namespace
template <typename T> template <typename T>
class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { class ConvShiftKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X"); const Tensor *X = context.Input<Tensor>("X");
...@@ -132,7 +133,8 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -132,7 +133,8 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
dim3 grid_dim(num_x_blocks, batch_size); dim3 grid_dim(num_x_blocks, batch_size);
auto stream = context.cuda_device_context().stream(); auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>( ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data); x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data);
...@@ -140,7 +142,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -140,7 +142,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class ConvShiftGradKernel<platform::GPUPlace, T> class ConvShiftGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -159,8 +161,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T> ...@@ -159,8 +161,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int y_width = Y->dims()[1]; int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2; int y_half_width = (y_width - 1) / 2;
auto &device_ctx = context.cuda_device_context(); auto &device_ctx =
math::SetConstant<platform::GPUPlace, T> zero; context.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = DivUp(x_width, x_per_block); int num_x_blocks = DivUp(x_width, x_per_block);
...@@ -186,8 +189,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T> ...@@ -186,8 +189,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv_shift, REGISTER_OP_CUDA_KERNEL(
ops::ConvShiftKernel<paddle::platform::GPUPlace, float>); conv_shift,
REGISTER_OP_GPU_KERNEL( ops::ConvShiftKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
conv_shift_grad, conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::GPUPlace, float>); ops::ConvShiftGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ConvShiftKernel : public framework::OpKernel<T> { class ConvShiftKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override; void Compute(const framework::ExecutionContext &context) const override;
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ConvShiftGradKernel : public framework::OpKernel<T> { class ConvShiftGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override; void Compute(const framework::ExecutionContext &context) const override;
......
...@@ -61,12 +61,13 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, ...@@ -61,12 +61,13 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn, conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
...@@ -74,9 +75,10 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, ...@@ -74,9 +75,10 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn, conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad, conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -83,7 +83,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
} }
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t algo; cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
// Get the algorithm // Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc, handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
...@@ -165,7 +166,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -165,7 +166,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit = user_workspace_size * 1024 * 1024; workspace_size_limit = user_workspace_size * 1024 * 1024;
} }
auto handle = ctx.cuda_device_context().cudnn_handle(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
if (input_grad) { if (input_grad) {
// choose backward algorithm for data // choose backward algorithm for data
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
...@@ -234,16 +236,16 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -234,16 +236,16 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>, ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>); ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>, ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>); ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>, ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>); ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>, ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>); ops::CudnnConvTransposeGradOpKernel<double>);
...@@ -197,21 +197,23 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, ...@@ -197,21 +197,23 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad); conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -16,20 +16,24 @@ ...@@ -16,20 +16,24 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>); float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>, ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>); ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>); float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
double>);
...@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class GemmConvTransposeKernel : public framework::OpKernel<T> { class GemmConvTransposeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -109,11 +109,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -109,11 +109,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0)); auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<DeviceContext, T> col2vol;
std::vector<int> dilations({1, 1, 1}); std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
...@@ -127,29 +128,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -127,29 +128,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
input_batch, false, static_cast<T>(1.0), static_cast<T>(1.0), &col_matrix,
&col_matrix, static_cast<T>(0.0)); static_cast<T>(0.0));
if (data_dim == 2U) { if (data_dim == 2U) {
// col2im: col_matrix -> dy // col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col, col2im(dev_ctx, col, std::vector<int>{dilations[0], dilations[1]},
std::vector<int>{dilations[0], dilations[1]}, strides, strides, std::vector<int>{paddings[0], paddings[1], paddings[0],
std::vector<int>{paddings[0], paddings[1], paddings[0], paddings[1]},
paddings[1]},
&output_batch); &output_batch);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), col, dilations, strides, paddings, col2vol(dev_ctx, col, dilations, strides, paddings, &output_batch);
&output_batch);
} }
} }
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class GemmConvTransposeGradKernel : public framework::OpKernel<T> { class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -206,6 +205,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -206,6 +205,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// convolution transpose grad on input: // convolution transpose grad on input:
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
auto& dev_ctx = context.template device_context<DeviceContext>();
if (input_grad || filter_grad) { if (input_grad || filter_grad) {
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
...@@ -217,19 +217,19 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -217,19 +217,19 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
Tensor filter_grad_; Tensor filter_grad_;
math::SetConstant<Place, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<DeviceContext, T> vol2col;
std::vector<int> dilations({1, 1, 1}); std::vector<int> dilations({1, 1, 1});
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0)); set_zero(dev_ctx, input_grad, static_cast<T>(0));
} }
if (filter_grad) { // filter size (m, c, k_h, k_w) if (filter_grad) { // filter size (m, c, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_grad, static_cast<T>(0)); set_zero(dev_ctx, filter_grad, static_cast<T>(0));
filter_grad_ = *filter_grad; filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
} }
...@@ -242,7 +242,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -242,7 +242,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (data_dim == 2U) { if (data_dim == 2U) {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, im2col(dev_ctx, output_grad_batch,
std::vector<int>{dilations[0], dilations[1]}, strides, std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
...@@ -250,8 +250,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -250,8 +250,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, dilations, vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
strides, paddings, &col); &col);
} }
if (input_grad) { if (input_grad) {
...@@ -263,9 +263,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -263,9 +263,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w) // d, h, w)
math::matmul<Place, T>(context.device_context(), filter, false, math::matmul<DeviceContext, T>(
col_matrix, false, static_cast<T>(1.0), dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0)); &input_grad_batch, static_cast<T>(0.0));
} }
if (filter_grad) { if (filter_grad) {
// input batch // input batch
...@@ -275,9 +275,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -275,9 +275,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w) // k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false, math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
col_matrix, true, static_cast<T>(1.0), true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0)); &filter_grad_, static_cast<T>(1.0));
} }
} }
} }
......
...@@ -155,7 +155,8 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -155,7 +155,8 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, cos_sim_grad, REGISTER_OP(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, cos_sim_grad,
ops::CosSimOpGrad); ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL(cos_sim,
ops::CosSimKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::CPUPlace, float>); cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
cos_sim_grad,
ops::CosSimGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include "paddle/operators/cos_sim_op.h" #include "paddle/operators/cos_sim_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cos_sim, REGISTER_OP_CUDA_KERNEL(
ops::CosSimKernel<paddle::platform::GPUPlace, float>); cos_sim, ops::CosSimKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::GPUPlace, float>); cos_sim_grad,
ops::CosSimGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> { class CosSimKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -51,7 +51,8 @@ class CosSimKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,8 @@ class CosSimKernel : public framework::OpKernel<T> {
auto y_norm = EigenVector<T>::Flatten(*out_y_norm); auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
// compute // compute
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}}); auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt(); x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt(); y_norm.device(place) = y.square().sum(row_along).sqrt();
...@@ -66,7 +67,7 @@ class CosSimKernel : public framework::OpKernel<T> { ...@@ -66,7 +67,7 @@ class CosSimKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> { class CosSimGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -96,7 +97,8 @@ class CosSimGradKernel : public framework::OpKernel<T> { ...@@ -96,7 +97,8 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto z_bcast = z.broadcast(bcast_cols); auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols); auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (rows_x == rows_y) { if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
......
...@@ -135,5 +135,6 @@ namespace ops = paddle::operators; ...@@ -135,5 +135,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp, REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp,
ops::CRFDecodingOpMaker); ops::CRFDecodingOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crf_decoding, ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, float>, crf_decoding,
ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, double>); ops::CRFDecodingOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CRFDecodingOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -24,7 +24,7 @@ using framework::LoDTensor; ...@@ -24,7 +24,7 @@ using framework::LoDTensor;
using framework::LoD; using framework::LoD;
using framework::Tensor; using framework::Tensor;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> { class CRFDecodingOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -44,8 +44,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -44,8 +44,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
const size_t seq_num = lod[level].size() - 1; const size_t seq_num = lod[level].size() - 1;
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace()); int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
math::SetConstant<platform::CPUPlace, int64_t>()(ctx.device_context(), math::SetConstant<DeviceContext, int64_t>()(
decoded_path, 0); ctx.template device_context<DeviceContext>(), decoded_path, 0);
for (size_t i = 0; i < seq_num; ++i) { for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(lod[level][i]); int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]); int end_pos = static_cast<int>(lod[level][i + 1]);
......
...@@ -133,5 +133,5 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -133,5 +133,5 @@ class CropOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad); REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>); REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(crop_grad, REGISTER_OP_CPU_KERNEL(
ops::CropGradKernel<paddle::platform::CPUPlace, float>); crop_grad, ops::CropGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
#include "paddle/operators/crop_op.h" #include "paddle/operators/crop_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>); REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_GPU_KERNEL(crop_grad, REGISTER_OP_CUDA_KERNEL(
ops::CropGradKernel<paddle::platform::GPUPlace, float>); crop_grad, ops::CropGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -49,7 +49,7 @@ class CropKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class CropKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Place, typename T, size_t D> template <typename DeviceContext, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) { void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) { if (d_x != nullptr) {
...@@ -63,12 +63,13 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -63,12 +63,13 @@ void CropGradFunction(const framework::ExecutionContext& context) {
} }
auto d_x_tensor = EigenTensor<T, D>::From(*d_x); auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out); auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
d_x_tensor.device(context.GetEigenDevice<Place>()) = d_x_tensor.device(
*context.template device_context<DeviceContext>().eigen_device()) =
d_out_tensor.pad(paddings, 0); d_out_tensor.pad(paddings, 0);
} }
} }
template <typename Place, typename T> template <typename DeviceContext, typename T>
class CropGradKernel : public framework::OpKernel<T> { class CropGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -76,22 +77,22 @@ class CropGradKernel : public framework::OpKernel<T> { ...@@ -76,22 +77,22 @@ class CropGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size(); context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) { switch (rank) {
case 1: case 1:
CropGradFunction<Place, T, 1>(context); CropGradFunction<DeviceContext, T, 1>(context);
break; break;
case 2: case 2:
CropGradFunction<Place, T, 2>(context); CropGradFunction<DeviceContext, T, 2>(context);
break; break;
case 3: case 3:
CropGradFunction<Place, T, 3>(context); CropGradFunction<DeviceContext, T, 3>(context);
break; break;
case 4: case 4:
CropGradFunction<Place, T, 4>(context); CropGradFunction<DeviceContext, T, 4>(context);
break; break;
case 5: case 5:
CropGradFunction<Place, T, 5>(context); CropGradFunction<DeviceContext, T, 5>(context);
break; break;
case 6: case 6:
CropGradFunction<Place, T, 6>(context); CropGradFunction<DeviceContext, T, 6>(context);
break; break;
default: default:
PADDLE_THROW( PADDLE_THROW(
......
...@@ -53,8 +53,9 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -53,8 +53,9 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
Tensor* y = ctx.Output<Tensor>("Y"); Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()( math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
ctx.device_context(), y, x, label, ctx.Attr<bool>("soft_label")); ctx.template device_context<platform::CUDADeviceContext>(), y, x, label,
ctx.Attr<bool>("soft_label"));
} }
}; };
...@@ -80,15 +81,17 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -80,15 +81,17 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
int block = 512; int block = 512;
int grid = (batch_size * class_num + block - 1) / block; int grid = (batch_size * class_num + block - 1) / block;
auto stream = ctx.cuda_device_context().stream();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
dx_data, dy_data, x_data, label_data, batch_size, class_num); dx_data, dy_data, x_data, label_data, batch_size, class_num);
} else { } else {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(ctx.device_context(), dx, 0); functor(dev_ctx, dx, 0);
auto* label_data = label->data<int64_t>(); auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block; grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
...@@ -101,8 +104,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -101,8 +104,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>); ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad, REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>, ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>); ops::CrossEntropyGradientOpCUDAKernel<double>);
...@@ -37,8 +37,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> { ...@@ -37,8 +37,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
Tensor* y = ctx.Output<Tensor>("Y"); Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()( math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
ctx.device_context(), y, x, labels, ctx.Attr<bool>("soft_label")); ctx.template device_context<platform::CPUDeviceContext>(), y, x, labels,
ctx.Attr<bool>("soft_label"));
} }
}; };
...@@ -61,7 +62,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -61,7 +62,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
auto lbl_mat = EigenMatrix<T>::From(*label); auto lbl_mat = EigenMatrix<T>::From(*label);
auto dx_mat = EigenMatrix<T>::From(*dx); auto dx_mat = EigenMatrix<T>::From(*dx);
dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) = dx_mat.device(*ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device()) =
-(lbl_mat * -(lbl_mat *
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat); dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
} else { } else {
...@@ -70,8 +72,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -70,8 +72,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
math::SetConstant<platform::CPUPlace, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(ctx.device_context(), dx, 0); functor(ctx.template device_context<platform::CPUDeviceContext>(), dx, 0);
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
......
...@@ -99,4 +99,4 @@ REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp, ...@@ -99,4 +99,4 @@ REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp,
ops::DecayedAdagradOpMaker); ops::DecayedAdagradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
decayed_adagrad, decayed_adagrad,
ops::DecayedAdagradOpKernel<paddle::platform::CPUPlace, float>); ops::DecayedAdagradOpKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -16,6 +16,6 @@ ...@@ -16,6 +16,6 @@
#include "paddle/operators/decayed_adagrad_op.h" #include "paddle/operators/decayed_adagrad_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
decayed_adagrad, decayed_adagrad,
ops::DecayedAdagradOpKernel<paddle::platform::GPUPlace, float>); ops::DecayedAdagradOpKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class DecayedAdagradOpKernel : public framework::OpKernel<T> { class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -43,7 +43,7 @@ class DecayedAdagradOpKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class DecayedAdagradOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = decay * moment + (1 - decay) * grad * grad; moment_out.device(place) = decay * moment + (1 - decay) * grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
......
...@@ -100,6 +100,8 @@ namespace ops = paddle::operators; ...@@ -100,6 +100,8 @@ namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad, REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>); ops::DropoutOpGrad<float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float, float>); dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>); dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -58,7 +58,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -58,7 +58,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>(); auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
...@@ -80,7 +80,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -80,7 +80,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float, float>); dropout,
REGISTER_OP_GPU_KERNEL( ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_CUDA_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType> template <typename DeviceContext, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel<T> { class CPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -55,13 +55,14 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -55,13 +55,14 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * dropout_prob; Y.device(place) = X * dropout_prob;
} }
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> { class DropoutGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -77,7 +78,8 @@ class DropoutGradKernel : public framework::OpKernel<T> { ...@@ -77,7 +78,8 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1); auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1); auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
dX.device(place) = dY * M; dX.device(place) = dY * M;
} }
}; };
......
...@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, ...@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad); elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -24,7 +24,7 @@ struct AddFunctor { ...@@ -24,7 +24,7 @@ struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<AddFunctor<T>, T, Place> functor( TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
x, y, z, ctx.device_context(), AddFunctor<T>()); x, y, z, ctx.template device_context<DeviceContext>(), AddFunctor<T>());
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
...@@ -137,11 +137,11 @@ struct ElementwiseAddBroadCast2GradFunctor { ...@@ -137,11 +137,11 @@ struct ElementwiseAddBroadCast2GradFunctor {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddOneGradFunctor<T>, ElementwiseAddOneGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>, ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx); ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
......
...@@ -35,13 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, ...@@ -35,13 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
elementwise_div_grad, ops::ElementwiseOpGrad); elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx); ElementwiseCompute<EigenDivFunctor, DeviceContext, T>(ctx);
} }
}; };
...@@ -102,11 +102,11 @@ struct ElementwiseDivBroadCast2GradFunctor { ...@@ -102,11 +102,11 @@ struct ElementwiseDivBroadCast2GradFunctor {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> { class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivGradFunctor<T>, ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>, ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx); ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
......
...@@ -36,13 +36,13 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, ...@@ -36,13 +36,13 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad); elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx); ElementwiseCompute<EigenMulFunctor, DeviceContext, T>(ctx);
} }
}; };
...@@ -101,11 +101,11 @@ struct ElementwiseMulBroadCast2GradFunctor { ...@@ -101,11 +101,11 @@ struct ElementwiseMulBroadCast2GradFunctor {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulGradFunctor<T>, ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>, ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx); ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
......
...@@ -59,17 +59,17 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -59,17 +59,17 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
template <typename T, typename Place> template <typename T, typename DeviceContext>
class RowwiseTransformIterator; class RowwiseTransformIterator;
template <typename T, typename Place> template <typename T, typename DeviceContext>
class MidWiseTransformIterator; class MidWiseTransformIterator;
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CPUPlace> { class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
public: public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() { RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++i_; ++i_;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
...@@ -77,13 +77,13 @@ class RowwiseTransformIterator<T, platform::CPUPlace> { ...@@ -77,13 +77,13 @@ class RowwiseTransformIterator<T, platform::CPUPlace> {
return *this; return *this;
} }
bool operator==( bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { rhs) const {
return (ptr_ + i_) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=( bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { rhs) const {
return (ptr_ + i_) != &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
...@@ -96,12 +96,12 @@ class RowwiseTransformIterator<T, platform::CPUPlace> { ...@@ -96,12 +96,12 @@ class RowwiseTransformIterator<T, platform::CPUPlace> {
}; };
template <typename T> template <typename T>
class MidWiseTransformIterator<T, platform::CPUPlace> { class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
public: public:
MidWiseTransformIterator(const T* ptr, int n, int post) MidWiseTransformIterator(const T* ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() { MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++j_; ++j_;
i_ = j_ / post_; i_ = j_ / post_;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
...@@ -111,13 +111,13 @@ class MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -111,13 +111,13 @@ class MidWiseTransformIterator<T, platform::CPUPlace> {
return *this; return *this;
} }
bool operator==( bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { rhs) const {
return (ptr_ + i_) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=( bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { rhs) const {
return (ptr_ + i_) != &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
...@@ -133,12 +133,12 @@ class MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -133,12 +133,12 @@ class MidWiseTransformIterator<T, platform::CPUPlace> {
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::GPUPlace> class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> { RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
public: public:
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t; super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n) HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){}; : super_t(x), begin_(x), n_(n){};
...@@ -153,12 +153,12 @@ class RowwiseTransformIterator<T, platform::GPUPlace> ...@@ -153,12 +153,12 @@ class RowwiseTransformIterator<T, platform::GPUPlace>
}; };
template <typename T> template <typename T>
class MidWiseTransformIterator<T, platform::GPUPlace> class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> { MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
public: public:
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t; super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){}; : super_t(x), begin_(x), n_(n), post_(post){};
...@@ -174,12 +174,11 @@ class MidWiseTransformIterator<T, platform::GPUPlace> ...@@ -174,12 +174,11 @@ class MidWiseTransformIterator<T, platform::GPUPlace>
}; };
#endif #endif
template <typename Functor, typename T, typename Place> template <typename Functor, typename T, typename DeviceContext>
class TransformFunctor { class TransformFunctor {
public: public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const platform::DeviceContext& ctx, framework::Tensor* z, const DeviceContext& ctx, Functor func)
Functor func)
: x_(x->data<T>()), : x_(x->data<T>()),
y_(y->data<T>()), y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())), z_(z->mutable_data<T>(ctx.GetPlace())),
...@@ -188,20 +187,20 @@ class TransformFunctor { ...@@ -188,20 +187,20 @@ class TransformFunctor {
func_(func) {} func_(func) {}
inline void Run() const { inline void Run() const {
platform::Transform<Place> trans; platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_); trans(ctx_, x_, x_ + nx_, y_, z_, func_);
} }
inline void RunRowWise(int n, int pre) const { inline void RunRowWise(int n, int pre) const {
platform::Transform<Place> trans; platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_, trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
func_); z_, func_);
} }
inline void RunMidWise(int n, int pre, int post) const { inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<Place> trans; platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post), trans(ctx_, x_, x_ + nx_,
z_, func_); MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
} }
private: private:
...@@ -209,22 +208,24 @@ class TransformFunctor { ...@@ -209,22 +208,24 @@ class TransformFunctor {
const T* y_; const T* y_;
T* z_; T* z_;
int64_t nx_; int64_t nx_;
const platform::DeviceContext& ctx_; const DeviceContext& ctx_;
Functor func_; Functor func_;
}; };
#define EIGEN_FUNCTOR(name, eigen_op) \ #define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \ struct Eigen##name##Functor { \
template <typename Place, typename T> \ template <typename DeviceContext, typename T> \
inline void Run(const framework::Tensor* x, const framework::Tensor* y, \ inline void Run(const framework::Tensor* x, const framework::Tensor* y, \
framework::Tensor* z, \ framework::Tensor* z, \
const framework::ExecutionContext& ctx) { \ const framework::ExecutionContext& ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \ auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \ auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \ auto z_e = framework::EigenVector<T>::Flatten(*z); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_e); \ z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_e); \
} \ } \
template <typename Place, typename T> \ template <typename DeviceContext, typename T> \
inline void RunBroadCast(const framework::Tensor* x, \ inline void RunBroadCast(const framework::Tensor* x, \
const framework::Tensor* y, framework::Tensor* z, \ const framework::Tensor* y, framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \ const framework::ExecutionContext& ctx, int pre, \
...@@ -235,9 +236,11 @@ class TransformFunctor { ...@@ -235,9 +236,11 @@ class TransformFunctor {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) \ auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) \
.broadcast(Eigen::DSizes<int, 2>(pre, 1)) \ .broadcast(Eigen::DSizes<int, 2>(pre, 1)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \ .reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \ z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \ } \
template <typename Place, typename T> \ template <typename DeviceContext, typename T> \
inline void RunBroadCast2(const framework::Tensor* x, \ inline void RunBroadCast2(const framework::Tensor* x, \
const framework::Tensor* y, \ const framework::Tensor* y, \
framework::Tensor* z, \ framework::Tensor* z, \
...@@ -249,11 +252,13 @@ class TransformFunctor { ...@@ -249,11 +252,13 @@ class TransformFunctor {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) \ auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) \
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) \ .broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \ .reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \ z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \ } \
} }
template <class functor, typename Place, typename T> template <class functor, typename DeviceContext, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) { void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
...@@ -269,7 +274,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { ...@@ -269,7 +274,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
if (x_dims == y_dims) { if (x_dims == y_dims) {
functor f; functor f;
f.template Run<Place, T>(x, y, z, ctx); f.template Run<DeviceContext, T>(x, y, z, ctx);
return; return;
} }
...@@ -282,11 +287,11 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { ...@@ -282,11 +287,11 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
get_mid_dims(x_dims, y_dims, axis, pre, n, post); get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) { if (post == 1) {
functor f; functor f;
f.template RunBroadCast<Place, T>(x, y, z, ctx, pre, n); f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
return; return;
} else { } else {
functor f; functor f;
f.template RunBroadCast2<Place, T>(x, y, z, ctx, pre, n, post); f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
return; return;
} }
} }
...@@ -303,8 +308,9 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL); ...@@ -303,8 +308,9 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y)) #define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV); EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename Place, typename T, typename functor, typename functor1, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename functor1, typename broadcastfunctor,
typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
...@@ -313,7 +319,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -313,7 +319,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
auto* out = ctx.Input<Tensor>("Out"); auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
......
...@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker, ...@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
elementwise_sub_grad, ops::ElementwiseOpGrad); elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, double>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int>, ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int64_t>); ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad, elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, float>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, double>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int>, ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int64_t>); ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx); ElementwiseCompute<EigenSubFunctor, DeviceContext, T>(ctx);
} }
}; };
...@@ -101,11 +101,11 @@ struct ElementwiseSubBroadCast2GradFunctor { ...@@ -101,11 +101,11 @@ struct ElementwiseSubBroadCast2GradFunctor {
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> { class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubOneGradFunctor<T>, ElementwiseSubOneGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>, ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx); ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
......
...@@ -130,7 +130,8 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -130,7 +130,8 @@ class ExpandGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(expand, ops::ExpandOp, ops::ExpandOpMaker, expand_grad, REGISTER_OP(expand, ops::ExpandOp, ops::ExpandOpMaker, expand_grad,
ops::ExpandGradOp); ops::ExpandGradOp);
REGISTER_OP_CPU_KERNEL(expand,
ops::ExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand_grad, ops::ExpandGradKernel<paddle::platform::CPUPlace, float>); expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include "paddle/operators/expand_op.h" #include "paddle/operators/expand_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(expand, REGISTER_OP_CUDA_KERNEL(
ops::ExpandKernel<paddle::platform::GPUPlace, float>); expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
expand_grad, ops::ExpandGradKernel<paddle::platform::GPUPlace, float>); expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -56,7 +56,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor, ...@@ -56,7 +56,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ExpandKernel : public framework::OpKernel<T> { class ExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -83,12 +83,13 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -83,12 +83,13 @@ class ExpandKernel : public framework::OpKernel<T> {
auto x = EigenTensor<T, Rank>::From(*in0); auto x = EigenTensor<T, Rank>::From(*in0);
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0); auto y = EigenTensor<T, Rank>::From(*out0);
auto place = context.GetEigenDevice<Place>(); auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims); y.device(place) = x.broadcast(bcast_dims);
} }
}; };
template <typename Place, typename T> template <typename DeviceContext, typename T>
class ExpandGradKernel : public framework::OpKernel<T> { class ExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -164,7 +165,8 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -164,7 +165,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
reduce_dims[i] = reduce_dims_vec[i]; reduce_dims[i] = reduce_dims_vec[i];
} }
auto out_grad = EigenVector<T>::Flatten(*in0); auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(context.GetEigenDevice<Place>()) = x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions()); out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
} }
}; };
......
...@@ -100,8 +100,11 @@ REGISTER_OPERATOR(fill_constant_batch_size_like, ...@@ -100,8 +100,11 @@ REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpMaker); ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, int>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int64_t>); int64_t>);
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_constant_batch_size_like, fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, int>, ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>);
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -27,8 +27,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { ...@@ -27,8 +27,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value"); auto value = ctx.Attr<float>("value");
math::SetConstant<Place, T> setter; math::SetConstant<DeviceContext, T> setter;
setter(ctx.device_context(), out, static_cast<T>(value)); setter(ctx.template device_context<DeviceContext>(), out,
static_cast<T>(value));
} }
}; };
......
...@@ -54,8 +54,9 @@ namespace ops = paddle::operators; ...@@ -54,8 +54,9 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker); ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int>, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, double>, ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, bool>); ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int>, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
...@@ -19,15 +19,16 @@ limitations under the License. */ ...@@ -19,15 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename DeviceContext, typename T>
class FillZerosLikeKernel : public framework::OpKernel<T> { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Y");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> setter; math::SetConstant<DeviceContext, T> setter;
setter(context.device_context(), out, static_cast<T>(0)); setter(context.template device_context<DeviceContext>(), out,
static_cast<T>(0));
} }
}; };
......
...@@ -135,5 +135,5 @@ The paper that proposed Follow The Regularized Leader (FTRL): ...@@ -135,5 +135,5 @@ The paper that proposed Follow The Regularized Leader (FTRL):
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker); REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
REGISTER_OP_CPU_KERNEL(ftrl, REGISTER_OP_CPU_KERNEL(
ops::FTRLOpKernel<paddle::platform::CPUPlace, float>); ftrl, ops::FTRLOpKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -15,5 +15,5 @@ specific language governing permissions and limitations under the License. */ ...@@ -15,5 +15,5 @@ specific language governing permissions and limitations under the License. */
#include "paddle/operators/ftrl_op.h" #include "paddle/operators/ftrl_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ftrl, REGISTER_OP_CUDA_KERNEL(
ops::FTRLOpKernel<paddle::platform::GPUPlace, float>); ftrl, ops::FTRLOpKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -24,7 +24,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -24,7 +24,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename DeviceContext, typename T>
class FTRLOpKernel : public framework::OpKernel<T> { class FTRLOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -53,7 +53,7 @@ class FTRLOpKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,7 @@ class FTRLOpKernel : public framework::OpKernel<T> {
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out); auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out);
auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out); auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out);
auto place = ctx.GetEigenDevice<Place>(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(grad->numel()); Eigen::DSizes<int, 1> grad_dsize(grad->numel());
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
using framework::Tensor; using framework::Tensor;
using platform::Place; using platform::DeviceContext;
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
......
...@@ -49,7 +49,8 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,8 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::GPUPlace>(); auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX); GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
...@@ -60,5 +61,5 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -60,5 +61,5 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>); REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>); REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
...@@ -53,7 +53,8 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -53,7 +53,8 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::CPUPlace>(); auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX); ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册