未验证 提交 61ec0b95 编写于 作者: Q QI JUN 提交者: GitHub

Refine device context (#6433)

There are mainly following fixes:

- take `DeviceContext` as the template parameter of math functors and OpKernel instead of `Place`
- remove `eigen_device` interface in base class  `DeviceContext`
- remove `GetEigenDevice` interface in `ExecutionContext` and base class `DeviceContext`
- remove unused `platform::EigenDeviceConverter`
- rename `REGISTER_OP_GPU_KERNEL` to `REGISTER_OP_CUDA_KERNEL`
- rename `USE_GPU_ONLY_OP` to `USE_CUDA_ONLY_OP`
上级 7902ad65
......@@ -181,8 +181,8 @@ class OpKernelRegistrar : public Registrar {
return 0; \
}
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
......@@ -217,7 +217,7 @@ class OpKernelRegistrar : public Registrar {
#else
#define USE_OP_KERNEL(op_type) \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_DEVICE_KERNEL(op_type, GPU)
USE_OP_DEVICE_KERNEL(op_type, CUDA)
#endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
......@@ -226,9 +226,9 @@ class OpKernelRegistrar : public Registrar {
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_GPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, GPU)
#define USE_CUDA_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CUDA)
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
......
......@@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.GetEigenDevice<platform::GPUPlace>();
}
#endif
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
......@@ -429,7 +415,7 @@ void OperatorWithKernel::Run(const Scope& scope,
}
OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.device_context());
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
......
......@@ -276,17 +276,25 @@ class ExecutionContext {
out_tensor->set_lod(in_tensor.lod());
}
template <typename PlaceType,
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
template <typename DeviceContextType>
const DeviceContextType& device_context() const {
return *reinterpret_cast<const DeviceContextType*>(&device_context_);
}
const platform::DeviceContext& device_context() const {
return device_context_;
}
#ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_);
}
#endif
//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
......@@ -297,14 +305,6 @@ class ExecutionContext {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA
const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
return *reinterpret_cast<const platform::CUDADeviceContext*>(
&device_context_);
}
#endif
private:
const OperatorBase& op_;
const Scope& scope_;
......
......@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.device_context());
return OpKernelType(DataType::FP32, ctx.GetPlace());
}
};
......
......@@ -138,7 +138,7 @@ function(op_library TARGET)
if ("${TARGET}" STREQUAL "nccl_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(ncclAllReduce);\n")
endif()
# reduce_op contains several operators
......
......@@ -57,7 +57,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
ctx.device_context());
ctx.GetPlace());
}
};
......
......@@ -104,5 +104,6 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): types of T is for inference data.
// label data is always int64
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(accuracy,
paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>);
......@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AccuracyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -611,16 +611,17 @@ REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CPUPlace, \
ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUPlace, \
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -17,16 +17,17 @@
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::GPUPlace, \
ops::functor<double>>); \
REGISTER_OP_GPU_KERNEL( \
act_type##_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::GPUPlace, \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
......@@ -19,7 +19,7 @@
namespace paddle {
namespace operators {
template <typename Place, typename Functor>
template <typename DeviceContext, typename Functor>
class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
......@@ -32,18 +32,19 @@ class ActivationKernel
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y);
functor(*place, x, y);
}
};
template <typename Place, typename Functor>
template <typename DeviceContext, typename Functor>
class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
......@@ -59,13 +60,14 @@ class ActivationGradKernel
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y, dy, dx);
functor(*place, x, y, dy, dx);
}
};
......
......@@ -109,5 +109,5 @@ $$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUPlace, double>);
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,6 +16,6 @@
#include "paddle/operators/adadelta_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdadeltaOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -51,7 +51,7 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
framework::EigenVector<T>::Flatten(*avg_squared_grad_out_tensor);
auto avg_squared_update_out =
framework::EigenVector<T>::Flatten(*avg_squared_update_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
avg_squared_grad_out.device(place) =
rho * avg_squared_grad + (1 - rho) * grad.square();
......
......@@ -100,8 +100,8 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
} // namespace
template <typename T>
struct SparseAdagradFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
......@@ -120,7 +120,7 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CPUPlace, T> constant_functor;
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
......@@ -144,9 +144,9 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.GetEigenDevice<platform::CPUPlace>()) = gm * gm;
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CPUPlace, T> functor;
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment);
// 3. update parameter
......@@ -164,13 +164,13 @@ struct SparseAdagradFunctor<platform::CPUPlace, T> {
}
};
template struct SparseAdagradFunctor<platform::CPUPlace, float>;
template struct SparseAdagradFunctor<platform::CPUPlace, double>;
template struct SparseAdagradFunctor<platform::CPUDeviceContext, float>;
template struct SparseAdagradFunctor<platform::CPUDeviceContext, double>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
REGISTER_OP_CPU_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::CPUPlace, float>,
ops::AdagradOpKernel<paddle::platform::CPUPlace, double>);
adagrad, ops::AdagradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdagradOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -72,8 +72,8 @@ __global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows,
} // namespace
template <typename T>
struct SparseAdagradFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
......@@ -92,7 +92,7 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::GPUPlace, T> constant_functor;
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
......@@ -119,9 +119,9 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.GetEigenDevice<platform::GPUPlace>()) = gm * gm;
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::GPUPlace, T> functor;
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment);
// 3. update parameter
......@@ -139,13 +139,13 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
}
};
template struct SparseAdagradFunctor<platform::GPUPlace, float>;
template struct SparseAdagradFunctor<platform::GPUPlace, double>;
template struct SparseAdagradFunctor<platform::CUDADeviceContext, float>;
template struct SparseAdagradFunctor<platform::CUDADeviceContext, double>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::GPUPlace, float>,
ops::AdagradOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
adagrad, ops::AdagradOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdagradOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -19,15 +19,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
struct SparseAdagradFunctor {
void operator()(const platform::DeviceContext& context,
void operator()(const DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param);
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -52,11 +52,11 @@ class AdagradOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = moment + grad * grad;
moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param_out.device(*place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
......@@ -65,8 +65,9 @@ class AdagradOpKernel : public framework::OpKernel<T> {
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<Place, T> functor;
functor(ctx.device_context(), *ctx.Input<framework::SelectedRows>("Grad"),
SparseAdagradFunctor<DeviceContext, T> functor;
functor(ctx.template device_context<DeviceContext>(),
*ctx.Input<framework::SelectedRows>("Grad"),
*ctx.Input<framework::Tensor>("LearningRate"), epsilon,
moment_out_tensor, param_out_tensor);
} else {
......
......@@ -128,6 +128,6 @@ $$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,6 +16,6 @@
#include "paddle/operators/adam_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -52,17 +52,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square();
moment1_out.device(*place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(*place) = beta2 * moment2 + (1 - beta2) * grad.square();
// All of these are tensors of 1 element
auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow);
// Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
param_out.device(place) =
param_out.device(*place) =
param -
lr_t.broadcast(m_dsize) *
(moment1_out / (moment2_out.sqrt() + epsilon));
......
......@@ -127,6 +127,6 @@ division by 0 error.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
adamax, ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,6 +16,6 @@
#include "paddle/operators/adamax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
adamax, ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamaxOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -51,14 +51,14 @@ class AdamaxOpKernel : public framework::OpKernel<T> {
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto inf_norm_out =
framework::EigenVector<T>::Flatten(*inf_norm_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = beta1 * moment + (1 - beta1) * grad;
inf_norm_out.device(place) =
moment_out.device(*place) = beta1 * moment + (1 - beta1) * grad;
inf_norm_out.device(*place) =
grad.abs().cwiseMax((beta2 * inf_norm) + epsilon);
auto lr_t = lr / (1 - beta1_pow);
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(place) =
param_out.device(*place) =
param - lr_t.broadcast(m_dsize) * (moment_out / inf_norm_out);
}
};
......
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......
......@@ -135,7 +135,8 @@ The required data format for this layer is one of the following:
};
template <typename T>
class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
class BatchNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
......@@ -318,12 +319,12 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.device_context());
ctx.GetPlace());
}
};
template <typename T>
class BatchNormGradKernel<platform::CPUPlace, T>
class BatchNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -436,8 +437,9 @@ class BatchNormGradKernel<platform::CPUPlace, T>
namespace ops = paddle::operators;
REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUPlace, float>);
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -47,7 +47,8 @@ void ExtractNCWHD(const framework::DDim &dims,
}
template <typename T>
class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -121,11 +122,12 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0);
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, saved_mean, 0);
functor(dev_ctx, saved_variance, 0);
auto handle = ctx.cuda_device_context().cudnn_handle();
auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths.
if (is_test) {
......@@ -171,7 +173,7 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
};
template <typename T>
class BatchNormGradKernel<platform::GPUPlace, T>
class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
......@@ -244,11 +246,12 @@ class BatchNormGradKernel<platform::GPUPlace, T>
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
ctx.cuda_device_context().cudnn_handle(), mode_,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x->template data<T>(), data_desc_, d_y->template data<T>(), data_desc_,
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(),
d_scale->template mutable_data<T>(ctx.GetPlace()),
......@@ -266,8 +269,9 @@ class BatchNormGradKernel<platform::GPUPlace, T>
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
batch_norm,
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::GPUPlace, float>);
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -34,13 +34,13 @@ inline TensorFormat StringToTensorFormat(const std::string& str) {
}
}
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
......
......@@ -159,9 +159,12 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double>);
ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double>);
ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -16,11 +16,15 @@ limitations under the License. */
#include "paddle/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
float>,
ops::BilinearTensorProductKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_CUDA_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, double>);
ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CUDADeviceContext,
double>);
......@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -46,7 +46,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
int out_dim = weight_dims[0];
auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2];
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
......@@ -60,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
}
......@@ -74,7 +75,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -96,8 +97,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto x_mat = EigenMatrix<T>::From(*x);
auto y_mat = EigenMatrix<T>::From(*y);
auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
......@@ -110,18 +111,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
ctx.GetPlace());
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
math::SetConstant<Place, T> set_zero;
math::SetConstant<DeviceContext, T> set_zero;
// Set Output(X@Grad) be zero.
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
set_zero(dev_ctx, d_x, static_cast<T>(0));
}
// Set Output(Y@Grad) be zero.
if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
set_zero(dev_ctx, d_y, static_cast<T>(0));
}
// Caculate the Output(X@Grad) and Output(Y@Grad).
......@@ -137,18 +138,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
batch_size, x_dim, y_dim, 1, y_scale.data<T>(),
weight_i.data<T>(), 1, d_x->data<T>());
math::gemm<DeviceContext, T>(
dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
}
if (d_y) {
x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) *
x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x_scale.data<T>(),
weight_i.data<T>(), 1, d_y->data<T>());
math::gemm<DeviceContext, T>(
dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
}
}
}
......@@ -165,9 +166,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) *
x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
x_dim, y_dim, batch_size, 1, x_scale.data<T>(),
y->data<T>(), 0, d_weight_i.data<T>());
math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim,
y_dim, batch_size, 1, x_scale.data<T>(),
y->data<T>(), 0, d_weight_i.data<T>());
}
}
......
......@@ -68,7 +68,7 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUPlace;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
......
......@@ -16,7 +16,7 @@
template <typename T>
using CastOpKernel =
paddle::operators::CastOpKernel<paddle::platform::GPUPlace, T>;
paddle::operators::CastOpKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_GPU_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>);
......@@ -27,13 +27,13 @@ struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename Place, typename InT>
template <typename DeviceContext, typename InT>
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
const platform::DeviceContext& ctx_;
const DeviceContext& ctx_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const platform::DeviceContext& ctx)
const DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
......@@ -42,13 +42,13 @@ struct CastOpFunctor {
auto numel = in_->numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
platform::Transform<Place> trans;
platform::Transform<DeviceContext> trans;
trans(ctx_, in_begin, in_end, out_begin,
CastOpTransformFunctor<InT, OutT>());
}
};
template <typename Place, typename InT>
template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -56,7 +56,8 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<Place, InT>(in, out, context.device_context()));
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
}
};
......
......@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ChunkEvalKernel : public framework::OpKernel<T> {
public:
struct Segment {
......
......@@ -71,4 +71,5 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -15,5 +15,6 @@
#include "paddle/operators/clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ClipByNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -38,7 +38,8 @@ class ClipByNormKernel : public framework::OpKernel<T> {
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output);
auto x_norm = x.square().sum().sqrt();
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto temp = (x_norm <= max_norm).template cast<T>().eval();
auto scaling = temp + (static_cast<T>(1) - temp) * max_norm / x_norm;
......
......@@ -83,7 +83,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
ops::ClipOpGrad);
REGISTER_OP_CPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -15,7 +15,7 @@
#include "paddle/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip,
ops::ClipKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(clip_grad,
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -55,7 +55,7 @@ class ClipGradFunctor {
T max_;
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -66,13 +66,13 @@ class ClipKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
Transform<Place> trans;
trans(context.device_context(), x_data, x_data + numel, out_data,
ClipFunctor<T>(min, max));
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -86,9 +86,9 @@ class ClipGradKernel : public framework::OpKernel<T> {
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform<Place> trans;
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
d_x_data, ClipGradFunctor<T>(min, max));
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), d_out_data,
d_out_data + numel, x_data, d_x_data, ClipGradFunctor<T>(min, max));
}
}
};
......
......@@ -14,10 +14,10 @@
#include "paddle/operators/compare_op.h"
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, GPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, GPU,
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(greater_than, CUDA,
paddle::operators::GreaterThanFunctor);
REGISTER_LOGICAL_KERNEL(greater_equal, GPU,
REGISTER_LOGICAL_KERNEL(greater_equal, CUDA,
paddle::operators::GreaterEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
......@@ -59,7 +59,7 @@ struct EqualFunctor {
}
};
template <typename Place, typename Functor>
template <typename DeviceContext, typename Functor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
......@@ -69,24 +69,23 @@ class CompareOpKernel
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<Place> trans;
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
binary_func);
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<int>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<int64_t>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<float>>, \
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
functor<double>>);
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
......@@ -14,7 +14,8 @@ limitations under the License. */
#include "paddle/operators/concat_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
concat_grad, ops::ConcatGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -43,7 +43,7 @@ class ConcatKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......
......@@ -57,18 +57,20 @@ REGISTER_OP(conv2d_cudnn, ops::ConvOp, ops::CudnnConv2DOpMaker,
REGISTER_OP(conv3d_cudnn, ops::ConvOp, ops::CudnnConv3DOpMaker,
conv3d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(conv3d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -118,7 +118,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
......@@ -238,7 +239,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
auto handle = ctx.cuda_device_context().cudnn_handle();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
if (input_grad) {
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
......@@ -313,16 +315,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
......@@ -235,16 +235,18 @@ namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
conv2d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);
conv3d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,16 +16,18 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv2d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
conv2d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_GPU_KERNEL(conv3d,
ops::GemmConvKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
conv3d_grad, ops::GemmConvGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvGradKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d, ops::GemmConvKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -72,7 +72,7 @@ class ConvOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class GemmConvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -141,9 +141,10 @@ class GemmConvKernel : public framework::OpKernel<T> {
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<Place, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
auto& dev_ctx = context.template device_context<DeviceContext>();
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -157,27 +158,26 @@ class GemmConvKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(context.device_context(), in_slice, dilations, strides,
im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
false, T(1.0), &out_slice, T(0.0));
}
}
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class GemmConvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -256,14 +256,15 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
}
math::SetConstant<Place, T> set_zero;
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
set_zero(dev_ctx, input_grad, static_cast<T>(0));
math::Col2VolFunctor<Place, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
......@@ -282,18 +283,17 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape);
}
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
if (is_expand && data_dim == 2U) {
col2im(context.device_context(), col, dilations, strides,
col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&in_grad_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(context.device_context(), col, dilations, strides, paddings,
&in_grad_slice);
col2vol(dev_ctx, col, dilations, strides, paddings, &in_grad_slice);
}
}
}
......@@ -303,9 +303,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; i++) {
Tensor out_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_matrix_shape);
......@@ -321,21 +321,20 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(context.device_context(), in_slice, dilations, strides,
im2col(dev_ctx, in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
vol2col(context.device_context(), in_slice, dilations, strides,
paddings, &col);
vol2col(dev_ctx, in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
......
......@@ -111,7 +111,8 @@ __global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width,
} // namespace
template <typename T>
class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
class ConvShiftKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
......@@ -132,7 +133,8 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
dim3 grid_dim(num_x_blocks, batch_size);
auto stream = context.cuda_device_context().stream();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data);
......@@ -140,7 +142,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
};
template <typename T>
class ConvShiftGradKernel<platform::GPUPlace, T>
class ConvShiftGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
......@@ -159,8 +161,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2;
auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::GPUPlace, T> zero;
auto &device_ctx =
context.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
const int x_per_block = 256;
int num_x_blocks = DivUp(x_width, x_per_block);
......@@ -186,8 +189,9 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
conv_shift,
ops::ConvShiftKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
conv_shift_grad,
ops::ConvShiftGradKernel<paddle::platform::GPUPlace, float>);
ops::ConvShiftGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -18,13 +18,13 @@
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ConvShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ConvShiftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
......
......@@ -61,12 +61,13 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
......@@ -74,9 +75,10 @@ REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -83,7 +83,8 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t algo;
auto handle = ctx.cuda_device_context().cudnn_handle();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
// Get the algorithm
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
......@@ -165,7 +166,8 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
auto handle = ctx.cuda_device_context().cudnn_handle();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
if (input_grad) {
// choose backward algorithm for data
PADDLE_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
......@@ -234,16 +236,16 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
......@@ -197,21 +197,23 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -16,20 +16,24 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
double>);
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
conv3d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, double>);
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext,
double>);
......@@ -52,7 +52,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class GemmConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -109,11 +109,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
filter.Resize(filter_matrix_shape);
output->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0));
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, output, static_cast<T>(0));
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Col2VolFunctor<Place, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
......@@ -127,29 +128,27 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
static_cast<T>(1.0), &col_matrix,
static_cast<T>(0.0));
if (data_dim == 2U) {
// col2im: col_matrix -> dy
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
col2im(context.device_context(), col,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
col2im(dev_ctx, col, std::vector<int>{dilations[0], dilations[1]},
strides, std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&output_batch);
} else if (data_dim == 3U) {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), col, dilations, strides, paddings,
&output_batch);
col2vol(dev_ctx, col, dilations, strides, paddings, &output_batch);
}
}
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -206,6 +205,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// convolution transpose grad on input:
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
auto& dev_ctx = context.template device_context<DeviceContext>();
if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -217,19 +217,19 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
col_matrix.Resize(col_matrix_shape);
Tensor filter_grad_;
math::SetConstant<Place, T> set_zero;
math::SetConstant<DeviceContext, T> set_zero;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
math::Vol2ColFunctor<Place, T> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
set_zero(dev_ctx, input_grad, static_cast<T>(0));
}
if (filter_grad) { // filter size (m, c, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
}
......@@ -242,7 +242,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (data_dim == 2U) {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch,
im2col(dev_ctx, output_grad_batch,
std::vector<int>{dilations[0], dilations[1]}, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
......@@ -250,8 +250,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
} else if (data_dim == 3U) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, dilations,
strides, paddings, &col);
vol2col(dev_ctx, output_grad_batch, dilations, strides, paddings,
&col);
}
if (input_grad) {
......@@ -263,9 +263,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
math::matmul<DeviceContext, T>(
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
}
if (filter_grad) {
// input batch
......@@ -275,9 +275,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
}
}
}
......
......@@ -155,7 +155,8 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, cos_sim_grad,
ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL(cos_sim,
ops::CosSimKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::CPUPlace, float>);
cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
cos_sim_grad,
ops::CosSimGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -16,7 +16,8 @@
#include "paddle/operators/cos_sim_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cos_sim,
ops::CosSimKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
cos_sim_grad, ops::CosSimGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
cos_sim, ops::CosSimKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
cos_sim_grad,
ops::CosSimGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -51,7 +51,8 @@ class CosSimKernel : public framework::OpKernel<T> {
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
// compute
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
......@@ -66,7 +67,7 @@ class CosSimKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class CosSimGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -96,7 +97,8 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
......
......@@ -135,5 +135,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp,
ops::CRFDecodingOpMaker);
REGISTER_OP_CPU_KERNEL(
crf_decoding, ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, float>,
ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, double>);
crf_decoding,
ops::CRFDecodingOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CRFDecodingOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -24,7 +24,7 @@ using framework::LoDTensor;
using framework::LoD;
using framework::Tensor;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class CRFDecodingOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -44,8 +44,8 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
const size_t seq_num = lod[level].size() - 1;
int64_t* path = decoded_path->mutable_data<int64_t>(platform::CPUPlace());
math::SetConstant<platform::CPUPlace, int64_t>()(ctx.device_context(),
decoded_path, 0);
math::SetConstant<DeviceContext, int64_t>()(
ctx.template device_context<DeviceContext>(), decoded_path, 0);
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
......
......@@ -133,5 +133,5 @@ class CropOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
crop_grad, ops::CropGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -16,6 +16,6 @@
#include "paddle/operators/crop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_GPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CUDA_KERNEL(
crop_grad, ops::CropGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -49,7 +49,7 @@ class CropKernel : public framework::OpKernel<T> {
}
};
template <typename Place, typename T, size_t D>
template <typename DeviceContext, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
......@@ -63,12 +63,13 @@ void CropGradFunction(const framework::ExecutionContext& context) {
}
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
d_x_tensor.device(context.GetEigenDevice<Place>()) =
d_x_tensor.device(
*context.template device_context<DeviceContext>().eigen_device()) =
d_out_tensor.pad(paddings, 0);
}
}
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class CropGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -76,22 +77,22 @@ class CropGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
CropGradFunction<Place, T, 1>(context);
CropGradFunction<DeviceContext, T, 1>(context);
break;
case 2:
CropGradFunction<Place, T, 2>(context);
CropGradFunction<DeviceContext, T, 2>(context);
break;
case 3:
CropGradFunction<Place, T, 3>(context);
CropGradFunction<DeviceContext, T, 3>(context);
break;
case 4:
CropGradFunction<Place, T, 4>(context);
CropGradFunction<DeviceContext, T, 4>(context);
break;
case 5:
CropGradFunction<Place, T, 5>(context);
CropGradFunction<DeviceContext, T, 5>(context);
break;
case 6:
CropGradFunction<Place, T, 6>(context);
CropGradFunction<DeviceContext, T, 6>(context);
break;
default:
PADDLE_THROW(
......
......@@ -53,8 +53,9 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx.device_context(), y, x, label, ctx.Attr<bool>("soft_label"));
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), y, x, label,
ctx.Attr<bool>("soft_label"));
}
};
......@@ -80,15 +81,17 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
int block = 512;
int grid = (batch_size * class_num + block - 1) / block;
auto stream = ctx.cuda_device_context().stream();
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = dev_ctx.stream();
if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
dx_data, dy_data, x_data, label_data, batch_size, class_num);
} else {
math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), dx, 0);
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, dx, 0);
auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
......@@ -101,8 +104,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
ops::CrossEntropyOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpCUDAKernel<float>,
ops::CrossEntropyGradientOpCUDAKernel<double>);
......@@ -37,8 +37,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
Tensor* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx.device_context(), y, x, labels, ctx.Attr<bool>("soft_label"));
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
ctx.template device_context<platform::CPUDeviceContext>(), y, x, labels,
ctx.Attr<bool>("soft_label"));
}
};
......@@ -61,7 +62,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
auto lbl_mat = EigenMatrix<T>::From(*label);
auto dx_mat = EigenMatrix<T>::From(*dx);
dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
dx_mat.device(*ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device()) =
-(lbl_mat *
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
} else {
......@@ -70,8 +72,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>();
const int64_t* label_data = label->data<int64_t>();
math::SetConstant<platform::CPUPlace, T> functor;
functor(ctx.device_context(), dx, 0);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(ctx.template device_context<platform::CPUDeviceContext>(), dx, 0);
for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
......
......@@ -99,4 +99,4 @@ REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp,
ops::DecayedAdagradOpMaker);
REGISTER_OP_CPU_KERNEL(
decayed_adagrad,
ops::DecayedAdagradOpKernel<paddle::platform::CPUPlace, float>);
ops::DecayedAdagradOpKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -16,6 +16,6 @@
#include "paddle/operators/decayed_adagrad_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
decayed_adagrad,
ops::DecayedAdagradOpKernel<paddle::platform::GPUPlace, float>);
ops::DecayedAdagradOpKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -43,7 +43,7 @@ class DecayedAdagradOpKernel : public framework::OpKernel<T> {
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(place) = decay * moment + (1 - decay) * grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
......
......@@ -100,6 +100,8 @@ namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float, float>);
dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>);
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -58,7 +58,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
......@@ -80,7 +80,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float, float>);
REGISTER_OP_GPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
dropout,
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType>
template <typename DeviceContext, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -55,13 +55,14 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * dropout_prob;
}
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -77,7 +78,8 @@ class DropoutGradKernel : public framework::OpKernel<T> {
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
dX.device(place) = dY * M;
}
};
......
......@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,15 +17,16 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, int64_t>);
REGISTER_OP_GPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, int64_t>);
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -24,7 +24,7 @@ struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<AddFunctor<T>, T, Place> functor(
x, y, z, ctx.device_context(), AddFunctor<T>());
TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
x, y, z, ctx.template device_context<DeviceContext>(), AddFunctor<T>());
auto x_dims = x->dims();
auto y_dims = y->dims();
......@@ -137,11 +137,11 @@ struct ElementwiseAddBroadCast2GradFunctor {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>,
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddOneGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
......
......@@ -35,13 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,15 +17,16 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, int64_t>);
REGISTER_OP_GPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, int64_t>);
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -19,11 +19,11 @@
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx);
ElementwiseCompute<EigenDivFunctor, DeviceContext, T>(ctx);
}
};
......@@ -102,11 +102,11 @@ struct ElementwiseDivBroadCast2GradFunctor {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>,
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
......
......@@ -36,13 +36,13 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,15 +17,16 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, int64_t>);
REGISTER_OP_GPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, int64_t>);
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -18,11 +18,11 @@
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx);
ElementwiseCompute<EigenMulFunctor, DeviceContext, T>(ctx);
}
};
......@@ -101,11 +101,11 @@ struct ElementwiseMulBroadCast2GradFunctor {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>,
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
......
......@@ -59,17 +59,17 @@ inline void get_mid_dims(const framework::DDim& x_dims,
}
}
template <typename T, typename Place>
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T, typename Place>
template <typename T, typename DeviceContext>
class MidWiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUPlace> {
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
......@@ -77,13 +77,13 @@ class RowwiseTransformIterator<T, platform::CPUPlace> {
return *this;
}
bool operator==(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
return (ptr_ + i_) != &(*rhs);
}
......@@ -96,12 +96,12 @@ class RowwiseTransformIterator<T, platform::CPUPlace> {
};
template <typename T>
class MidWiseTransformIterator<T, platform::CPUPlace> {
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
public:
MidWiseTransformIterator(const T* ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
++j_;
i_ = j_ / post_;
if (UNLIKELY(i_ == n_)) {
......@@ -111,13 +111,13 @@ class MidWiseTransformIterator<T, platform::CPUPlace> {
return *this;
}
bool operator==(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
return (ptr_ + i_) != &(*rhs);
}
......@@ -133,12 +133,12 @@ class MidWiseTransformIterator<T, platform::CPUPlace> {
#ifdef __NVCC__
template <typename T>
class RowwiseTransformIterator<T, platform::GPUPlace>
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){};
......@@ -153,12 +153,12 @@ class RowwiseTransformIterator<T, platform::GPUPlace>
};
template <typename T>
class MidWiseTransformIterator<T, platform::GPUPlace>
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){};
......@@ -174,12 +174,11 @@ class MidWiseTransformIterator<T, platform::GPUPlace>
};
#endif
template <typename Functor, typename T, typename Place>
template <typename Functor, typename T, typename DeviceContext>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const platform::DeviceContext& ctx,
Functor func)
framework::Tensor* z, const DeviceContext& ctx, Functor func)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())),
......@@ -188,20 +187,20 @@ class TransformFunctor {
func_(func) {}
inline void Run() const {
platform::Transform<Place> trans;
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
}
inline void RunRowWise(int n, int pre) const {
platform::Transform<Place> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
func_);
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_, func_);
}
inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<Place> trans;
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
z_, func_);
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
}
private:
......@@ -209,22 +208,24 @@ class TransformFunctor {
const T* y_;
T* z_;
int64_t nx_;
const platform::DeviceContext& ctx_;
const DeviceContext& ctx_;
Functor func_;
};
#define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \
template <typename Place, typename T> \
template <typename DeviceContext, typename T> \
inline void Run(const framework::Tensor* x, const framework::Tensor* y, \
framework::Tensor* z, \
const framework::ExecutionContext& ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_e); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_e); \
} \
template <typename Place, typename T> \
template <typename DeviceContext, typename T> \
inline void RunBroadCast(const framework::Tensor* x, \
const framework::Tensor* y, framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \
......@@ -235,9 +236,11 @@ class TransformFunctor {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) \
.broadcast(Eigen::DSizes<int, 2>(pre, 1)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \
template <typename Place, typename T> \
template <typename DeviceContext, typename T> \
inline void RunBroadCast2(const framework::Tensor* x, \
const framework::Tensor* y, \
framework::Tensor* z, \
......@@ -249,11 +252,13 @@ class TransformFunctor {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) \
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \
}
template <class functor, typename Place, typename T>
template <class functor, typename DeviceContext, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
......@@ -269,7 +274,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
if (x_dims == y_dims) {
functor f;
f.template Run<Place, T>(x, y, z, ctx);
f.template Run<DeviceContext, T>(x, y, z, ctx);
return;
}
......@@ -282,11 +287,11 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor f;
f.template RunBroadCast<Place, T>(x, y, z, ctx, pre, n);
f.template RunBroadCast<DeviceContext, T>(x, y, z, ctx, pre, n);
return;
} else {
functor f;
f.template RunBroadCast2<Place, T>(x, y, z, ctx, pre, n, post);
f.template RunBroadCast2<DeviceContext, T>(x, y, z, ctx, pre, n, post);
return;
}
}
......@@ -303,8 +308,9 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename Place, typename T, typename functor, typename functor1,
typename broadcastfunctor, typename broadcast2functor>
template <typename DeviceContext, typename T, typename functor,
typename functor1, typename broadcastfunctor,
typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
......@@ -313,7 +319,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims();
auto y_dims = y->dims();
......
......@@ -34,13 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int64_t>);
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,15 +17,16 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, int64_t>);
REGISTER_OP_GPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, double>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int>,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, int64_t>);
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -18,11 +18,11 @@
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx);
ElementwiseCompute<EigenSubFunctor, DeviceContext, T>(ctx);
}
};
......@@ -101,11 +101,11 @@ struct ElementwiseSubBroadCast2GradFunctor {
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>,
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubOneGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
......
......@@ -130,7 +130,8 @@ class ExpandGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(expand, ops::ExpandOp, ops::ExpandOpMaker, expand_grad,
ops::ExpandGradOp);
REGISTER_OP_CPU_KERNEL(expand,
ops::ExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
expand_grad, ops::ExpandGradKernel<paddle::platform::CPUPlace, float>);
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -17,7 +17,8 @@
#include "paddle/operators/expand_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(expand,
ops::ExpandKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
expand_grad, ops::ExpandGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -56,7 +56,7 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ExpandKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -83,12 +83,13 @@ class ExpandKernel : public framework::OpKernel<T> {
auto x = EigenTensor<T, Rank>::From(*in0);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0);
auto place = context.GetEigenDevice<Place>();
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
}
};
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class ExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -164,7 +165,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(context.GetEigenDevice<Place>()) =
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
}
};
......
......@@ -100,8 +100,11 @@ REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace, int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUPlace,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -16,10 +16,13 @@
#include "paddle/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace, int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::GPUPlace,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
float>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
double>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int>,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -27,8 +27,9 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value");
math::SetConstant<Place, T> setter;
setter(ctx.device_context(), out, static_cast<T>(value));
math::SetConstant<DeviceContext, T> setter;
setter(ctx.template device_context<DeviceContext>(), out,
static_cast<T>(value));
}
};
......
......@@ -54,8 +54,9 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, bool>);
fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
......@@ -16,9 +16,10 @@
#include "paddle/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_zeros_like, ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, double>,
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -19,15 +19,16 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class FillZerosLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Y");
out->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> setter;
setter(context.device_context(), out, static_cast<T>(0));
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), out,
static_cast<T>(0));
}
};
......
......@@ -135,5 +135,5 @@ The paper that proposed Follow The Regularized Leader (FTRL):
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ftrl, ops::FTRLOp, ops::FTRLOpMaker);
REGISTER_OP_CPU_KERNEL(ftrl,
ops::FTRLOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
ftrl, ops::FTRLOpKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -15,5 +15,5 @@ specific language governing permissions and limitations under the License. */
#include "paddle/operators/ftrl_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ftrl,
ops::FTRLOpKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_CUDA_KERNEL(
ftrl, ops::FTRLOpKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -24,7 +24,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class FTRLOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -53,7 +53,7 @@ class FTRLOpKernel : public framework::OpKernel<T> {
auto p_out = EigenVector<T>::Flatten(*param_out);
auto s_acc_out = EigenVector<T>::Flatten(*sq_accum_out);
auto l_acc_out = EigenVector<T>::Flatten(*lin_accum_out);
auto place = ctx.GetEigenDevice<Place>();
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Place;
using platform::DeviceContext;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
......
......@@ -49,7 +49,8 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
......@@ -60,5 +61,5 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>);
......@@ -53,7 +53,8 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册