未验证 提交 2fe4bf2f 编写于 作者: Z zyfncg 提交者: GitHub

Optimize the performanece of sum api (#42231)

* optimize the performanece of sum api

* optimize IsDenseTensorInput

* remove debug log
上级 51ea349c
...@@ -70,6 +70,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -70,6 +70,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
} }
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::LOD_TENSOR;
}
bool IsDenseTensorInputs(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name); auto var_types = ctx_.GetInputsVarType(name);
return std::all_of(var_types.begin(), var_types.end(), return std::all_of(var_types.begin(), var_types.end(),
[](const proto::VarType::Type& type) { [](const proto::VarType::Type& type) {
...@@ -78,11 +83,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -78,11 +83,8 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name); auto var_type = ctx_.GetInputVarType(name);
return std::all_of(var_types.begin(), var_types.end(), return var_type == proto::VarType::SELECTED_ROWS;
[](const proto::VarType::Type& type) {
return type == proto::VarType::SELECTED_ROWS;
});
} }
bool IsDenseTensorVectorInput(const std::string& name) const override { bool IsDenseTensorVectorInput(const std::string& name) const override {
......
...@@ -365,6 +365,11 @@ std::vector<DDim> InterpretercoreInferShapeContext::GetInputsDim( ...@@ -365,6 +365,11 @@ std::vector<DDim> InterpretercoreInferShapeContext::GetInputsDim(
return GetDims(vars); return GetDims(vars);
} }
proto::VarType::Type InterpretercoreInferShapeContext::GetInputVarType(
const std::string& name) const {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> std::vector<proto::VarType::Type>
InterpretercoreInferShapeContext::GetInputsVarType( InterpretercoreInferShapeContext::GetInputsVarType(
const std::string& name) const { const std::string& name) const {
......
...@@ -100,6 +100,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -100,6 +100,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
std::vector<DDim> GetInputsDim(const std::string& name) const override; std::vector<DDim> GetInputsDim(const std::string& name) const override;
proto::VarType::Type GetInputVarType(const std::string& name) const override;
std::vector<proto::VarType::Type> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override; const std::string& name) const override;
......
...@@ -245,6 +245,10 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -245,6 +245,10 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRunMKLDNNKernel() const override; bool IsRunMKLDNNKernel() const override;
proto::VarType::Type GetInputVarType(const std::string &name) const override {
return GetVarType(Inputs(name).at(0));
}
std::vector<proto::VarType::Type> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override { const std::string &name) const override {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
......
...@@ -979,6 +979,10 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -979,6 +979,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
return GetDims(vars); return GetDims(vars);
} }
proto::VarType::Type GetInputVarType(const std::string& name) const override {
return GetVarType(InputVars(name).at(0));
}
std::vector<proto::VarType::Type> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override { const std::string& name) const override {
return GetVarTypes(InputVars(name)); return GetVarTypes(InputVars(name));
......
...@@ -479,6 +479,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -479,6 +479,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
} }
bool IsDenseTensorInput(const std::string& name) const override { bool IsDenseTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::DenseTensor>();
}
bool IsDenseTensorInputs(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name); auto vars = ctx_.MultiInputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
return var->IsType<phi::DenseTensor>(); return var->IsType<phi::DenseTensor>();
...@@ -486,10 +491,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -486,10 +491,8 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
} }
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
auto vars = ctx_.MultiInputVar(name); const auto* var = ctx_.InputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { return var->IsType<phi::SelectedRows>();
return var->IsType<phi::SelectedRows>();
});
} }
bool IsDenseTensorVectorInput(const std::string& name) const override { bool IsDenseTensorVectorInput(const std::string& name) const override {
......
...@@ -65,6 +65,8 @@ class InferShapeContext { ...@@ -65,6 +65,8 @@ class InferShapeContext {
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
virtual bool HasAttr(const std::string &name) const = 0; virtual bool HasAttr(const std::string &name) const = 0;
virtual proto::VarType::Type GetInputVarType(
const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetInputsVarType( virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const = 0; const std::string &name) const = 0;
virtual std::vector<proto::VarType::Type> GetOutputsVarType( virtual std::vector<proto::VarType::Type> GetOutputsVarType(
......
...@@ -300,6 +300,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -300,6 +300,15 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
return vec_res; return vec_res;
} }
framework::proto::VarType::Type GetInputVarType(
const std::string& name) const override {
auto it = var_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name));
return framework::ToVarType(it->second[0]->Var().Type());
}
std::vector<framework::proto::VarType::Type> GetInputsVarType( std::vector<framework::proto::VarType::Type> GetInputsVarType(
const std::string& name) const override { const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res; std::vector<framework::proto::VarType::Type> vec_res;
......
...@@ -89,6 +89,12 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { ...@@ -89,6 +89,12 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
BOOST_GET_CONST(int, ctx->GetAttr("out_dtype"))); BOOST_GET_CONST(int, ctx->GetAttr("out_dtype")));
if (data_type >= 0) { if (data_type >= 0) {
ctx->SetOutputDataType("Out", data_type); ctx->SetOutputDataType("Out", data_type);
} else {
auto x_type = ctx->GetInputDataType("X");
if (x_type == framework::proto::VarType::BOOL ||
x_type == framework::proto::VarType::INT32) {
ctx->SetOutputDataType("Out", framework::proto::VarType::INT64);
}
} }
} }
}; };
......
...@@ -63,6 +63,12 @@ bool ProtoArgumentMappingContext::IsDenseTensorInput( ...@@ -63,6 +63,12 @@ bool ProtoArgumentMappingContext::IsDenseTensorInput(
const std::string& name) const { const std::string& name) const {
return true; return true;
} }
bool ProtoArgumentMappingContext::IsDenseTensorInputs(
const std::string& name) const {
return true;
}
bool ProtoArgumentMappingContext::IsSelectedRowsInput( bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const { const std::string& name) const {
return false; return false;
......
...@@ -41,6 +41,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { ...@@ -41,6 +41,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
size_t OutputSize(const std::string& name) const override; size_t OutputSize(const std::string& name) const override;
bool IsDenseTensorInput(const std::string& name) const override; bool IsDenseTensorInput(const std::string& name) const override;
bool IsDenseTensorInputs(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override; bool IsSelectedRowsInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override; bool IsDenseTensorVectorInput(const std::string& name) const override;
......
...@@ -91,6 +91,7 @@ class ArgumentMappingContext { ...@@ -91,6 +91,7 @@ class ArgumentMappingContext {
virtual size_t OutputSize(const std::string& name) const = 0; virtual size_t OutputSize(const std::string& name) const = 0;
virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsDenseTensorInputs(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray // For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0; virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
...@@ -2260,8 +2260,7 @@ void SumRawInferMeta(const MetaTensor& x, ...@@ -2260,8 +2260,7 @@ void SumRawInferMeta(const MetaTensor& x,
if (dtype != DataType::UNDEFINED) { if (dtype != DataType::UNDEFINED) {
out_dtype = dtype; out_dtype = dtype;
} else { } else {
if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32 || if (x.dtype() == DataType::BOOL || x.dtype() == DataType::INT32) {
x.dtype() == DataType::INT64) {
out_dtype = DataType::INT64; out_dtype = DataType::INT64;
} else { } else {
out_dtype = x.dtype(); out_dtype = x.dtype();
......
...@@ -29,6 +29,9 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -29,6 +29,9 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>( phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
} }
......
...@@ -27,6 +27,9 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -27,6 +27,9 @@ void SumRawKernel(const Context& dev_ctx,
bool reduce_all, bool reduce_all,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
if (out_dtype == DataType::UNDEFINED && out->dtype() != x.dtype()) {
out_dtype = out->dtype();
}
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>( phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace phi { namespace phi {
KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SumOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) { if (ctx.IsDenseTensorInputs("X")) {
return KernelSignature("add_n", {"X"}, {}, {"Out"}); return KernelSignature("add_n", {"X"}, {}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
......
...@@ -68,6 +68,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -68,6 +68,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return dense_tensor_inputs.count(name) > 0; return dense_tensor_inputs.count(name) > 0;
} }
bool IsDenseTensorInputs(const std::string& name) const override {
return dense_tensor_inputs.count(name) > 0;
}
bool IsSelectedRowsInput(const std::string& name) const override { bool IsSelectedRowsInput(const std::string& name) const override {
return selected_rows_inputs.count(name) > 0; return selected_rows_inputs.count(name) > 0;
} }
......
...@@ -1132,15 +1132,10 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1132,15 +1132,10 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else: else:
reduce_all_flag = False reduce_all_flag = False
def get_dtype(x, dtype): dtype_flag = False
if dtype is not None: if dtype is not None:
return (True, dtype) dtype_flag = True
src_type = convert_dtype(x.dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if src_type in ['bool','int32', 'int64']:
return (True, 'int64')
return (False, src_type)
dtype_flag, dtype = get_dtype(x, dtype)
if in_dygraph_mode(): if in_dygraph_mode():
if reduce_all_flag: if reduce_all_flag:
...@@ -1148,17 +1143,14 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1148,17 +1143,14 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
else: else:
axis = axis if axis != None and axis != [] else [0] axis = axis if axis != None and axis != [] else [0]
out_dtype = convert_np_dtype_to_dtype_(dtype) return _C_ops.final_state_sum(x, axis, dtype, keepdim)
out = _C_ops.final_state_sum(x, axis, out_dtype, keepdim)
return out
if _in_legacy_dygraph(): if _in_legacy_dygraph():
axis = axis if axis != None and axis != [] else [0] axis = axis if axis != None and axis != [] else [0]
if dtype_flag: if dtype_flag:
return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim, return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all_flag, 'in_dtype', 'reduce_all', reduce_all_flag, 'in_dtype',
x.dtype, 'out_dtype', x.dtype, 'out_dtype', dtype)
convert_np_dtype_to_dtype_(dtype))
else: else:
return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim, return _C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim,
'reduce_all', reduce_all_flag) 'reduce_all', reduce_all_flag)
...@@ -1172,7 +1164,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1172,7 +1164,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
if dtype_flag: if dtype_flag:
attrs.update({ attrs.update({
'in_dtype': x.dtype, 'in_dtype': x.dtype,
'out_dtype': convert_np_dtype_to_dtype_(dtype) 'out_dtype': dtype
}) })
check_variable_and_dtype( check_variable_and_dtype(
...@@ -1186,7 +1178,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): ...@@ -1186,7 +1178,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
if dtype_flag: if dtype_flag:
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=convert_np_dtype_to_dtype_(dtype)) dtype=dtype)
else: else:
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册