未验证 提交 d7be46b3 编写于 作者: Z zhangyikun02 提交者: GitHub

add implement of resnet_basic_block op for XPU2, test=kunlun (#44143)

上级 337bb47b
......@@ -258,24 +258,25 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel {
class ResNetBasicBlockOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
// has_shortcut = True: X else: X
// / /
// | | | |
// CONV1 | CONV1 |
// | | | |
// BN1 | BN1 |
// | | | |
// RELU1 | RELU1 |
// | | | |
// CONV2 CONV3 CONV2 |
// | | | |
// BN2 BN3 BN2 |
// \ / \ /
// ADD ADD
// | |
// RELU RELU
// | |
// Y Y
// has_shortcut = True: else:
// X X
// / /
// | | | |
// CONV1 | CONV1 |
// | | | |
// BN1 | BN1 |
// | | | |
// RELU1 | RELU1 |
// | | | |
// CONV2 CONV3 CONV2 |
// | | | |
// BN2 BN3 BN2 |
// \ / \ /
// ADD ADD
// | |
// RELU RELU
// | |
// Y Y
AddInput("X", "Input tensor of conv 1");
AddInput("Filter1", "Filter tensor of conv 1");
AddInput("Scale1", "Scale tensor of bn 1");
......
......@@ -505,6 +505,14 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sequence_conv_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// Fused op
{"resnet_basic_block_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"resnet_basic_block",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
};
return s_xpu2_kernels;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册