diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 89726bf9859e71ee04c2f9380554090845fd44e5..2ced43f9e6c60da642f7a6252f889d9c9ab9748f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -166,6 +166,8 @@ function(op_library TARGET) # Append first implemented MKLDNN activation operator if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") + elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") endif() diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index c701a2ad63048f69e8443fee7434928698aa20cc..e4c471d86b7bff1bfb3b697ab24219144b4667f5 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) +cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog - shape_inference data_transform lod_tensor profiler transfer_scope_cache) + shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) -cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) +cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(tuple_test SRCS tuple_test.cc ) diff --git a/paddle/fluid/framework/op_kernel_type.cc b/paddle/fluid/framework/op_kernel_type.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d4801e4a0eed7083e671e1d49b8628dfb280cf9 --- /dev/null +++ b/paddle/fluid/framework/op_kernel_type.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_kernel_type.h" + +namespace paddle { +namespace framework { + +size_t OpKernelType::Hash::operator()(const OpKernelType& key) const { + int cur_loc = 0; + + int place = key.place_.which(); + cur_loc += OpKernelType::kPlaceBits; + + int data_type = static_cast(key.data_type_) << cur_loc; + cur_loc += OpKernelType::kPrimaryDTypeBits; + + int data_layout = static_cast(key.data_layout_) << cur_loc; + cur_loc += OpKernelType::kLayoutBits; + + int library_type = static_cast(key.library_type_) << cur_loc; + cur_loc += OpKernelType::kLibBits; + + int customized_value = key.customized_type_value_; + PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits)); + customized_value = customized_value << cur_loc; + cur_loc += OpKernelType::kCustomizeBits; + PADDLE_ENFORCE(cur_loc < 64); + + std::hash hasher; + return hasher(place + data_type + data_layout + library_type + + customized_value); +} + +bool OpKernelType::operator==(const OpKernelType& o) const { + return platform::places_are_same_class(place_, o.place_) && + data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && + library_type_ == o.library_type_ && + customized_type_value_ == o.customized_type_value_; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index ac0330218973123771367ed5ba9477c90143a043..9edc1a3e150027b5a3dbd8483dc8b58d1d4ab918 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -24,54 +24,55 @@ limitations under the License. */ namespace paddle { namespace framework { -struct OpKernelType { - struct Hash { - size_t operator()(const OpKernelType& key) const { - int place = key.place_.which(); - int data_type = static_cast(key.data_type_) << LEFT_SHIFT; - int data_layout = static_cast(key.data_layout_) << (LEFT_SHIFT * 2); - int library_type = static_cast(key.library_type_) - << (LEFT_SHIFT * 3); - - std::hash hasher; - return hasher(place + data_type + data_layout + library_type); - } - }; +class OpKernelType { + public: + constexpr static int kDefaultCustomizedTypeValue = 0; - // place, data_type, library_type kinds less than 2^8 - constexpr static int LEFT_SHIFT = 8; - - proto::VarType::Type data_type_; - DataLayout data_layout_; - platform::Place place_; - LibraryType library_type_; + // In total should be smaller than 64. + constexpr static int kPlaceBits = 4; + constexpr static int kPrimaryDTypeBits = 8; + constexpr static int kLayoutBits = 4; + constexpr static int kLibBits = 4; + constexpr static int kCustomizeBits = 4; OpKernelType(proto::VarType::Type data_type, platform::Place place, DataLayout data_layout = DataLayout::kAnyLayout, - LibraryType library_type = LibraryType::kPlain) + LibraryType library_type = LibraryType::kPlain, + int customized_type_value = kDefaultCustomizedTypeValue) : data_type_(data_type), data_layout_(data_layout), place_(place), - library_type_(library_type) {} + library_type_(library_type), + customized_type_value_(customized_type_value) {} OpKernelType(proto::VarType::Type data_type, const platform::DeviceContext& dev_ctx, DataLayout data_layout = DataLayout::kAnyLayout, - LibraryType library_type = LibraryType::kPlain) + LibraryType library_type = LibraryType::kPlain, + int customized_type_value = kDefaultCustomizedTypeValue) : data_type_(data_type), data_layout_(data_layout), place_(dev_ctx.GetPlace()), - library_type_(library_type) {} + library_type_(library_type), + customized_type_value_(customized_type_value) {} + + virtual ~OpKernelType() {} + + struct Hash { + size_t operator()(const OpKernelType& key) const; + }; size_t hash_key() const { return Hash()(*this); } - bool operator==(const OpKernelType& o) const { - return platform::places_are_same_class(place_, o.place_) && - data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && - library_type_ == o.library_type_; - } + bool operator==(const OpKernelType& o) const; bool operator!=(const OpKernelType& o) const { return !(*this == o); } + + proto::VarType::Type data_type_; + DataLayout data_layout_; + platform::Place place_; + LibraryType library_type_; + int customized_type_value_; }; inline std::ostream& operator<<(std::ostream& os, diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 0e6e74293c30d5f8caa58fe6bfa63657d2669b46..36673e48c2047bca54f604b082dfec123f1e2c82 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -35,6 +35,7 @@ limitations under the License. */ namespace paddle { namespace framework { + class Registrar { public: // In our design, various kinds of classes, e.g., operators and kernels, @@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor; template inline void RegisterKernelClass(const char* op_type, const char* library_type, - Func func) { + int customized_type_value, Func func) { std::string library(library_type); std::string data_layout = "ANYLAYOUT"; if (library == "MKLDNN") { @@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type, } OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), StringToDataLayout(data_layout), - StringToLibraryType(library_type)); + StringToLibraryType(library_type), customized_type_value); OperatorWithKernel::AllOpKernels()[op_type][key] = func; } @@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor { using KERNEL_TYPE = typename std::tuple_element>::type; - void operator()(const char* op_type, const char* library_type) const { + void operator()(const char* op_type, const char* library_type, + int customized_type_value) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; RegisterKernelClass( - op_type, library_type, [](const framework::ExecutionContext& ctx) { + op_type, library_type, customized_type_value, + + [](const framework::ExecutionContext& ctx) { KERNEL_TYPE().Compute(ctx); }); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; - func(op_type, library_type); + func(op_type, library_type, customized_type_value); } }; template struct OpKernelRegistrarFunctor { - void operator()(const char* op_type, const char* library_type) const {} + void operator()(const char* op_type, const char* library_type, + int customized_type_value) const {} }; // User can register many kernel in one place. The data type could be @@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor { template class OpKernelRegistrar : public Registrar { public: - explicit OpKernelRegistrar(const char* op_type, const char* library_type) { + explicit OpKernelRegistrar(const char* op_type, const char* library_type, + int customized_type_value) { OpKernelRegistrarFunctor func; - func(op_type, library_type); + func(op_type, library_type, customized_type_value); } }; @@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx; template class OpKernelRegistrarEx : public Registrar { public: - explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { + explicit OpKernelRegistrarEx(const char* op_type, const char* library_type, + int customized_type_value) { OpKernelRegistrarFunctorEx func; - func(op_type, library_type); + func(op_type, library_type, customized_type_value); } }; template struct OpKernelRegistrarFunctorEx { - void operator()(const char* op_type, const char* library_type) const {} + void operator()(const char* op_type, const char* library_type, + int customized_type_value) const {} }; template @@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx>::type; - void operator()(const char* op_type, const char* library_type) const { - RegisterKernelClass(op_type, library_type, Functor()); + void operator()(const char* op_type, const char* library_type, + int customized_type_value) const { + RegisterKernelClass(op_type, library_type, + customized_type_value, Functor()); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctorEx= size, I + 2, DataTypeAndKernelType...> func; - func(op_type, library_type); + func(op_type, library_type, customized_type_value); } }; +// clang-format off /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx \ - __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ - #library_type); \ - int TouchOpKernelRegistrar_##op_type##_##library_type() { \ - __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ - return 0; \ +#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \ + place_class, customized_name, \ + customized_type_value, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \ + "REGISTER_OP_KERNEL must be called in " \ + "global namespace"); \ + static ::paddle::framework::OpKernelRegistrar \ + __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\ + #op_type, #library_type, customized_type_value); \ + int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\ + __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \ + .Touch(); \ + return 0; \ } +#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \ + REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \ + op_type, library_type, place_class, DEFAULT_TYPE, \ + ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \ + __VA_ARGS__) + #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__) #define REGISTER_OP_CPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) -#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##op_type##_##library_type##__, \ - "REGISTER_OP_KERNEL_EX must be called in global namespace"); \ - static ::paddle::framework::OpKernelRegistrarEx \ - __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ - #library_type); \ - int TouchOpKernelRegistrar_##op_type##_##library_type() { \ - __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ - return 0; \ +#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \ + customized_name, \ + customized_type_value, \ + ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \ + "REGISTER_OP_KERNEL_EX must be called in " \ + "global namespace"); \ + static ::paddle::framework::OpKernelRegistrarEx \ + __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\ + #op_type, #library_type, customized_type_value); \ + int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\ + __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \ + .Touch(); \ + return 0; \ } #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ - REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ - __VA_ARGS__) + REGISTER_OP_KERNEL_EX( \ + op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \ + ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \ + __VA_ARGS__) -#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ - REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) +#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ + REGISTER_OP_KERNEL_EX( \ + op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \ + ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \ + __VA_ARGS__) /** * Macro to mark what Operator and Kernel @@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx("scale", "scale of cosine op"); + AddAttr("kernel_sub_type", "kernels with different implementations.") + .SetDefault(0); AddComment("This is test op"); } }; @@ -95,6 +97,8 @@ TEST(OperatorBase, all) { namespace paddle { namespace framework { +static int special_type_value = 1; + class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: void Make() { @@ -103,11 +107,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .GreaterThan(0.0); + AddAttr("kernel_sub_type", "kernels with different implementations.") + .SetDefault(0); AddComment("This is test op"); } }; static int cpu_kernel_run_num = 0; +static int cpu_kernel2_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { public: @@ -117,7 +124,10 @@ class OpWithKernelTest : public OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override {} OpKernelType GetExpectedKernelType( const ExecutionContext& ctx) const override { - return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); + int sub_type = ctx.Attr("kernel_sub_type"); + return OpKernelType(proto::VarType::FP32, ctx.GetPlace(), + framework::DataLayout::kAnyLayout, + framework::LibraryType::kPlain, sub_type); } }; @@ -132,6 +142,17 @@ class CPUKernelTest : public OpKernel { } }; +template +class CPUKernel2Test : public OpKernel { + public: + void Compute(const ExecutionContext& ctx) const { + std::cout << ctx.op().DebugString() << std::endl; + cpu_kernel2_run_num++; + ASSERT_EQ(ctx.op().Input("x"), "IN1"); + ASSERT_EQ(ctx.op().Output("y"), "OUT1"); + } +}; + class OpKernelTestMultiInputsProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: @@ -142,6 +163,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .GreaterThan(0.0); + AddAttr("kernel_sub_type", "kernels with different implementations.") + .SetDefault(0); AddComment("This is test op"); } }; @@ -189,9 +212,15 @@ class CPUKernalMultiInputsTest : public OpKernel { REGISTER_OP_WITHOUT_GRADIENT( op_with_kernel, paddle::framework::OpWithKernelTest, paddle::framework::OpKernelTestProtoAndCheckerMaker); + REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( + op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME, + paddle::framework::special_type_value, + paddle::framework::CPUKernel2Test); + // test with single input TEST(OpKernel, all) { paddle::framework::InitDevices(true); @@ -211,7 +240,19 @@ TEST(OpKernel, all) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_place); + // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called. + ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); + ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0); + + attr = op_desc.mutable_attrs()->Add(); + attr->set_name("kernel_sub_type"); + attr->set_type(paddle::framework::proto::AttrType::INT); + attr->set_i(1); + auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc); + op2->Run(scope, cpu_place); + // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called. ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); + ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1); } REGISTER_OP_WITHOUT_GRADIENT( diff --git a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc index 453f222f1f1e3f3b9ee8fa7bd49f4cab2286e7ea..b086c910d38a243d98315f2d6eb82ecc0ec5c06d 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_prelu_op.cc @@ -90,5 +90,4 @@ TEST(prelu_op, test_scalar) { } // namespace inference } // namespace paddle -// USE_OP(prelu); -USE_CPU_ONLY_OP(prelu); +USE_OP(prelu); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index e822785ad6f4f6f67b72141f3e7b04aefa72e58b..95443e813327c1247ac530c4d2e68b3607ff0e73 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,4 +1,4 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu avg_pool_op_plugin.cu - DEPS enforce tensorrt_engine) + DEPS enforce tensorrt_engine prelu) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index e8f4254402a5d8a5e6c5a2384bf9fbe48341956e..3075e87ea6d719a3f49d14c8c4b8015f7d688a50 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -14,92 +14,16 @@ #include #include +#include #include "glog/logging.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" +#include "paddle/fluid/operators/math/prelu.h" namespace paddle { namespace inference { namespace tensorrt { namespace plugin { -static const int CUDA_NUM_THREADS = 1024; -static const int CUDA_MAX_NUM_BLOCKS = 65535; -inline static int GET_NUM_BLOCKS(const int N) { - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -__global__ void PReluChannelWiseKernel(const float *input, const float *alpha, - float *output, int channel, - size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - float *out = output + offset; - float scale = alpha[blockIdx.x % channel]; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale * x; - } -} - -__global__ void PReluElementWiseKernel(const float *input, const float *alpha, - float *output, size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - const float *scale = alpha + offset; - float *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale[i] * x; - } -} - -__global__ void PReluScalarKernel(const float *input, const float *alpha, - float *output, size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const float *in = input + offset; - float scale = *alpha; - float *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - float x = in[i]; - out[i] = (x > 0) ? x : scale * x; - } -} - -static inline void PReluChannelWise(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, - const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluChannelWiseKernel<<>>( - input, alpha, output, dims.d[0], spatial_size); -} - -static inline void PReluElementWise(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, - const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluElementWiseKernel<<>>( - input, alpha, output, spatial_size); -} - -static inline void PReluScalar(cudaStream_t stream, const float *input, - const float *alpha, float *output, - int batch_size, const nvinfer1::Dims &dims) { - size_t unroll = batch_size * dims.d[0]; - size_t spatial_size = dims.d[1] * dims.d[2]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluScalarKernel<<>>( - input, alpha, output, spatial_size); -} - nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, const nvinfer1::Dims *inputDims, int nbInputs) { @@ -110,19 +34,31 @@ nvinfer1::Dims PReluPlugin::getOutputDimensions(int index, return output_dims; } -int PReluPlugin::enqueue(int batchSize, const void *const *inputs, +int PReluPlugin::enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) { // input dims is CHW. const auto &input_dims = this->getInputDims(0); const float *input = reinterpret_cast(inputs[0]); const float *alpha = reinterpret_cast(alpha_.get().values); float *output = reinterpret_cast(outputs)[0]; + + std::vector input_shape; + input_shape.push_back(batch_size); + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + if (mode_ == "channel") { - PReluChannelWise(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluChannelWiseDirectCUDAFunctor + prelu_channel_wise; + prelu_channel_wise(stream, input, alpha, output, input_shape); } else if (mode_ == "element") { - PReluElementWise(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluElementWiseDirectCUDAFunctor + prelu_element_wise; + prelu_element_wise(stream, input, alpha, output, input_shape); } else { - PReluScalar(stream, input, alpha, output, batchSize, input_dims); + operators::math::PreluScalarDirectCUDAFunctor prelu_scalar; + prelu_scalar(stream, input, alpha, output, input_shape); } return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 3458df16062eb8840e67371ccf65566d9fccc532..2808133e844a6587ea3ec3fc0732caac2a6e433a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -71,7 +71,7 @@ endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) - set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) endif() # FIXME(typhoonzero): operator deps may not needed. diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 05e268bf6a8d9a2562a4c278d317f75dac28e52c..ce45dd58419ab20cccf00544288b79d869515578 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -491,8 +491,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; -REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, - ops::ConvMKLDNNOpKernel); - -REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, - ops::ConvMKLDNNGradOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kConvMKLDNNFP32, + ops::ConvMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kConvMKLDNNFP32, + ops::ConvMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 342525be49e28f1785e25d4daad38c3c81b4774f..7455b9492f054b32ee7fb1fc90b1a344367ceb81 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -74,6 +74,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = ctx.Attr("data_format"); @@ -89,6 +91,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; + customized_type_value = kConvMKLDNNFP32; } #endif @@ -105,7 +108,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, - library); + library, customized_type_value); } void Conv2DOpMaker::Make() { @@ -342,6 +345,8 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = ctx.Attr("data_format"); @@ -357,12 +362,13 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( platform::CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; + customized_type_value = kConvMKLDNNFP32; } #endif return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + layout_, library_, customized_type_value); } } // namespace operators diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index e69814001e4da5d10e51ee57c1dbe291338b8b49..249f308c13ff5636fbaa6747b28cab7886b7e736 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -27,6 +27,8 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +constexpr int kConvMKLDNNFP32 = 1; +constexpr int kConvMKLDNNINT8 = 2; // Base convolution operator definations for other conv // like operators to reuse the implementation. diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index e01070c7b8ed4374cf8a61cfde4de940b4ea38b2..dd64cc327fc383937bc9a9d6e7daa0cec488e4cc 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -177,11 +177,19 @@ struct CudnnRNNCache { seed_)); CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); + +#if CUDNN_VERSION >= 6000 CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6( handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); +#else + CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor( + rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, + CUDNN_DATA_FLOAT)); +#endif CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_)); CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 972dcf5494e9acd47e7ff615db45f056a43724a6..0dbcc442dfa1a395cdb0ffbd69eb78ad66cfaa17 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -150,14 +150,14 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { "Output(W@Grad should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@Grad should not be null."); - if (!ctx->Attrs().Get("is_sparse")) { - if (ctx->HasOutput(framework::GradVarName("Bias"))) { - ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Bias")); - } - ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); + + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); } + ctx->SetOutputDim(framework::GradVarName("W"), ctx->GetInputDim("W")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } protected: diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 07ff8f947e59d2954783e2ba537bfce3cb320f22..b73a32af89e882ac02623dd1d312f400a78fc47a 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -185,7 +185,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { ctx.Output(framework::GradVarName("W")); w_grad->set_rows(real_rows); // Build a map of id -> row_index to speed up finding the index of one id - w_grad->SyncIndex(); w_grad->set_height(w.dims()[0]); auto* w_grad_value = w_grad->mutable_value(); framework::DDim temp_dim(w.dims()); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index db91628d94e80b9db25370df9f36456ed8c2d34e..8e8f83a6353bb3827d9e465624838646f657c12a 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -59,6 +59,7 @@ math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) +math_library(prelu) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc index 71b9293eeded77553ca06a8574cca3941fa36b6a..5a6e64b6f87d33249f0153e5f391deaf78e53de5 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.cc +++ b/paddle/fluid/operators/math/matrix_bit_code.cc @@ -89,6 +89,8 @@ template void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, const framework::Tensor& weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat->dims()[0]; size_t tmat_width = tmat->dims()[1]; size_t input_width = input.dims()[1]; @@ -99,13 +101,12 @@ void MatrixBitCodeFunctor::Mul(framework::Tensor* tmat, for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_row = input_value + input_width * i; for (int j = 0; j < code_length; ++j) { size_t index = code->calc_index(j); + const T* weight_row = weight_value + weight_width * index; T sum = static_cast(0.0); - for (size_t k = 0; k < input_width; ++k) { - sum += weight_value[weight_width * index + k] * - input_value[input_width * i + k]; - } + sum = blas.DOT(input_width, weight_row, input_row); tmat_value[i * tmat_width + j] += sum; } } @@ -115,6 +116,8 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::Tensor* weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -122,16 +125,25 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto tmat_value = tmat.data(); auto weight_value = weight->data(); auto input_value = input.data(); + + std::unordered_map>> ops; + for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_value_row = input_value + input_width * i; + const T* tmat_row = tmat_value + i * tmat_width; for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - - for (size_t k = 0; k < input_width; ++k) { - weight_value[weight_width * index + k] += - tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; - } + ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } + } + for (auto& op : ops) { + auto& op_in_row = op.second; + for (auto& pair : op_in_row) { + auto& scale = pair.first; + auto* input_row = pair.second; + T* weight_row = weight_value + op.first * weight_width; + blas.AXPY(input_width, scale, input_row, weight_row); } } } @@ -140,6 +152,8 @@ template void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, framework::SelectedRows* weight, const framework::Tensor& input) { + auto blas = + GetBlas(platform::CPUDeviceContext()); size_t num_samples = tmat.dims()[0]; size_t input_width = input.dims()[1]; size_t tmat_width = tmat.dims()[1]; @@ -147,17 +161,28 @@ void MatrixBitCodeFunctor::MulGradWeight(const framework::Tensor& tmat, auto tmat_value = tmat.data(); auto weight_value = weight->mutable_value()->data(); auto input_value = input.data(); + + std::unordered_map>> ops; + ops.reserve(weight->rows().size()); + for (size_t i = 0; i < num_samples; ++i) { auto code = code_table_->get_code(i); int code_length = code->get_length(); + const T* input_value_row = input_value + input_width * i; + const T* tmat_row = tmat_value + i * tmat_width; for (int j = 0; j < code_length; ++j) { - size_t index = code->calc_index(j); - for (size_t k = 0; k < input_width; ++k) { - int64_t row_index = weight->GetIndexFromId(static_cast(index)); - weight_value[row_index * weight_width + k] += - tmat_value[i * tmat_width + j] * input_value[input_width * i + k]; - } + ops[code->calc_index(j)].emplace_back(tmat_row[j], input_value_row); + } + } + + for (auto& row : weight->rows()) { + auto& op_in_row = ops[row]; + for (auto& pair : op_in_row) { + auto& scale = pair.first; + auto* input_row = pair.second; + blas.AXPY(input_width, scale, input_row, weight_value); } + weight_value += weight_width; } } diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h index c30bb52641e865efe57659a551bc4b493634c6b9..35ca73802b48982ddf3ed7485b56f50221c9f28c 100644 --- a/paddle/fluid/operators/math/matrix_bit_code.h +++ b/paddle/fluid/operators/math/matrix_bit_code.h @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" #if defined(_WIN32) diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu new file mode 100644 index 0000000000000000000000000000000000000000..701a802080f65ea32b95402682dc46362ccf0966 --- /dev/null +++ b/paddle/fluid/operators/math/prelu.cu @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/prelu.h" + +namespace paddle { +namespace operators { +namespace math { + +static const int CUDA_NUM_THREADS = 1024; +static const int CUDA_MAX_NUM_BLOCKS = 65535; +inline static int GET_NUM_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void PReluChannelWiseKernel(const T *input, const T *alpha, + T *output, int channel, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + T *out = output + offset; + T scale = alpha[blockIdx.x % channel]; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +template +__global__ void PReluElementWiseKernel(const T *input, const T *alpha, + T *output, size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + const T *scale = alpha + offset; + T *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale[i] * x; + } +} + +template +__global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + const T *in = input + offset; + T scale = *alpha; + T *out = output + offset; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T x = in[i]; + out[i] = (x > 0) ? x : scale * x; + } +} + +template +static inline void PReluChannelWise(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, input_shape[1], spatial_size); +} + +template +static inline void PReluElementWise(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +template +static inline void PReluScalar(cudaStream_t stream, const T *input, + const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +template +void PreluChannelWiseDirectCUDAFunctor::operator()( + cudaStream_t stream, const T *input, const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluChannelWiseKernel<<>>( + input, alpha, output, input_shape[1], spatial_size); +} + +template +void PreluElementWiseDirectCUDAFunctor::operator()( + cudaStream_t stream, const T *input, const T *alpha, T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluElementWiseKernel<<>>( + input, alpha, output, spatial_size); +} + +template +void PreluScalarDirectCUDAFunctor::operator()(cudaStream_t stream, + const T *input, const T *alpha, + T *output, + std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluScalarKernel<<>>( + input, alpha, output, spatial_size); +} + +template class PreluChannelWiseDirectCUDAFunctor; +template class PreluChannelWiseDirectCUDAFunctor; + +template class PreluElementWiseDirectCUDAFunctor; +template class PreluElementWiseDirectCUDAFunctor; + +template class PreluScalarDirectCUDAFunctor; +template class PreluScalarDirectCUDAFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h new file mode 100644 index 0000000000000000000000000000000000000000..3237c6d4cbf956aafb4046ea2ffa42efe62e7b28 --- /dev/null +++ b/paddle/fluid/operators/math/prelu.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +#ifdef PADDLE_WITH_CUDA +template +class PreluChannelWiseDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; + +template +class PreluElementWiseDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; + +template +class PreluScalarDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T *input, const T *alpha, + T *output, std::vector input_shape); +}; +#endif + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 58cfbb76e93a1c15c9b7cf9f9e596066c29b7ebb..64d94ab6044c1992145062319120b0372f5061c0 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -58,7 +58,7 @@ class PReluOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..36b5259ae5106914f5668625cad535ebc8aa72ec --- /dev/null +++ b/paddle/fluid/operators/prelu_op.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/prelu.h" +#include "paddle/fluid/operators/prelu_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CUDAPReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* alpha = context.Input("Alpha"); + auto* out = context.Output("Out"); + + const T* x_ptr = x->data(); + T* o_ptr = out->mutable_data(context.GetPlace()); + + const T* alpha_ptr = alpha->data(); + auto& mode = context.Attr("mode"); + + int numel = x->numel(); + auto dim = x->dims(); + std::vector input_shape = framework::vectorize2int(dim); + + if (mode == "channel") { + math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; + prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, + alpha_ptr, o_ptr, input_shape); + } else if (mode == "element") { + math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; + prelu_element_wise(context.cuda_device_context().stream(), x_ptr, + alpha_ptr, o_ptr, input_shape); + } else { + math::PreluScalarDirectCUDAFunctor prelu_scalar; + prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr, + o_ptr, input_shape); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + prelu, ops::CUDAPReluKernel, + ops::CUDAPReluKernel); diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 213cd8a9ce094512cea6f6405492ec8feff11516..550fe2edee13d628e761eca194809823537a4024 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -125,8 +125,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnRNNBackwardWeights); \ __macro(cudnnRNNForwardInference); \ __macro(cudnnDestroyDropoutDescriptor); \ - __macro(cudnnDestroyRNNDescriptor); \ - __macro(cudnnSetRNNDescriptor_v6); + __macro(cudnnDestroyRNNDescriptor); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) @@ -165,6 +164,12 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +// APIs in R6 +#if CUDNN_VERSION >= 6000 +#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) __macro(cudnnSetRNNDescriptor_v6); +CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + #if CUDNN_VERSION >= 7001 #define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \ __macro(cudnnSetConvolutionGroupCount); \ diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 912ba8f005e297ddfdc9b44e88c318e2551ba343..6299b166af8a5f65cf587ae282c955f33db0044b 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -442,8 +442,6 @@ EOF make install -j 8 if [ "$1" == "cp27-cp27m" ]; then pip install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl - set -e - python -c "import paddle.fluid" elif [ "$1" == "cp35-cp35m" ]; then pip3.5 install --user ${INSTALL_PREFIX:-/paddle/build}/opt/paddle/share/wheels/*.whl elif [ "$1" == "cp36-cp36m" ]; then