未验证 提交 4c438d30 编写于 作者: J Jiabin Yang 提交者: GitHub

Support rsqrt_p (#46369)

* support rsqrt_p

* refine code and ut

* add_prim_rsqrt

* fix ut
上级 9a291685
......@@ -37,7 +37,8 @@ set(PRIM_OP_SRCS
max_p_op.cc
erf_p_op.cc
abs_p_op.cc
cast_p_op.cc)
cast_p_op.cc
rsqrt_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class RsqrtPrimOp : public framework::OperatorBase {
public:
RsqrtPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator rsqrt_p should not be excuted directly"));
}
};
class RsqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of rsqrt_p op.");
AddOutput("Y", "(Tensor), The output tensor of rsqrt_p op.");
AddComment(R"DOC(
Autograd primitive rsqrt_p operator.
)DOC");
}
};
class RsqrtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class RsqrtPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(rsqrt_p,
paddle::operators::RsqrtPrimOp,
paddle::operators::RsqrtPrimOpMaker,
paddle::operators::RsqrtPrimOpShapeInference,
paddle::operators::RsqrtPrimOpVarTypeInference);
......@@ -241,6 +241,39 @@ class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestRSqrtPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'rsqrt_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'rsqrt_p',
# jvp op:
'div_p',
'div_p',
'mul_p',
'fill_constant_p',
# 'sqrt_p',
# transpose op:
]
class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -879,5 +879,26 @@ class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'rsqrt'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['rsqrt', 'rsqrt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -690,5 +690,25 @@ class TestCastPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestRsqrtPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'rsqrt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['rsqrt_p', 'rsqrt']
self.out_map = {self.output['Y']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -152,6 +152,7 @@ class TestWithoutProgramGuard(unittest.TestCase):
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(10, 10)), ), None, 'float32'),
('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'),
))
# paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick
......@@ -267,6 +268,7 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), ), 'float64'),
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
# In where op, grad of condition computed by paddle.static.gradients is None,
......
......@@ -48,15 +48,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
A = paddle.tanh(X0)
B = paddle.tanh(X1)
Y = paddle.add(A, B)
C = paddle.rsqrt(B)
Y = paddle.add(A, C)
self.orig_xs = [X0, X1]
self.orig_ys = [
Y,
]
self.orig_ops = ['tanh', 'tanh', 'elementwise_add']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p']
self.orig_ops = ['tanh', 'tanh', 'elementwise_add', 'rsqrt']
self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p', 'rsqrt_p']
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
'fill_constant_p',
......@@ -71,6 +72,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'fill_constant_p',
'mul_p',
'add_p',
'fill_constant_p',
'div_p',
'div_p',
'mul_p',
]
self.transpose_ops = self.orig2prim_ops + [
# call fill_const() in transpose() function
......@@ -84,6 +89,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p',
'sub_p',
'fill_constant_p',
'mul_p',
'div_p',
'div_p',
'fill_constant_p',
# transposed op
'mul_p',
'mul_p'
......@@ -92,13 +101,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant',
'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul',
'elementwise_mul'
'elementwise_mul', 'rsqrt', 'fill_constant', 'elementwise_div',
'elementwise_div', 'elementwise_mul'
]
self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
'fill_constant', 'elementwise_mul', 'elementwise_mul'
'fill_constant', 'elementwise_mul', 'elementwise_mul', 'rsqrt',
'fill_constant', 'elementwise_div', 'elementwise_div',
'elementwise_mul'
]
def test_run(self):
......
......@@ -394,3 +394,8 @@ def cast(x, dtype, out=None):
outputs={'Y': out},
attrs={'dtype': dtype})
return out
@REGISTER_FN('rsqrt_p', 'X', 'Y')
def rsqrt(x, out=None):
return _simple_unop(LayerHelper('rsqrt_p', **locals()))
......@@ -23,7 +23,7 @@ from .primops import (add, broadcast, concat, cos, div, eq, erf, exp,
fill_const, gather, ge, gt, log, matmul, max, mul, ne,
neg, reduce_sum, reshape, scatter_add, select, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose)
transpose, rsqrt)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -252,6 +252,11 @@ def sqrt_orig2prim(op, x):
return sqrt(x)
@REGISTER_ORIG2PRIM('rsqrt')
def rsqrt_orig2prim(op, x):
return rsqrt(x)
@REGISTER_ORIG2PRIM('matmul_v2')
def matmul_v2_orig2prim(op, x, y):
......@@ -456,6 +461,11 @@ def sub_prim2orig(op, x, y):
return paddle.subtract(x, y)
@REGISTER_PRIM2ORIG('rsqrt_p')
def rsqrt_prim2orig(op, x):
return paddle.rsqrt(x)
@REGISTER_PRIM2ORIG('mul_p')
def mul_prim2orig(op, x, y):
return paddle.multiply(x, y)
......@@ -969,6 +979,17 @@ def cast_jvp(op, x_dot):
return primops.cast(x_dot, y.dtype)
@REGISTER_JVP('rsqrt_p')
def rsqrt_jvp(op, x_dot):
if x_dot is None:
return None
y = op_position_output(op)
x = op_position_inputs(op)
c2 = fill_const(value=-2.0, shape=y.shape, dtype=y.dtype)
y_dot = mul(x_dot, div(div(y, x), c2))
return y_dot
## Register transpose rules
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册