提交 3a5693e0 编写于 作者: Y Yu Yang

Add Skeleton of Double support

上级 d2edbe57
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <typeindex>
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
inline DataType ToDataType(std::type_index type) {
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}
} // namespace framework
} // namespace paddle
...@@ -104,8 +104,9 @@ template <typename PlaceType, typename KernelType> ...@@ -104,8 +104,9 @@ template <typename PlaceType, typename KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
explicit OpKernelRegistrar(const char* op_type) { explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key; using T = typename KernelType::ELEMENT_TYPE;
key.place_ = PlaceType(); OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType); OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
} }
}; };
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "op_info.h" #include "op_info.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -407,7 +408,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
const Scope& scope_; const Scope& scope_;
}; };
class OpKernel { class OpKernelBase {
public: public:
/** /**
* ExecutionContext is the only parameter of Kernel Run function. * ExecutionContext is the only parameter of Kernel Run function.
...@@ -418,33 +419,47 @@ class OpKernel { ...@@ -418,33 +419,47 @@ class OpKernel {
virtual void Compute(const ExecutionContext& context) const = 0; virtual void Compute(const ExecutionContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernelBase() = default;
};
template <typename T>
class OpKernel : public OpKernelBase {
public:
using ELEMENT_TYPE = T;
}; };
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
DataType data_type_;
OpKernelKey() = default; OpKernelKey(DataType data_type, platform::Place place)
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { : place_(place), data_type_(data_type) {}
place_ = dev_ctx.GetPlace();
} OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelKey& o) const { bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_); return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
} }
}; };
struct OpKernelHash { struct OpKernelHash {
std::hash<bool> hash_; std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const { size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_)); int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
// NOTE: Number of places limit to 16.
int pre_hash = data_type << 4 | (place & 0x0F);
return hash_(pre_hash);
} }
}; };
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
...@@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -458,8 +473,10 @@ class OperatorWithKernel : public OperatorBase {
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); ExecutionContext ctx(*this, scope, dev_ctx);
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
...@@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase { ...@@ -469,13 +486,43 @@ class OperatorWithKernel : public OperatorBase {
} }
bool SupportGPU() const override { bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key; auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
key.place_ = platform::GPUPlace(); return std::any_of(op_kernels.begin(), op_kernels.end(),
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0; [](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
} }
protected: protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0; virtual void InferShape(InferShapeContextBase* ctx) const = 0;
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
}; };
} // namespace framework } // namespace framework
......
...@@ -29,20 +29,10 @@ limitations under the License. */ ...@@ -29,20 +29,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind
namespace framework { namespace framework {
class Tensor { class Tensor {
public: public:
template <bool less, size_t i, typename... args>
friend struct pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, int MajorType, typename IndexType> template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
...@@ -119,6 +109,8 @@ class Tensor { ...@@ -119,6 +109,8 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { return holder_->type(); }
private: private:
template <typename T> template <typename T>
inline void check_memory_size() const; inline void check_memory_size() const;
......
...@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, ...@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
} }
template <typename T> template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel { class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel { class AccuracyKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference"); auto* inference = ctx.Input<Tensor>("Inference");
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel { class ActivationKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel { ...@@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel { class ActivationGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -202,7 +202,7 @@ struct SquareGradFunctor { ...@@ -202,7 +202,7 @@ struct SquareGradFunctor {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel { class BReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel { ...@@ -219,7 +219,7 @@ class BReluKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel { class BReluGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel { ...@@ -239,7 +239,7 @@ class BReluGradKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel { class SoftReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel { ...@@ -256,7 +256,7 @@ class SoftReluKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel { class SoftReluGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel { ...@@ -277,7 +277,7 @@ class SoftReluGradKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel { class PowKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel { ...@@ -293,7 +293,7 @@ class PowKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel { class PowGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel { ...@@ -312,7 +312,7 @@ class PowGradKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel { class STanhKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
...@@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel { ...@@ -329,7 +329,7 @@ class STanhKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel { class STanhGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
......
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X"); auto* input0 = context.Input<Tensor>("X");
......
...@@ -56,7 +56,7 @@ class ClipGradFunctor { ...@@ -56,7 +56,7 @@ class ClipGradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ClipKernel : public framework::OpKernel { class ClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
...@@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel { class ClipGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ConcatKernel : public framework::OpKernel { class ConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel { class CosSimKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// get Tensor // get Tensor
...@@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel { ...@@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel { class CosSimGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// get Tensor // get Tensor
......
...@@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; ...@@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
class CropKernel : public framework::OpKernel { class CropKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel { class CropGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
......
...@@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class CrossEntropyGradientOp : public framework::OperatorWithKernel { class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...@@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
} }
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, ...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
} // namespace } // namespace
template <typename T> template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel { class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -69,7 +69,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -69,7 +69,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class CrossEntropyOpKernel : public framework::OpKernel { class CrossEntropyOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
...@@ -42,7 +42,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -42,7 +42,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class CrossEntropyGradientOpKernel : public framework::OpKernel { class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......
...@@ -47,7 +47,7 @@ struct MaskGenerator { ...@@ -47,7 +47,7 @@ struct MaskGenerator {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename Place, typename T, typename AttrType> template <typename Place, typename T, typename AttrType>
class GPUDropoutKernel : public framework::OpKernel { class GPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
......
...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType> template <typename Place, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel { class CPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel { ...@@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel { class DropoutGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"), PADDLE_ENFORCE(context.Attr<bool>("is_training"),
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseAddKernel : public framework::OpKernel { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx); ElementwiseCompute<EigenAddFunctor, Place, T>(ctx);
...@@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor { ...@@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>,
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseDivKernel : public framework::OpKernel { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx); ElementwiseCompute<EigenDivFunctor, Place, T>(ctx);
...@@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor { ...@@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel { class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>,
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseMulKernel : public framework::OpKernel { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx); ElementwiseCompute<EigenMulFunctor, Place, T>(ctx);
...@@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor { ...@@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel { class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>,
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseSubKernel : public framework::OpKernel { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx); ElementwiseCompute<EigenSubFunctor, Place, T>(ctx);
...@@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor { ...@@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel { class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>,
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>("Y"); auto* output = context.Output<framework::Tensor>("Y");
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel { class GatherOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
...@@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel { ...@@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel { class GatherGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
......
...@@ -16,7 +16,7 @@ namespace paddle { ...@@ -16,7 +16,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean"); float mean = context.Attr<float>("mean");
......
...@@ -37,7 +37,7 @@ struct GaussianGenerator { ...@@ -37,7 +37,7 @@ struct GaussianGenerator {
}; };
template <typename T> template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel { class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2DKernel : public framework::OpKernel { class GemmConv2DKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvGrad2DKernel : public framework::OpKernel { class GemmConvGrad2DKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
......
...@@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, ...@@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
} }
template <typename T> template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel { class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); auto table_t = context.Input<Tensor>("W");
...@@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { ...@@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel { class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<Tensor>("Ids");
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor auto table_t = context.Input<Tensor>("W"); // float tensor
...@@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel { ...@@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class LookupTableGradKernel : public framework::OpKernel { class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<Tensor>("Ids");
......
...@@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, ...@@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
} }
template <typename T, typename AttrType = T> template <typename T, typename AttrType = T>
class LstmUnitOpCUDAKernel : public framework::OpKernel { class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel { ...@@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel {
}; };
template <typename T, typename AttrType = T> template <typename T, typename AttrType = T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel { class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -33,7 +33,7 @@ inline T tanh(T x) { ...@@ -33,7 +33,7 @@ inline T tanh(T x) {
} }
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class LstmUnitKernel : public framework::OpKernel { class LstmUnitKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
...@@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel { ...@@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class LstmUnitGradKernel : public framework::OpKernel { class LstmUnitGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MeanKernel : public framework::OpKernel { class MeanKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X"); auto* input = context.Input<Tensor>("X");
...@@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel { ...@@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MeanGradKernel : public framework::OpKernel { class MeanGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out")); auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MinusKernel : public framework::OpKernel { class MinusKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* left_tensor = context.Input<framework::Tensor>("X"); auto* left_tensor = context.Input<framework::Tensor>("X");
......
...@@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward { ...@@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward {
}; };
template <typename T> template <typename T>
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel { class ModifiedHuberLossGradGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y"); auto* in0 = context.Input<Tensor>("Y");
......
...@@ -47,7 +47,7 @@ struct ModifiedHuberLossForward { ...@@ -47,7 +47,7 @@ struct ModifiedHuberLossForward {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ModifiedHuberLossKernel : public framework::OpKernel { class ModifiedHuberLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
// CPU backward kernel // CPU backward kernel
template <typename T> template <typename T>
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { class ModifiedHuberLossGradCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y"); auto* in0 = context.Input<Tensor>("Y");
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X"); const Tensor* x = context.Input<Tensor>("X");
...@@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel { ...@@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel { class MultiplexGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
...@@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel { ...@@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel { class MultiplexGradGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexCPUKernel : public framework::OpKernel { class MultiplexCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
...@@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGradCPUKernel : public framework::OpKernel { class MultiplexGradCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
...@@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) { ...@@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class PadKernel : public framework::OpKernel { class PadKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
...@@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) { ...@@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class PadGradKernel : public framework::OpKernel { class PadGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
......
...@@ -40,7 +40,7 @@ class PReluFunctor { ...@@ -40,7 +40,7 @@ class PReluFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class PReluKernel : public framework::OpKernel { class PReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -77,7 +77,7 @@ class PReluGradFunctor { ...@@ -77,7 +77,7 @@ class PReluGradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class PReluGradKernel : public framework::OpKernel { class PReluGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RankLossKernel : public framework::OpKernel { class RankLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel { ...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class RankLossGradKernel : public framework::OpKernel { class RankLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_left_t = auto* d_left_t =
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ReshapeKernel : public framework::OpKernel { class ReshapeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
...@@ -39,7 +39,7 @@ class ReshapeKernel : public framework::OpKernel { ...@@ -39,7 +39,7 @@ class ReshapeKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ReshapeGradKernel : public framework::OpKernel { class ReshapeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class RowwiseAddKernel : public framework::OpKernel { class RowwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out"); auto out = context.Output<Tensor>("Out");
...@@ -50,7 +50,7 @@ class RowwiseAddKernel : public framework::OpKernel { ...@@ -50,7 +50,7 @@ class RowwiseAddKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel { class RowwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class ScaleKernel : public framework::OpKernel { class ScaleKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& context) const { virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class ScatterOpKernel : public framework::OpKernel { class ScatterOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *Ref = ctx.Input<Tensor>("Ref"); auto *Ref = ctx.Input<Tensor>("Ref");
...@@ -40,7 +40,7 @@ class ScatterOpKernel : public framework::OpKernel { ...@@ -40,7 +40,7 @@ class ScatterOpKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ScatterGradientOpKernel : public framework::OpKernel { class ScatterGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
......
...@@ -38,7 +38,7 @@ enum SeqPoolType { ...@@ -38,7 +38,7 @@ enum SeqPoolType {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequencePoolKernel : public framework::OpKernel { class SequencePoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
...@@ -85,7 +85,7 @@ class SequencePoolKernel : public framework::OpKernel { ...@@ -85,7 +85,7 @@ class SequencePoolKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequencePoolGradKernel : public framework::OpKernel { class SequencePoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
......
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
......
...@@ -45,7 +45,7 @@ struct SmoothL1LossForward { ...@@ -45,7 +45,7 @@ struct SmoothL1LossForward {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossKernel : public framework::OpKernel { class SmoothL1LossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -115,7 +115,7 @@ struct SmoothL1LossBackward { ...@@ -115,7 +115,7 @@ struct SmoothL1LossBackward {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossGradKernel : public framework::OpKernel { class SmoothL1LossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("InsideWeight"); auto* in0 = context.Input<Tensor>("InsideWeight");
......
...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel { class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto X = context.Input<Tensor>("X"); auto X = context.Input<Tensor>("X");
...@@ -40,7 +40,7 @@ class SoftmaxKernel : public framework::OpKernel { ...@@ -40,7 +40,7 @@ class SoftmaxKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxGradKernel : public framework::OpKernel { class SoftmaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto Y = context.Input<Tensor>("Y"); auto Y = context.Input<Tensor>("Y");
......
...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, ...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
} // namespace } // namespace
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
...@@ -73,7 +73,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
......
...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
...@@ -47,7 +47,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { ...@@ -47,7 +47,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* out_grad = const Tensor* out_grad =
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SplitKernel : public framework::OpKernel { class SplitKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SquaredL2DistanceKernel : public framework::OpKernel { class SquaredL2DistanceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -68,7 +68,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel { ...@@ -68,7 +68,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SquaredL2DistanceGradKernel : public framework::OpKernel { class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("sub_result"); auto* in0 = context.Input<Tensor>("sub_result");
......
...@@ -22,7 +22,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -22,7 +22,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SumKernel : public framework::OpKernel { class SumKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ins = context.MultiInput<Tensor>("X"); auto ins = context.MultiInput<Tensor>("X");
...@@ -43,7 +43,7 @@ class SumKernel : public framework::OpKernel { ...@@ -43,7 +43,7 @@ class SumKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel { class SumGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out")); auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -279,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int* indices, ...@@ -279,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
} }
template <typename T> template <typename T>
class TopkOpCUDAKernel : public framework::OpKernel { class TopkOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class TopkKernel : public framework::OpKernel { class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
......
...@@ -38,7 +38,7 @@ void EigenTranspose(const framework::ExecutionContext& context, ...@@ -38,7 +38,7 @@ void EigenTranspose(const framework::ExecutionContext& context,
} }
template <typename Place, typename T> template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel { class TransposeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
...@@ -73,7 +73,7 @@ class TransposeKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class TransposeKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel { class TransposeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad = auto* out_grad =
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename T> template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel { class CPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* tensor = ctx.Output<framework::Tensor>("Out"); auto* tensor = ctx.Output<framework::Tensor>("Out");
......
...@@ -40,7 +40,7 @@ struct UniformGenerator { ...@@ -40,7 +40,7 @@ struct UniformGenerator {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename T> template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel { class GPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -47,7 +47,7 @@ bool is_cpu_place(const Place &p) { ...@@ -47,7 +47,7 @@ bool is_cpu_place(const Place &p) {
} }
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
return is_gpu_place(p1) == is_gpu_place(p2); return p1.which() == p2.which();
} }
std::ostream &operator<<(std::ostream &os, const Place &p) { std::ostream &operator<<(std::ostream &os, const Place &p) {
......
...@@ -42,7 +42,7 @@ template <size_t I, typename... ARGS> ...@@ -42,7 +42,7 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
std::vector<size_t> strides; std::vector<size_t> strides;
...@@ -56,13 +56,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -56,13 +56,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
framework::Tensor dst_tensor; framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) { if (paddle::platform::is_gpu_place(tensor.place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace()); dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { } else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor; dst_tensor = tensor;
} }
return py::buffer_info( return py::buffer_info(
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()), dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.place()),
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(), sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册