未验证 提交 ad8beaaf 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[cherry-pick] add abs,mean,sum,ge,gt,pow,etc higher-order differentiation operators (#46184)

* [cherry-pick] extend reduce_sum,reduce_sum,eq,ne,ge,abs,pow,etc higher order operators

* add reduce_mean,reduce_sum primitive ops
* add ne_p gt_p primitive operators
* add ge_p abs_p primitive oparators
* add cast primitive operators
* add pow,square prim2oirg rules
* add elementwise_div orig2prim rule

* [cherry-pick] add mean,sum,ge,gt,ne,abs,etc higher-order differentiation operators(#45888)

* add reduce_mean,reduce_sum primitive ops

* add ne_p gt_p primitive operators

* add ge_p abs_p primitive oparators
上级 45a3c656
......@@ -8,7 +8,7 @@ register_operators()
set(PRIM_OP_SRCS
reshape_p_op.cc
broadcast_p_op.cc
reduce_p_op.cc
reduce_sum_p_op.cc
transpose_p_op.cc
split_p_op.cc
concat_p_op.cc
......@@ -30,9 +30,14 @@ set(PRIM_OP_SRCS
log_p_op.cc
select_p_op.cc
eq_p_op.cc
gt_p_op.cc
ge_p_op.cc
ne_p_op.cc
pow_p_op.cc
max_p_op.cc
erf_p_op.cc)
erf_p_op.cc
abs_p_op.cc
cast_p_op.cc)
cc_test(
prim_op_test
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class AbsPrimOp : public framework::OperatorBase {
public:
AbsPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator abs_p should not be excuted directly"));
}
};
class AbsPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of abs_p op.");
AddOutput("Y", "(Tensor), The output tensor of abs_p op.");
AddComment(R"DOC(Autograd primitive abs_p operator.)DOC");
}
};
class AbsPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class AbsPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(abs_p,
paddle::operators::AbsPrimOp,
paddle::operators::AbsPrimOpMaker,
paddle::operators::AbsPrimOpShapeInference,
paddle::operators::AbsPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class CastPrimOp : public framework::OperatorBase {
public:
CastPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator cast_p should not be excuted directly"));
}
};
class CastPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of cast_p op.");
AddOutput("Y", "(Tensor), The output tensor of cast_p op.");
AddAttr<int>("dtype", "output data type");
AddComment(R"DOC(Autograd primitive cast_p operator.)DOC");
}
};
class CastPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class CastPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("dtype")));
ctx->SetOutputDataType("Y", out_type);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(cast_p,
paddle::operators::CastPrimOp,
paddle::operators::CastPrimOpMaker,
paddle::operators::CastPrimOpShapeInference,
paddle::operators::CastPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class GePrimOp : public framework::OperatorBase {
public:
GePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator ge_p should not be excuted directly"));
}
};
class GePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of ge_p op.");
AddInput("Y", "(Tensor), The input tensor of ge_p op.");
AddOutput("Z", "(Tensor), The output tensor of ge_p op.");
AddComment(R"DOC(
Autograd primitive ge_p operator.
)DOC");
}
};
class GePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class GePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(ge_p,
paddle::operators::GePrimOp,
paddle::operators::GePrimOpMaker,
paddle::operators::GePrimOpShapeInference,
paddle::operators::GePrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class GtPrimOp : public framework::OperatorBase {
public:
GtPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator gt_p should not be excuted directly"));
}
};
class GtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of gt_p op.");
AddInput("Y", "(Tensor), The input tensor of gt_p op.");
AddOutput("Z", "(Tensor), The output tensor of gt_p op.");
AddComment(R"DOC(
Autograd primitive gt_p operator.
)DOC");
}
};
class GtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class GtPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(gt_p,
paddle::operators::GtPrimOp,
paddle::operators::GtPrimOpMaker,
paddle::operators::GtPrimOpShapeInference,
paddle::operators::GtPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class NePrimOp : public framework::OperatorBase {
public:
NePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator ne_p should not be excuted directly"));
}
};
class NePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of ne_p op.");
AddInput("Y", "(Tensor), The input tensor of ne_p op.");
AddOutput("Z", "(Tensor), The output tensor of ne_p op.");
AddComment(R"DOC(
Autograd primitive ne_p operator.
)DOC");
}
};
class NePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank,
y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank,
y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i],
y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i,
x_shape[i],
y_shape[i]));
}
PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class NePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type,
y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type,
y_type));
PADDLE_ENFORCE_EQ(x_dtype,
y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype,
y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(ne_p,
paddle::operators::NePrimOp,
paddle::operators::NePrimOpMaker,
paddle::operators::NePrimOpShapeInference,
paddle::operators::NePrimOpVarTypeInference);
......@@ -18,7 +18,7 @@
USE_OP_ITSELF(reshape_p);
USE_OP_ITSELF(broadcast_p);
USE_OP_ITSELF(reduce_p);
USE_OP_ITSELF(reduce_sum_p);
USE_OP_ITSELF(transpose_p);
USE_OP_ITSELF(split_p);
USE_OP_ITSELF(concat_p);
......@@ -130,7 +130,7 @@ TEST(PrimOp, broadcast_p) {
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, reduce_p) {
TEST(PrimOp, reduce_sum_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
......@@ -141,7 +141,7 @@ TEST(PrimOp, reduce_p) {
NewVar(block, x0, shape);
AppendOp(block,
"reduce_p",
"reduce_sum_p",
{{"X", {x0}}},
{{"Y", {x1}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", false}});
......@@ -151,7 +151,7 @@ TEST(PrimOp, reduce_p) {
ASSERT_EQ(shapes.size(), 1UL);
ASSERT_EQ(shapes[0], 4L);
AppendOp(block,
"reduce_p",
"reduce_sum_p",
{{"X", {x0}}},
{{"Y", {x2}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", true}});
......
......@@ -24,25 +24,25 @@ class VarDesc;
namespace paddle {
namespace operators {
class ReducePrimOp : public framework::OperatorBase {
class ReduceSumPrimOp : public framework::OperatorBase {
public:
ReducePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
ReduceSumPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator reduce_p should not be excuted directly"));
"Prim operator reduce_sum_p should not be excuted directly"));
}
};
class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker {
class ReduceSumPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of reduce_p op.");
AddOutput("Y", "(Tensor), The output tensor of reduce_p op.");
AddInput("X", "(Tensor), The input tensor of reduce_sum_p op.");
AddOutput("Y", "(Tensor), The output tensor of reduce_sum_p op.");
AddAttr<std::vector<int64_t>>(
"axis",
"(std::vector<int64_t>) The axis along which to reduce on. Must be in "
......@@ -53,12 +53,12 @@ class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker {
"If true, retain the reduced axis with length 1.")
.SetDefault(false);
AddComment(R"DOC(
Autograd primitive reduce_p operator.
Autograd primitive reduce_sum_p operator.
)DOC");
}
};
class ReducePrimOpShapeInference : public framework::InferShapeBase {
class ReduceSumPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
......@@ -87,7 +87,7 @@ class ReducePrimOpShapeInference : public framework::InferShapeBase {
}
};
class ReducePrimOpVarTypeInference
class ReduceSumPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
......@@ -101,8 +101,8 @@ class ReducePrimOpVarTypeInference
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(reduce_p,
paddle::operators::ReducePrimOp,
paddle::operators::ReducePrimOpMaker,
paddle::operators::ReducePrimOpShapeInference,
paddle::operators::ReducePrimOpVarTypeInference);
REGISTER_OPERATOR(reduce_sum_p,
paddle::operators::ReduceSumPrimOp,
paddle::operators::ReduceSumPrimOpMaker,
paddle::operators::ReduceSumPrimOpShapeInference,
paddle::operators::ReduceSumPrimOpVarTypeInference);
......@@ -32,4 +32,4 @@ from . import dist_pnorm
from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_p
from . import dist_reduce_sum_p
......@@ -33,21 +33,21 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedReducePrimtive(DistributedOperatorImplContainer):
class DistributedReduceSumPrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type)
super(DistributedReduceSumPrimtive, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedReducePrimtive("reduce_p"))
DistributedReduceSumPrimtive("reduce_sum_p"))
# Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
# Batch Dimension ReduceSum Primitive
class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name)
super(DistributedReduceSumPrimtiveImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -149,4 +149,5 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
register_distributed_operator_impl(
"reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p"))
"reduce_sum_p",
DistributedReduceSumPrimtiveImpl0("batch_dimension_reduce_sum_p"))
......@@ -78,7 +78,7 @@ class TestPrimDistOp(unittest.TestCase):
outputs={'Z': self.w_grad},
attrs=self.attrs)
op = self.layer_help.append_op(type="reduce_p",
op = self.layer_help.append_op(type="reduce_sum_p",
inputs={'X': self.tmp2},
outputs={'Y': self.batch_reduced},
attrs={"axis": [0]})
......
......@@ -400,6 +400,75 @@ class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'abs_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
self.all_ops = [
# prim op:
'abs_p',
# jvp op:
'select_p',
'ge_p',
'fill_constant_p',
'fill_constant_p',
'sub_p',
# transpose op:
]
class TestCastPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'cast_p'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
self.prim_input = {
'X': X,
}
self.prim_output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {'dtype': paddle.float64}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64')
self.jvp_args = (X_DOT, )
self.jvp_out_shape_map = {0: self.prim_output['Y']}
# Set transpose
check_dot = lambda v: True
Y_BAR = paddle.static.data(name='Y_BAR', shape=[5, 6], dtype='float')
self.transpose_args = (check_dot, Y_BAR)
self.transpose_out_shape_map = {0: X}
self.all_ops = [
# prim op:
'cast_p',
# jvp op:
'cast_p',
# transpose op:
'cast_p'
]
class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......@@ -503,7 +572,7 @@ class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose):
# jvp op:
'broadcast_p',
# transpose op:
'reduce_p',
'reduce_sum_p',
'reshape_p'
]
......@@ -650,11 +719,11 @@ class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose):
class TestReduceSumPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'reduce_p'
self.op_type = 'reduce_sum_p'
X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64')
self.prim_input = {'X': X}
self.prim_output = {
......@@ -682,9 +751,9 @@ class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose):
self.all_ops = [
# prim op:
'reduce_p',
'reduce_sum_p',
# jvp op:
'reduce_p',
'reduce_sum_p',
# transpose op:
'reshape_p',
'broadcast_p',
......@@ -978,6 +1047,96 @@ class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose):
]
class TestGtPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'gt_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestGePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'ge_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'ge_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestNePJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[4, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.prim_attrs = {}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64')
Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64')
self.jvp_args = (X_DOT, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
self.all_ops = [
# prim op:
'ne_p',
# jvp op:
'fill_constant_p',
# transpose op:
]
class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self):
......
......@@ -110,6 +110,26 @@ class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestElementWiseDivOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_div'
X = paddle.static.data(name='X', shape=[8, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_div', 'div_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -229,6 +249,26 @@ class TestErfOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestAbsOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'abs'
X = paddle.static.data(name='X', shape=[3, 4], dtype='float')
self.input = {
'X': X,
}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['abs', 'abs_p']
self.out_map = {0: self.output['Out']}
class TestLogOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -422,7 +462,9 @@ class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.all_ops = [
'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p'
]
self.out_map = {0: self.output['Out']}
......@@ -445,7 +487,9 @@ class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim):
}
self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.all_ops = [
'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p'
]
self.out_map = {0: self.output['Out']}
......@@ -580,6 +624,63 @@ class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestNeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'not_equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['not_equal', 'ne_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestGtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'greater_than'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['greater_than', 'gt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestGeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'greater_equal'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float')
self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.orig2prim_args = (X, Y)
self.all_ops = ['greater_equal', 'ge_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestPowOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
......@@ -665,5 +766,118 @@ class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim):
self.out_map = {0: self.output['Out']}
class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reduce_sum'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0, 1], 'keep_dim': False}
self.orig2prim_args = (X, )
self.all_ops = ['reduce_sum', 'reduce_sum_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'reduce_mean'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': [0, 1], 'keep_dim': False}
self.orig2prim_args = (X, )
self.all_ops = [
'reduce_mean', 'reduce_sum_p', 'fill_constant_p', 'div_p'
]
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestSizeOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'size'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'Input': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(
dtype=paddle.int64)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['size', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestCastOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'cast'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'in_dtype': X.dtype, 'out_dtype': paddle.float64}
self.orig2prim_args = (X, )
self.all_ops = ['cast', 'cast_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestPowScalarOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'pow'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'factor': 2.}
self.orig2prim_args = (None, X)
self.all_ops = ['pow', 'pow_p', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'square'
X = paddle.static.data(name='X', shape=[5, 8], dtype='float')
self.input = {'X': X}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.orig2prim_args = (X, )
self.all_ops = ['square', 'pow_p', 'fill_constant_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}
if __name__ == '__main__':
unittest.main()
......@@ -244,6 +244,26 @@ class TestErfPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Y']: 0}
class TestAbsPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'abs_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}
self.prim2orig_args = (X, )
self.all_ops = ['abs_p', 'abs']
self.out_map = {self.output['Y']: 0}
class TestLogPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......@@ -375,7 +395,7 @@ class TestConcatPPrim2Orig(TestAddPPrim2Orig):
class TestReducePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'reduce_p'
self.op_type = 'reduce_sum_p'
X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
self.input = {'X': X}
......@@ -386,7 +406,7 @@ class TestReducePPrim2Orig(TestAddPPrim2Orig):
self.attrs = {'axis': [1], 'keepdim': True}
self.prim2orig_args = (X, )
self.all_ops = ['reduce_p', 'reduce_sum']
self.all_ops = ['reduce_sum_p', 'reduce_sum']
self.out_map = {self.output['Y']: 0}
......@@ -555,6 +575,63 @@ class TestEqPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0}
class TestNePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'ne_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['ne_p', 'not_equal']
self.out_map = {self.output['Z']: 0}
class TestGtPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'gt_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['gt_p', 'greater_than']
self.out_map = {self.output['Z']: 0}
class TestGePPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'ge_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')
self.input = {'X': X, 'Y': Y}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype='bool')
}
self.attrs = {}
self.prim2orig_args = (X, Y)
self.all_ops = ['ge_p', 'greater_equal']
self.out_map = {self.output['Z']: 0}
class TestPowPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
......@@ -593,5 +670,25 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig):
self.out_map = {self.output['Z']: 0}
class TestCastPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'cast_p'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
self.input = {
'X': X,
}
self.output = {
'Y':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'dtype': paddle.int64}
self.prim2orig_args = (X, )
self.all_ops = ['cast_p', 'cast']
self.out_map = {self.output['Y']: 0}
if __name__ == '__main__':
unittest.main()
......@@ -150,6 +150,8 @@ class TestWithoutProgramGuard(unittest.TestCase):
(np.random.rand(3, 3), np.random.rand(3, 3)),
(np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'),
('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(10, 10)), ), None, 'float32'),
))
# paddle.where, paddle.pow, paddle.maximum has no double grad definition,
# can not compute forward grad use double trick
......@@ -255,6 +257,8 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'),
('multiply', paddle.multiply,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('div', paddle.divide,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'),
('add', paddle.add,
(np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'),
('input_not_sequence', paddle.tanh,
......@@ -283,7 +287,36 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y)
(np.random.rand(200, 189), ), None, 'float32'),
('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True),
(np.random.rand(200, 189), ), None, 'float32'),
))
('sum', paddle.sum, (np.random.rand(200, 345), ), None, 'float32'),
('sum_with_axis', lambda x: paddle.sum(x, axis=1),
(np.random.rand(200, 345), ), None, 'float32'),
('sum_with_keepdim', lambda x: paddle.sum(x, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('mean', paddle.mean, (np.random.rand(200, 345), ), None, 'float32'),
('mean_with_axis', lambda x: paddle.mean(x, axis=1),
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_keepdim', lambda x: paddle.mean(x, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('mean_with_axis_keepdim',
lambda x: paddle.mean(x, axis=0, keepdim=True),
(np.random.rand(200, 345), ), None, 'float32'),
('abs', paddle.abs, (np.random.uniform(-10, 10,
(200, 345)), ), None, 'float32'),
('cast_float', lambda x: paddle.cast(x, paddle.float64),
(np.random.rand(10, 20), ), None, 'float32'),
('cast_int', lambda x: paddle.cast(x, paddle.int32),
(np.random.rand(10, 20), ), None, 'float32'),
('square', paddle.square, (np.random.rand(100), ), None, 'float32'),
('pow_scalar', lambda x: paddle.pow(x, 2),
(np.random.rand(20, 30), ), None, 'float32'),
('var', paddle.var, (np.random.rand(200, 324), ), None, 'float32'),
('var_with_axis', lambda x: paddle.var(x, axis=1),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('var_without_unbiased',
lambda x: paddle.var(x, axis=1, unbiased=False),
(np.random.rand(10, 20, 30), ), None, 'float32'),
('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True),
(np.random.rand(10, 20, 30), ), None, 'float32')))
class TestGrad(unittest.TestCase):
def setUp(self):
......
......@@ -42,7 +42,11 @@ paddle.enable_static()
('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'),
('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'),
('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'),
('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'),
('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'),
('cast', primops.cast, randn(2, 3), {
'dtype': paddle.int64
}, (2, 3), 'int64'),
('reshape', primops.reshape, randn(2, 3), {
'shape': (3, 2)
}, (3, 2), 'float64'),
......@@ -58,10 +62,10 @@ paddle.enable_static()
('concat_axis1', primops.concat, ((randn(2, 3), randn(2, 3)), ), {
'axis': 1
}, (2, 6), 'float64'),
('reduce_axis1', primops.reduce, randn(2, 3), {
('reduce_axis1', primops.reduce_sum, randn(2, 3), {
'axis': (1, )
}, (2, ), 'float64'),
('reduce_axis01', primops.reduce, randn(2, 3), {
('reduce_axis01', primops.reduce_sum, randn(2, 3), {
'axis': (0, 1)
}, (1, ), 'float64'),
('split', primops.split, randn(2, 3), {
......@@ -99,6 +103,9 @@ paddle.enable_static()
('select', primops.select,
(randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('ge', primops.ge, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'),
('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'),
))
......
......@@ -290,8 +290,8 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'index_select'
]
self.orig2prim_ops = [
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p',
'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_sum_p',
'sqrt_p', 'broadcast_p', 'sub_p', 'concat_p', 'gather_p'
]
self.linearize_ops = self.orig2prim_ops + [
# call fill_const() in linearize() function
......@@ -306,7 +306,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'mul_p',
'mul_p',
'add_p',
'reduce_p',
'reduce_sum_p',
'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p
'mul_p',
'div_p',
......@@ -326,7 +326,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'fill_constant_p',
'mul_p',
# transposed op
'reduce_p',
'reduce_sum_p',
'reshape_p',
'reshape_p',
'mul_p',
......@@ -334,7 +334,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'reshape_p',
'broadcast_p',
'div_p',
'reduce_p',
'reduce_sum_p',
'reshape_p',
'fill_constant_p',
'sub_p',
......
......@@ -137,6 +137,11 @@ def exp(x, out=None):
return _simple_unop(LayerHelper('exp_p', **locals()))
@REGISTER_FN('abs_p', 'X', 'Y')
def abs(x, out=None):
return _simple_unop(LayerHelper('abs_p', **locals()))
@REGISTER_FN('reshape_p', 'X', 'Y')
def reshape(x, shape, out=None):
return _manipulation_unop(LayerHelper('reshape_p', **locals()))
......@@ -193,15 +198,17 @@ def concat(xs, axis=0, out=None):
return out
@REGISTER_FN('reduce_p', 'X', 'Y')
def reduce(x, axis, keepdim=False, out=None):
@REGISTER_FN('reduce_sum_p', 'X', 'Y')
def reduce_sum(x, axis=None, keepdim=False, out=None):
axes = axis or tuple(range(0, len(x.shape)))
axes = (axes, ) if isinstance(axes, int) else axes
if not isinstance(axis, (tuple, list)):
raise TypeError(f'axis must be tuple or list, but got {type(axis)}')
if not isinstance(keepdim, bool):
raise TypeError(f'keepdim must be bool, but got {type(keepdim)}')
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_p', **locals())
attrs = {'axis': axis, 'keepdim': keepdim}
helper = LayerHelper('reduce_sum_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -347,6 +354,21 @@ def eq(x, y, out=None):
return _simple_binop(LayerHelper('eq_p', **locals()))
@REGISTER_FN('gt_p', 'X', 'Y', 'Z')
def gt(x, y, out=None):
return _simple_binop(LayerHelper('gt_p', **locals()))
@REGISTER_FN('ge_p', 'X', 'Y', 'Z')
def ge(x, y, out=None):
return _simple_binop(LayerHelper('ge_p', **locals()))
@REGISTER_FN('ne_p', 'X', 'Y', 'Z')
def ne(x, y, out=None):
return _simple_binop(LayerHelper('ne_p', **locals()))
@REGISTER_FN('pow_p', 'X', 'Y', 'Z')
def pow(x, y, out=None):
return _simple_binop(LayerHelper('pow_p', **locals()))
......@@ -360,3 +382,15 @@ def max(x, y, out=None):
@REGISTER_FN('erf_p', 'X', 'Y')
def erf(x, out=None):
return _simple_unop(LayerHelper('erf_p', **locals()))
@REGISTER_FN('cast_p', 'X', 'Y')
def cast(x, dtype, out=None):
helper = LayerHelper('cast_p', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type=helper.layer_type,
inputs={'X': x},
outputs={'Y': out},
attrs={'dtype': dtype})
return out
......@@ -80,7 +80,7 @@ def op_position_inputs(op):
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None, 'args should not be None in op_position_inputs().'
assert args is not None, f'args of {op.type} should not be None in op_position_inputs().'
*input_names, _ = args
inputs = []
......
......@@ -11,16 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import functools
import math
import operator
import typing
import paddle
from . import primops
from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather,
matmul, mul, neg, reduce, reshape, scatter_add, set_value,
from .primops import (add, broadcast, concat, cos, div, eq, erf, exp,
fill_const, gather, ge, gt, log, matmul, max, mul, ne,
neg, reduce_sum, reshape, scatter_add, select, set_value,
sin, slice_assign, slice_select, split, sqrt, sub, tanh,
transpose, log, select, eq, max, erf)
transpose)
from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG,
REGISTER_TRANSPOSE, lookup_fn, lookup_jvp,
lookup_orig2prim, lookup_prim2orig, lookup_transpose,
......@@ -155,6 +158,13 @@ def elementwise_mul_orig2prim(op, x, y):
return z
@REGISTER_ORIG2PRIM('elementwise_div')
def elementwise_div_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return primops.div(x, y)
@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op, x):
return tanh(x)
......@@ -180,6 +190,11 @@ def erf_orig2prim(op, x):
return erf(x)
@REGISTER_ORIG2PRIM('abs')
def abs_orig2prim(op, x):
return primops.abs(x)
@REGISTER_ORIG2PRIM('log')
def log_orig2prim(op, x):
return log(x)
......@@ -307,13 +322,18 @@ def p_norm_orig2prim(op, x):
x = reshape(x, shape=[num_el(x.shape)])
if abs(op.attr('porder') - 2.0) < 1e-5:
return sqrt(reduce(mul(x, x), axis=[0]))
return sqrt(reduce_sum(mul(x, x), axis=[0]))
elif abs(op.attr('porder') - 1.0) < 1e-5:
return reduce(sqrt(mul(x, x)), axis=[0])
return reduce_sum(sqrt(mul(x, x)), axis=[0])
else:
raise RuntimeError('Only support lower l2/l1 norm currently')
@REGISTER_ORIG2PRIM('cast')
def cast_orig2prim(op, x):
return primops.cast(x, paddle.dtype(op.attr('out_dtype')))
# TODO: support broadcast
@REGISTER_ORIG2PRIM('where')
def select_orig2prim(op, condition, x, y):
......@@ -327,15 +347,48 @@ def equal_orig2prim(op, x, y):
return eq(x, y)
@REGISTER_ORIG2PRIM('not_equal')
def ne_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return ne(x, y)
@REGISTER_ORIG2PRIM('greater_than')
def gt_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return gt(x, y)
@REGISTER_ORIG2PRIM('greater_equal')
def ge_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
return ge(x, y)
# paddle.pow API use "elementwise_pow" operator when y is a Tensor.
@REGISTER_ORIG2PRIM('elementwise_pow')
def elementwise_pow_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
z = primops.pow(x, y)
return z
# paddle.pow API use "pow" operator when y is a scalar.
@REGISTER_ORIG2PRIM('pow')
def pow_orig2prim(op, x, y):
# x is factorTensor defined in paddle phi op. Currently it is None.
return primops.pow(y, fill_const(op.attr('factor'), y.shape, y.dtype))
@REGISTER_ORIG2PRIM('square')
def square_orig2prim(op, x):
return primops.pow(x, fill_const(2., x.shape, x.dtype))
@REGISTER_ORIG2PRIM('elementwise_max')
def elementwise_max_orig2prim(op, x, y):
if x.shape != y.shape:
......@@ -367,6 +420,31 @@ def gelu_orig2prim(op, x):
erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype)))))
@REGISTER_ORIG2PRIM('reduce_sum')
def reduce_sum_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
return reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
@REGISTER_ORIG2PRIM('reduce_mean')
def reduce_mean_orig2prim(op, x):
axes = tuple(range(0, len(
x.shape))) if op.attr('reduce_all') else op.attr('dim')
sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim'))
norm = fill_const(shape=sum.shape,
value=functools.reduce(operator.mul,
[x.shape[axis] for axis in axes]),
dtype=sum.dtype)
return div(sum, norm)
@REGISTER_ORIG2PRIM('size')
def size_orig2prim(op, x):
return fill_const(functools.reduce(operator.mul, x.shape), (1, ),
paddle.int64)
## Register prim2orig lower rules
@REGISTER_PRIM2ORIG('add_p')
def add_prim2orig(op, x, y):
......@@ -418,6 +496,11 @@ def erf_prim2orig(op, x):
return paddle.erf(x)
@REGISTER_PRIM2ORIG('abs_p')
def abs_prim2orig(op, x):
return paddle.abs(x)
@REGISTER_PRIM2ORIG('log_p')
def log_prim2orig(op, x):
return paddle.log(x)
......@@ -453,7 +536,7 @@ def concat_prim2orig(op, xs):
return paddle.concat(xs, axis=op.attr('axis'))
@REGISTER_PRIM2ORIG('reduce_p')
@REGISTER_PRIM2ORIG('reduce_sum_p')
def reduce_prim2orig(op, x):
return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim'))
......@@ -514,6 +597,21 @@ def eq_prim2orig(op, x, y):
return paddle.equal(x, y)
@REGISTER_PRIM2ORIG('gt_p')
def gt_prim2orig(op, x, y):
return paddle.greater_than(x, y)
@REGISTER_PRIM2ORIG('ge_p')
def ge_prim2orig(op, x, y):
return paddle.greater_equal(x, y)
@REGISTER_PRIM2ORIG('ne_p')
def ne_prim2orig(op, x, y):
return paddle.not_equal(x, y)
@REGISTER_PRIM2ORIG('pow_p')
def pow_prim2orig(op, x, y):
return paddle.pow(x, y)
......@@ -524,6 +622,11 @@ def max_prim2orig(op, x, y):
return paddle.maximum(x, y)
@REGISTER_PRIM2ORIG('cast_p')
def cast_prim2orig(op, x):
return paddle.cast(x, paddle.dtype(op.attr('dtype')))
## Register linearize rules
@REGISTER_JVP('add_p')
def add_jvp(op, x_dot, y_dot):
......@@ -629,6 +732,14 @@ def erf_jvp(op, x_dot):
mul(x_dot, exp(neg(primops.pow(x, fill_const(2., x.shape, x.dtype))))))
@REGISTER_JVP('abs_p')
def abs_jvp(op, x_dot):
if x_dot is None:
return None
x, = op_position_inputs(op)
return select(ge(x, fill_const(0., x.shape, x.dtype)), x_dot, neg(x_dot))
@REGISTER_JVP('log_p')
def log_jvp(op, x_dot):
if x_dot is None:
......@@ -678,8 +789,8 @@ def concat_jvp(op, xs_dot):
return linear_jvp(op, xs_dot, axis=axis)
@REGISTER_JVP('reduce_p')
def reduce_jvp(op, x_dot):
@REGISTER_JVP('reduce_sum_p')
def reduce_sum_jvp(op, x_dot):
if x_dot is None:
return None
axis = op.attr('axis')
......@@ -778,6 +889,33 @@ def eq_jvp(op, x_dot, y_dot):
return z_dot
@REGISTER_JVP('gt_p')
def gt_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('ge_p')
def ge_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('ne_p')
def ne_jvp(op, x_dot, y_dot):
if x_dot is None and y_dot is None:
return None
x, _ = op_position_inputs(op)
z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype)
return z_dot
@REGISTER_JVP('pow_p')
def pow_jvp(op, x_dot, y_dot):
......@@ -825,6 +963,12 @@ def max_jvp(op, x_dot, y_dot):
return select(eq(y, z), y_dot, x_dot)
@REGISTER_JVP('cast_p')
def cast_jvp(op, x_dot):
y = op_position_output(op)
return primops.cast(x_dot, y.dtype)
## Register transpose rules
......@@ -886,7 +1030,7 @@ def broadcast_transpose(op, check_dot, y_bar):
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
axis += keepdim
# TODO: Change it. keepdim boolean
out = reduce(y_bar, axis=axis, keepdim=False)
out = reduce_sum(y_bar, axis=axis, keepdim=False)
return reshape(out, x.shape)
......@@ -921,8 +1065,8 @@ def concat_transpose(op, check_dot, y_bar):
return split(y_bar, num_or_sections=sections, axis=axis)
@REGISTER_TRANSPOSE('reduce_p')
def reduce_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('reduce_sum_p')
def reduce_sum_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x), 'check_dot(x) must be True'
axes = op.attr('axis')
......@@ -1029,3 +1173,9 @@ def select_transpose(op, check_dot, z_bar):
y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None
return cond_bar, x_bar, y_bar
@REGISTER_TRANSPOSE('cast_p')
def cast_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
return primops.cast(y_bar, x.dtype)
......@@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
from paddle.fluid import framework as framework
from paddle.fluid.framework import default_main_program
from paddle.fluid.framework import Operator
from paddle import compat as cpt
from .primops import fill_const, add
from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig
from .primrules import _orig2prim, _prim2orig, _jvp, _transpose
from .utils import get_input_var_list, get_output_var_list, flatten, flatten_and_remove_none
from collections import OrderedDict
from paddle.fluid import framework as framework
from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors
from .primops import add, fill_const
from .primreg import (lookup_orig2prim, lookup_prim2orig, op_position_inputs,
op_position_output)
from .primrules import _jvp, _orig2prim, _prim2orig, _transpose
from .utils import (flatten, flatten_and_remove_none, get_input_var_list,
get_output_var_list)
def topo_path(xs, ys, block=None):
""" Returns the list of ops on the path from `xs` to `ys` in topological
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册