未验证 提交 79f8eeca 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Add selected_rows kernel for Full (#39465)

* Add selected_rows kernel for full

* remove fill_constant register in fluid

* fix bug without GPU

* add jit_kernel_helper dependency for fc

* do some refactor

* add unittest for ops signatures

* add coverage unittest

* fix merge conflict

* fix full selectew_rows bug
上级 eec6ef81
...@@ -288,7 +288,7 @@ function(append_op_util_declare TARGET) ...@@ -288,7 +288,7 @@ function(append_op_util_declare TARGET)
string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}") string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}") string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");") string(APPEND util_declare ");\n")
file(APPEND ${op_utils_header} "${util_declare}") file(APPEND ${op_utils_header} "${util_declare}")
endfunction() endfunction()
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(fill_constant); USE_OP_ITSELF(fill_constant);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h"
USE_OP(fill_constant); USE_OP_ITSELF(fill_constant);
USE_OP(uniform_random); USE_OP(uniform_random);
USE_OP(lookup_table); USE_OP(lookup_table);
USE_OP(transpose2); USE_OP(transpose2);
......
...@@ -178,16 +178,6 @@ REGISTER_OPERATOR( ...@@ -178,16 +178,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
REGISTER_OP_VERSION(fill_constant) REGISTER_OP_VERSION(fill_constant)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
...@@ -39,7 +39,7 @@ if (WITH_ASCEND_CL) ...@@ -39,7 +39,7 @@ if (WITH_ASCEND_CL)
else() else()
math_library(beam_search DEPS math_function) math_library(beam_search DEPS math_function)
endif() endif()
math_library(fc DEPS blas) math_library(fc DEPS blas jit_kernel_helper)
math_library(matrix_bit_code) math_library(matrix_bit_code)
math_library(unpooling) math_library(unpooling)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/infermeta/nullary.h" #include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/empty_kernel.h"
...@@ -30,6 +31,13 @@ void FullKernel(const Context& dev_ctx, ...@@ -30,6 +31,13 @@ void FullKernel(const Context& dev_ctx,
DataType dtype, DataType dtype,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void FullSR(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out);
template <typename T, typename Context> template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx, void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/full_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/backends/gpu/gpu_context.h"
#endif
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/common/complex.h"
namespace pten {
template <typename T, typename Context>
void FullSR(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out) {
pten::FullKernel<T>(dev_ctx, shape, val, dtype, out->mutable_value());
}
} // namespace pten
PT_REGISTER_KERNEL(full_sr,
CPU,
ALL_LAYOUT,
pten::FullSR,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
pten::dtype::float16,
pten::dtype::bfloat16,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(full_sr,
GPU,
ALL_LAYOUT,
pten::FullSR,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
pten::dtype::float16,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}
#endif
...@@ -66,6 +66,57 @@ KernelSignature FillConstantOpArgumentMapping( ...@@ -66,6 +66,57 @@ KernelSignature FillConstantOpArgumentMapping(
} }
} }
} }
} else if (ctx.IsSelectedRowsOutput("Out")) {
if (ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "ValueTensor", "dtype"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "value", "dtype"}, {"Out"});
} else {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "str_value", "dtype"}, {"Out"});
}
}
} else if (ctx.InputSize("ShapeTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("full_sr",
{},
{"ShapeTensorList", "ValueTensor", "dtype"},
{"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"ShapeTensorList", "value", "dtype"}, {"Out"});
} else {
return KernelSignature("full_sr",
{},
{"ShapeTensorList", "str_value", "dtype"},
{"Out"});
}
}
} else {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full_sr", {}, {"shape", "ValueTensor", "dtype"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"shape", "value", "dtype"}, {"Out"});
} else {
return KernelSignature(
"full_sr", {}, {"shape", "str_value", "dtype"}, {"Out"});
}
}
}
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("unregistered", {}, {}, {});
} }
......
...@@ -2,3 +2,4 @@ add_subdirectory(api) ...@@ -2,3 +2,4 @@ add_subdirectory(api)
add_subdirectory(common) add_subdirectory(common)
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(kernels) add_subdirectory(kernels)
add_subdirectory(ops_signature)
cc_test(test_op_signature SRCS test_op_signature.cc DEPS op_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/tests/ops_signature/test_op_signature.h"
#include <gtest/gtest.h>
#include <memory>
#include <unordered_set>
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/ops/compat/signatures.h"
namespace pten {
namespace tests {
// The unittests in this file are just order to pass the CI-Coverage,
// so it isn't necessary to check the all cases.
TEST(ARG_MAP, fill_constant) {
TestArgumentMappingContext arg_case1(
{"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case1);
ASSERT_EQ(signature1.name, "full_sr");
TestArgumentMappingContext arg_case2(
{"ShapeTensor"},
{},
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case2);
ASSERT_EQ(signature2.name, "full_sr");
TestArgumentMappingContext arg_case3(
{"ShapeTensor"},
{},
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case3);
ASSERT_EQ(signature3.name, "full_sr");
TestArgumentMappingContext arg_case4(
{"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature4 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case4);
ASSERT_EQ(signature4.name, "full_sr");
TestArgumentMappingContext arg_case5(
{"ShapeTensorList"},
{},
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature5 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case5);
ASSERT_EQ(signature5.name, "full_sr");
TestArgumentMappingContext arg_case6(
{"ShapeTensorList"},
{},
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature6 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case6);
ASSERT_EQ(signature6.name, "full_sr");
TestArgumentMappingContext arg_case7(
{"ValueTensor"},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}}},
{},
{"Out"});
auto signature7 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case7);
ASSERT_EQ(signature7.name, "full_sr");
TestArgumentMappingContext arg_case8(
{},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}},
{"value", paddle::any{0}},
{"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature8 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case8);
ASSERT_EQ(signature8.name, "full_sr");
TestArgumentMappingContext arg_case9(
{},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}},
{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature9 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case9);
ASSERT_EQ(signature9.name, "full_sr");
}
} // namespace tests
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gtest/gtest.h>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
namespace tests {
class TestArgumentMappingContext : public pten::ArgumentMappingContext {
public:
TestArgumentMappingContext(
std::unordered_set<std::string> dense_tensor_ins,
std::unordered_set<std::string> sr_ins,
std::unordered_map<std::string, paddle::any> op_attrs,
std::unordered_set<std::string> dense_tensor_outs,
std::unordered_set<std::string> sr_outs = {})
: dense_tensor_inputs(dense_tensor_ins),
selected_rows_inputs(sr_ins),
attrs(op_attrs),
dense_tensor_outputs(dense_tensor_outs),
selected_rows_outputs(sr_outs) {}
bool HasInput(const std::string& name) const override {
return dense_tensor_inputs.count(name) > 0 ||
selected_rows_inputs.count(name) > 0;
}
bool HasOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0 ||
selected_rows_outputs.count(name) > 0;
}
bool HasAttr(const std::string& name) const override {
return attrs.count(name) > 0;
}
paddle::any Attr(const std::string& name) const override {
return attrs.at(name);
}
size_t InputSize(const std::string& name) const override {
return dense_tensor_inputs.size() + selected_rows_inputs.size();
}
size_t OutputSize(const std::string& name) const override {
return dense_tensor_outputs.size() + selected_rows_outputs.size();
}
bool IsDenseTensorInput(const std::string& name) const override {
return dense_tensor_inputs.count(name) > 0;
}
bool IsSelectedRowsInput(const std::string& name) const override {
return selected_rows_inputs.count(name) > 0;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
bool IsSelectedRowsOutput(const std::string& name) const override {
return selected_rows_outputs.count(name) > 0;
}
private:
const std::unordered_set<std::string> dense_tensor_inputs;
const std::unordered_set<std::string> selected_rows_inputs;
const std::unordered_map<std::string, paddle::any> attrs;
const std::unordered_set<std::string> dense_tensor_outputs;
const std::unordered_set<std::string> selected_rows_outputs;
};
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册