diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 3850062617823a830f21a00e74f9c80a0f382fa7..41728e16adf0ee34e71f342d9d96d3853460eb69 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -2048,6 +2048,53 @@ protected: const TensorLayout& doup, const TensorLayout& mask, const TensorLayout& dinp, size_t workspace_in_bytes); }; +class SoftmaxBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase); + DEF_OPR_PARAM(Softmax); + +protected: + void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); + void check_layout_fwd(const TensorLayout& input, const TensorLayout& output); +}; + +class SoftmaxForward : public SoftmaxBase { + DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1); + +public: + /** + * \param[in] input input tensor + * \param[out] output output tensor + */ + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_out output, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& input, TensorLayout& output); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& output) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& output, + size_t workspace_in_bytes); +}; +using Softmax = SoftmaxForward; + +class SoftmaxBackward : public SoftmaxBase { + DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1); + +public: + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& diff, + const TensorLayout& grad_x) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& diff, + const TensorLayout& grad_x, size_t workspace_in_bytes); +}; class RNNCellForward : public OperatorBase { DEF_OPR_PARAM(RNNCell); diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f6b36aa72f4cbd4d1d92981abeaf4b540812f089..b4feb266962626e969333021799081270df0801a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -253,6 +253,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum_alias('Format', 'Convolution') ) +(pdef('Softmax'). + add_fields('int32', 'axis', -1) +) + (pdef('AdaptivePooling', version=0, is_legacy=True). add_enum_alias('Mode', 'PoolingV0'). add_enum_alias('Format', 'ConvolutionV0') diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index ad54f5885cc87fe5bb7bc791aa44431c674cf0dd..1e717915281e14c0714d0fd22f69c505377d6625 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -219,7 +219,9 @@ private: cb(RNN) \ cb(RNNBackward) \ cb(LSTM) \ - cb(LSTMBackward) + cb(LSTMBackward) \ + cb(SoftmaxForward) \ + cb(SoftmaxBackward) // clang-format on /*! diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index f1558e757f854eb758fd04877d6c2183312d01fc..6a2e4d33f50be84594d6287e17e8a37b6ee9c840 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -145,6 +145,8 @@ DEF(RNNBackward, 10, true, true); DEF(LSTMCellForward, 10, true, true); DEF(LSTMForward, 8, true, true); DEF(LSTMBackward, 13, true, true); +DEF(SoftmaxForward, 2, true, true); +DEF(SoftmaxBackward, 3, true, false); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/softmax.cpp b/dnn/src/common/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e72453d4df8f95b0d2ba502422d6f4d22f21f8f7 --- /dev/null +++ b/dnn/src/common/softmax.cpp @@ -0,0 +1,61 @@ +/** + * \file dnn/src/common/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void SoftmaxBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { + megdnn_assert( + param().axis >= -static_cast(src.ndim) && + param().axis < static_cast(src.ndim), + "axis: %d ndim: %zu", param().axis, src.ndim); + megdnn_assert_contiguous(src); + dst = src; + + dst.dtype = src.dtype; + dst.format = src.format; + dst.init_contiguous_stride(); +} + +void SoftmaxBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { + TensorLayout dst_expected; + megdnn_assert_eq_dtype(src, dst); + deduce_layout_fwd(src, dst_expected); + megdnn_assert_eq_layout(dst_expected, dst); + megdnn_assert(src.dtype == dst.dtype); +} + +void SoftmaxForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { + deduce_layout_fwd(src, dst); +} + +void SoftmaxForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { + check_layout_fwd(src, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void SoftmaxBackward::check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert_eq_layout(src, diff); + megdnn_assert_eq_layout(src, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 061af59646b75ad292d81bdd21ebb1e611a5c634..ac003713f298c0c86e811429ef8c1cf182bb66de 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -76,6 +76,7 @@ #include "src/cuda/separable_filter/opr_impl.h" #include "src/cuda/sleep/opr_impl.h" #include "src/cuda/sliding_window_transpose/opr_impl.h" +#include "src/cuda/softmax/opr_impl.h" #include "src/cuda/split/opr_impl.h" #include "src/cuda/svd/opr_impl.h" #include "src/cuda/tensor_remap/opr_impl.h" @@ -221,6 +222,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward); template std::unique_ptr HandleImpl::create_operator() { diff --git a/dnn/src/cuda/softmax/opr_impl.cpp b/dnn/src/cuda/softmax/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ceb2a716c584b34bd72928899d2d1ca3988c61c8 --- /dev/null +++ b/dnn/src/cuda/softmax/opr_impl.cpp @@ -0,0 +1,174 @@ +/** + * \file dnn/src/cuda/softmax/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/cuda/softmax/opr_impl.h" +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +int SizeToAxis(const int axis, const size_t* dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +int SizeOutAxis(const int axis, const size_t* dims, const int ndim) { + int size = 1; + for (int i = axis + 1; i < ndim; i++) { + size *= dims[i]; + } + return size; +} + +std::vector SoftmaxForwardImpl::init_mode( + _megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { + auto dims = src.layout.shape; + const int rank = src.layout.ndim; + const int axis = CanonicalAxis(param().axis, rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims, rank); + + mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; + + return {N, dim, D, 1}; +} + +int sc(const size_t x) { + return static_cast(x); +} + +cudnnDataType_t to_cudnn_dtype( + DType type, const param::Convolution::Format format = {}) { + switch (type.enumv()) { + case DTypeEnum::Float32: + return CUDNN_DATA_FLOAT; + case DTypeEnum::Float16: + return CUDNN_DATA_HALF; +#if CUDNN_MAJOR >= 7 + case DTypeEnum::Int32: + case DTypeEnum::QuantizedS32: + return CUDNN_DATA_INT32; +#endif +#if CUDNN_MAJOR >= 6 + case DTypeEnum::QuantizedS8: { + if (format == param::Convolution::Format::NCHW4) + return CUDNN_DATA_INT8x4; +#if CUDNN_VERSION >= 7500 + else if (format == param::Convolution::Format::NCHW32) + return CUDNN_DATA_INT8x32; +#endif + else + return CUDNN_DATA_INT8; + } + + case DTypeEnum::Int8: { + if (format == param::Convolution::Format::NCHW4) + return CUDNN_DATA_INT8x4; +#if CUDNN_VERSION >= 7500 + else if (format == param::Convolution::Format::NCHW32) + return CUDNN_DATA_INT8x32; +#endif + else + return CUDNN_DATA_INT8; + } +#endif + default: +#if CUDNN_MAJOR >= 6 + megdnn_throw("dtype must be float16/float32/int8/int32"); +#else + megdnn_throw("dtype must be float16/float32"); +#endif + } +} + +void SoftmaxForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + dt_float32 alpha = 1.0f, beta = 0.0f; + TensorDesc src_desc, dst_desc; + + cudnnSoftmaxMode_t mode; + std::vector tensor_dims = init_mode(src, mode); + const int dimA[] = { + sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), + sc(tensor_dims[3])}; + const int strideA[] = { + sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), + sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; + + cudnn_check(cudnnSetTensorNdDescriptor( + src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); + cudnn_check(cudnnSetTensorNdDescriptor( + dst_desc.desc, to_cudnn_dtype(dst.layout.dtype), 4, dimA, strideA)); + + cudnn_check(cudnnSoftmaxForward( + cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, + src_desc.desc, src.raw_ptr(), &beta, dst_desc.desc, dst.raw_ptr())); +} + +//================================Softmax Backward============================ + +std::vector SoftmaxBackwardImpl::init_mode( + _megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { + auto dims = src.layout.shape; + const int rank = src.layout.ndim; + const int axis = CanonicalAxis(param().axis, rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims, rank); + + mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; + + return {N, dim, D, 1}; +} + +void SoftmaxBackwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + { + dt_float32 alpha = 1.0f, beta = 0.0f; + TensorDesc src_desc, diff_desc, grad_desc; + cudnnSoftmaxMode_t mode; + std::vector tensor_dims = init_mode(src, mode); + + const int dimA[] = { + sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), + sc(tensor_dims[3])}; + const int strideA[] = { + sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), + sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; + + cudnn_check(cudnnSetTensorNdDescriptor( + src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); + cudnn_check(cudnnSetTensorNdDescriptor( + diff_desc.desc, to_cudnn_dtype(diff.layout.dtype), 4, dimA, strideA)); + cudnn_check(cudnnSetTensorNdDescriptor( + grad_desc.desc, to_cudnn_dtype(grad.layout.dtype), 4, dimA, strideA)); + + cudnn_check(cudnnSoftmaxBackward( + cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, + src_desc.desc, src.raw_ptr(), diff_desc.desc, diff.raw_ptr(), &beta, + grad_desc.desc, grad.raw_ptr())); + } +} + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/cuda/softmax/opr_impl.h b/dnn/src/cuda/softmax/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..803869d2848bc8046990a06e852547916bd08660 --- /dev/null +++ b/dnn/src/cuda/softmax/opr_impl.h @@ -0,0 +1,58 @@ +/** + * \file dnn/src/cuda/softmax/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +class SoftmaxForwardImpl final : public SoftmaxForward { +public: + using SoftmaxForward::SoftmaxForward; + + std::vector init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; + + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, /* src */ + const TensorLayout& /* dst */) override { + return 0; + } +}; + +class SoftmaxBackwardImpl final : public SoftmaxBackward { +public: + using SoftmaxBackward::SoftmaxBackward; + + std::vector init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; + + size_t get_workspace_in_bytes( + const TensorLayout& /* input */, const TensorLayout& /* diff */, + const TensorLayout& /* grad_x */) override { + return 0; + } + + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 8756816822eb5be04322e4ce565239c2cd2a3929..be79091bbedd2df30bf4652c4f511425b31f8a45 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -81,6 +81,7 @@ #include "src/naive/separable_filter/opr_impl.h" #include "src/naive/sleep/opr_impl.h" #include "src/naive/sliding_window_transpose/opr_impl.h" +#include "src/naive/softmax/opr_impl.h" #include "src/naive/split/opr_impl.h" #include "src/naive/svd/opr_impl.h" #include "src/naive/tensor_remap/opr_impl.h" diff --git a/dnn/src/naive/softmax/opr_impl.cpp b/dnn/src/naive/softmax/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1a2a8fb8ab5c033c5a3e48455dc941cb279f725 --- /dev/null +++ b/dnn/src/naive/softmax/opr_impl.cpp @@ -0,0 +1,116 @@ +/** + * \file dnn/src/naive/softmax/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/naive/softmax/opr_impl.h" + +#include +#include "megdnn/dtype.h" +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise_helper.cuh" +#include "src/common/opr_delegate.h" +#include "src/common/reduce_helper.h" +#include "src/common/utils.h" +#include "src/naive/elemwise/opr_impl.h" +#include "src/naive/handle.h" +#include "src/naive/lowbit_utils.h" + +using namespace megdnn; + +namespace { +template +TensorND op_exec(_megdnn_tensor_in src, megdnn::dt_byte* workspace_ptr, const T& opr) { + TensorLayout dst_layout; + opr->deduce_layout(src.layout, dst_layout); + TensorND dst{workspace_ptr, dst_layout}; + workspace_ptr += dst_layout.span().dist_byte(); + auto new_workspace = Workspace{ + workspace_ptr, opr->get_workspace_in_bytes(src.layout, dst_layout)}; + workspace_ptr += opr->get_workspace_in_bytes(src.layout, dst_layout); + opr->exec(src, dst, new_workspace); + return dst; +} + +} // namespace + +namespace megdnn { +namespace naive { + +//===============================Softmax Forward============================ + +void SoftmaxForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + auto axis = param().axis; + if (axis < 0) + axis += src.layout.ndim; + check_exec(src.layout, dst.layout, workspace.size); + auto workspace_ptr = workspace.raw_ptr; + + auto reduce_opr = handle()->create_operator(); + reduce_opr->param().axis = axis; + reduce_opr->param().mode = Reduce::Mode::MAX; + reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; + TensorND max_tensor = op_exec(src, workspace_ptr, reduce_opr); + + auto elemwise_opr = handle()->create_operator(); + elemwise_opr->param().mode = Elemwise::Mode::SUB; + elemwise_opr->exec({src, max_tensor}, dst); + + elemwise_opr->param().mode = Elemwise::Mode::EXP; + TensorLayout exp_layout; + elemwise_opr->deduce_layout({src.layout}, exp_layout); + TensorND exp_tensor{workspace_ptr, exp_layout}; + workspace_ptr += exp_layout.span().dist_byte(); + elemwise_opr->exec({dst}, exp_tensor); + + reduce_opr->param().mode = Reduce::Mode::SUM; + TensorND down_tensor = op_exec(exp_tensor, workspace_ptr, reduce_opr); + + elemwise_opr->param().mode = Elemwise::Mode::TRUE_DIV; + elemwise_opr->exec({exp_tensor, down_tensor}, dst); +} + +//=============================Softmax backward ============================ + +void SoftmaxBackwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + auto axis = param().axis; + if (axis < 0) + axis += src.layout.ndim; + check_exec(src.layout, diff.layout, grad.layout, workspace.size); + auto workspace_ptr = workspace.raw_ptr; + TensorLayout mulres = src.layout; + mulres.dtype = src.layout.dtype; + mulres.format = src.layout.format; + mulres.init_contiguous_stride(); + + TensorND mul_tensor{workspace_ptr, mulres}; + workspace_ptr += mulres.span().dist_byte(); + TensorND mul_tensor2{workspace_ptr, mulres}; + workspace_ptr += mulres.span().dist_byte(); + + auto elemwise_opr = handle()->create_operator(); + elemwise_opr->param().mode = Elemwise::Mode::MUL; + elemwise_opr->exec({src, diff}, mul_tensor); + + auto reduce_opr = handle()->create_operator(); + reduce_opr->param().axis = axis; + reduce_opr->param().mode = Reduce::Mode::SUM; + reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; + TensorND sum_tensor = op_exec(mul_tensor, workspace_ptr, reduce_opr); + + elemwise_opr->exec({sum_tensor, src}, mul_tensor2); + + elemwise_opr->param().mode = Elemwise::Mode::SUB; + elemwise_opr->exec({mul_tensor, mul_tensor2}, grad); +} +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/softmax/opr_impl.h b/dnn/src/naive/softmax/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..932f8dcde6ccd9309fec0de431df5e0b7f709023 --- /dev/null +++ b/dnn/src/naive/softmax/opr_impl.h @@ -0,0 +1,45 @@ +/** + * \file dnn/src/naive/softmax/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class SoftmaxForwardImpl final : public SoftmaxForward { +public: + using SoftmaxForward::SoftmaxForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&) override { + return src.span().dist_byte() * 2; + } +}; + +class SoftmaxBackwardImpl final : public SoftmaxBackward { +public: + using SoftmaxBackward::SoftmaxBackward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&, + const TensorLayout&) override { + return src.span().dist_byte() * 3; + } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/common/softmax.h b/dnn/test/common/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..aff5b3af43f9d9fbf03ff4752dd108094776e842 --- /dev/null +++ b/dnn/test/common/softmax.h @@ -0,0 +1,41 @@ +/** + * \file dnn/test/common/softmax.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace softmax { + +struct TestArg { + param::Softmax param; + TensorShape ishape; + TestArg(param::Softmax param, TensorShape ishape) : param(param), ishape(ishape) {} +}; + +inline std::vector get_args() { + std::vector args; + using Param = param::Softmax; + + for (int32_t axis = 0; axis < 5; axis++) { + args.emplace_back(Param{axis}, TensorShape{2, 23, 32, 30, 17}); + } + + return args; +} + +} // namespace softmax +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/test/cuda/softmax.cpp b/dnn/test/cuda/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75d19821047205f280c3ebf6ef75f89f34ceebac --- /dev/null +++ b/dnn/test/cuda/softmax.cpp @@ -0,0 +1,71 @@ +/** + * \file dnn/test/cuda/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/cuda/fixture.h" + +#include "megdnn/tensor_iter.h" +#include "test/common/checker.h" +#include "test/common/softmax.h" + +#include "src/common/utils.h" +#include "test/cuda/utils.h" + +// to check cudnn version +#include +#include "test/cuda/benchmark.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, SOFTMAX_FORWARD) { + auto args = softmax::get_args(); + std::vector dtypes{dtype::Float16(), dtype::Float32()}; + + for (auto dtype : dtypes) + for (auto&& arg : args) { + auto param = arg.param; + auto src = arg.ishape; + Checker checker(handle_cuda()); + if (dtype == dtype::BFloat16()) { + checker.set_epsilon(2e-2); + } else { + checker.set_epsilon(1e-2); + } + checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( + TensorShapeArray{src, {}}); + } +} + +TEST_F(CUDA, SOFTMAX_BACKWARD) { + auto args = softmax::get_args(); + for (auto&& arg : args) { + Checker checker(handle_cuda()); + TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); + TensorLayout olayout; + + { + auto opr = handle_cuda()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout(ilayout, olayout); + } + auto set_dtype = [&checker](DType dtype) { + checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); + }; + + set_dtype(dtype::Float32()); + checker.set_epsilon(1e-3).set_param(arg.param).exec( + TensorShapeArray{ilayout, olayout, ilayout}); + } +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/softmax.cpp b/dnn/test/naive/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62a3ec13cc0ec6795b16b9cdf58efb62ba278729 --- /dev/null +++ b/dnn/test/naive/softmax.cpp @@ -0,0 +1,56 @@ +/** + * \file dnn/test/naive/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/naive/fixture.h" + +#include "megdnn/oprs/nn.h" +#include "test/common/checker.h" + +using namespace megdnn; +using namespace test; + +TEST_F(NAIVE, SOFTMAX_FORWARD) { + Checker checker(handle(), /* check_dispatch */ false); + + Softmax::Param param{0}; + + TensorND input = TensorValue( + {2, 2, 2, 2}, dtype::Float32(), + {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + + TensorND output = TensorValue( + {2, 2, 2, 2}, dtype::Float32(), + {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, + 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); + + checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output}); +} + +TEST_F(NAIVE, SOFTMAX_BACKWARD) { + Checker checker(handle(), /* check_dispatch */ false); + + Softmax::Param param{0}; + + TensorND input = TensorValue( + {2, 2, 2, 2}, dtype::Float32(), + {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, + 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); + + TensorND diff = TensorValue( + {2, 2, 2, 2}, dtype::Float32(), + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + TensorND output = TensorValue( + {2, 2, 2, 2}, dtype::Float32(), + {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + + checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); +} \ No newline at end of file diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index a7220993ff52be6ee92fe73e98ae1e9ae9c99cd7..e17eb800aad165a6a236a010d04689eb7ecf2d0f 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1061,10 +1061,15 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: """ if axis is None: axis = _get_softmax_axis(len(inp.shape)) - offset = inp.max(axis=axis, keepdims=True).detach() - cached = exp(inp - offset) - down = sum(cached, axis=axis, keepdims=True) - return cached / down + if isinstance(axis, list): + offset = inp.max(axis=axis, keepdims=True).detach() + cached = exp(inp - offset) + down = sum(cached, axis=axis, keepdims=True) + return cached / down + else: + op = builtin.Softmax(axis=axis,) + (output,) = apply(op, inp) + return output def layer_norm( diff --git a/imperative/src/impl/ops/softmax.cpp b/imperative/src/impl/ops/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..68319e93d951709d48b0d629d0105b432ef8fc8d --- /dev/null +++ b/imperative/src/impl/ops/softmax.cpp @@ -0,0 +1,52 @@ +/** + * \file imperative/src/impl/ops/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/dnn/softmax.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb { +namespace imperative { +namespace { +namespace softmax { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& softmax = static_cast(def); + OperatorNodeConfig config{softmax.make_name()}; + return opr::Softmax::make(inputs[0], softmax.param(), config); +} + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return Softmax::make(node->param()); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef&, const SmallVector& inputs) { + SmallVector out_shapes(1); + auto&& i0 = inputs[0]; + out_shapes[0] = {i0.layout, i0.comp_node}; + return {out_shapes, true}; +} + +OP_TRAIT_REG(Softmax, Softmax, opr::Softmax) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .fallback(); + +} // namespace softmax +} // namespace +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 167210c18e280f87c8fa241d76fec884be53b0c2..894dd5fc7699f574fc0a0e0941faec5b255ff03e 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -354,6 +354,7 @@ def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; def TQT: MgbHashableOp<"TQT", [TQTParam]>; def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; +def Softmax: MgbHashableOp<"Softmax", [SoftmaxParam]>; def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { let extraArguments = (ins MgbDTypeAttr:$dtype diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index cda4eca765109a6860587af59a6b6f302ceb531b..c66850f001453f1e2492a67ac931b1c658777445 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -327,4 +327,7 @@ decl_opr('TQT', decl_opr('LSQ', inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], params='LSQ') +decl_opr('Softmax', + inputs=[Doc('src','input tensor')], + params='Softmax') # vim: ft=python diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index ea8227bab67e14be49bb5b6f4f53577a44c83520..2ac6df932fcb0831a5e769eace6917b1f5306277 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -25,6 +25,7 @@ #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/sliding_window_transpose.h" +#include "megbrain/opr/dnn/softmax.h" #include "megbrain/opr/dnn/tqt.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" @@ -324,6 +325,19 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Param = opr::SoftmaxBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::SoftmaxBackward::make(i[0], i[1], param, config) + .node() + ->owner_opr(); + } +}; + template <> struct OprLoadDumpImpl : public GeneralOprLoadDumpImpl< @@ -720,6 +734,8 @@ MGB_SEREG_OPR(RNNForward, 3); MGB_SEREG_OPR(RNNBackward, 7); MGB_SEREG_OPR(LSTMForward, 4); MGB_SEREG_OPR(LSTMBackward, 9); +MGB_SEREG_OPR(Softmax, 1); +MGB_SEREG_OPR(SoftmaxBackward, 2); } // namespace opr } // namespace mgb diff --git a/src/opr/impl/dnn/softmax.cpp b/src/opr/impl/dnn/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1064edbb2131642497b47dc3fe3fc96ade71eaa5 --- /dev/null +++ b/src/opr/impl/dnn/softmax.cpp @@ -0,0 +1,124 @@ +/** + * \file src/opr/impl/dnn/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/softmax.h" + +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; + +/* ==================== SoftmaxForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxForward); + +SoftmaxForward::SoftmaxForward( + VarNode* inp, const Param& param, const OperatorNodeConfig& config) + : Super{inp->owner_graph(), config, "softmax", {inp}} { + init_megdnn_opr(*this, param); + + add_input({inp}); + output(0)->dtype(inp->dtype()); +} + +SymbolVar SoftmaxForward::make( + SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { + auto out = inp.node() + ->owner_graph() + ->insert_opr(std::make_unique( + inp.node(), param, config)) + ->output(); + + return out[0]; +} + +void SoftmaxForward::get_output_var_shape( + const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { + out_shape[0] = inp_shape[0]; +} + +size_t SoftmaxForward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return megdnn_opr()->get_workspace_in_bytes( + {input_shapes[0], input(0)->dtype(), input(0)->format()}, + {output_shapes[0], output(0)->dtype(), output(0)->format()}); +} + +void SoftmaxForward::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(SoftmaxForward) { + SymbolVar grad = SoftmaxBackward::make(opr.output(0), out_grad[0], opr.param()); + + return grad.node(); +} +#endif + +// /* ==================== SoftmaxBackward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxBackward); + +SoftmaxBackward::SoftmaxBackward( + VarNode* src, VarNode* diff, const Param& param, + const OperatorNodeConfig& config) + : Super({src->owner_graph(), config, "Softmax_backward", {src, diff}}, 0, + true) { + init_megdnn_opr(*this, param); + add_input({src, diff}); +} + +SymbolVar SoftmaxBackward::make( + SymbolVar src, SymbolVar diff, const Param& param, + const OperatorNodeConfig& config) { + auto out = src.node() + ->owner_graph() + ->insert_opr(std::make_unique( + src.node(), diff.node(), param, config)) + ->output(); + return out[0]; +} + +void SoftmaxBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + this->init_output_static_infer_desc_workspace(false); +} + +void SoftmaxBackward::init_output_dtype() { + output(0)->dtype(input(0)->dtype()); +} + +size_t SoftmaxBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return megdnn_opr()->get_workspace_in_bytes( + {input_shapes[0], input(0)->dtype(), input(0)->format()}, + {input_shapes[1], input(1)->dtype(), input(1)->format()}, + {output_shapes[0], output(0)->dtype(), output(0)->format()}); +} + +void SoftmaxBackward::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/dnn/softmax.h b/src/opr/include/megbrain/opr/dnn/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..71261d0e51471d0b9be553092beb71f22fcb103d --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/softmax.h @@ -0,0 +1,64 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/softmax.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs/nn.h" + +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + SoftmaxForward, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC SoftmaxForward( + VarNode* src, const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar src, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void get_output_var_shape( + const TensorShapeArray& inp_shape, + TensorShapeArray& out_shape) const override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + +using Softmax = SoftmaxForward; + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + SoftmaxBackward, intl::MegDNNOprWrapperBwd) // { +public: + MGE_WIN_DECLSPEC_FUC SoftmaxBackward( + VarNode* x, VarNode* y_grad, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar x, SymbolVar y_grad, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/softmax.cpp b/src/opr/test/dnn/softmax.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce888f4a8cf444a582c216fc2bfc7100f6cd3a2e --- /dev/null +++ b/src/opr/test/dnn/softmax.cpp @@ -0,0 +1,65 @@ +/** + * \file src/opr/test/dnn/softmax.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/softmax.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" + +using namespace std; +using namespace mgb; + +namespace { +using Param = opr::SoftmaxForward::Param; +void run(int32_t axis) { + using Checker = AutoOprChecker<1, 1>; + Param param{axis}; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto o0 = opr::SoftmaxForward::make(inputs[0], param); + return {o0}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = + MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + opr->param() = param; + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(inp[0]->shape()); + size_t wk_size = + opr->get_workspace_in_bytes(inp[0]->layout(), dest[0].layout()); + std::unique_ptr wk_store{new dt_byte[wk_size]}; + opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), {wk_store.get(), wk_size}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator src_gen(10.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions opt; + opt.numdiff_max_err = 1e-4; + + Checker checker{make_graph, fwd}; + checker.set_input_generator(0, gen); + checker.run({TensorShape{1, 2, 3, 4}}, opt) + .run({TensorShape{2, 3, 8, 8}}, opt) + .run({TensorShape{1, 3, 4, 4}}, opt); +} + +} // anonymous namespace + +TEST(TestOprDNN, SoftmaxForward) { + REQUIRE_GPU(1); + run(1); +} \ No newline at end of file diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 15a9c0dc353bc082656eb66c426210dd0afd1fb9..b6c46c39869a08da1726b03ee0c24047a71a2649 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -121,6 +121,7 @@ union OperatorParam { param.RNNCell = 87, param.RNN = 88, param.LSTM = 89, + param.Softmax = 90, } table Operator {