未验证 提交 d7be46b3 编写于 作者: Z zhangyikun02 提交者: GitHub

add implement of resnet_basic_block op for XPU2, test=kunlun (#44143)

上级 337bb47b
...@@ -258,7 +258,8 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { ...@@ -258,7 +258,8 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel {
class ResNetBasicBlockOpMaker : public framework::OpProtoAndCheckerMaker { class ResNetBasicBlockOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
// has_shortcut = True: X else: X // has_shortcut = True: else:
// X X
// / / // / /
// | | | | // | | | |
// CONV1 | CONV1 | // CONV1 | CONV1 |
......
...@@ -505,6 +505,14 @@ XPUOpMap& get_kl2_ops() { ...@@ -505,6 +505,14 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sequence_conv_grad", {"sequence_conv_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// Fused op
{"resnet_basic_block_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"resnet_basic_block",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
}; };
return s_xpu2_kernels; return s_xpu2_kernels;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册