未验证 提交 1711407d 编写于 作者: X Xiaoxu Chen 提交者: GitHub

support pow with scalar input, square, cast, var, size operators for deepxde (#46024)

* add reduce_mean,reduce_sum primitive ops
* add ne_p gt_p primitive operators
* add ge_p abs_p primitive oparators
* add cast primitive operators
* add pow,square prim2oirg rules
* add elementwise_div orig2prim rule
上级 267d71a4
......@@ -36,7 +36,8 @@ set(PRIM_OP_SRCS
pow_p_op.cc
max_p_op.cc
erf_p_op.cc
abs_p_op.cc)
abs_p_op.cc
cast_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class CastPrimOp : public framework::OperatorBase {
public:
CastPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator cast_p should not be excuted directly"));
}
};
class CastPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cast_p op.");
AddOutput("Y", "(Tensor), The output tensor of cast_p op.");
AddAttr<int>("dtype", "output data type");
AddComment(R"DOC(Autograd primitive cast_p operator.)DOC");
}
};
class CastPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class CastPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("dtype")));
ctx->SetOutputDataType("Y", out_type);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(cast_p,
paddle::operators::CastPrimOp,
paddle::operators::CastPrimOpMaker,
paddle::operators::CastPrimOpShapeInference,
paddle::operators::CastPrimOpVarTypeInference);
......@@ -433,6 +433,42 @@ class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestCastPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'cast_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'dtype': paddle.float64}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: True
Y_BAR = paddle.static.data(name='Y_BAR', shape=[5, 6], dtype='float')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X}
self.all_ops = [
# prim op:
'cast_p',
# jvp op:
'cast_p',
# transpose op:
'cast_p'
]
class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -110,6 +110,26 @@ class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestElementWiseDivOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_div'
X = paddle.static.data(name='X', shape=[8, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_div', 'div_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -786,5 +806,78 @@ class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestSizeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'size'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'Input': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.int64)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['size', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestCastOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'cast'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'in_dtype': X.dtype, 'out_dtype': paddle.float64}
self.orig2prim_args = (X, )
self.all_ops = ['cast', 'cast_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestPowScalarOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'pow'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'factor': 2.}
self.orig2prim_args = (None, X)
self.all_ops = ['pow', 'pow_p', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'square'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['square', 'pow_p', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -670,5 +670,25 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0}
class TestCastPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'cast_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dtype': paddle.int64}
self.prim2orig_args = (X, )
self.all_ops = ['cast_p', 'cast']
self.out_map = {self.output['Y']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -257,6 +257,8 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('div', paddle.divide,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
......@@ -300,7 +302,21 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(200, 345), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(200, 345)), ), None, 'float32'),
))
('cast_float', lambda x: paddle.cast(x, paddle.float64),
(np.random.rand(10, 20), ), None, 'float32'),
('cast_int', lambda x: paddle.cast(x, paddle.int32),
(np.random.rand(10, 20), ), None, 'float32'),
('square', paddle.square, (np.random.rand(100), ), None, 'float32'),
('pow_scalar', lambda x: paddle.pow(x, 2),
(np.random.rand(20, 30), ), None, 'float32'),
('var', paddle.var, (np.random.rand(200, 324), ), None, 'float32'),
('var_with_axis', lambda x: paddle.var(x, axis=1),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('var_without_unbiased',
lambda x: paddle.var(x, axis=1, unbiased=False),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
class TestGrad(unittest.TestCase):
def setUp(self):
......
......@@ -44,6 +44,9 @@ paddle.enable_static()
('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('cast', primops.cast, randn(2, 3), {
'dtype': paddle.int64
}, (2, 3), 'int64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
}, (3, 2), 'float64'),
......
......@@ -382,3 +382,15 @@ def max(x, y, out=None):
@REGISTER_FN('erf_p', 'X', 'Y')
def erf(x, out=None):
return _simple_unop(LayerHelper('erf_p', **locals()))
@REGISTER_FN('cast_p', 'X', 'Y')
def cast(x, dtype, out=None):
helper = LayerHelper('cast_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type=helper.layer_type,
inputs={'X': x},
outputs={'Y': out},
attrs={'dtype': dtype})
return out
......@@ -80,7 +80,7 @@ def op_position_inputs(op):
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_inputs().'
assert args is not None, f'args of {op.type} should not be None in op_position_inputs().'
*input_names, _ = args
inputs = []
......
......@@ -158,6 +158,13 @@ def elementwise_mul_orig2prim(op, x, y):
return z
@REGISTER_ORIG2PRIM('elementwise_div')
def elementwise_div_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return primops.div(x, y)
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op, x):
return tanh(x)
......@@ -322,6 +329,11 @@ def p_norm_orig2prim(op, x):
raise RuntimeError('Only support lower l2/l1 norm currently')
@REGISTER_ORIG2PRIM('cast')
def cast_orig2prim(op, x):
return primops.cast(x, paddle.dtype(op.attr('out_dtype')))
# TODO: support broadcast
@REGISTER_ORIG2PRIM('where')
def select_orig2prim(op, condition, x, y):
......@@ -356,15 +368,27 @@ def ge_orig2prim(op, x, y):
return ge(x, y)
# paddle.pow API use "elementwise_pow" operator when y is a Tensor.
@REGISTER_ORIG2PRIM('elementwise_pow')
def elementwise_pow_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
z = primops.pow(x, y)
return z
# paddle.pow API use "pow" operator when y is a scalar.
@REGISTER_ORIG2PRIM('pow')
def pow_orig2prim(op, x, y):
# x is factorTensor defined in paddle phi op. Currently it is None.
return primops.pow(y, fill_const(op.attr('factor'), y.shape, y.dtype))
@REGISTER_ORIG2PRIM('square')
def square_orig2prim(op, x):
return primops.pow(x, fill_const(2., x.shape, x.dtype))
@REGISTER_ORIG2PRIM('elementwise_max')
def elementwise_max_orig2prim(op, x, y):
if x.shape != y.shape:
......@@ -415,6 +439,12 @@ def reduce_mean_orig2prim(op, x):
return div(sum, norm)
@REGISTER_ORIG2PRIM('size')
def size_orig2prim(op, x):
return fill_const(functools.reduce(operator.mul, x.shape), (1, ),
paddle.int64)
## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
......@@ -592,6 +622,11 @@ def max_prim2orig(op, x, y):
return paddle.maximum(x, y)
@REGISTER_PRIM2ORIG('cast_p')
def cast_prim2orig(op, x):
return paddle.cast(x, paddle.dtype(op.attr('dtype')))
## Register linearize rules
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
......@@ -928,6 +963,12 @@ def max_jvp(op, x_dot, y_dot):
return select(eq(y, z), y_dot, x_dot)
@REGISTER_JVP('cast_p')
def cast_jvp(op, x_dot):
y = op_position_output(op)
return primops.cast(x_dot, y.dtype)
## Register transpose rules
......@@ -1132,3 +1173,9 @@ def select_transpose(op, check_dot, z_bar):
y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None
return cond_bar, x_bar, y_bar
@REGISTER_TRANSPOSE('cast_p')
def cast_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
return primops.cast(y_bar, x.dtype)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册