未验证 提交 5efaaaa3 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add element_mul_add_fuse_pass and elementwise_madd_xpu kernel (#56629)

上级 6dd9a024
......@@ -290,6 +290,8 @@ if(WITH_XPU)
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
fuse elementwise_mul + elementwise_add op to addcmul_xpu op
For example:
graph:
x y
\ /
\ /
elementwise_mul w
\ /
\ /
elementwise_add
|
|
output
------------------------------------------------------
After the pass is applied:
x y w
\ | /
\ | /
addcmul_xpu
|
|
output
*/
struct ElementwiseMulAddFusePass : public PatternBase {
ElementwiseMulAddFusePass(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_mul);
PATTERN_DECL_NODE(elementwise_add);
// declare variable node's name
PATTERN_DECL_NODE(mul_x);
PATTERN_DECL_NODE(mul_y);
PATTERN_DECL_NODE(mul_out);
PATTERN_DECL_NODE(add_w);
PATTERN_DECL_NODE(add_out);
};
ElementwiseMulAddFusePass::ElementwiseMulAddFusePass(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto elementwise_mul =
pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul");
auto elementwise_add =
pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add");
auto mul_x = pattern->NewNode(mul_x_repr())
->AsInput()
->assert_is_op_input("elementwise_mul", "X");
auto mul_y = pattern->NewNode(mul_y_repr())
->AsInput()
->assert_is_op_input("elementwise_mul", "Y");
auto mul_out = pattern->NewNode(mul_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add", "X")
->assert_has_n_outputs(1);
elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out});
auto add_w = pattern->NewNode(add_w_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto add_out = pattern->NewNode(add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add->LinksFrom({mul_out, add_w}).LinksTo({add_out});
}
/*
special case for addcmul_xpu op:
graph:
x y
\ /
\ /
elementwise_mul x
\ /
\ /
elementwise_add
|
|
output
------------------------------------------------------
After the pass is applied:
x y
\ /
\ /
addcmul_xpu
|
|
output
*/
struct ElementwiseMulAddFuseXYPattern : public PatternBase {
ElementwiseMulAddFuseXYPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_mul);
PATTERN_DECL_NODE(elementwise_add);
// declare variable node's name
PATTERN_DECL_NODE(mul_x);
PATTERN_DECL_NODE(mul_y);
PATTERN_DECL_NODE(mul_out);
PATTERN_DECL_NODE(add_out);
};
ElementwiseMulAddFuseXYPattern::ElementwiseMulAddFuseXYPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto elementwise_mul =
pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul");
auto elementwise_add =
pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add");
auto mul_x = pattern->NewNode(mul_x_repr())
->AsInput()
->assert_is_op_input("elementwise_mul", "X")
->assert_is_op_input("elementwise_add", "Y");
auto mul_y = pattern->NewNode(mul_y_repr())
->AsInput()
->assert_is_op_input("elementwise_mul", "Y");
auto mul_out = pattern->NewNode(mul_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_mul", "Out")
->assert_is_op_input("elementwise_add", "X");
elementwise_mul->LinksFrom({mul_x, mul_y}).LinksTo({mul_out});
auto add_out = pattern->NewNode(add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add->LinksFrom({mul_out, mul_x}).LinksTo({add_out});
}
} // namespace patterns
class ElementwiseMulAddFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FuseElementwiseMulAdd(ir::Graph* graph) const;
void FuseElementwiseMulAddWithOnlyXY(ir::Graph* graph) const;
const std::string name_scope_{"elementwise_mul_add_fuse_pass"};
};
void ElementwiseMulAddFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FuseElementwiseMulAdd(graph);
FuseElementwiseMulAddWithOnlyXY(graph);
}
void ElementwiseMulAddFusePass::FuseElementwiseMulAdd(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ElementwiseMulAddFusePass pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ElementwiseMulAddFusePass";
// declare operator node's name
GET_IR_NODE(elementwise_mul);
GET_IR_NODE(elementwise_add);
// declare variable node's name
GET_IR_NODE(mul_x);
GET_IR_NODE(mul_y);
GET_IR_NODE(mul_out);
GET_IR_NODE(add_w);
GET_IR_NODE(add_out);
bool flag = true;
auto var_type = mul_x->Var()->GetDataType();
if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) {
flag = false;
}
auto x_shape = mul_x->Var()->GetShape();
auto y_shape = mul_y->Var()->GetShape();
auto w_shape = add_w->Var()->GetShape();
if (x_shape.size() == y_shape.size() && x_shape.size() == w_shape.size()) {
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i] || x_shape[i] != w_shape[i] ||
x_shape[i] == -1) {
flag = false;
}
}
} else {
flag = false;
}
if (flag) {
auto* block = elementwise_mul->Op()->Block();
// delete useless node
std::unordered_set<const Node*> delete_nodes;
// Generate addcmul_xpu op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("addcmul_xpu");
fused_op_desc.SetInput("x", {mul_x->Name()});
fused_op_desc.SetInput("y", {mul_y->Name()});
fused_op_desc.SetInput("w", {add_w->Name()});
fused_op_desc.SetOutput("out", {add_out->Name()});
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(mul_x, fused_op);
IR_NODE_LINK_TO(mul_y, fused_op);
IR_NODE_LINK_TO(add_w, fused_op);
IR_NODE_LINK_TO(fused_op, add_out);
delete_nodes.insert({elementwise_mul, elementwise_add, mul_out});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void ElementwiseMulAddFusePass::FuseElementwiseMulAddWithOnlyXY(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ElementwiseMulAddFuseXYPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ElementwiseMulAddFusePass";
// declare operator node's name
GET_IR_NODE(elementwise_mul);
GET_IR_NODE(elementwise_add);
// declare variable node's name
GET_IR_NODE(mul_x);
GET_IR_NODE(mul_y);
GET_IR_NODE(mul_out);
GET_IR_NODE(add_out);
bool flag = true;
auto var_type = mul_x->Var()->GetDataType();
if (var_type != proto::VarType::FP16 && var_type != proto::VarType::FP32) {
flag = false;
}
auto x_shape = mul_x->Var()->GetShape();
auto y_shape = mul_y->Var()->GetShape();
if (x_shape.size() == y_shape.size()) {
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i] || x_shape[i] == -1) {
flag = false;
}
}
} else {
flag = false;
}
if (flag) {
auto* block = elementwise_mul->Op()->Block();
// delete useless node
std::unordered_set<const Node*> delete_nodes;
// Generate addcmul_xpu op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("addcmul_xpu");
fused_op_desc.SetInput("x", {mul_x->Name()});
fused_op_desc.SetInput("y", {mul_y->Name()});
fused_op_desc.SetInput("w", {mul_x->Name()});
fused_op_desc.SetOutput("out", {add_out->Name()});
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(mul_x, fused_op);
IR_NODE_LINK_TO(mul_y, fused_op);
IR_NODE_LINK_TO(fused_op, add_out);
delete_nodes.insert({elementwise_mul, elementwise_add, mul_out});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(elementwise_mul_add_fuse_pass,
paddle::framework::ir::ElementwiseMulAddFusePass);
REGISTER_PASS_CAPABILITY(elementwise_mul_add_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("elementwise_add", 0)
.LE("elementwise_add", 1)
.GE("elementwise_mul", 0)
.LE("elementwise_mul", 1));
......@@ -552,6 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fast_layernorm_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
"elementwise_mul_add_fuse_pass",
"link_xpu_op_max_pass",
"delete_isolated_node_pass",
// "auto_mixed_precision_pass",
......
......@@ -23,6 +23,15 @@
func : add_layernorm_xpu
data_type : x
- op : addcmul_xpu
args : (Tensor x, Tensor y, Tensor w)
output : Tensor(out)
infer_meta :
func : AddCMulXPUInferMeta
kernel :
func : addcmul_xpu
data_type : x
- op : conv1d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, str padding_algorithm, int dilations, int strides, int groups, int act_type, float act_param)
output : Tensor(out), Tensor(out_max)
......
......@@ -36,6 +36,8 @@ XPUOpMap& get_kl2_ops() {
{"adam_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
{"addcmul_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"arg_max",
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
......@@ -161,6 +163,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT64,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::INT64,
phi::DataType::INT32})},
{"conv2d_grad",
......
......@@ -821,6 +821,15 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void AddCMulXPUInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
......
......@@ -201,6 +201,11 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);
void AddCMulXPUInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& w,
MetaTensor* out);
void FusedScaleBiasReluConvBnstatsInferMeta(
const MetaTensor& x,
const MetaTensor& w,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void AddCMulXPUKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& w,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto* x_data = x.data<T>();
const auto* y_data = y.data<T>();
const auto* w_data = w.data<T>();
auto* out_data = ctx.template Alloc<T>(out);
#ifdef PADDLE_WITH_XPU_PLUGIN
int r = xpu::plugin::fast_addcmul(ctx.x_context(),
reinterpret_cast<const XPUType*>(w_data),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
reinterpret_cast<XPUType*>(out_data),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_addcmul");
#else
int r = xpu::addcmul(ctx.x_context(),
reinterpret_cast<const XPUType*>(w_data),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
reinterpret_cast<XPUType*>(out_data),
1.0f,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "addcmul");
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(addcmul_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::AddCMulXPUKernel,
float,
phi::dtype::float16) {}
......@@ -119,4 +119,6 @@ PD_REGISTER_KERNEL(concat,
double,
phi::dtype::float16,
int64_t,
int) {}
int,
int8_t,
bool) {}
......@@ -114,6 +114,9 @@ DLL_EXPORT int fast_embedding(Context* ctx,
int64_t ym,
int64_t padding_idx,
TID start_index = 0);
template <typename T>
DLL_EXPORT int fast_addcmul(
Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len);
} // namespace plugin
} // namespace api
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename T>
static inline __device__ void primitive_addcmul(T* x, const T* y, int len) {
float32x16_t vx0;
float32x16_t vy0;
float32x16_t vx1;
float32x16_t vy1;
for (int i = 0; i < len; i += 32) {
vload2_lm(x + i, vx0, vx1);
vload2_lm(y + i, vy0, vy1);
vx0 = vvmac_float32x16(vx0, vy0, vx0);
vx1 = vvmac_float32x16(vx1, vy1, vx1);
vstore2_lm(x + i, vx0, vx1);
}
mfence_lm();
}
template <typename T>
__global__ void fast_addcmul(const T* x, const T* y, T* z, int64_t len) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int buf_len = 512 / sizeof(T);
__simd__ float local_x_after_cast[buf_len];
__simd__ float local_y_after_cast[buf_len];
T* local_x = (T*)(local_x_after_cast);
T* local_y = (T*)(local_y_after_cast);
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
primitive_addcmul<T>(local_x, local_y, read_len);
LM2GM_ASYNC(local_x, z + i, read_len * sizeof(T));
mfence_lm();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_ADDCMUL_(DTYPE) \
template __global__ void fast_addcmul<DTYPE>( \
const DTYPE* x, const DTYPE* y, DTYPE* z, int64_t len);
_XPU_DEF__FAST_ADDCMUL_(float);
_XPU_DEF__FAST_ADDCMUL_(float16);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace xpu2 {
namespace plugin {
template <typename T>
__attribute__((global)) void fast_addcmul(const T* x,
const T* y,
T* z,
int64_t len);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int xpu2_wrapper(
Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) {
if (x == w) {
xpu2::plugin::fast_addcmul<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(x, y, z, len);
} else {
return addcmul(ctx, w, x, y, z, 1.0f, len);
}
return SUCCESS;
}
template <typename T>
int fast_addcmul(
Context* ctx, const T* w, const T* x, const T* y, T* z, int64_t len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_mul_add", T);
WRAPPER_DUMP_PARAM4(ctx, w, x, y, z);
WRAPPER_DUMP_PARAM2(ctx, len, ctx->_l3_mgr.get_size());
WRAPPER_DUMP(ctx);
WRAPPER_CHECK_4PTRS(ctx, T, len, w, x, y, z);
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T>(ctx, w, x, y, z, len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_addcmul(
Context*, const float*, const float*, const float*, float*, int64_t);
template int fast_addcmul(Context*,
const float16*,
const float16*,
const float16*,
float16*,
int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestGatherAddTransposePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["addcmul_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=3, max_size=4
)
)
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
mul_op = OpConfig(
"elementwise_mul",
inputs={"X": ["mul_x"], "Y": ["mul_y"]},
outputs={"Out": ["mul_out"]},
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["mul_out"], "Y": ["add_w"]},
outputs={"Out": ["add_out"]},
)
ops = [mul_op, add_op]
program_config = ProgramConfig(
ops=ops,
inputs={
"mul_x": TensorConfig(data_gen=partial(generate_data, x_shape)),
"mul_y": TensorConfig(data_gen=partial(generate_data, x_shape)),
"add_w": TensorConfig(data_gen=partial(generate_data, x_shape)),
},
weights={},
outputs=["add_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["elementwise_mul_add_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册