diff --git a/dnn/src/naive/softmax/opr_impl.cpp b/dnn/src/naive/softmax/opr_impl.cpp index 0991eba629284730895a234f5d5b4d05aad42b2c..de99b7a40be9fbc9327ff85d8442ce3e11f65d6a 100644 --- a/dnn/src/naive/softmax/opr_impl.cpp +++ b/dnn/src/naive/softmax/opr_impl.cpp @@ -10,96 +10,120 @@ #include "src/naive/elemwise/opr_impl.h" #include "src/naive/handle.h" #include "src/naive/lowbit_utils.h" - -using namespace megdnn; - -namespace { -template -TensorND op_exec(_megdnn_tensor_in src, megdnn::dt_byte* workspace_ptr, const T& opr) { - TensorLayout dst_layout; - opr->deduce_layout(src.layout, dst_layout); - TensorND dst{workspace_ptr, dst_layout}; - workspace_ptr += dst_layout.span().dist_byte(); - auto new_workspace = Workspace{ - workspace_ptr, opr->get_workspace_in_bytes(src.layout, dst_layout)}; - workspace_ptr += opr->get_workspace_in_bytes(src.layout, dst_layout); - opr->exec(src, dst, new_workspace); - return dst; -} - -} // namespace - namespace megdnn { namespace naive { //===============================Softmax Forward============================ +size_t SoftmaxForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { + int32_t axis = param().axis; + int32_t nidm = src.ndim; + if (axis < 0) + axis += nidm; + megdnn_assert(axis >= 0, "is not a vaild axis=%d for dim=%d", axis, nidm); + + reduce_opr = handle()->create_operator(); + elemwise_opr = handle()->create_operator(); + + reduce_opr->param().axis = axis; + reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; + reduce_opr->param().mode = Reduce::Mode::MAX; + + reduce_opr->param().mode = Reduce::Mode::MAX; + size_t max_workspace = reduce_opr->get_workspace_in_bytes(src, dst); + reduce_opr->param().mode = Reduce::Mode::SUM; + size_t sum_workspace = reduce_opr->get_workspace_in_bytes(src, dst); + reduce_worksize = max_workspace > sum_workspace ? max_workspace : sum_workspace; + + return WorkspaceBundle(nullptr, {src.span().dist_byte(), reduce_worksize}) + .total_size_in_bytes(); +} + void SoftmaxForwardImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { - auto axis = param().axis; - if (axis < 0) - axis += src.layout.ndim; check_exec(src.layout, dst.layout, workspace.size); - auto workspace_ptr = workspace.raw_ptr; - auto reduce_opr = handle()->create_operator(); - reduce_opr->param().axis = axis; + WorkspaceBundle workspace_bundle{ + workspace.raw_ptr, {src.layout.span().dist_byte(), reduce_worksize}}; + + TensorLayout tmp_layout; reduce_opr->param().mode = Reduce::Mode::MAX; - reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; - TensorND max_tensor = op_exec(src, workspace_ptr, reduce_opr); + reduce_opr->deduce_layout(src.layout, tmp_layout); + TensorND max_tensor{workspace_bundle.get_workspace(0).raw_ptr, tmp_layout}; + reduce_opr->exec(src, max_tensor, workspace_bundle.get_workspace(1)); - auto elemwise_opr = handle()->create_operator(); elemwise_opr->param().mode = Elemwise::Mode::SUB; elemwise_opr->exec({src, max_tensor}, dst); + // no broadcast elemwise_opr->param().mode = Elemwise::Mode::EXP; - TensorLayout exp_layout; - elemwise_opr->deduce_layout({src.layout}, exp_layout); - TensorND exp_tensor{workspace_ptr, exp_layout}; - workspace_ptr += exp_layout.span().dist_byte(); - elemwise_opr->exec({dst}, exp_tensor); + elemwise_opr->exec({dst}, dst); reduce_opr->param().mode = Reduce::Mode::SUM; - TensorND down_tensor = op_exec(exp_tensor, workspace_ptr, reduce_opr); + reduce_opr->deduce_layout(src.layout, tmp_layout); + + TensorND deno_tensor{workspace_bundle.get_workspace(0).raw_ptr, tmp_layout}; + reduce_opr->exec(dst, deno_tensor, workspace_bundle.get_workspace(1)); elemwise_opr->param().mode = Elemwise::Mode::TRUE_DIV; - elemwise_opr->exec({exp_tensor, down_tensor}, dst); + elemwise_opr->exec({dst, deno_tensor}, dst); } //=============================Softmax backward ============================ +size_t SoftmaxBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout&) { + int32_t axis = param().axis; + int32_t nidm = src.ndim; + if (axis < 0) + axis += nidm; + megdnn_assert(axis >= 0, "is not a vaild axis=%d for dim=%d", axis, nidm); + + reduce_opr = handle()->create_operator(); + elemwise_opr = handle()->create_operator(); + reduce_opr->param().axis = axis; + reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; + reduce_opr->param().mode = Reduce::Mode::SUM; + reduce_worksize = reduce_opr->get_workspace_in_bytes(src, diff); + + return WorkspaceBundle( + nullptr, + {src.span().dist_byte(), src.span().dist_byte(), reduce_worksize}) + .total_size_in_bytes(); +} + void SoftmaxBackwardImpl::exec( _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { - auto axis = param().axis; - if (axis < 0) - axis += src.layout.ndim; check_exec(src.layout, diff.layout, grad.layout, workspace.size); - auto workspace_ptr = workspace.raw_ptr; - TensorLayout mulres = src.layout; - mulres.dtype = src.layout.dtype; - mulres.format = src.layout.format; - mulres.init_contiguous_stride(); - - TensorND mul_tensor{workspace_ptr, mulres}; - workspace_ptr += mulres.span().dist_byte(); - TensorND mul_tensor2{workspace_ptr, mulres}; - workspace_ptr += mulres.span().dist_byte(); - - auto elemwise_opr = handle()->create_operator(); + + WorkspaceBundle workspace_bundle{ + workspace.raw_ptr, + {src.layout.span().dist_byte(), src.layout.span().dist_byte(), + reduce_worksize}}; + + TensorLayout mul_layout = src.layout; + mul_layout.dtype = src.layout.dtype; + mul_layout.format = src.layout.format; + mul_layout.init_contiguous_stride(); + + TensorND mul_lhs_tensor{workspace_bundle.get_workspace(0).raw_ptr, mul_layout}; + TensorND mul_rhs_tensor{workspace_bundle.get_workspace(1).raw_ptr, mul_layout}; + elemwise_opr->param().mode = Elemwise::Mode::MUL; - elemwise_opr->exec({src, diff}, mul_tensor); + elemwise_opr->exec({src, diff}, mul_lhs_tensor); - auto reduce_opr = handle()->create_operator(); - reduce_opr->param().axis = axis; - reduce_opr->param().mode = Reduce::Mode::SUM; - reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; - TensorND sum_tensor = op_exec(mul_tensor, workspace_ptr, reduce_opr); + TensorLayout sum_layout; + reduce_opr->deduce_layout(mul_lhs_tensor.layout, sum_layout); + TensorND sum_tensor{grad.raw_ptr(), sum_layout}; + reduce_opr->exec(mul_lhs_tensor, sum_tensor, workspace_bundle.get_workspace(2)); - elemwise_opr->exec({sum_tensor, src}, mul_tensor2); + // there are broadcast occurring elemwsie mul + elemwise_opr->exec({sum_tensor, src}, mul_rhs_tensor); elemwise_opr->param().mode = Elemwise::Mode::SUB; - elemwise_opr->exec({mul_tensor, mul_tensor2}, grad); + elemwise_opr->exec({mul_lhs_tensor, mul_rhs_tensor}, grad); } } // namespace naive } // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/softmax/opr_impl.h b/dnn/src/naive/softmax/opr_impl.h index e0a7827716a742a549074a9a9ccc04bf741e2a97..30f414b0e3ae924be52801df5002fe9c6f9a01cf 100644 --- a/dnn/src/naive/softmax/opr_impl.h +++ b/dnn/src/naive/softmax/opr_impl.h @@ -1,6 +1,6 @@ #pragma once #include "megdnn/oprs.h" - +#include "src/common/utils.h" namespace megdnn { namespace naive { @@ -11,9 +11,12 @@ public: _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout&) override { - return src.span().dist_byte() * 2; - } + const TensorLayout& src, const TensorLayout& dst) override; + +private: + size_t reduce_worksize = 0; + std::unique_ptr reduce_opr; + std::unique_ptr elemwise_opr; }; class SoftmaxBackwardImpl : public SoftmaxBackward { @@ -23,10 +26,13 @@ public: _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, _megdnn_workspace workspace) override; size_t get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout&, - const TensorLayout&) override { - return src.span().dist_byte() * 3; - } + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout&) override; + +private: + size_t reduce_worksize = 0; + std::unique_ptr reduce_opr; + std::unique_ptr elemwise_opr; }; } // namespace naive