提交 2a900a69 编写于 作者: M Megvii Engine Team

perf(imperative): improve reduce op performance

GitOrigin-RevId: 26d982a7b8f0092d3f93e4d9cbac5b44e326427d
上级 c2293815
...@@ -180,7 +180,7 @@ public: ...@@ -180,7 +180,7 @@ public:
virtual void exec( virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); MGE_WIN_DECLSPEC_FUC void deduce_layout(const TensorLayout& src, TensorLayout& dst);
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0; const TensorLayout& src, const TensorLayout& dst) = 0;
......
...@@ -384,22 +384,18 @@ def _reduce(mode): ...@@ -384,22 +384,18 @@ def _reduce(mode):
data = self data = self
if axis is None: if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True" assert not keepdims, "can not set axis=None and keepdims=True"
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) (result,) = apply(builtin.Reduce(mode=mode), data)
elif isinstance(axis, collections.abc.Iterable): elif isinstance(axis, collections.abc.Iterable):
axis = _normalize_axis(self.ndim, axis, reverse=True) axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis: for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai) op = builtin.Reduce(mode=mode, axis=ai, keepdim=keepdims)
(data,) = apply(op, data) (data,) = apply(op, data)
if not keepdims:
data = squeeze_cpp(data, ai)
result = data result = data
else: else:
# builtin.Reduce already accept negtive axis # builtin.Reduce already accept negtive axis
op = builtin.Reduce(mode=mode, axis=axis) op = builtin.Reduce(mode=mode, axis=axis, keepdim=keepdims)
(result,) = apply(op, data) (result,) = apply(op, data)
if not keepdims:
result = squeeze_cpp(result, axis)
return result return result
return f return f
......
...@@ -271,24 +271,34 @@ std::optional<ValueRefList> reduce_grad_rule( ...@@ -271,24 +271,34 @@ std::optional<ValueRefList> reduce_grad_rule(
if (reduce.mode != Reduce::Mode::SUM) { if (reduce.mode != Reduce::Mode::SUM) {
return {}; return {};
} }
if (inputs.size() != 1) { auto axis = reduce.axis;
if (inputs.size() != 1 || axis == INT_MAX) {
return {}; return {};
} }
std::array<ValueRef, 1> input_shapes; std::array<ValueRef, 1> input_shapes;
if (inputs_require_grad[0]) { if (inputs_require_grad[0]) {
input_shapes[0] = get_shape(inputs[0]); input_shapes[0] = get_shape(inputs[0]);
} }
if (axis < 0) {
axis = (*inputs[0].shape()).ndim + axis;
}
auto maker = CustomGradMaker(backward, inputs.size()); auto maker = CustomGradMaker(backward, inputs.size());
auto keepdim = reduce.keepdim || axis == INT_MAX;
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { maker.backward(
mgb_assert(grads.size() == 1); [shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
ValueRef grad = grads[0]; mgb_assert(grads.size() == 1);
SmallVector<ValueRef> ret(1); ValueRef grad = grads[0];
if (grad && shapes[0]) { if (!keepdim) {
ret[0] = broadcast_to(grad, shapes[0]); auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
} grad = imperative::apply(*grad_op, grad)[0];
return ret; }
}); SmallVector<ValueRef> ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]);
}
return ret;
});
maker.finalize(); maker.finalize();
return imperative::apply(ApplyOp(op), inputs); return imperative::apply(ApplyOp(op), inputs);
} }
......
...@@ -36,8 +36,8 @@ EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device) ...@@ -36,8 +36,8 @@ EncodedSubgraph generate_batchnorm_backward_graph(DType dtype, CompNode device)
op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0]; op, Subgraph::vars_t({(Subgraph::var_t)args...}), 1)[0];
}; };
auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0)); auto prod = Reduce::make(megdnn::param::Reduce(Reduce::Mode::PRODUCT, 0), true);
auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM)); auto sum = Reduce::make(megdnn::param::Reduce(Reduce::Mode::SUM), true);
auto sub = Elemwise::make(Elemwise::Mode::SUB); auto sub = Elemwise::make(Elemwise::Mode::SUB);
auto mul = Elemwise::make(Elemwise::Mode::MUL); auto mul = Elemwise::make(Elemwise::Mode::MUL);
auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV); auto div = Elemwise::make(Elemwise::Mode::TRUE_DIV);
......
...@@ -9,10 +9,16 @@ ...@@ -9,10 +9,16 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/graph/symbol_var.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/dtype.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h" #include "../dnn_op_helper.h"
#include "../op_trait.h" #include "../op_trait.h"
...@@ -22,18 +28,41 @@ namespace { ...@@ -22,18 +28,41 @@ namespace {
namespace reduce { namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def); auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()}; auto comp_node = inputs[0]->comp_node();
OperatorNodeConfig config{reduce.make_name(), comp_node, inputs[0]->dtype()};
if (inputs.size() > 1) { if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else {
return opr::Reduce::make(
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config);
} }
using Param = megdnn::param::Reduce;
auto param = reduce.param();
if (param.axis < 0) {
param.axis = inputs[0]->shape().ndim + param.axis;
}
SymbolVar target_shape = (cg::VarNode*)nullptr;
if (param.axis == INT_MAX) {
DTypeScalar vi{1};
// auto graph = ComputingGraph::make();
auto graph = inputs[0]->owner_graph();
target_shape = opr::ImmutableTensor::make(*graph, vi, config);
}
auto res = opr::Reduce::make(inputs[0], param, target_shape, config);
if (!reduce.keepdim && param.axis != INT_MAX) {
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> remove_param;
remove_param.push_back(Desc::make_remove(param.axis));
OperatorNodeConfig remove_config{
def.make_name(), comp_node, inputs[0]->dtype()};
return opr::AxisAddRemove::make(res, remove_param, remove_config);
}
return res;
} }
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Reduce>(); auto* node = &node_->cast_final_safe<opr::Reduce>();
return Reduce::make(node->param()); return Reduce::make(node->param(), true);
} }
// TODO: using this for apply_on_physical_tensor // TODO: using this for apply_on_physical_tensor
...@@ -57,21 +86,159 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -57,21 +86,159 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return {Tensor::make( return {Tensor::make(
inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())}; inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())};
} }
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated); auto size = inputs.size();
if (size > 1) {
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated);
}
auto comp_node = inputs[0]->comp_node();
using TensorND = megdnn::TensorND;
auto&& op_def = def.cast_final_safe<Reduce>();
SmallVector<TensorND> inp_tensornds;
inp_tensornds.reserve(inputs.size());
auto src = inputs[0]->layout();
DnnOprCaller<megdnn::Reduce> dnn_op(comp_node);
dnn_op.op->param() = op_def.param();
auto axis = op_def.param().axis;
auto keepdim = op_def.keepdim;
if (axis < 0) {
axis = inputs[0]->layout().ndim + axis;
}
dnn_op.op->param().axis = axis == INT_MAX ? 0 : axis;
if (axis == INT_MAX) {
src.shape[0] = src.total_nr_elems();
src.ndim = 1;
src.init_contiguous_stride();
}
TensorLayout layout{src.dtype};
dnn_op.op->deduce_layout(src, layout);
if (inputs[0]->layout().is_empty()) {
inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src);
auto mode = op_def.param().mode;
DnnOprCaller<megdnn::Fill> fill_op(comp_node);
if (!keepdim && src.ndim > 1) {
layout.remove_axis_inplace(axis);
layout.init_contiguous_stride();
}
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
std::string err_msg;
switch (mode) {
case Reduce::Mode::SUM:
if (!out.empty()) {
fill_op.op->param() = 0;
fill_op.op->exec(out.as_megdnn(), {});
}
break;
case Reduce::Mode::PRODUCT:
if (!out.empty()) {
fill_op.op->param() = 1;
fill_op.op->exec(out.as_megdnn(), {});
}
break;
case Reduce::Mode::MEAN:
err_msg = "mean";
break;
case Reduce::Mode::MIN:
err_msg = "min";
break;
case Reduce::Mode::MAX:
err_msg = "max";
break;
case Reduce::Mode::SUM_SQR:
err_msg = "sum_sqr";
break;
default:
mgb_throw(MegBrainError, "bad reduce mode");
}
if (!err_msg.empty()) {
mgb_throw(
MegBrainError, "empty input is not allowed for reduce mode: %s",
err_msg.c_str());
}
return {Tensor::make(out)};
}
auto dnn_ten = inputs[0]->dnn_tensor();
dnn_ten.layout = src;
inp_tensornds.push_back(dnn_ten);
megdnn::Workspace dnn_wk;
auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout);
if (wk_size != 0) {
auto wk = Blob::make(comp_node, wk_size);
dnn_wk.raw_ptr = wk->storage().get();
dnn_wk.size = wk_size;
}
DeviceTensorND out =
BlobManager::inst()->alloc_workspace_with_defrag(comp_node, layout);
dnn_op.op->exec(inp_tensornds[0], out.as_megdnn(), dnn_wk);
if (!keepdim && src.ndim > 1) {
auto out_layout = out.layout();
out_layout.remove_axis_inplace(axis);
out_layout.init_contiguous_stride();
out.resize(out_layout);
}
return {Tensor::make(out)};
} }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto [output_descs, validated] = auto&& op_def = def.cast_final_safe<Reduce>();
proxy_graph_detail::infer_output_attrs_fallible(def, inputs); auto axis = op_def.param().axis;
if (inputs.size() == 2 && !output_descs[0].layout.ndim) { auto keepdim = op_def.keepdim;
size_t size = inputs.size();
SmallVector<LogicalTensorDesc> dests(size);
if (size > 1) {
auto [output_descs, validated] =
proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
if (!inputs[1].value.empty()) { if (!inputs[1].value.empty()) {
cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value); cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value);
output_descs[0].layout.init_contiguous_stride(); output_descs[0].layout.init_contiguous_stride();
} }
return {output_descs, validated};
}
if (axis < 0) {
axis = inputs[0].layout.ndim + axis;
}
if (axis == INT_MAX || inputs[0].layout.ndim == 1) {
TensorLayout layout{inputs[0].layout.dtype};
layout.shape[0] = 1;
layout.ndim = 1;
dests[0].layout = layout;
dests[0].comp_node = inputs[0].comp_node;
} else {
for (size_t i = 0; i < size; ++i) {
dests[i].comp_node = inputs[i].comp_node;
dests[i].layout = inputs[i].layout;
if (not keepdim && dests[i].layout.ndim > 1) {
dests[i].layout.remove_axis_inplace(axis);
} else {
dests[i].layout.shape[axis] = 1;
}
dests[i].layout.init_contiguous_stride();
}
} }
return {output_descs, validated};
return {dests, true};
} }
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
......
...@@ -92,7 +92,13 @@ ValueRefList remove_axis_rule( ...@@ -92,7 +92,13 @@ ValueRefList remove_axis_rule(
ValueRefList reduce_rule( ValueRefList reduce_rule(
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask, const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) { const Type<ScalarValue>& scalar_type) {
bool keepdim = reduce.keepdim;
auto axis = reduce.axis;
if (inputs.size() == 1) { if (inputs.size() == 1) {
if (axis == INT_MAX || (inputs[0].shape()->ndim == 1 && keepdim == false)) {
// CompNode device = *inputs[0].device();
return {scalar_type.make(imperative::apply(reduce, inputs)[0])};
}
return imperative::apply(reduce, inputs); return imperative::apply(reduce, inputs);
} }
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
......
...@@ -43,6 +43,7 @@ EncodedSubgraph make_backward_graph_from_forward( ...@@ -43,6 +43,7 @@ EncodedSubgraph make_backward_graph_from_forward(
const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs, const EncodedSubgraph& forward, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad, const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad); const SmallVector<bool>& output_has_grad);
EncodedSubgraph make_from_computing_graph( EncodedSubgraph make_from_computing_graph(
const VarNodeArray& inputs, const VarNodeArray& outputs); const VarNodeArray& inputs, const VarNodeArray& outputs);
......
...@@ -26,7 +26,11 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { ...@@ -26,7 +26,11 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
}]; }];
} }
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{
let extraArguments = (ins
MgbBoolAttr:$keepdim
);
}
def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let inputs = (ins AnyType:$inputs); let inputs = (ins AnyType:$inputs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册