未验证 提交 62176f63 编写于 作者: X Xiaoxu Chen 提交者: GitHub

add mean,sum,ge,gt,ne,abs primitive operators for supporting deepxde (#45888)

* add reduce_mean,reduce_sum primitive ops

* add ne_p gt_p primitive operators

* add ge_p abs_p primitive oparators
上级 e43e4825
......@@ -8,7 +8,7 @@ register_operators()
set(PRIM_OP_SRCS
reshape_p_op.cc
broadcast_p_op.cc
reduce_p_op.cc
reduce_sum_p_op.cc
transpose_p_op.cc
split_p_op.cc
concat_p_op.cc
......@@ -30,9 +30,13 @@ set(PRIM_OP_SRCS
log_p_op.cc
select_p_op.cc
eq_p_op.cc
gt_p_op.cc
ge_p_op.cc
ne_p_op.cc
pow_p_op.cc
max_p_op.cc
erf_p_op.cc)
erf_p_op.cc
abs_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class AbsPrimOp : public framework::OperatorBase {
public:
AbsPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator abs_p should not be excuted directly"));
}
};
class AbsPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of abs_p op.");
AddOutput("Y", "(Tensor), The output tensor of abs_p op.");
AddComment(R"DOC(Autograd primitive abs_p operator.)DOC");
}
};
class AbsPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class AbsPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(abs_p,
paddle::operators::AbsPrimOp,
paddle::operators::AbsPrimOpMaker,
paddle::operators::AbsPrimOpShapeInference,
paddle::operators::AbsPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class GePrimOp : public framework::OperatorBase {
public:
GePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator ge_p should not be excuted directly"));
}
};
class GePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of ge_p op.");
AddInput("Y", "(Tensor), The input tensor of ge_p op.");
AddOutput("Z", "(Tensor), The output tensor of ge_p op.");
AddComment(R"DOC(
Autograd primitive ge_p operator.
)DOC");
}
};
class GePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class GePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(ge_p,
paddle::operators::GePrimOp,
paddle::operators::GePrimOpMaker,
paddle::operators::GePrimOpShapeInference,
paddle::operators::GePrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class GtPrimOp : public framework::OperatorBase {
public:
GtPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator gt_p should not be excuted directly"));
}
};
class GtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of gt_p op.");
AddInput("Y", "(Tensor), The input tensor of gt_p op.");
AddOutput("Z", "(Tensor), The output tensor of gt_p op.");
AddComment(R"DOC(
Autograd primitive gt_p operator.
)DOC");
}
};
class GtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class GtPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(gt_p,
paddle::operators::GtPrimOp,
paddle::operators::GtPrimOpMaker,
paddle::operators::GtPrimOpShapeInference,
paddle::operators::GtPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class NePrimOp : public framework::OperatorBase {
public:
NePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator ne_p should not be excuted directly"));
}
};
class NePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of ne_p op.");
AddInput("Y", "(Tensor), The input tensor of ne_p op.");
AddOutput("Z", "(Tensor), The output tensor of ne_p op.");
AddComment(R"DOC(
Autograd primitive ne_p operator.
)DOC");
}
};
class NePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class NePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(ne_p,
paddle::operators::NePrimOp,
paddle::operators::NePrimOpMaker,
paddle::operators::NePrimOpShapeInference,
paddle::operators::NePrimOpVarTypeInference);
......@@ -18,7 +18,7 @@
USE_OP_ITSELF(reshape_p);
USE_OP_ITSELF(broadcast_p);
USE_OP_ITSELF(reduce_p);
USE_OP_ITSELF(reduce_sum_p);
USE_OP_ITSELF(transpose_p);
USE_OP_ITSELF(split_p);
USE_OP_ITSELF(concat_p);
......@@ -130,7 +130,7 @@ TEST(PrimOp, broadcast_p) {
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, reduce_p) {
TEST(PrimOp, reduce_sum_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
......@@ -141,7 +141,7 @@ TEST(PrimOp, reduce_p) {
NewVar(block, x0, shape);
AppendOp(block,
"reduce_p",
"reduce_sum_p",
{{"X", {x0}}},
{{"Y", {x1}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", false}});
......@@ -151,7 +151,7 @@ TEST(PrimOp, reduce_p) {
ASSERT_EQ(shapes.size(), 1UL);
ASSERT_EQ(shapes[0], 4L);
AppendOp(block,
"reduce_p",
"reduce_sum_p",
{{"X", {x0}}},
{{"Y", {x2}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", true}});
......
......@@ -24,9 +24,9 @@ class VarDesc;
namespace paddle {
namespace operators {
class ReducePrimOp : public framework::OperatorBase {
class ReduceSumPrimOp : public framework::OperatorBase {
public:
ReducePrimOp(const std::string &type,
ReduceSumPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
......@@ -34,15 +34,15 @@ class ReducePrimOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator reduce_p should not be excuted directly"));
"Prim operator reduce_sum_p should not be excuted directly"));
}
};
class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker {
class ReduceSumPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of reduce_p op.");
AddOutput("Y", "(Tensor), The output tensor of reduce_p op.");
AddInput("X", "(Tensor), The input tensor of reduce_sum_p op.");
AddOutput("Y", "(Tensor), The output tensor of reduce_sum_p op.");
AddAttr<std::vector<int64_t>>(
"axis",
"(std::vector<int64_t>) The axis along which to reduce on. Must be in "
......@@ -53,12 +53,12 @@ class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, retain the reduced axis with length 1.")
.SetDefault(false);
AddComment(R"DOC(
Autograd primitive reduce_p operator.
Autograd primitive reduce_sum_p operator.
)DOC");
}
};
class ReducePrimOpShapeInference : public framework::InferShapeBase {
class ReduceSumPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
......@@ -87,7 +87,7 @@ class ReducePrimOpShapeInference : public framework::InferShapeBase {
}
};
class ReducePrimOpVarTypeInference
class ReduceSumPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
......@@ -101,8 +101,8 @@ class ReducePrimOpVarTypeInference
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(reduce_p,
paddle::operators::ReducePrimOp,
paddle::operators::ReducePrimOpMaker,
paddle::operators::ReducePrimOpShapeInference,
paddle::operators::ReducePrimOpVarTypeInference);
REGISTER_OPERATOR(reduce_sum_p,
paddle::operators::ReduceSumPrimOp,
paddle::operators::ReduceSumPrimOpMaker,
paddle::operators::ReduceSumPrimOpShapeInference,
paddle::operators::ReduceSumPrimOpVarTypeInference);
......@@ -32,4 +32,4 @@ from . import dist_pnorm
from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_p
from . import dist_reduce_sum_p
......@@ -33,21 +33,21 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedReducePrimtive(DistributedOperatorImplContainer):
class DistributedReduceSumPrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type)
super(DistributedReduceSumPrimtive, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedReducePrimtive("reduce_p"))
DistributedReduceSumPrimtive("reduce_sum_p"))
# Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
# Batch Dimension ReduceSum Primitive
class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name)
super(DistributedReduceSumPrimtiveImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -149,4 +149,5 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
register_distributed_operator_impl(
"reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p"))
"reduce_sum_p",
DistributedReduceSumPrimtiveImpl0("batch_dimension_reduce_sum_p"))
......@@ -78,7 +78,7 @@ class TestPrimDistOp(unittest.TestCase):
outputs={'Z': self.w_grad},
attrs=self.attrs)
op = self.layer_help.append_op(type="reduce_p",
op = self.layer_help.append_op(type="reduce_sum_p",
inputs={'X': self.tmp2},
outputs={'Y': self.batch_reduced},
attrs={"axis": [0]})
......
......@@ -400,6 +400,39 @@ class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'abs_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'abs_p',
# jvp op:
'select_p',
'ge_p',
'fill_constant_p',
'fill_constant_p',
'sub_p',
# transpose op:
]
class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......@@ -503,7 +536,7 @@ class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose):
# jvp op:
'broadcast_p',
# transpose op:
'reduce_p',
'reduce_sum_p',
'reshape_p'
]
......@@ -650,11 +683,11 @@ class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose):
class TestReduceSumPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'reduce_p'
self.op_type = 'reduce_sum_p'
X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64')
self.prim_input = {'X': X}
self.prim_output = {
......@@ -682,9 +715,9 @@ class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose):
self.all_ops = [
# prim op:
'reduce_p',
'reduce_sum_p',
# jvp op:
'reduce_p',
'reduce_sum_p',
# transpose op:
'reshape_p',
'broadcast_p',
......@@ -978,6 +1011,96 @@ class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestGtPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'gt_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestGePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'ge_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'ge_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestNePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'ne_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -228,6 +228,26 @@ class TestErfOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestAbsOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'abs'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['abs', 'abs_p']
self.out_map = {0: self.output['Out']}
class TestLogOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -381,7 +401,9 @@ class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.all_ops = [
'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p'
]
self.out_map = {0: self.output['Out']}
......@@ -404,7 +426,9 @@ class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim):
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.all_ops = [
'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p'
]
self.out_map = {0: self.output['Out']}
......@@ -539,6 +563,63 @@ class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestNeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'not_equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['not_equal', 'ne_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestGtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'greater_than'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['greater_than', 'gt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestGeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'greater_equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['greater_equal', 'ge_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestPowOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -624,5 +705,45 @@ class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reduce_sum'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0, 1], 'keep_dim': False}
self.orig2prim_args = (X, )
self.all_ops = ['reduce_sum', 'reduce_sum_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reduce_mean'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0, 1], 'keep_dim': False}
self.orig2prim_args = (X, )
self.all_ops = [
'reduce_mean', 'reduce_sum_p', 'fill_constant_p', 'div_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -244,6 +244,26 @@ class TestErfPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestAbsPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'abs_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['abs_p', 'abs']
self.out_map = {self.output['Y']: 0}
class TestLogPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......@@ -375,7 +395,7 @@ class TestConcatPPrim2Orig(TestAddPPrim2Orig):
class TestReducePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reduce_p'
self.op_type = 'reduce_sum_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X}
......@@ -386,7 +406,7 @@ class TestReducePPrim2Orig(TestAddPPrim2Orig):
self.attrs = {'axis': [1], 'keepdim': True}
self.prim2orig_args = (X, )
self.all_ops = ['reduce_p', 'reduce_sum']
self.all_ops = ['reduce_sum_p', 'reduce_sum']
self.out_map = {self.output['Y']: 0}
......@@ -555,6 +575,63 @@ class TestEqPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0}
class TestNePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['ne_p', 'not_equal']
self.out_map = {self.output['Z']: 0}
class TestGtPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['gt_p', 'greater_than']
self.out_map = {self.output['Z']: 0}
class TestGePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'ge_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['ge_p', 'greater_equal']
self.out_map = {self.output['Z']: 0}
class TestPowPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......
......@@ -150,6 +150,8 @@ class TestWithoutProgramGuard(unittest.TestCase):
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(10, 10)), ), None, 'float32'),
))
# paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick
......@@ -283,6 +285,21 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(200, 189), ), None, 'float32'),
('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True),
(np.random.rand(200, 189), ), None, 'float32'),
('sum', paddle.sum, (np.random.rand(200, 345), ), None, 'float32'),
('sum_with_axis', lambda x: paddle.sum(x, axis=1),
(np.random.rand(200, 345), ), None, 'float32'),
('sum_with_keepdim', lambda x: paddle.sum(x, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('mean', paddle.mean, (np.random.rand(200, 345), ), None, 'float32'),
('mean_with_axis', lambda x: paddle.mean(x, axis=1),
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_keepdim', lambda x: paddle.mean(x, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_axis_keepdim',
lambda x: paddle.mean(x, axis=0, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(200, 345)), ), None, 'float32'),
))
class TestGrad(unittest.TestCase):
......
......@@ -42,6 +42,7 @@ paddle.enable_static()
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
......@@ -58,10 +59,10 @@ paddle.enable_static()
('concat_axis1', primops.concat, ((randn(2, 3), randn(2, 3)), ), {
'axis': 1
}, (2, 6), 'float64'),
('reduce_axis1', primops.reduce, randn(2, 3), {
('reduce_axis1', primops.reduce_sum, randn(2, 3), {
'axis': (1, )
}, (2, ), 'float64'),
('reduce_axis01', primops.reduce, randn(2, 3), {
('reduce_axis01', primops.reduce_sum, randn(2, 3), {
'axis': (0, 1)
}, (1, ), 'float64'),
('split', primops.split, randn(2, 3), {
......@@ -99,6 +100,9 @@ paddle.enable_static()
('select', primops.select,
(randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('ge', primops.ge, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
))
......
......@@ -290,8 +290,8 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'index_select'
]
self.orig2prim_ops = [
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p',
'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_sum_p',
'sqrt_p', 'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
......@@ -306,7 +306,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'mul_p',
'mul_p',
'add_p',
'reduce_p',
'reduce_sum_p',
'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p
'mul_p',
'div_p',
......@@ -326,7 +326,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'fill_constant_p',
'mul_p',
# transposed op
'reduce_p',
'reduce_sum_p',
'reshape_p',
'reshape_p',
'mul_p',
......@@ -334,7 +334,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'reshape_p',
'broadcast_p',
'div_p',
'reduce_p',
'reduce_sum_p',
'reshape_p',
'fill_constant_p',
'sub_p',
......
......@@ -137,6 +137,11 @@ def exp(x, out=None):
return _simple_unop(LayerHelper('exp_p', **locals()))
@REGISTER_FN('abs_p', 'X', 'Y')
def abs(x, out=None):
return _simple_unop(LayerHelper('abs_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals()))
......@@ -193,15 +198,17 @@ def concat(xs, axis=0, out=None):
return out
@REGISTER_FN('reduce_p', 'X', 'Y')
def reduce(x, axis, keepdim=False, out=None):
@REGISTER_FN('reduce_sum_p', 'X', 'Y')
def reduce_sum(x, axis=None, keepdim=False, out=None):
axes = axis or tuple(range(0, len(x.shape)))
axes = (axes, ) if isinstance(axes, int) else axes
if not isinstance(axis, (tuple, list)):
raise TypeError(f'axis must be tuple or list, but got {type(axis)}')
if not isinstance(keepdim, bool):
raise TypeError(f'keepdim must be bool, but got {type(keepdim)}')
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_p', **locals())
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_sum_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -347,6 +354,21 @@ def eq(x, y, out=None):
return _simple_binop(LayerHelper('eq_p', **locals()))
@REGISTER_FN('gt_p', 'X', 'Y', 'Z')
def gt(x, y, out=None):
return _simple_binop(LayerHelper('gt_p', **locals()))
@REGISTER_FN('ge_p', 'X', 'Y', 'Z')
def ge(x, y, out=None):
return _simple_binop(LayerHelper('ge_p', **locals()))
@REGISTER_FN('ne_p', 'X', 'Y', 'Z')
def ne(x, y, out=None):
return _simple_binop(LayerHelper('ne_p', **locals()))
@REGISTER_FN('pow_p', 'X', 'Y', 'Z')
def pow(x, y, out=None):
return _simple_binop(LayerHelper('pow_p', **locals()))
......
......@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import functools
import math
import operator
import typing
import paddle
from . import primops
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value,
from .primops import (add, broadcast, concat, cos, div, eq, erf, exp,
fill_const, gather, ge, gt, log, matmul, max, mul, ne,
neg, reduce_sum, reshape, scatter_add, select, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, log, select, eq, max, erf)
transpose)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -177,6 +180,11 @@ def erf_orig2prim(op, x):
return erf(x)
@REGISTER_ORIG2PRIM('abs')
def abs_orig2prim(op, x):
return primops.abs(x)
@REGISTER_ORIG2PRIM('log')
def log_orig2prim(op, x):
return log(x)
......@@ -294,9 +302,9 @@ def p_norm_orig2prim(op, x):
x = reshape(x, shape=[num_el(x.shape)])
if abs(op.attr('porder') - 2.0) < 1e-5:
return sqrt(reduce(mul(x, x), axis=[0]))
return sqrt(reduce_sum(mul(x, x), axis=[0]))
elif abs(op.attr('porder') - 1.0) < 1e-5:
return reduce(sqrt(mul(x, x)), axis=[0])
return reduce_sum(sqrt(mul(x, x)), axis=[0])
else:
raise RuntimeError('Only support lower l2/l1 norm currently')
......@@ -314,6 +322,27 @@ def equal_orig2prim(op, x, y):
return eq(x, y)
@REGISTER_ORIG2PRIM('not_equal')
def ne_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return ne(x, y)
@REGISTER_ORIG2PRIM('greater_than')
def gt_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return gt(x, y)
@REGISTER_ORIG2PRIM('greater_equal')
def ge_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return ge(x, y)
@REGISTER_ORIG2PRIM('elementwise_pow')
def elementwise_pow_orig2prim(op, x, y):
if x.shape != y.shape:
......@@ -354,6 +383,25 @@ def gelu_orig2prim(op, x):
erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype)))))
@REGISTER_ORIG2PRIM('reduce_sum')
def reduce_sum_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
return reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
@REGISTER_ORIG2PRIM('reduce_mean')
def reduce_mean_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)
## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
......@@ -405,6 +453,11 @@ def erf_prim2orig(op, x):
return paddle.erf(x)
@REGISTER_PRIM2ORIG('abs_p')
def abs_prim2orig(op, x):
return paddle.abs(x)
@REGISTER_PRIM2ORIG('log_p')
def log_prim2orig(op, x):
return paddle.log(x)
......@@ -440,7 +493,7 @@ def concat_prim2orig(op, xs):
return paddle.concat(xs, axis=op.attr('axis'))
@REGISTER_PRIM2ORIG('reduce_p')
@REGISTER_PRIM2ORIG('reduce_sum_p')
def reduce_prim2orig(op, x):
return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim'))
......@@ -501,6 +554,21 @@ def eq_prim2orig(op, x, y):
return paddle.equal(x, y)
@REGISTER_PRIM2ORIG('gt_p')
def gt_prim2orig(op, x, y):
return paddle.greater_than(x, y)
@REGISTER_PRIM2ORIG('ge_p')
def ge_prim2orig(op, x, y):
return paddle.greater_equal(x, y)
@REGISTER_PRIM2ORIG('ne_p')
def ne_prim2orig(op, x, y):
return paddle.not_equal(x, y)
@REGISTER_PRIM2ORIG('pow_p')
def pow_prim2orig(op, x, y):
return paddle.pow(x, y)
......@@ -616,6 +684,14 @@ def erf_jvp(op, x_dot):
mul(x_dot, exp(neg(primops.pow(x, fill_const(2., x.shape, x.dtype))))))
@REGISTER_JVP('abs_p')
def abs_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return select(ge(x, fill_const(0., x.shape, x.dtype)), x_dot, neg(x_dot))
@REGISTER_JVP('log_p')
def log_jvp(op, x_dot):
if x_dot is None:
......@@ -665,8 +741,8 @@ def concat_jvp(op, xs_dot):
return linear_jvp(op, xs_dot, axis=axis)
@REGISTER_JVP('reduce_p')
def reduce_jvp(op, x_dot):
@REGISTER_JVP('reduce_sum_p')
def reduce_sum_jvp(op, x_dot):
if x_dot is None:
return None
axis = op.attr('axis')
......@@ -765,6 +841,33 @@ def eq_jvp(op, x_dot, y_dot):
return z_dot
@REGISTER_JVP('gt_p')
def gt_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('ge_p')
def ge_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('ne_p')
def ne_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('pow_p')
def pow_jvp(op, x_dot, y_dot):
......@@ -873,7 +976,7 @@ def broadcast_transpose(op, check_dot, y_bar):
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
axis += keepdim
# TODO: Change it. keepdim boolean
out = reduce(y_bar, axis=axis, keepdim=False)
out = reduce_sum(y_bar, axis=axis, keepdim=False)
return reshape(out, x.shape)
......@@ -908,8 +1011,8 @@ def concat_transpose(op, check_dot, y_bar):
return split(y_bar, num_or_sections=sections, axis=axis)
@REGISTER_TRANSPOSE('reduce_p')
def reduce_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('reduce_sum_p')
def reduce_sum_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
axes = op.attr('axis')
......
......@@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
from paddle.fluid import framework as framework
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import Operator
from paddle import compat as cpt
from .primops import fill_const, add
from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig
from .primrules import _orig2prim, _prim2orig, _jvp, _transpose
from .utils import get_input_var_list, get_output_var_list, flatten, flatten_and_remove_none
from collections import OrderedDict
from paddle.fluid import framework as framework
from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors
from .primops import add, fill_const
from .primreg import (lookup_orig2prim, lookup_prim2orig, op_position_inputs,
op_position_output)
from .primrules import _jvp, _orig2prim, _prim2orig, _transpose
from .utils import (flatten, flatten_and_remove_none, get_input_var_list,
get_output_var_list)
def topo_path(xs, ys, block=None):
""" Returns the list of ops on the path from `xs` to `ys` in topological
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册