未验证 提交 97dec7ca 编写于 作者: L levi131 提交者: GitHub

Lml/add prim ops (#41201)

* native commit for triple grad of sigmod

* Updated unittests files

* init functional jacobian api

* Updated trible_test func

* Updated gradient_checker & test_script

* finish test with dtype float32

* add float64 test case

* polish code

* use atol=1e-5 with dtype float64

* fix for ci

* set timeout for test_jacobian

* fix dygraph grad to support high differential

* polish API docstring

* Updated gradient checker and some related files

* fix double grad strip error for high differential

* fix double grad strip error for high differential

* Add Sigmoid triple grad tests

* fix dygraph double grad dtype error when calling for high differential senario

* Updated triple grad teses func

* Use np.random to initialize ddx

* Updated triple_grad_check func

* add todo for gradient checker and refine some comments

* remove additional code

* add test for warnging in backward.py

* format python code

* support multi input in triple gradient checker

* Add matmul triple grad kernel

* Updated comments of TODO

* Supported some special tests

* Change code-format to follow CI std

* Updated gradient_checker.py

* Fix conflicts

* Removed unnecessary printing log

* Change code style to follow CI std

* merge upstream

* add_p

* rm useless files

* add sub_p mul_p div_p

* add sqrt_p and tanh_p

* add reshape_p

* add broadcast_p

* add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p

* add split_p and concat_p

* add gather_p and scatter_add_p

* add slice_select_p and slice_assign_p

* add multi input check for add_p, sub_p, mul_p, div_p

* update concat_p

* refine gather_p and scatter_add_p

* refine slice_assign_p and slice_select_p

* add 9 test for prim ops

* add more test and fix some bug

* add more test

* register proto

* add shape valid check for broadcast_p op, and add keepdim attr into reduce_p op proto

* support multi input and multi output for split_p and concat_p

* fix slice bug for slice_select_p and slice_assign_p

* dtype for axis attr should be long int

* update dtype for axis attr int64_t

* update for iscan CI

* add more shape and dtype check

* change IndexTensor into int32 dtype
上级 b12af9e1
......@@ -22,6 +22,7 @@ add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops)
add_subdirectory(string)
add_subdirectory(jit)
add_subdirectory(prim_ops)
if(WITH_MKLDNN)
add_subdirectory(mkldnn)
endif()
......
include(operators)
if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/prim_ops.
include(unity_build_rule.cmake)
endif()
register_operators()
SET(PRIM_OP_SRCS
reshape_p_op.cc
broadcast_p_op.cc
reduce_p_op.cc
transpose_p_op.cc
split_p_op.cc
concat_p_op.cc
slice_select_p_op.cc
slice_assign_p_op.cc
gather_p_op.cc
scatter_add_p_op.cc
add_p_op.cc
sub_p_op.cc
mul_p_op.cc
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc)
cc_test(prim_op_test SRCS prim_op_test.cc ${PRIM_OP_SRCS} DEPS op_registry)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class AddPrimOp : public framework::OperatorBase {
public:
AddPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator add_p should not be excuted directly"));
}
};
class AddPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of add_p op.");
AddInput("Y", "(Tensor), The input tensor of add_p op.");
AddOutput("Z", "(Tensor), The output tensor of add_p op.");
AddComment(R"DOC(
Autograd primitive add_p operator.
)DOC");
}
};
class AddPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_shape[i], y_shape[i]));
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class AddPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(add_p, paddle::operators::AddPrimOp,
paddle::operators::AddPrimOpMaker,
paddle::operators::AddPrimOpShapeInference,
paddle::operators::AddPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class BroadcastPrimOp : public framework::OperatorBase {
public:
BroadcastPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator broadcast_p should not be excuted directly"));
}
};
class BroadcastPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of broadcast_p op.");
AddOutput("Y", "(Tensor), The output tensor of broadcast_p op.");
AddAttr<std::vector<int64_t>>(
"shape",
"(std::vector<int64_t>) Target shape of broadcast_p operator.");
AddComment(R"DOC(
Autograd primitive broadcast_p operator.
)DOC");
}
};
static void CheckShapeValid(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &target_shape) {
size_t x_rank = x_shape.size();
size_t target_rank = target_shape.size();
PADDLE_ENFORCE_GE(target_rank, x_rank,
platform::errors::InvalidArgument(
"The rank of target shape should be greater than or "
"equal to input tensor's dimensions, "
"but received %d and %d",
target_rank, x_rank));
std::vector<int64_t>::const_iterator it = target_shape.begin();
for (size_t i = 0; i < x_rank; i++, it++) {
if (x_shape[i] != 1) {
it = std::find(it, target_shape.end(), x_shape[i]);
}
PADDLE_ENFORCE_EQ(
it != target_shape.end(), true,
platform::errors::InvalidArgument(
"Invalid shape, can not broadcast input tensor into target shape,"
"the first dismatching shape %d is shape of input tensor at "
"dimension %d",
x_shape[i], i));
}
}
class BroadcastPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto target_shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
CheckShapeValid(x_shape, target_shape);
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(target_shape);
}
};
class BroadcastPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(broadcast_p, paddle::operators::BroadcastPrimOp,
paddle::operators::BroadcastPrimOpMaker,
paddle::operators::BroadcastPrimOpShapeInference,
paddle::operators::BroadcastPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class ConcatPrimOp : public framework::OperatorBase {
public:
ConcatPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator concat_p should not be excuted directly"));
}
};
class ConcatPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("XS", "(Tensor), The input tensors of concat_p op.")
.AsDuplicable();
AddOutput("Y", "(Tensor), The output tensor of concat_p op.");
AddAttr<int64_t>("axis", "(int64_t), The axis along which to concat.");
AddComment(R"DOC(
Autograd primitive concat_p operator.
)DOC");
}
};
class ConcatPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_var_ptrs = ctx->GetInputVarPtrs("XS");
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
auto axis = ctx->Attrs().Get<int64_t>("axis");
int64_t cnt_along_axis = 0;
framework::VarDesc *first_x_var =
BOOST_GET(framework::VarDesc *, x_var_ptrs[0]);
auto first_x_shape = first_x_var->GetShape();
cnt_along_axis += first_x_shape[axis];
size_t first_x_rank = first_x_shape.size();
for (size_t i = 1; i < x_var_ptrs.size(); ++i) {
framework::VarDesc *x_var =
BOOST_GET(framework::VarDesc *, x_var_ptrs[i]);
auto x_shape = x_var->GetShape();
cnt_along_axis += x_shape[axis];
size_t x_rank = x_shape.size();
PADDLE_ENFORCE_EQ(
x_rank, first_x_rank,
platform::errors::InvalidArgument("The dimensions of %d input tensor "
"should be same as the dimensions "
"of 1st input tensor's, "
"but get %d and %d",
i + 1, x_rank, first_x_rank));
for (size_t j = 0; j < x_rank; ++j) {
if (j != size_t(axis)) {
PADDLE_ENFORCE_EQ(x_shape[j], first_x_shape[j],
platform::errors::InvalidArgument(
"The shape of %d input tensor at dimension %d "
"should be same as the 1st input tensor's, "
"but get %d and %d",
i + 1, j, x_shape[j], first_x_shape[j]));
}
}
}
std::vector<int64_t> y_shape(first_x_shape);
y_shape[axis] = cnt_along_axis;
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape);
}
};
class ConcatPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_names = Input(ctx, "XS");
auto y_name = Output(ctx, "Y")[0];
auto first_x_name = x_names[0];
auto first_x_type = GetType(ctx, first_x_name);
auto first_x_dtype = GetDataType(ctx, first_x_name);
for (size_t i = 1; i < x_names.size(); ++i) {
auto x_name = x_names[i];
auto x_type = GetType(ctx, x_name);
auto x_dtype = GetDataType(ctx, x_name);
PADDLE_ENFORCE_EQ(x_type, first_x_type,
platform::errors::InvalidArgument(
"The type of %d input tensor should be same as the "
"first input tensor's, "
"but get %d and %d",
i + 1, x_type, first_x_type));
PADDLE_ENFORCE_EQ(x_dtype, first_x_dtype,
platform::errors::InvalidArgument(
"The datatype of %d input tensor should be same as "
"the first input tensor's, "
"but get %d and %d",
i + 1, x_dtype, first_x_dtype));
}
SetType(ctx, y_name, GetType(ctx, first_x_name));
SetDataType(ctx, y_name, GetDataType(ctx, first_x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(concat_p, paddle::operators::ConcatPrimOp,
paddle::operators::ConcatPrimOpMaker,
paddle::operators::ConcatPrimOpShapeInference,
paddle::operators::ConcatPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class DivPrimOp : public framework::OperatorBase {
public:
DivPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator div_p should not be excuted directly"));
}
};
class DivPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of div_p op.");
AddInput("Y", "(Tensor), The input tensor of div_p op.");
AddOutput("Z", "(Tensor), The output tensor of div_p op.");
AddComment(R"DOC(
Autograd primitive div_p operator.
)DOC");
}
};
class DivPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_shape[i], y_shape[i]));
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class DivPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(div_p, paddle::operators::DivPrimOp,
paddle::operators::DivPrimOpMaker,
paddle::operators::DivPrimOpShapeInference,
paddle::operators::DivPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class FillConstantPrimOp : public framework::OperatorBase {
public:
FillConstantPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator fill_constant_p should not be excuted directly"));
}
};
class FillConstantPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Y", "(Tensor), The output tensor of fill_constant_p op.");
AddAttr<float>("value", "(float) The value of output tensor.");
AddAttr<std::vector<int64_t>>(
"shape", "(std::vector<int64_t>) The shape of output tensor.");
AddAttr<int>("dtype", "(int) The dtype of output tensor.");
AddComment(R"DOC(
Autograd primitive fill_constant_p operator.
)DOC");
}
};
class FillConstantPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape);
}
};
class FillConstantPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto y_name = Output(ctx, "Y")[0];
auto data_type = static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, ctx->GetAttr("dtype")));
SetDataType(ctx, y_name, data_type);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(fill_constant_p, paddle::operators::FillConstantPrimOp,
paddle::operators::FillConstantPrimOpMaker,
paddle::operators::FillConstantPrimOpShapeInference,
paddle::operators::FillConstantPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class GatherPrimOp : public framework::OperatorBase {
public:
GatherPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator gather_p should not be excuted directly"));
}
};
class GatherPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of gather_p op.");
AddInput("IndexTensor",
"(Tensor), The index tensor of gather_p op, which is a 1D tensor.")
.AsDispensable();
AddOutput("Y", "(Tensor), The output tensor of gather_p op.");
AddAttr<int64_t>("axis", "(int64_t), The axis along which to gather.");
AddAttr<std::vector<int64_t>>(
"index", "(std::vector<int64_t>) The index of gather_p op")
.SetDefault({0});
AddComment(R"DOC(
Autograd primitive gather_p operator.
)DOC");
}
};
class GatherPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
int64_t num_index = 0;
if (ctx->HasInput("IndexTensor")) {
framework::InferShapeVarPtr index_var_ptr =
ctx->GetInputVarPtrs("IndexTensor")[0];
framework::VarDesc *index_var =
BOOST_GET(framework::VarDesc *, index_var_ptr);
auto index_shape = index_var->GetShape();
PADDLE_ENFORCE_EQ(index_shape.size(), 1,
platform::errors::InvalidArgument(
"The index tensor should be a 1D tensor,"
"but get rank %d",
index_shape.size()));
num_index = index_shape[0];
} else {
num_index = ctx->Attrs().Get<std::vector<int64_t>>("index").size();
}
auto axis = ctx->Attrs().Get<int64_t>("axis");
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
x_shape[axis] = num_index;
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape);
}
};
class GatherPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
if (ctx->HasInput("IndexTensor")) {
auto index_name = Input(ctx, "IndexTensor")[0];
auto index_dtype = GetDataType(ctx, index_name);
PADDLE_ENFORCE_EQ(
index_dtype, framework::proto::VarType_Type_INT32,
platform::errors::InvalidArgument(
"The datatype of input tensor should be VarType_Type_INT32(%d), "
"but get %d",
framework::proto::VarType_Type_INT32, index_dtype));
}
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(gather_p, paddle::operators::GatherPrimOp,
paddle::operators::GatherPrimOpMaker,
paddle::operators::GatherPrimOpShapeInference,
paddle::operators::GatherPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class MatmulPrimOp : public framework::OperatorBase {
public:
MatmulPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator matmul_p should not be excuted directly"));
}
};
class MatmulPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of matmul_p op.");
AddInput("Y", "(Tensor), The input tensor of matmul_p op.");
AddOutput("Z", "(Tensor), The output tensor of matmul_p op.");
AddComment(R"DOC(
Autograd primitive matmul_p operator.
)DOC");
}
};
class MatmulPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The two input tensor's dimension should be equal"
"But received first input tensor's dimension is %d, "
"and another input tensor's dimension is %d",
x_rank, y_rank));
PADDLE_ENFORCE_EQ(x_rank == 2 || x_rank == 3, true,
platform::errors::InvalidArgument(
"The input tensor's dimension should be 2 or 3"
"But received input tensor's dimension is %d",
x_rank));
PADDLE_ENFORCE_EQ(
x_shape[x_rank - 1], y_shape[y_rank - 2],
platform::errors::InvalidArgument(
"Invalid shape for matmul, the last dimension of first input and "
"the penultimate dimension for the second input should be same."
"But received %d and %d.",
x_shape[x_rank - 1], y_shape[y_rank - 2]));
if (x_rank == 2) {
std::vector<int64_t> z_shape{x_shape[x_rank - 2], y_shape[y_rank - 1]};
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape);
} else {
PADDLE_ENFORCE_EQ(x_shape[0], y_shape[0],
platform::errors::InvalidArgument(
"Invalid shape for matmul when input tensor's "
"dimension is 3, the first dimension of first "
"input and the second input should be same."
"But received %d and %d.",
x_shape[0], y_shape[0]));
std::vector<int64_t> z_shape{x_shape[0], x_shape[x_rank - 2],
y_shape[y_rank - 1]};
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape);
}
}
};
class MatmulPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(matmul_p, paddle::operators::MatmulPrimOp,
paddle::operators::MatmulPrimOpMaker,
paddle::operators::MatmulPrimOpShapeInference,
paddle::operators::MatmulPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class MulPrimOp : public framework::OperatorBase {
public:
MulPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator mul_p should not be excuted directly"));
}
};
class MulPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of mul_p op.");
AddInput("Y", "(Tensor), The input tensor of mul_p op.");
AddOutput("Z", "(Tensor), The output tensor of mul_p op.");
AddComment(R"DOC(
Autograd primitive mul_p operator.
)DOC");
}
};
class MulPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_shape[i], y_shape[i]));
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class MulPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(mul_p, paddle::operators::MulPrimOp,
paddle::operators::MulPrimOpMaker,
paddle::operators::MulPrimOpShapeInference,
paddle::operators::MulPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
USE_OP_ITSELF(reshape_p);
USE_OP_ITSELF(broadcast_p);
USE_OP_ITSELF(reduce_p);
USE_OP_ITSELF(transpose_p);
USE_OP_ITSELF(split_p);
USE_OP_ITSELF(concat_p);
USE_OP_ITSELF(slice_select_p);
USE_OP_ITSELF(slice_assign_p);
USE_OP_ITSELF(gather_p);
USE_OP_ITSELF(scatter_add_p);
USE_OP_ITSELF(add_p);
USE_OP_ITSELF(sub_p);
USE_OP_ITSELF(mul_p);
USE_OP_ITSELF(div_p);
USE_OP_ITSELF(sqrt_p);
USE_OP_ITSELF(tanh_p);
USE_OP_ITSELF(matmul_p);
USE_OP_ITSELF(fill_constant_p);
namespace paddle {
namespace framework {
static void NewVar(BlockDesc *block, const std::string &name,
const std::vector<int64_t> &shape) {
auto *var_desc = block->Var(name);
if (shape.size() > 0) {
var_desc->SetShape(shape);
var_desc->SetType(proto::VarType::LOD_TENSOR);
var_desc->SetDataType(proto::VarType_Type_FP32);
}
}
static void AppendOp(BlockDesc *block, const std::string &type,
VariableNameMap inputs, VariableNameMap outputs,
AttributeMap attrs) {
auto &op_info = OpInfoMap::Instance().Get(type);
if (op_info.Checker()) {
op_info.Checker()->Check(&attrs);
}
auto *op = block->AppendOp();
op->SetType(type);
for (auto &pair : inputs) {
op->SetInput(pair.first, pair.second);
}
for (auto &pair : outputs) {
op->SetOutput(pair.first, pair.second);
for (auto &var_name : pair.second) {
if (!block->FindVarRecursive(var_name)) {
NewVar(block, var_name, {});
}
}
}
op->SetAttrMap(attrs);
op->InferVarType(block);
op->InferShape(*block);
}
TEST(PrimOp, reshape_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "reshape_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"shape", std::vector<int64_t>{12, 5}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 2UL);
ASSERT_EQ(shapes[0], 12L);
ASSERT_EQ(shapes[1], 5L);
}
TEST(PrimOp, broadcast_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 1};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "broadcast_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"shape", std::vector<int64_t>{3, 4, 5}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, reduce_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape);
AppendOp(block, "reduce_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", false}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 1UL);
ASSERT_EQ(shapes[0], 4L);
AppendOp(block, "reduce_p", {{"X", {x0}}}, {{"Y", {x2}}},
{{"axis", std::vector<int64_t>{0, 2}}, {"keepdim", true}});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 1L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 1L);
}
TEST(PrimOp, transpose_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "transpose_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"axis", std::vector<int64_t>{2, 1, 0}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 5L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 3L);
}
TEST(PrimOp, split_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{6, 8, 10};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
std::string x3 = "x3";
NewVar(block, x0, shape);
AppendOp(block, "split_p", {{"X", {x0}}}, {{"YS", {x1, x2, x3}}},
{{"axis", int64_t{1}},
{"num_or_sections", std::vector<int64_t>{2, 4, 2}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 2L);
ASSERT_EQ(shapes[2], 10L);
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 10L);
ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x3")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 2L);
ASSERT_EQ(shapes[2], 10L);
std::string x4 = "x4";
std::string x5 = "x5";
AppendOp(
block, "split_p", {{"X", {x0}}}, {{"YS", {x4, x5}}},
{{"axis", int64_t{2}}, {"num_or_sections", std::vector<int64_t>{2}}});
ASSERT_EQ(block->Var("x4")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x4")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x4")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 8L);
ASSERT_EQ(shapes[2], 5L);
ASSERT_EQ(block->Var("x5")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x5")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x5")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 8L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, concat_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape_0{3, 1, 5};
std::vector<int64_t> shape_1{3, 4, 5};
std::vector<int64_t> shape_2{3, 6, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
std::string x3 = "x3";
NewVar(block, x0, shape_0);
NewVar(block, x1, shape_1);
NewVar(block, x2, shape_2);
AppendOp(block, "concat_p", {{"XS", {x0, x1, x2}}}, {{"Y", {x3}}},
{{"axis", int64_t{1}}});
ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x3")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 11L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, slice_select_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{6, 8, 10};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "slice_select_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"axis", std::vector<int64_t>{0, 1, 2}},
{"starts", std::vector<int64_t>{0, 0, 0}},
{"ends", std::vector<int64_t>{5, 7, 9}},
{"strides", std::vector<int64_t>{2, 2, 2}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, slice_assign_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape_0{6, 8, 10};
std::vector<int64_t> shape_1{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape_0);
NewVar(block, x1, shape_1);
AppendOp(block, "slice_assign_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}},
{{"axis", std::vector<int64_t>{0, 1, 2}},
{"starts", std::vector<int64_t>{0, 0, 0}},
{"ends", std::vector<int64_t>{5, 7, 9}},
{"strides", std::vector<int64_t>{2, 2, 2}}});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 8L);
ASSERT_EQ(shapes[2], 10L);
}
TEST(PrimOp, gather_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{6, 8, 10};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "gather_p", {{"X", {x0}}}, {{"Y", {x1}}},
{{"axis", int64_t{1}}, {"index", std::vector<int64_t>{0, 2, 5}}});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 3L);
ASSERT_EQ(shapes[2], 10L);
std::string index_t = "index_t";
std::string x2 = "x2";
auto *var_desc = block->Var(index_t);
var_desc->SetShape(std::vector<int64_t>{3});
var_desc->SetType(proto::VarType::LOD_TENSOR);
var_desc->SetDataType(proto::VarType_Type_INT32);
AppendOp(block, "gather_p", {{"X", {x0}}, {"IndexTensor", {index_t}}},
{{"Y", {x2}}}, {{"axis", int64_t{1}}});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 3L);
ASSERT_EQ(shapes[2], 10L);
}
TEST(PrimOp, scatter_add_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape_0{6, 8, 10};
std::vector<int64_t> shape_1{6, 3, 10};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape_0);
NewVar(block, x1, shape_1);
AppendOp(block, "scatter_add_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}},
{{"axis", int64_t{1}}, {"index", std::vector<int64_t>{0, 2, 5}}});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 8L);
ASSERT_EQ(shapes[2], 10L);
std::string index_t = "index_t";
std::string x3 = "x3";
auto *var_desc = block->Var(index_t);
var_desc->SetShape(std::vector<int64_t>{3});
var_desc->SetType(proto::VarType::LOD_TENSOR);
var_desc->SetDataType(proto::VarType_Type_INT32);
AppendOp(block, "scatter_add_p",
{{"X", {x0}}, {"Y", {x1}}, {"IndexTensor", {index_t}}},
{{"Z", {x3}}}, {{"axis", int64_t{1}}});
ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x3")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 6L);
ASSERT_EQ(shapes[1], 8L);
ASSERT_EQ(shapes[2], 10L);
}
TEST(PrimOp, add_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape);
NewVar(block, x1, shape);
AppendOp(block, "add_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, sub_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape);
NewVar(block, x1, shape);
AppendOp(block, "sub_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, mul_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape);
NewVar(block, x1, shape);
AppendOp(block, "mul_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, div_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape);
NewVar(block, x1, shape);
AppendOp(block, "div_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, sqrt_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "sqrt_p", {{"X", {x0}}}, {{"Y", {x1}}}, {});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, tanh_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape{3, 4, 5};
std::string x0 = "x0";
std::string x1 = "x1";
NewVar(block, x0, shape);
AppendOp(block, "tanh_p", {{"X", {x0}}}, {{"Y", {x1}}}, {});
ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x1")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
TEST(PrimOp, matmul_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::vector<int64_t> shape_0{3, 4, 5};
std::vector<int64_t> shape_1{3, 5, 8};
std::string x0 = "x0";
std::string x1 = "x1";
std::string x2 = "x2";
NewVar(block, x0, shape_0);
NewVar(block, x1, shape_1);
AppendOp(block, "matmul_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {});
ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x2")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 8L);
std::vector<int64_t> shape_2{4, 5};
std::vector<int64_t> shape_3{5, 8};
std::string x3 = "x3";
std::string x4 = "x4";
std::string x5 = "x5";
NewVar(block, x3, shape_2);
NewVar(block, x4, shape_3);
AppendOp(block, "matmul_p", {{"X", {x3}}, {"Y", {x4}}}, {{"Z", {x5}}}, {});
ASSERT_EQ(block->Var("x5")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x5")->GetDataType(), proto::VarType_Type_FP32);
shapes = block->Var("x5")->GetShape();
ASSERT_EQ(shapes.size(), 2UL);
ASSERT_EQ(shapes[0], 4L);
ASSERT_EQ(shapes[1], 8L);
}
TEST(PrimOp, fill_constant_p) {
ProgramDesc program;
auto *block = program.MutableBlock(0);
std::string x0 = "x0";
AppendOp(block, "fill_constant_p", {{}}, {{"Y", {x0}}},
{{"value", 0.0f},
{"dtype", proto::VarType_Type_FP32},
{"shape", std::vector<int64_t>{3, 4, 5}}});
ASSERT_EQ(block->Var("x0")->GetType(), proto::VarType::LOD_TENSOR);
ASSERT_EQ(block->Var("x0")->GetDataType(), proto::VarType_Type_FP32);
auto shapes = block->Var("x0")->GetShape();
ASSERT_EQ(shapes.size(), 3UL);
ASSERT_EQ(shapes[0], 3L);
ASSERT_EQ(shapes[1], 4L);
ASSERT_EQ(shapes[2], 5L);
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class ReducePrimOp : public framework::OperatorBase {
public:
ReducePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator reduce_p should not be excuted directly"));
}
};
class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of reduce_p op.");
AddOutput("Y", "(Tensor), The output tensor of reduce_p op.");
AddAttr<std::vector<int64_t>>(
"axis",
"(std::vector<int64_t>) The axis along which to reduce on. Must be in "
"range [-rank(input), rank(input)]. If `axis[i] < 0`, the axis[i] to "
"reduce is `rank + axis[i]`.");
AddAttr<bool>("keepdim",
"(bool, default false) "
"If true, retain the reduced axis with length 1.")
.SetDefault(false);
AddComment(R"DOC(
Autograd primitive reduce_p operator.
)DOC");
}
};
class ReducePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto axis = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto keepdim = ctx->Attrs().Get<bool>("keepdim");
if (keepdim) {
for (size_t i = 0; i < axis.size(); ++i) {
x_shape[axis[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < axis.size(); ++i) {
x_shape[axis[i]] = kDelFlag;
}
x_shape.erase(remove(x_shape.begin(), x_shape.end(), kDelFlag),
x_shape.end());
}
if (!keepdim && x_shape.size() == 0) {
x_shape.push_back(1);
}
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape);
}
};
class ReducePrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(reduce_p, paddle::operators::ReducePrimOp,
paddle::operators::ReducePrimOpMaker,
paddle::operators::ReducePrimOpShapeInference,
paddle::operators::ReducePrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class ReshapePrimOp : public framework::OperatorBase {
public:
ReshapePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator reshape_p should not be excuted directly"));
}
};
class ReshapePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of reshape_p op.");
AddOutput("Y", "(Tensor), The output tensor of reshape_p op.");
AddAttr<std::vector<int64_t>>(
"shape", "(std::vector<int64_t>) Target shape of reshape_p operator.");
AddComment(R"DOC(
Autograd primitive reshape_p operator.
)DOC");
}
};
static int64_t product(const std::vector<int64_t> &shape) {
int64_t rslt = 1;
for (size_t i = 0; i < shape.size(); ++i) {
rslt *= shape[i];
}
return rslt;
}
class ReshapePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
PADDLE_ENFORCE_EQ(product(x_shape), product(shape),
platform::errors::InvalidArgument(
"The input tensor can't be reshaped to target shape, "
"the input tensor has %d elements but target shape "
"contains %d elements",
product(x_shape), product(shape)));
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape);
}
};
class ReshapePrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(reshape_p, paddle::operators::ReshapePrimOp,
paddle::operators::ReshapePrimOpMaker,
paddle::operators::ReshapePrimOpShapeInference,
paddle::operators::ReshapePrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class ScatterAddPrimOp : public framework::OperatorBase {
public:
ScatterAddPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator scatter_add_p should not be excuted directly"));
}
};
class ScatterAddPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The tensor to apply scatter rule and add on.");
AddInput("Y", "(Tensor), The source tensor of scatter_add_p op.");
AddInput(
"IndexTensor",
"(Tensor), The index tensor of scatter_add_p op, which is a 1D tensor.")
.AsDispensable();
AddOutput("Z", "(Tensor), The output tensor of scatter_add_p op.");
AddAttr<int64_t>("axis",
"(int64_t), The axis along which to scatter and add.");
AddAttr<std::vector<int64_t>>(
"index", "(std::vector<int64_t>) The index of scatter_add_p op")
.SetDefault({0});
AddComment(R"DOC(
Autograd primitive scatter_add_p operator.
)DOC");
}
};
class ScatterAddPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
int64_t num_index = 0;
if (ctx->HasInput("IndexTensor")) {
framework::InferShapeVarPtr index_var_ptr =
ctx->GetInputVarPtrs("IndexTensor")[0];
framework::VarDesc *index_var =
BOOST_GET(framework::VarDesc *, index_var_ptr);
auto index_shape = index_var->GetShape();
PADDLE_ENFORCE_EQ(index_shape.size(), 1,
platform::errors::InvalidArgument(
"The index tensor should be a 1D tensor,"
"but get rank %d",
index_shape.size()));
num_index = index_shape[0];
} else {
num_index = ctx->Attrs().Get<std::vector<int64_t>>("index").size();
}
auto axis = ctx->Attrs().Get<int64_t>("axis");
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
PADDLE_ENFORCE_EQ(y_shape[axis], num_index,
platform::errors::InvalidArgument(
"The shape of source input tensor at scatter axis "
"should be equal to num_index, "
"but get %d and %d",
y_shape[axis], num_index));
for (size_t i = 0; i < x_rank; ++i) {
if (i != size_t(axis)) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_rank, y_rank));
}
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class ScatterAddPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
if (ctx->HasInput("IndexTensor")) {
auto index_name = Input(ctx, "IndexTensor")[0];
auto index_dtype = GetDataType(ctx, index_name);
PADDLE_ENFORCE_EQ(
index_dtype, framework::proto::VarType_Type_INT32,
platform::errors::InvalidArgument(
"The datatype of input tensor should be VarType_Type_INT32(%d), "
"but get %d",
framework::proto::VarType_Type_INT32, index_dtype));
}
SetType(ctx, z_name, GetType(ctx, x_name));
SetDataType(ctx, z_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(scatter_add_p, paddle::operators::ScatterAddPrimOp,
paddle::operators::ScatterAddPrimOpMaker,
paddle::operators::ScatterAddPrimOpShapeInference,
paddle::operators::ScatterAddPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class SliceAssignPrimOp : public framework::OperatorBase {
public:
SliceAssignPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator slice_assign_p should not be excuted directly"));
}
};
class SliceAssignPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The tensor to slice from and assign on.");
AddInput("Y", "(Tensor), The source tensor of slice_assign_p op.");
AddOutput("Z", "(Tensor), The output tensor of slice_assign_p op.");
AddAttr<std::vector<int64_t>>(
"axis", "(std::vector<int64_t>), The axis along which to gather.");
AddAttr<std::vector<int64_t>>(
"starts",
"(std::vector<int64_t>) The slice starts of slice_assign_p op");
AddAttr<std::vector<int64_t>>(
"ends", "(std::vector<int64_t>) The slice ends of slice_assign_p op");
AddAttr<std::vector<int64_t>>(
"strides",
"(std::vector<int64_t>) The slice strides of slice_assign_p op");
AddComment(R"DOC(
Autograd primitive slice_assign_p operator.
)DOC");
}
};
class SliceAssignPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
auto axis = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto starts = ctx->Attrs().Get<std::vector<int64_t>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int64_t>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int64_t>>("strides");
PADDLE_ENFORCE_EQ(
starts.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of starts attribute and axis attribute should be same, "
"but get %d and %d",
starts.size(), axis.size()));
PADDLE_ENFORCE_EQ(
ends.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of ends attribute and axis attribute should be same, "
"but get %d and %d",
ends.size(), axis.size()));
PADDLE_ENFORCE_EQ(
strides.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of strides attribute and axis attribute should be same, "
"but get %d and %d",
strides.size(), axis.size()));
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
std::vector<int64_t> y_target_shape(x_shape);
for (size_t i = 0; i < axis.size(); ++i) {
y_target_shape[axis[i]] =
(ends[i] - starts[i] + strides[i] - 1) / strides[i];
}
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(y_target_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of source tensor of slice_assign_p op "
"at dimension %d should be %d, "
"but get %d",
i, y_target_shape[i], y_shape[i]));
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class SliceAssignPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, GetType(ctx, x_name));
SetDataType(ctx, z_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(slice_assign_p, paddle::operators::SliceAssignPrimOp,
paddle::operators::SliceAssignPrimOpMaker,
paddle::operators::SliceAssignPrimOpShapeInference,
paddle::operators::SliceAssignPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class SliceSelectPrimOp : public framework::OperatorBase {
public:
SliceSelectPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator slice_select_p should not be excuted directly"));
}
};
class SliceSelectPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of slice_select_p op.");
AddOutput("Y", "(Tensor), The output tensor of slice_select_p op.");
AddAttr<std::vector<int64_t>>(
"axis", "(std::vector<int64_t>), The axis along which to gather.");
AddAttr<std::vector<int64_t>>(
"starts",
"(std::vector<int64_t>) The slice starts of slice_select_p op");
AddAttr<std::vector<int64_t>>(
"ends", "(std::vector<int64_t>) The slice ends of slice_select_p op");
AddAttr<std::vector<int64_t>>(
"strides",
"(std::vector<int64_t>) The slice strides of slice_select_p op");
AddComment(R"DOC(
Autograd primitive slice_select_p operator.
)DOC");
}
};
class SliceSelectPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto axis = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto starts = ctx->Attrs().Get<std::vector<int64_t>>("starts");
auto ends = ctx->Attrs().Get<std::vector<int64_t>>("ends");
auto strides = ctx->Attrs().Get<std::vector<int64_t>>("strides");
PADDLE_ENFORCE_EQ(
starts.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of starts attribute and axis attribute should be same, "
"but get %d and %d",
starts.size(), axis.size()));
PADDLE_ENFORCE_EQ(
ends.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of ends attribute and axis attribute should be same, "
"but get %d and %d",
ends.size(), axis.size()));
PADDLE_ENFORCE_EQ(
strides.size(), axis.size(),
platform::errors::InvalidArgument(
"Number of strides attribute and axis attribute should be same, "
"but get %d and %d",
strides.size(), axis.size()));
for (size_t i = 0; i < axis.size(); ++i) {
x_shape[axis[i]] = (ends[i] - starts[i] + strides[i] - 1) / strides[i];
}
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape);
}
};
class SliceSelectPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(slice_select_p, paddle::operators::SliceSelectPrimOp,
paddle::operators::SliceSelectPrimOpMaker,
paddle::operators::SliceSelectPrimOpShapeInference,
paddle::operators::SliceSelectPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class SplitPrimOp : public framework::OperatorBase {
public:
SplitPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator split_p should not be excuted directly"));
}
};
class SplitPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of split_p op.");
AddOutput("YS", "(Tensor), The output tensors of split_p op.")
.AsDuplicable();
AddAttr<int64_t>("axis", "(int64_t), The axis along which to split.");
AddAttr<std::vector<int64_t>>(
"num_or_sections",
"(std::vector<int64_t>) If num_or_sections has only one element, then "
"num_or_sections indicates the number of equal sized sub-Tensors that "
"the input will be divided into. If num_or_sections has more then one "
"element, the length of it indicates the number of sub-Tensors and the "
"elements in it indicate the sizes of sub-Tensors’ dimension orderly. "
"The length of the vector must not be larger than the input's size of "
"specified axis.");
AddComment(R"DOC(
Autograd primitive split_p operator.
)DOC");
}
};
class SplitPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
auto y_var_ptrs = ctx->GetOutputVarPtrs("YS");
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto axis = ctx->Attrs().Get<int64_t>("axis");
auto num_or_sections =
ctx->Attrs().Get<std::vector<int64_t>>("num_or_sections");
std::vector<int64_t> y_shape(x_shape);
if (num_or_sections.size() == 1) {
PADDLE_ENFORCE_EQ(x_shape[axis] % num_or_sections[0], 0,
platform::errors::InvalidArgument(
"The input tensor can't be devided equally into %d "
"parts equally along axis %d",
num_or_sections[0], axis));
y_shape[axis] = x_shape[axis] / num_or_sections[0];
for (size_t i = 0; i < size_t(num_or_sections[0]); ++i) {
BOOST_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape);
}
} else {
int64_t cnt_along_axis = 0;
for (size_t i = 0; i < num_or_sections.size(); ++i) {
y_shape[axis] = num_or_sections[i];
cnt_along_axis += num_or_sections[i];
BOOST_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape);
}
PADDLE_ENFORCE_EQ(
x_shape[axis], cnt_along_axis,
platform::errors::InvalidArgument(
"The input tensor has %d elements along axis %d, thus can't be "
"devided into %d tensor with %d elements totally.",
x_shape[axis], axis, num_or_sections.size(), cnt_along_axis));
}
}
};
class SplitPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_names = Output(ctx, "YS");
for (auto y_name : y_names) {
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(split_p, paddle::operators::SplitPrimOp,
paddle::operators::SplitPrimOpMaker,
paddle::operators::SplitPrimOpShapeInference,
paddle::operators::SplitPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class SqrtPrimOp : public framework::OperatorBase {
public:
SqrtPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator sqrt_p should not be excuted directly"));
}
};
class SqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of sqrt_p op.");
AddOutput("Y", "(Tensor), The output tensor of sqrt_p op.");
AddComment(R"DOC(
Autograd primitive sqrt_p operator.
)DOC");
}
};
class SqrtPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class SqrtPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(sqrt_p, paddle::operators::SqrtPrimOp,
paddle::operators::SqrtPrimOpMaker,
paddle::operators::SqrtPrimOpShapeInference,
paddle::operators::SqrtPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class SubPrimOp : public framework::OperatorBase {
public:
SubPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator sub_p should not be excuted directly"));
}
};
class SubPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of sub_p op.");
AddInput("Y", "(Tensor), The input tensor of sub_p op.");
AddOutput("Z", "(Tensor), The output tensor of sub_p op.");
AddComment(R"DOC(
Autograd primitive sub_p operator.
)DOC");
}
};
class SubPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_shape[i], y_shape[i]));
}
BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};
class SubPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));
SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(sub_p, paddle::operators::SubPrimOp,
paddle::operators::SubPrimOpMaker,
paddle::operators::SubPrimOpShapeInference,
paddle::operators::SubPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class TanhPrimOp : public framework::OperatorBase {
public:
TanhPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator tanh_p should not be excuted directly"));
}
};
class TanhPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of tanh_p op.");
AddOutput("Y", "(Tensor), The output tensor of tanh_p op.");
AddComment(R"DOC(
Autograd primitive tanh_p operator.
)DOC");
}
};
class TanhPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape());
}
};
class TanhPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(tanh_p, paddle::operators::TanhPrimOp,
paddle::operators::TanhPrimOpMaker,
paddle::operators::TanhPrimOpShapeInference,
paddle::operators::TanhPrimOpVarTypeInference);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class TransposePrimOp : public framework::OperatorBase {
public:
TransposePrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator transpose_p should not be excuted directly"));
}
};
class TransposePrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of transpose_p op.");
AddOutput("Y", "(Tensor), The output tensor of transpose_p op.");
AddAttr<std::vector<int64_t>>("axis",
"(std::vector<int64_t>) Tanspose axis.");
AddComment(R"DOC(
Autograd primitive transpose_p operator.
)DOC");
}
};
class TransposePrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto axis = ctx->Attrs().Get<std::vector<int64_t>>("axis");
size_t x_rank = x_shape.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank, axis_size,
platform::errors::InvalidArgument(
"The input tensor's dimension "
"should be equal to the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank, axis_size));
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE_GE(axis[i], 0,
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."
"But received %d of axis[%d]",
axis[i], i));
PADDLE_ENFORCE_EQ(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1, true,
platform::errors::InvalidArgument(
"Each element of Attribute axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received axis[%d] is %d, axis_size is %d, "
"count[axis[%d]] is %d",
i, axis[i], axis_size, i, count[axis[i]]));
}
std::vector<int64_t> y_shape(axis_size);
for (size_t i = 0; i < axis_size; i++) {
y_shape[i] = x_shape[axis[i]];
}
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape);
}
};
class TransposePrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(transpose_p, paddle::operators::TransposePrimOp,
paddle::operators::TransposePrimOpMaker,
paddle::operators::TransposePrimOpShapeInference,
paddle::operators::TransposePrimOpVarTypeInference);
register_unity_group(cc
reshape_p_op.cc
broadcast_p_op.cc
reduce_p_op.cc
transpose_p_op.cc
split_p_op.cc
concat_p_op.cc
slice_select_p_op.cc
slice_assign_p_op.cc
gather_p_op.cc
scatter_add_p_op.cc
add_p_op.cc
sub_p_op.cc
mul_p_op.cc
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册