未验证 提交 22342d51 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add sin,cos,exp primitive operators (#44345)

上级 0d51fcf1
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class CosPrimOp : public framework::OperatorBase {
public:
CosPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator cos_p should not be excuted directly"));
}
};
class CosPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cos_p op.");
AddOutput("Y", "(Tensor), The output tensor of cos_p op.");
AddComment(R"DOC(Autograd primitive cos_p operator.)DOC");
}
};
class CosPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class CosPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(cos_p,
paddle::operators::CosPrimOp,
paddle::operators::CosPrimOpMaker,
paddle::operators::CosPrimOpShapeInference,
paddle::operators::CosPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class ExpPrimOp : public framework::OperatorBase {
public:
ExpPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator exp_p should not be excuted directly"));
}
};
class ExpPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of exp_p op.");
AddOutput("Y", "(Tensor), The output tensor of exp_p op.");
AddComment(R"DOC(Autograd primitive exp_p operator.)DOC");
}
};
class ExpPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class ExpPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(exp_p,
paddle::operators::ExpPrimOp,
paddle::operators::ExpPrimOpMaker,
paddle::operators::ExpPrimOpShapeInference,
paddle::operators::ExpPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class SinPrimOp : public framework::OperatorBase {
public:
SinPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator sin_p should not be excuted directly"));
}
};
class SinPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of sin_p op.");
AddOutput("Y", "(Tensor), The output tensor of sin_p op.");
AddComment(R"DOC(Autograd primitive sin_p operator.)DOC");
}
};
class SinPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class SinPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(sin_p,
paddle::operators::SinPrimOp,
paddle::operators::SinPrimOpMaker,
paddle::operators::SinPrimOpShapeInference,
paddle::operators::SinPrimOpVarTypeInference);
......@@ -273,6 +273,97 @@ class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestSinPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'sin_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'sin_p',
# jvp op:
'mul_p',
'cos_p',
# transpose op:
]
class TestCosPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'cos_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'cos_p',
# jvp op:
'mul_p',
'sin_p',
'fill_constant_p',
'sub_p'
# transpose op:
]
class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'exp_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'exp_p',
# jvp op:
'mul_p',
# transpose op:
]
class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -148,6 +148,66 @@ class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestSinOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sin'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['sin', 'sin_p']
self.out_map = {0: self.output['Out']}
class TestCosOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'cos'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['cos', 'cos_p']
self.out_map = {0: self.output['Out']}
class TestExpOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'exp'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['exp', 'exp_p']
self.out_map = {0: self.output['Out']}
class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......
......@@ -164,6 +164,66 @@ class TestTanhPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestSinPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'sin_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['sin_p', 'sin']
self.out_map = {self.output['Y']: 0}
class TestCosPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'cos_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['cos_p', 'cos']
self.out_map = {self.output['Y']: 0}
class TestExpPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'exp_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['exp_p', 'exp']
self.out_map = {self.output['Y']: 0}
class TestReshapePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......
......@@ -17,7 +17,6 @@ import unittest
import numpy as np
import paddle
from paddle.incubate.autograd import primapi
import config
import utils
......@@ -135,9 +134,8 @@ class TestWithoutProgramGuard(unittest.TestCase):
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'),
(('matmul', paddle.matmul,
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), (
('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
......@@ -147,7 +145,8 @@ class TestWithoutProgramGuard(unittest.TestCase):
(np.random.rand(5, 5), ), None, 'float64'),
('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64')))
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
))
class TestForwardGrad(unittest.TestCase):
@classmethod
......@@ -219,7 +218,8 @@ class TestForwardGrad(unittest.TestCase):
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = primapi.forward_grad(ys, static_xs, static_v)
ys_grad = paddle.incubate.autograd.forward_grad(
ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
......@@ -229,15 +229,144 @@ class TestForwardGrad(unittest.TestCase):
def test_illegal_param(self):
paddle.incubate.autograd.enable_prim()
with self.assertRaises(TypeError):
primapi.forward_grad(1, paddle.static.data('inputs', shape=[1]))
paddle.incubate.autograd.forward_grad(
1, paddle.static.data('inputs', shape=[1]))
with self.assertRaises(TypeError):
primapi.forward_grad(paddle.static.data('targets', shape=[1]), 1)
paddle.incubate.autograd.forward_grad(
paddle.static.data('targets', shape=[1]), 1)
paddle.incubate.autograd.disable_prim()
@utils.place(config.DEVICES)
@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), (
('matmul', paddle.matmul,
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
(np.random.rand(5, 5), ), None, 'float64'),
('input_gradients_not_none', paddle.matmul,
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), ), 'float64'),
('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'),
('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'),
('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
def tearDown(self):
paddle.incubate.autograd.disable_prim()
paddle.disable_static()
@classmethod
def setUpClass(cls):
cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs)
cls._rtol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("rtol")
cls._atol = config.TOLERANCE.get(str(
cls.dtype)).get("first_order_grad").get("atol")
def test_grad(self):
def expected():
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
_, ys_grad = paddle.incubate.autograd.vjp(
self.fun, static_xs, static_v)
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
return out
def actual():
paddle.incubate.autograd.enable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v)
paddle.incubate.autograd.prim2orig(mp.block(0))
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.disable_prim()
return out
actual = actual()
expected = expected()
self.assertEqual(type(actual), type(expected))
for i, j in zip(actual, expected):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
def test_illegal_param(self):
paddle.incubate.autograd.enable_prim()
with self.assertRaises(TypeError):
paddle.incubate.autograd.grad(
1, paddle.static.data('inputs', shape=[1]))
with self.assertRaises(TypeError):
paddle.incubate.autograd.grad(
paddle.static.data('targets', shape=[1]), 1)
paddle.incubate.autograd.disable_prim()
def test_disable_prim(self):
def expected():
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v)
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
return out
def actual():
paddle.incubate.autograd.disable_prim()
sp = paddle.static.Program()
mp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
feed, static_xs, static_v = utils.gen_static_data_and_feed(
self.xs, self.v, stop_gradient=False)
ys = self.fun(*static_xs) if isinstance(
static_xs, typing.Sequence) else self.fun(static_xs)
ys_grad = paddle.static.gradients(ys, static_xs, static_v)
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(mp, feed=feed, fetch_list=ys_grad)
paddle.incubate.autograd.enable_prim()
return out
actual = actual()
expected = expected()
self.assertEqual(type(actual), type(expected))
for i, j in zip(actual, expected):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
class TestGradWithHigherOrder(unittest.TestCase):
def setUp(self):
paddle.enable_static()
paddle.incubate.autograd.enable_prim()
......@@ -346,44 +475,6 @@ class TestGrad(unittest.TestCase):
np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5)
paddle.incubate.autograd.disable_prim()
def test_disable_prim(self):
def actual(x: np.array):
paddle.incubate.autograd.disable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype)
var_x.stop_gradient = False
y = paddle.tanh(var_x)
y_grad = paddle.incubate.autograd.grad(y, var_x)
y_second_grad = paddle.incubate.autograd.grad(y_grad, var_x)
exe = paddle.static.Executor()
exe.run(startup)
return exe.run(main,
feed={'x': x},
fetch_list=[y_grad, y_second_grad])
def expect(x: np.array):
paddle.incubate.autograd.disable_prim()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype)
var_x.stop_gradient = False
y = paddle.tanh(var_x)
y_grad = paddle.static.gradients(y, var_x)
y_second_grad = paddle.static.gradients(y_grad, var_x)
exe = paddle.static.Executor()
exe.run(startup)
return exe.run(main,
feed={'x': x},
fetch_list=[y_grad, y_second_grad])
x = np.random.randn(100, 200)
for i, j in zip(actual(x), expect(x)):
np.testing.assert_allclose(i, j)
if __name__ == '__main__':
unittest.main()
......@@ -11,154 +11,128 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import uuid
import numpy as np
import paddle
from paddle.incubate.autograd.primops import (neg, set_value, add, sub, mul,
div, sqrt, tanh, reshape,
broadcast, transpose, split,
concat, reduce, matmul,
slice_select, slice_assign,
gather, scatter_add, fill_const)
from paddle.incubate.autograd.primx import Transform, topo_path, orig2prim, prim2orig
from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled
class TestPyPrimOps(unittest.TestCase):
""" Test Python wrappers of primitive ops. """
def setUp(self):
from numpy.random import randint, randn
from paddle.incubate.autograd import primops, primx
from paddle.incubate.autograd import utils as prim_utils
import config
import utils
paddle.enable_static()
@utils.place(config.DEVICES)
@utils.parameterize(
(utils.TEST_CASE_NAME, 'op', 'args', 'kwargs', 'expected_shape',
'expected_dtype'),
(
('add', primops.add, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('mul', primops.mul, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('div', primops.div, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('sqrt', primops.sqrt, randn(2, 3), {}, (2, 3), 'float64'),
('tanh', primops.tanh, randn(2, 3), {}, (2, 3), 'float64'),
('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'),
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
}, (3, 2), 'float64'),
('broadcast', primops.broadcast, randn(2), {
'shape': (3, 2)
}, (3, 2), 'float64'),
('transpose', primops.transpose, randn(2, 3), {
'axis': (1, 0)
}, (3, 2), 'float64'),
('concat_axis0', primops.concat, ((randn(2, 3), randn(2, 3)), ), {
'axis': 0
}, (4, 3), 'float64'),
('concat_axis1', primops.concat, ((randn(2, 3), randn(2, 3)), ), {
'axis': 1
}, (2, 6), 'float64'),
('reduce_axis1', primops.reduce, randn(2, 3), {
'axis': (1, )
}, (2, ), 'float64'),
('reduce_axis01', primops.reduce, randn(2, 3), {
'axis': (0, 1)
}, (1, ), 'float64'),
('split', primops.split, randn(2, 3), {
'num_or_sections': [1, 2],
'axis': 1
}, ((2, 1), (2, 2)), ('float64', 'float64')),
('matmul', primops.matmul, (randn(2, 3), randn(3, 2)), {},
(2, 2), 'float64'),
('slice_select', primops.slice_select, randn(3, 2), {
'axis': [0],
'starts': [0],
'ends': [2],
'strides': [1]
}, (2, 2), 'float64'),
('slice_assign', primops.slice_assign, (randn(2, 3), randn(2, 2)), {
'axis': [1],
'starts': [1],
'ends': [3],
'strides': [1]
}, (2, 3), 'float64'),
('gather', primops.gather, (randn(3, 2), randint(0, 2,
(5, ), np.int32)), {
'axis': 0
}, (5, 2), 'float64'),
('scatter_add', primops.scatter_add,
(randn(3, 2), randn(5, 2), randint(0, 2, (5, ), np.int32)), {
'axis': 0
}, (3, 2), 'float64'),
('fill_const', primops.fill_const, (), {
'value': 10,
'shape': (3, 2),
'dtype': paddle.float32
}, (3, 2), 'float32'),
('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'),
))
class TestPrimops(unittest.TestCase):
@classmethod
def setUpClass(cls):
paddle.enable_static()
def test_ops(self):
A = np.random.rand(1)
B = np.random.rand(2)
C = np.random.rand(2, 3)
D = np.random.rand(2, 3)
E = np.random.rand(3, 2)
a = paddle.static.data(name='A', shape=A.shape, dtype='float32')
b = paddle.static.data(name='B', shape=B.shape, dtype='float32')
c = paddle.static.data(name='C', shape=C.shape, dtype='float32')
d = paddle.static.data(name='D', shape=D.shape, dtype='float32')
e = paddle.static.data(name='E', shape=E.shape, dtype='float32')
add_1 = add(a, a)
self.assertEqual(add_1.dtype, a.dtype)
self.assertEqual(add_1.shape, a.shape)
add_2 = add(c, d)
self.assertEqual(add_2.dtype, c.dtype)
self.assertEqual(add_2.shape, c.shape)
sub_1 = sub(c, d)
self.assertEqual(sub_1.dtype, c.dtype)
self.assertEqual(sub_1.shape, c.shape)
mul_1 = mul(c, d)
self.assertEqual(mul_1.dtype, c.dtype)
self.assertEqual(mul_1.shape, c.shape)
div_1 = div(c, d)
self.assertEqual(div_1.dtype, c.dtype)
self.assertEqual(div_1.shape, c.shape)
sqrt_1 = sqrt(b)
self.assertEqual(sqrt_1.dtype, b.dtype)
self.assertEqual(sqrt_1.shape, b.shape)
tanh_1 = tanh(d)
self.assertEqual(tanh_1.dtype, d.dtype)
self.assertEqual(tanh_1.shape, d.shape)
reshape_1 = reshape(c, d.shape)
self.assertEqual(reshape_1.dtype, c.dtype)
self.assertEqual(reshape_1.shape, d.shape)
broadcast_1 = broadcast(b, e.shape)
self.assertEqual(broadcast_1.dtype, b.dtype)
self.assertEqual(broadcast_1.shape, e.shape)
transpose_1 = transpose(c, axis=[1, 0])
self.assertEqual(transpose_1.dtype, c.dtype)
self.assertEqual(transpose_1.shape, e.shape)
split_1_0, split_1_1 = split(c, num_or_sections=[1, 2], axis=1)
self.assertEqual(split_1_0.dtype, c.dtype)
self.assertEqual(split_1_0.shape, (2, 1))
self.assertEqual(split_1_1.shape, (2, 2))
concat_1 = concat([c, d], axis=0)
self.assertEqual(concat_1.dtype, c.dtype)
self.assertEqual(concat_1.shape, (4, 3))
reduce_1 = reduce(d, axis=[1])
self.assertEqual(reduce_1.dtype, d.dtype)
self.assertEqual(reduce_1.shape, (2, ))
reduce_2 = reduce(c, axis=[0, 1])
self.assertEqual(reduce_2.dtype, c.dtype)
self.assertEqual(reduce_2.shape, (1, ))
# TODO: reduce + keepdim
matmul_1 = matmul(d, e)
self.assertEqual(matmul_1.dtype, d.dtype)
self.assertEqual(matmul_1.shape, (2, 2))
slice_select_1 = slice_select(e,
axis=[0],
starts=[0],
ends=[2],
strides=[1])
self.assertEqual(slice_select_1.dtype, e.dtype)
self.assertEqual(slice_select_1.shape, (2, 2))
slice_select_2 = slice_select(d,
axis=[0, 1],
starts=[0, 1],
ends=[2, 3],
strides=[1, 2])
self.assertEqual(slice_select_2.dtype, d.dtype)
self.assertEqual(slice_select_2.shape, (2, 1))
y = broadcast(b, [2, 2])
slice_assign_1 = slice_assign(d,
y,
axis=[1],
starts=[1],
ends=[3],
strides=[1])
self.assertEqual(slice_assign_1.dtype, d.dtype)
self.assertEqual(slice_assign_1.shape, d.shape)
index = paddle.static.data('index', shape=[5], dtype='int32')
gather_1 = gather(e, index, axis=0)
self.assertEqual(gather_1.dtype, e.dtype)
self.assertEqual(gather_1.shape, (5, 2))
y = paddle.rand([5, 2], dtype='float32')
scatter_add_1 = scatter_add(e, y, index, axis=0)
self.assertEqual(scatter_add_1.dtype, e.dtype)
self.assertEqual(scatter_add_1.shape, e.shape)
fill_const_1 = fill_const(value=10, shape=a.shape, dtype=a.dtype)
self.assertEqual(fill_const_1.shape, a.shape)
self.assertEqual(fill_const_1.dtype, a.dtype)
neg_1 = neg(x=b)
self.assertEqual(neg_1.shape, b.shape)
self.assertEqual(neg_1.dtype, b.dtype)
set_value_1 = set_value(d,
a,
axis=[1],
starts=[1],
ends=[3],
strides=[1],
out=d)
self.assertEqual(set_value_1.shape, d.shape)
self.assertEqual(set_value_1.dtype, d.dtype)
@classmethod
def tearDownClass(cls):
paddle.disable_static()
def test_prim_ops(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
args = self._as_tuple(self.args)
args = self.arr2var(args)
results = self.op(*args, **self.kwargs)
results = self._as_tuple(results)
expected_shape = self._as_tuple(self.expected_shape)
expected_dtype = self._as_tuple(self.expected_dtype)
for r, shape, dtype in zip(results, expected_shape, expected_dtype):
self.assertEqual(r.shape, shape)
self.assertEqual(str(r.dtype).split('.')[1], dtype)
def arr2var(self, arr):
"""convert numpy ndarray to paddle Variable recursively."""
return [
paddle.static.data(f'x{uuid.uuid4()}', v.shape, v.dtype)
if isinstance(v, np.ndarray) else self.arr2var(v) for v in arr
]
def _as_tuple(self, input):
if isinstance(input, (tuple, list)) and len(input) == 0:
return input
if not isinstance(input, (tuple, list)) or all(
isinstance(i, int) for i in input):
return (input, )
return input
if __name__ == '__main__':
......
......@@ -122,6 +122,21 @@ def tanh(x, out=None):
return _simple_unop(LayerHelper('tanh_p', **locals()))
@REGISTER_FN('sin_p', 'X', 'Y')
def sin(x, out=None):
return _simple_unop(LayerHelper('sin_p', **locals()))
@REGISTER_FN('cos_p', 'X', 'Y')
def cos(x, out=None):
return _simple_unop(LayerHelper('cos_p', **locals()))
@REGISTER_FN('exp_p', 'X', 'Y')
def exp(x, out=None):
return _simple_unop(LayerHelper('exp_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals()))
......
......@@ -15,13 +15,15 @@ import typing
import paddle
from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE
from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp,
lookup_transpose, op_position_inputs, op_position_output)
from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select,
slice_assign, gather, scatter_add, fill_const, set_value)
from .utils import get_input_var_list, get_output_var_list, INT_DTYPE_2_STRING
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
op_position_inputs, op_position_output)
from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list
def _orig2prim(op, *args):
......@@ -149,6 +151,21 @@ def tanh_orig2prim(op, x):
return tanh(x)
@REGISTER_ORIG2PRIM('sin')
def sin_orig2prim(op, x):
return sin(x)
@REGISTER_ORIG2PRIM('cos')
def cos_orig2prim(op, x):
return cos(x)
@REGISTER_ORIG2PRIM('exp')
def exp_orig2prim(op, x):
return exp(x)
@REGISTER_ORIG2PRIM('fill_zeros_like')
def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
......@@ -301,6 +318,21 @@ def tanh_prim2orig(op, x):
return paddle.tanh(x)
@REGISTER_PRIM2ORIG('sin_p')
def sin_prim2orig(op, x):
return paddle.sin(x)
@REGISTER_PRIM2ORIG('cos_p')
def cos_prim2orig(op, x):
return paddle.cos(x)
@REGISTER_PRIM2ORIG('exp_p')
def exp_prim2orig(op, x):
return paddle.exp(x)
@REGISTER_PRIM2ORIG('reshape_p')
def reshape_prim2orig(op, x):
return paddle.reshape(x, shape=op.attr('shape'))
......@@ -453,6 +485,30 @@ def tanh_jvp(op, x_dot):
return y_dot
@REGISTER_JVP('sin_p')
def sin_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return mul(x_dot, cos(x))
@REGISTER_JVP('cos_p')
def cos_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return mul(x_dot, neg(sin(x)))
@REGISTER_JVP('exp_p')
def exp_jvp(op, x_dot):
if x_dot is None:
return None
y = op_position_output(op)
return mul(x_dot, y)
@REGISTER_JVP('reshape_p')
def reshape_jvp(op, x_dot):
if x_dot is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册