未验证 提交 5593858d 编写于 作者: D dzhwinter 提交者: GitHub

Feature/use cudnn (#7141)

* "add c++ side kernel selection"

* "add multiple kernel op test"

* "kernel selection only support cudnn"

* "better formatter"

* "small fix with UseCPU"

* "depends on change interface Get(Place, Library)"

* "fix CI"

* "fix python cudnn test"

* "leave the register cudnn op to another PR"

* "fix CI"

* "use all kernel by default"

* "fix CI"
上级 59116442
...@@ -73,8 +73,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry ...@@ -73,8 +73,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece operator)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
...@@ -37,6 +37,28 @@ auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), ...@@ -37,6 +37,28 @@ auto KernelNHWC = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(), auto KernelNCHW = OpKernelType(proto::DataType::FP64, platform::CPUPlace(),
DataLayout::kNCHW, LibraryType::kPlain); DataLayout::kNCHW, LibraryType::kPlain);
// TODO(dzhwinter): Only for testing multiple op kernel.
// Dummy transform function for library_type
// should be removed.
auto KernelPlain = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kPlain);
auto KernelCUDNN = OpKernelType(proto::DataType::FP32, platform::CUDAPlace(0),
DataLayout::kAnyLayout, LibraryType::kCUDNN);
void DummyTrans(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!.");
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!");
auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>();
*dst = src;
}
void TransDataType(const platform::DeviceContext* ctx, void TransDataType(const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in, const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) { Variable* out) {
...@@ -121,6 +143,8 @@ std::vector<int> NCHW2NHWC = {0, 2, 3, 1}; ...@@ -121,6 +143,8 @@ std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
} }
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
REGISTER_DATA_TRANSFORM_FN(f::KernelPlain, f::KernelCUDNN, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelCUDNN, f::KernelPlain, f::DummyTrans);
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW,
std::bind(f::TransDataLayout, NHWC2NCHW, std::bind(f::TransDataLayout, NHWC2NCHW,
std::placeholders::_1, std::placeholders::_1,
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/framework/init.h" #include "paddle/framework/init.h"
#include "paddle/framework/operator.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/string/piece.h" #include "paddle/string/piece.h"
...@@ -24,7 +25,6 @@ namespace framework { ...@@ -24,7 +25,6 @@ namespace framework {
std::once_flag gflags_init_flag; std::once_flag gflags_init_flag;
// TODO(qijun) move init gflags to init.cc
void InitGflags(std::vector<std::string> &argv) { void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
int argc = argv.size(); int argc = argv.size();
...@@ -72,6 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -72,6 +72,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
LOG(WARNING) << "Not specified CPU device, create CPU by Default."; LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
} }
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
framework::UseALL();
return true; return true;
} }
......
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/op_registry.h" #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
namespace pd = paddle::framework; namespace pd = paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
using OperatorBase::OperatorBase; using OperatorBase::OperatorBase;
...@@ -252,7 +255,6 @@ TEST(OperatorRegistrar, CPU) { ...@@ -252,7 +255,6 @@ TEST(OperatorRegistrar, CPU) {
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
} }
#ifdef PADDLE_WITH_CUDA
TEST(OperatorRegistrar, CUDA) { TEST(OperatorRegistrar, CUDA) {
paddle::framework::proto::OpDesc op_desc; paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0); paddle::platform::CUDAPlace cuda_place(0);
...@@ -263,4 +265,131 @@ TEST(OperatorRegistrar, CUDA) { ...@@ -263,4 +265,131 @@ TEST(OperatorRegistrar, CUDA) {
op->Run(scope, cuda_place); op->Run(scope, cuda_place);
} }
#endif
static int op_test_value = 0;
using paddle::platform::DeviceContext;
using paddle::platform::CPUDeviceContext;
using paddle::platform::CUDADeviceContext;
namespace paddle {
namespace framework {
class OpWithMultiKernelTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(InferShapeContext* ctx) const override {}
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
++op_test_value;
}
};
template <typename T>
class OpMultiKernelTest<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
--op_test_value;
}
};
template <typename DeviceContext, typename T>
class OpMultiKernelTest2 : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const;
};
template <typename T>
class OpMultiKernelTest2<CPUDeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value += 10;
}
};
template <typename T>
class OpMultiKernelTest2<CUDADeviceContext, T>
: public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const {
op_test_value -= 10;
}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(op_with_multi_kernel,
paddle::framework::OpWithMultiKernelTest,
paddle::framework::OpKernelTestMaker);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CPU, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, MKLDNN, paddle::platform::CPUPlace,
paddle::framework::OpMultiKernelTest2<CPUDeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDA, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest<CUDADeviceContext, float>);
REGISTER_OP_KERNEL(
op_with_multi_kernel, CUDNN, paddle::platform::CUDAPlace,
paddle::framework::OpMultiKernelTest2<CUDADeviceContext, float>);
TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle::framework::proto::OpDesc op_desc;
paddle::platform::CUDAPlace cuda_place(0);
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
op_desc.set_type("op_with_multi_kernel");
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
// remove cuda kernels
paddle::framework::UseCPU();
op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -9);
// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20);
}
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
...@@ -25,6 +26,53 @@ limitations under the License. */ ...@@ -25,6 +26,53 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
void UseCPU() {
kKernelPriority.clear();
/*Plain CPU*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
void UseMKLDNN() {
UseCPU();
#if PADDLE_WITH_MKLML
{
/*MKLDNN Kernel*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseCUDA() {
UseMKLDNN();
#if PADDLE_WITH_CUDA
/*Plain GPU*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
#endif
}
void UseCUDNN() {
UseCUDA();
#if PADDLE_WITH_CUDA
if (platform::dynload::HasCUDNN()) {
/*CUDNN Kernel*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseALL() {
UseCPU();
UseMKLDNN();
UseCUDA();
UseCUDNN();
}
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL, PADDLE_ENFORCE_LE(ins.size(), 1UL,
...@@ -402,6 +450,12 @@ const platform::DeviceContext* GetDeviceContext( ...@@ -402,6 +450,12 @@ const platform::DeviceContext* GetDeviceContext(
} }
} }
const platform::DeviceContext* GetDeviceContext(
const framework::OpKernelType& kernel) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
return pool.Get(kernel.place_);
}
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
...@@ -422,13 +476,8 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -422,13 +476,8 @@ void OperatorWithKernel::Run(const Scope& scope,
ExecutionContext ctx(*this, scope, *dev_ctx); ExecutionContext ctx(*this, scope, *dev_ctx);
auto actual_kernel_key = GetActualKernelType(ctx); auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) { auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
if (actual_kernel_key == expected_kernel_key) { if (actual_kernel_key == expected_kernel_key) {
PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_, PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
...@@ -436,9 +485,24 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -436,9 +485,24 @@ void OperatorWithKernel::Run(const Scope& scope,
"CPU and other devices. For example, multi-GPU model " "CPU and other devices. For example, multi-GPU model "
"parallelism will failed."); "parallelism will failed.");
} else { } else {
// find the best key candidate
const DataTransformFnMap& trans_map = DataTransformFnMap::Instance();
for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(actual_kernel_key.data_type_, std::get<0>(candidate),
actual_kernel_key.data_layout_, std::get<1>(candidate));
auto candidate_pair = std::make_pair(actual_kernel_key, candidate_key);
if ((actual_kernel_key == candidate_key) ||
(kernels.count(candidate_key) &&
trans_map.GetNullable(candidate_pair))) {
expected_kernel_key = candidate_key;
break;
}
}
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun = const DataTransformFn* trans_fun = trans_map.GetNullable(kernel_pair);
DataTransformFnMap::Instance().GetNullable(kernel_pair);
if (trans_fun) { if (trans_fun) {
auto input_vars = this->InputVars(); auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed // TODO(qijun) filter the input vars that do not need to be transformed
...@@ -471,7 +535,20 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -471,7 +535,20 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
} }
kernel_iter->second->Compute(ctx); VLOG(10) << "Actual kernel: " << actual_kernel_key
<< "Expected kernel: " << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
auto* expected_dev_ctx = GetDeviceContext(expected_kernel_key);
ExecutionContext expected_ctx(*this, scope, *expected_dev_ctx);
kernel_iter->second->Compute(expected_ctx);
} }
OpKernelType OperatorWithKernel::GetActualKernelType( OpKernelType OperatorWithKernel::GetActualKernelType(
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <string> #include <string>
#include <tuple>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -52,10 +53,33 @@ constexpr char kGradVarSuffix[] = "@GRAD"; ...@@ -52,10 +53,33 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
// define some kernel hint // define some kernel priority
const std::string kUseCPU = "use_cpu"; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn"; /**
* @brief Use cpu kernel only
*/
void UseCPU();
/**
* @brief Perfer MKLDNN kernel than Plain CPU kernel
*/
void UseMKLDNN();
/**
* @brief Perfer CUDA kernel than Plain CPU kernel
*/
void UseCUDA();
/**
* @brief Perfer cudnn kernel than Plain CUDA kernel
*/
void UseCUDNN();
/**
* @brief Use all available kernels
*/
void UseALL();
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
......
...@@ -315,10 +315,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -315,10 +315,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_KERNEL(conv2d, CUDNN, paddle::platform::CUDAPlace, // TODO(dzhwinter) : below register should be removed
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn, REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>, paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>); paddle::operators::CudnnConvOpKernel<double>);
......
...@@ -62,12 +62,25 @@ class ConvOp : public framework::OperatorWithKernel { ...@@ -62,12 +62,25 @@ class ConvOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override; void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& kernel) const override {
return framework::OpKernelType(kernel.data_type_, platform::CUDAPlace(0),
kernel.data_layout_,
framework::LibraryType::kCUDNN);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -23,11 +23,6 @@ void BindConstValue(pybind11::module& m) { ...@@ -23,11 +23,6 @@ void BindConstValue(pybind11::module& m) {
m.def("kTempVarName", [] { return framework::kTempVarName; }); m.def("kTempVarName", [] { return framework::kTempVarName; });
m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
// for kernel_hint key
m.def("kUseCPU", [] { return framework::kUseCPU; });
m.def("kUseCUDNN", [] { return framework::kUseCUDNN; });
m.def("kUseMKLDNN", [] { return framework::kUseMKLDNN; });
} }
} // namespace pybind } // namespace pybind
......
...@@ -430,6 +430,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -430,6 +430,12 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_glog", framework::InitGLOG); m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices); m.def("init_devices", &framework::InitDevices);
m.def("use_cpu", framework::UseCPU);
m.def("use_mkldnn", framework::UseMKLDNN);
m.def("use_cuda", framework::UseCUDA);
m.def("use_cudnn", framework::UseCUDNN);
m.def("use_all", framework::UseALL);
m.def("is_compile_gpu", IsCompileGPU); m.def("is_compile_gpu", IsCompileGPU);
m.def("set_feed_variable", framework::SetFeedVariable); m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable); m.def("get_fetch_variable", framework::GetFetchVariable);
......
...@@ -17,10 +17,6 @@ TEMP_VAR_NAME = core.kTempVarName() ...@@ -17,10 +17,6 @@ TEMP_VAR_NAME = core.kTempVarName()
GRAD_VAR_SUFFIX = core.kGradVarSuffix() GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
USE_CPU = core.kUseCPU()
USE_CUDNN = core.kUseMKLDNN()
USE_MKLDNN = core.kUseMKLDNN()
def grad_var_name(var_name): def grad_var_name(var_name):
""" """
......
import unittest import unittest
import numpy as np import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -47,6 +49,7 @@ def conv2d_forward_naive(input, filter, group, conv_param): ...@@ -47,6 +49,7 @@ def conv2d_forward_naive(input, filter, group, conv_param):
class TestConv2dOp(OpTest): class TestConv2dOp(OpTest):
def setUp(self): def setUp(self):
core.use_cuda()
self.init_op_type() self.init_op_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
...@@ -167,26 +170,31 @@ class TestWithDilation(TestConv2dOp): ...@@ -167,26 +170,31 @@ class TestWithDilation(TestConv2dOp):
#----------------Conv2dCudnn---------------- #----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp): class TestCudnn(TestConv2dOp):
def init_op_type(self): def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithPad(TestWithPad): class TestCudnnWithPad(TestWithPad):
def init_op_type(self): def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithStride(TestWithStride): class TestCudnnWithStride(TestWithStride):
def init_op_type(self): def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWithGroup(TestWithGroup): class TestCudnnWithGroup(TestWithGroup):
def init_op_type(self): def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn" self.op_type = "conv2d_cudnn"
class TestCudnnWith1x1(TestWith1x1): class TestCudnnWith1x1(TestWith1x1):
def init_op_type(self): def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn" self.op_type = "conv2d_cudnn"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册