From 5caa6fc517d5e13c59d142439ddf8113a9108d0e Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Mon, 22 Nov 2021 14:26:33 +0800 Subject: [PATCH] [PTen] Add variable transform to/from ptenTensor and add cast kernel (#36916) * add cast kernel * add cast cuda kernel * add cast kernel * make cast kernel output dtype undefined * get cast dtype from vardesc * move cast to manipulation and add test case * add castinfershape * avoid reinitilaze variable * InitializeVariable support datatype * merge develop branch * fix merge bug * revert modify initializeVariable * revert modify on InitializeVariable * revert modify on InitializeVariable * mutable support reset dtype * enable make pten tensor from variable when def_arg.type is undefined * fix build pten ctx start_idx error * copy pten out tensor to variable * merge develop branch * fix non pten kernel cast failed * add reset allocation place for remake tensor * fix inplace realloc error * add mutable on pten kernles and remove unused cast files * rename function names * fix output type error * fix conflict with develop branch * set data type to variable with pten's dtype * fix test_cast_api type mismatch * densorTensro mutable_data support 0 bytes value * fix the inplace bug of reshape kernel * fix pten.backend != variable.place when moving storage, palce mismatch bug * fix conflict with develop branch * Fix bug of paddle::experimental::MovesStorage * fix ReMakePtenDenseTensor place mismatch bug * Revert "fix ReMakePtenDenseTensor place mismatch bug" This reverts commit 86336032f60b8a15eacd2c1ff2fa513f5d8dfd1a. * fix ReMakePtenDenseTensor place mismatch bug * reverts the set_lod interface, test=develop * modify by the review options * modify error message * add & for const input arguments * add reference in params * elementwise_sub add mutable_data * fix ResetHolderWithType check size bug * add dependence pten_tensor to test_cast_api object * remove unused code to pass ci coverage Co-authored-by: Chen Weihang Co-authored-by: YuanRisheng Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- paddle/fluid/framework/operator.cc | 121 ++++++++++++--- paddle/fluid/framework/operator.h | 2 + paddle/fluid/framework/tensor.cc | 6 +- paddle/fluid/framework/tensor.h | 4 +- paddle/fluid/imperative/prepared_operator.cc | 135 ++++++++++++----- paddle/fluid/operators/cast_op.h | 29 +++- paddle/fluid/operators/reshape_op.cc | 11 +- paddle/pten/api/ext/dispatch.h | 46 ++++++ paddle/pten/api/include/manipulation.h | 2 + paddle/pten/api/lib/manipulation.cc | 34 +++++ paddle/pten/api/lib/utils/storage.h | 13 +- paddle/pten/api/lib/utils/tensor_utils.cc | 140 +++++++++++++++--- paddle/pten/api/lib/utils/tensor_utils.h | 5 + paddle/pten/core/compat_utils.h | 6 +- paddle/pten/core/convert_utils.cc | 20 +++ paddle/pten/core/convert_utils.h | 5 + paddle/pten/core/dense_tensor.cc | 11 +- paddle/pten/core/dense_tensor.h | 6 +- paddle/pten/core/kernel_context.h | 11 +- paddle/pten/core/kernel_utils.h | 1 + paddle/pten/include/manipulation.h | 14 ++ paddle/pten/infermeta/unary.cc | 6 + paddle/pten/infermeta/unary.h | 2 + paddle/pten/kernels/cpu/manipulation.cc | 64 +++++--- paddle/pten/kernels/cpu/manipulation.h | 7 + paddle/pten/kernels/cpu/math.cc | 6 + paddle/pten/kernels/cuda/manipulation.cu | 64 +++++--- paddle/pten/kernels/cuda/manipulation.h | 7 + paddle/pten/kernels/cuda/math.cu | 4 + .../pten/kernels/functions/cpu/elementwise.h | 1 + paddle/pten/kernels/functions/eigen/dot.h | 1 + .../kernels/functions/eigen/elementwise.h | 1 + paddle/pten/kernels/functions/eigen/mean.h | 1 + .../kernels/functions/general/manipulation.h | 3 +- .../pten/kernels/functions/math/cast_func.h | 48 ++++++ paddle/pten/kernels/xpu/manipulation.cc | 3 +- paddle/pten/tests/api/CMakeLists.txt | 2 +- paddle/pten/tests/api/test_cast_api.cc | 69 +++++++++ paddle/pten/tests/kernels/CMakeLists.txt | 1 + .../pten/tests/kernels/test_cast_dev_api.cc | 74 +++++++++ 40 files changed, 837 insertions(+), 149 deletions(-) create mode 100644 paddle/pten/kernels/functions/math/cast_func.h create mode 100644 paddle/pten/tests/api/test_cast_api.cc create mode 100644 paddle/pten/tests/kernels/test_cast_dev_api.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index b5a649c206e..f2615694cfb 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1183,6 +1183,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } BuildPtenKernelContext(*runtime_ctx, dev_ctx); (*pt_kernel_)(pt_kernel_context_.get()); + + WriteBackToOutputs(runtime_ctx); + pt_kernel_context_->ClearData(); } else { (*kernel_func_)( @@ -1808,50 +1811,98 @@ void OperatorWithKernel::BuildPtenKernelContext( for (size_t i = 0; i < input_names.size(); ++i) { auto& in_def = input_defs.at(i); auto& ins_vector = ctx.inputs.at(input_names[i]); - if (pt_kernel_context_->InputsSize() <= i) { + + // calcute the start and end index of the input tensors + size_t start_idx = + (i == 0 ? 0 : pt_kernel_context_->InputRangeAt(i - 1).second); + size_t end_idx = start_idx + ins_vector.size(); + + // The current size of input/output in pt_kernel_context_ is at least equal + // the start_idx. For the reason of reusing the allocted of inputs or + // outputs in pt_kernel_context_, the current size of input/output can be + // greater then the index of which the tensort wanted to set to, so it will + // use ReMakePtenDenseTensorFromVar to make pten tensor. + if (pt_kernel_context_->InputsSize() == start_idx) { paddle::SmallVector> tmp_inputs; for (auto* var : ins_vector) { tmp_inputs.emplace_back( experimental::MakePtenTensorBaseFromVar(*var, in_def)); } pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs)); - } else { + } else if (pt_kernel_context_->InputsSize() > start_idx) { size_t input_size = pt_kernel_context_->InputsSize(); for (size_t j = 0; j < ins_vector.size(); ++j) { - if (input_size > i + j) { + if (input_size > start_idx + j) { experimental::ReMakePtenDenseTensorFromVar( *ins_vector[j], in_def, - pt_kernel_context_->MutableInputAt(i + j)); + pt_kernel_context_->MutableInputAt(start_idx + + j)); + // TODO(chentianyu03): When multi input kernel, open this code + /* + } else { + pt_kernel_context_->EmplaceBackInputWithoutSetRange( + experimental::MakePtenTensorBaseFromVar(*ins_vector[j], + in_def)); + */ } - // TODO(chenweihang): adapt multi-input case later } pt_kernel_context_->MutableInputRangeAt(i) = - std::make_pair(i, i + ins_vector.size()); + std::make_pair(start_idx, end_idx); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Error start index when trying to set new tensor to inputs, start " + "index is `%d`, but current pt_kernel_context_.inputs.size() is " + "`%d`.", + start_idx, pt_kernel_context_->InputsSize())); } } for (size_t i = 0; i < output_names.size(); ++i) { auto& out_def = output_defs.at(i); auto& outs_vector = ctx.outputs.at(output_names[i]); - if (pt_kernel_context_->OutputsSize() <= i) { + + size_t start_idx = + (i == 0 ? 0 : pt_kernel_context_->OutputRangeAt(i - 1).second); + size_t end_idx = start_idx + outs_vector.size(); + + // The current size of input/output in pt_kernel_context_ is at least equal + // the start_idx. For the reason of reusing the allocted of inputs or + // outputs in pt_kernel_context_, the current size of input/output can be + // greater then the index of which the tensort wanted to set to, so it will + // use ReMakePtenDenseTensorFromVar to make pten tensor. + if (pt_kernel_context_->OutputsSize() == start_idx) { paddle::SmallVector> tmp_outputs; for (auto* var : outs_vector) { tmp_outputs.emplace_back( experimental::MakePtenTensorBaseFromVar(var, out_def)); } pt_kernel_context_->EmplaceBackOutputs(std::move(tmp_outputs)); - } else { + } else if (pt_kernel_context_->OutputsSize() > start_idx) { size_t output_size = pt_kernel_context_->OutputsSize(); for (size_t j = 0; j < outs_vector.size(); ++j) { - if (output_size > i + j) { + if (output_size > start_idx + j) { experimental::ReMakePtenDenseTensorFromVar( outs_vector[j], out_def, - pt_kernel_context_->MutableOutputAt(i + j)); + pt_kernel_context_->MutableOutputAt(start_idx + + j)); + + // TODO(chentianyu03): When multi output kernel, open this code + /* + } else { + pt_kernel_context_->EmplaceBackOutputWithoutSetRange( + experimental::MakePtenTensorBaseFromVar(outs_vector[j], + out_def)); + */ } - // TODO(chenweihang): adapt multi-output case later } pt_kernel_context_->MutableOutputRangeAt(i) = - std::make_pair(i, i + outs_vector.size()); + std::make_pair(start_idx, end_idx); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Error start index when trying to set new tensor to inputs, start " + "index is `%d`, but current pt_kernel_context_.outputs.size() is " + "`%d`.", + start_idx, pt_kernel_context_->OutputsSize())); } } @@ -1883,14 +1934,23 @@ void OperatorWithKernel::BuildPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector)) && - std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - // Emplace Back Attr according to the type of Pten_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - pt_kernel_context_->EmplaceBackAttr(vector_int64_attr); + std::type_index(typeid(pten::DataType))) { + auto data_type = pten::TransToPtenDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + pt_kernel_context_->EmplaceBackAttr(data_type); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + pt_kernel_context_->EmplaceBackAttr(vector_int64_attr); + } + // TODO(YuanRisheng) Need support vector attr + } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " @@ -1901,5 +1961,26 @@ void OperatorWithKernel::BuildPtenKernelContext( } } +void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { + // auto& input_names = std::get<0>(pt_kernel_signature_->args); + // auto& attr_names = std::get<1>(pt_kernel_signature_->args); + auto& output_names = std::get<2>(pt_kernel_signature_->args); + + // pt_kernel_context_ + + for (size_t i = 0; i < output_names.size(); ++i) { + auto& outs_vector = ctx->outputs.at(output_names[i]); + + auto& range_pair = pt_kernel_context_->OutputRangeAt(i); + auto pten_outs = + pt_kernel_context_->MutableOutputBetween( + range_pair.first, range_pair.second); + + for (size_t j = 0; j < pten_outs.size(); ++j) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 4c071b777fe..6a5bac393ed 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -589,6 +589,8 @@ class OperatorWithKernel : public OperatorBase { void BuildPtenKernelContext(const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const; + void WriteBackToOutputs(RuntimeContext* ctx) const; + protected: mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_; diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index fbd7aa588d4..8d927b87c9a 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -204,10 +204,12 @@ void Tensor::ResetHolder(std::shared_ptr holder) { } void Tensor::ResetHolderWithType(std::shared_ptr holder, - const proto::VarType::Type type) { - ResetHolder(holder); + const proto::VarType::Type& type) { type_ = type; + ResetHolder(holder); } +void Tensor::set_type(const proto::VarType::Type& type) { type_ = type; } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 539859c45c9..e889de8552d 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -271,7 +271,9 @@ class Tensor { void ResetHolder(std::shared_ptr holder); void ResetHolderWithType(std::shared_ptr holder, - const proto::VarType::Type type); + const proto::VarType::Type& type); + + void set_type(const proto::VarType::Type& type); TensorInplaceVersion& InplaceVersionCounter() { return *inplace_version_counter_; diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 32ee8aceee8..521f85d9429 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -295,7 +295,16 @@ static void BuildDygraphPtenKernelContext( for (size_t i = 0; i < input_names.size(); ++i) { auto& in_def = input_defs.at(i); auto& ins_vector = ins.at(input_names[i]); - if (kernel_ctx->InputsSize() <= i) { + + size_t start_idx = (i == 0 ? 0 : kernel_ctx->InputRangeAt(i - 1).second); + size_t end_idx = start_idx + ins_vector.size(); + + // The current size of input/output in pt_kernel_context_ is at least equal + // the start_idx. For the reason of reusing the allocted of inputs or + // outputs in pt_kernel_context_, the current size of input/output can be + // greater then the index of which the tensort wanted to set to, so it will + // use ReMakePtenDenseTensorFromVar to make pten tensor. + if (kernel_ctx->InputsSize() == start_idx) { paddle::SmallVector> tmp_inputs; for (const auto& var : ins_vector) { const auto& variable = var->Var(); @@ -303,25 +312,45 @@ static void BuildDygraphPtenKernelContext( experimental::MakePtenTensorBaseFromVar(variable, in_def)); } kernel_ctx->EmplaceBackInputs(std::move(tmp_inputs)); - } else { + } else if (kernel_ctx->InputsSize() > start_idx) { size_t input_size = kernel_ctx->InputsSize(); for (size_t j = 0; j < ins_vector.size(); ++j) { - if (input_size > i + j) { + if (input_size > start_idx + j) { experimental::ReMakePtenDenseTensorFromVar( ins_vector[j]->Var(), in_def, - kernel_ctx->MutableInputAt(i + j)); + kernel_ctx->MutableInputAt(start_idx + j)); + // TODO(chentianyu03): When multi input kernel, open this code + /* + } else { + kernel_ctx->EmplaceBackInputWithoutSetRange( + experimental::MakePtenTensorBaseFromVar(ins_vector[j]->Var(), + in_def)); + */ } - // TODO(chenweihang): adapt multi-input case later } - kernel_ctx->MutableInputRangeAt(i) = - std::make_pair(i, i + ins_vector.size()); + kernel_ctx->MutableInputRangeAt(i) = std::make_pair(start_idx, end_idx); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Error start index when trying to set new tensor to inputs, start " + "index is `%d`, but current pt_kernel_context_.inputs.size() is " + "`%d`.", + start_idx, kernel_ctx->InputsSize())); } } for (size_t i = 0; i < output_names.size(); ++i) { auto& out_def = output_defs.at(i); auto& outs_vector = outs.at(output_names[i]); - if (kernel_ctx->OutputsSize() <= i) { + + size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); + size_t end_idx = start_idx + outs_vector.size(); + + // The current size of input/output in pt_kernel_context_ is at least equal + // the start_idx. For the reason of reusing the allocted of inputs or + // outputs in pt_kernel_context_, the current size of input/output can be + // greater then the index of which the tensort wanted to set to, so it will + // use ReMakePtenDenseTensorFromVar to make pten tensor. + if (kernel_ctx->OutputsSize() == start_idx) { paddle::SmallVector> tmp_outputs; for (auto& var : outs_vector) { auto* variable = var->MutableVar(); @@ -329,18 +358,29 @@ static void BuildDygraphPtenKernelContext( experimental::MakePtenTensorBaseFromVar(variable, out_def)); } kernel_ctx->EmplaceBackOutputs(std::move(tmp_outputs)); - } else { + } else if (kernel_ctx->OutputsSize() > start_idx) { size_t output_size = kernel_ctx->OutputsSize(); for (size_t j = 0; j < outs_vector.size(); ++j) { if (output_size > i + j) { experimental::ReMakePtenDenseTensorFromVar( outs_vector[j]->MutableVar(), out_def, kernel_ctx->MutableOutputAt(i + j)); + // TODO(chentianyu03): When multi output kernel, open this code + /* + } else { + kernel_ctx->EmplaceBackOutputWithoutSetRange( + experimental::MakePtenTensorBaseFromVar( + outs_vector[j]->MutableVar(), out_def)); + */ } - // TODO(chenweihang): adapt multi-output case later } - kernel_ctx->MutableOutputRangeAt(i) = - std::make_pair(i, i + outs_vector.size()); + kernel_ctx->MutableOutputRangeAt(i) = std::make_pair(start_idx, end_idx); + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Error start index when trying to set new tensor to inputs, start " + "index is `%d`, but current pt_kernel_context_.outputs.size() is " + "`%d`.", + start_idx, kernel_ctx->OutputsSize())); } } @@ -372,14 +412,22 @@ static void BuildDygraphPtenKernelContext( } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector)) && - std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { - // Emplace Back Attr according to the type of Pten_Kernel args. - const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); - const std::vector vector_int64_attr(vector_int_attr.begin(), - vector_int_attr.end()); - kernel_ctx->EmplaceBackAttr(vector_int64_attr); + std::type_index(typeid(pten::DataType))) { + auto data_type = pten::TransToPtenDataType( + static_cast( + BOOST_GET_CONST(int, attr))); + kernel_ctx->EmplaceBackAttr(data_type); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + // Emplace Back Attr according to the type of Pten_Kernel args. + const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); + const std::vector vector_int64_attr(vector_int_attr.begin(), + vector_int_attr.end()); + kernel_ctx->EmplaceBackAttr(vector_int64_attr); + } + // TODO(YuanRisheng) Need support vector attr } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " @@ -390,6 +438,26 @@ static void BuildDygraphPtenKernelContext( } } +template +static void WriteBackToOutputs( + const framework::KernelSignature& pt_kernel_signature, + const NameVarMap& outs, pten::KernelContext* kernel_ctx) { + auto& output_names = std::get<2>(pt_kernel_signature.args); + + for (size_t i = 0; i < output_names.size(); ++i) { + auto& outs_vector = outs.at(output_names[i]); + + auto& range_pair = kernel_ctx->OutputRangeAt(i); + auto pten_outs = kernel_ctx->MutableOutputBetween( + range_pair.first, range_pair.second); + + for (size_t j = 0; j < pten_outs.size(); ++j) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], + outs_vector[j]->MutableVar()); + } + } +} + template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, @@ -414,19 +482,6 @@ static void PreparedOpRunImpl( op.Type(), outs, dev_ctx->GetPlace()); } - /*For profiling/benchmark only*/ - if (FLAGS_benchmark) { - dev_ctx->Wait(); -#if defined(PADDLE_WITH_CUDA) - PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); - VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error"; -#endif -#if defined(PADDLE_WITH_HIP) - PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError()); - VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error"; -#endif - } - /** * [ Why need handle complex gradient to real gradient? ] * @@ -463,6 +518,20 @@ static void PreparedOpRunPtImpl( pt_kernel(pt_kernel_context); + if (FLAGS_benchmark) { + dev_ctx->Wait(); +#if defined(PADDLE_WITH_CUDA) + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); + VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error"; +#endif +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError()); + VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error"; +#endif + } + + WriteBackToOutputs(pt_kernel_signature, outs, pt_kernel_context); + // Ensure that it does not affect the VarBase life cycle management pt_kernel_context->ClearData(); diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index cd60c7707cb..bf0e81a23bf 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -18,6 +18,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/include/manipulation.h" + namespace paddle { namespace operators { @@ -53,11 +57,26 @@ class CastOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - framework::VisitDataType( - static_cast( - context.Attr("out_dtype")), - CastOpFunctor( - in, out, context.template device_context())); + + auto out_dtype = context.Attr("out_dtype"); + // todo: not used in_dtype + auto in_dtype = context.Attr("in_dtype"); + + auto& dev_ctx = context.device_context(); + out->mutable_data(dev_ctx.GetPlace(), + static_cast(out_dtype)); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*in); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); + + auto pt_out_dtype = pten::TransToPtenDataType( + static_cast(out_dtype)); + auto pt_in_dtype = pten::TransToPtenDataType( + static_cast(in_dtype)); + + // call new kernel + pten::Cast(dev_ctx, *pt_x.get(), pt_out_dtype, pt_in_dtype, + pt_out.get()); } }; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 1a8725bd988..901a25b6f30 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -552,14 +552,13 @@ class Reshape2Op : public ReshapeOp { const framework::ExecutionContext &ctx) const override { auto multi_inputs = ctx.MultiInput("ShapeTensor"); if (multi_inputs.size() > 0) { - return framework::KernelSignature( - "reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"}); + return framework::KernelSignature("reshape2.mulhost", + {"X", "ShapeTensor"}, {}, {"Out"}); } else if (ctx.HasInput("Shape")) { - return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {}, - {"XShape", "Out"}); + return framework::KernelSignature("reshape2.host", {"X", "Shape"}, {}, + {"Out"}); } else { - return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"}, - {"XShape", "Out"}); + return framework::KernelSignature("reshape2", {"X"}, {"shape"}, {"Out"}); } } }; diff --git a/paddle/pten/api/ext/dispatch.h b/paddle/pten/api/ext/dispatch.h index 2b90bd77943..3b40a39af53 100644 --- a/paddle/pten/api/ext/dispatch.h +++ b/paddle/pten/api/ext/dispatch.h @@ -195,4 +195,50 @@ namespace paddle { // TODO(chenweihang): Add more Marcos in the future if needed +#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::pten::DataType::BOOL, bool, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::pten::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::INT16, int16_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::UINT16, uint16_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::INT32, int32_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::UINT32, uint32_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::UINT64, uint64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::pten::DataType::BFLOAT16, \ + paddle::experimental::bfloat16, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::pten::DataType::FLOAT16, \ + paddle::experimental::float16, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::pten::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::pten::DataType::COMPLEX64, \ + paddle::experimental::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::pten::DataType::COMPLEX128, \ + paddle::experimental::complex128, \ + __VA_ARGS__) \ + default: \ + PADDLE_THROW(paddle::platform::errors::InvalidArgument( \ + "Invalid enum data type `%d`.", static_cast(__dtype__))); \ + } \ + }() + } // namespace paddle diff --git a/paddle/pten/api/include/manipulation.h b/paddle/pten/api/include/manipulation.h index e09e113732a..579fa5cdf94 100644 --- a/paddle/pten/api/include/manipulation.h +++ b/paddle/pten/api/include/manipulation.h @@ -21,6 +21,8 @@ namespace experimental { PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis); +PD_DLL_DECL Tensor cast(const Tensor& x, DataType out_dtype); + PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape); } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/manipulation.cc b/paddle/pten/api/lib/manipulation.cc index 3d9dba0458b..62affde5ec2 100644 --- a/paddle/pten/api/lib/manipulation.cc +++ b/paddle/pten/api/lib/manipulation.cc @@ -60,6 +60,40 @@ PD_DLL_DECL Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { return out; } +PD_DLL_DECL Tensor cast(const Tensor& x, DataType out_dtype) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "cast", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + kernel_context.EmplaceBackInput(dense_x); + kernel_context.EmplaceBackAttr(out_dtype); + kernel_context.EmplaceBackAttr(dense_x->meta().dtype); + + // 4. InferShape + auto out_meta = CastInferMeta(dense_x->meta(), out_dtype); + + // 5. Prepare outputs + Tensor out; + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + PD_DLL_DECL Tensor reshape(const Tensor& x, const std::vector& shape) { // 1. Get kernel signature and kernel auto kernel_key_set = ParseKernelKeyByInputArgs(x); diff --git a/paddle/pten/api/lib/utils/storage.h b/paddle/pten/api/lib/utils/storage.h index b05ae4cb084..e98c5a82fed 100644 --- a/paddle/pten/api/lib/utils/storage.h +++ b/paddle/pten/api/lib/utils/storage.h @@ -76,10 +76,6 @@ class SharedStorage : public pten::Storage { // system, we need to allow the SharedStorage realloc, // and it can be removed after the compatibility phase is over in the future void Realloc(size_t n) override { - if (data() != nullptr) { - PADDLE_THROW(paddle::platform::errors::Unavailable( - "The external shared storage cannot be reallocated.")); - } ResetAllocation(paddle::memory::AllocShared(place(), n), 0); } @@ -109,9 +105,16 @@ class SharedStorage : public pten::Storage { size_ = allocation->size(); } + // Temporary method: For compatible with fluid Tensor and improve performance + void ResetAllocationPlace(const paddle::platform::Place& place) { + data_ = pten::Allocation(nullptr, place); + } + // Temporary method: For compatible with fluid Tensor and improve performance void Reset() { - allocation_.reset(); + if (allocation_ != nullptr) { + allocation_.reset(); + } data_.Clear(); size_ = 0; } diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index 4936006d26f..d4be9574783 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -54,6 +54,49 @@ std::unique_ptr MakePtenDenseTensor( std::move(meta)); } +std::unique_ptr MakePtenDenseTensor( + const paddle::framework::Tensor& tensor, + const pten::TensorArgDef& arg_def) { + pten::DenseTensorMeta meta{arg_def.dtype, + tensor.dims(), + pten::TransToPtenDataLayout(tensor.layout())}; + + if (tensor.IsInitialized() && + tensor.place() == pten::TransToFluidPlace(arg_def.backend)) { + auto shared_storage = + pten::make_intrusive(tensor.Holder(), tensor.offset()); + return std::make_unique(std::move(shared_storage), + std::move(meta)); + } else { + return std::make_unique( + std::move(pten::make_intrusive( + pten::TransToFluidPlace(arg_def.backend))), + std::move(meta)); + } +} + +std::unique_ptr MakePtenDenseTensor( + const paddle::framework::LoDTensor& tensor, + const pten::TensorArgDef& arg_def) { + pten::DenseTensorMeta meta{arg_def.dtype, + tensor.dims(), + pten::TransToPtenDataLayout(tensor.layout()), + pten::TransToPtenLoD(tensor.lod())}; + + if (tensor.IsInitialized() && + tensor.place() == pten::TransToFluidPlace(arg_def.backend)) { + auto shared_storage = + pten::make_intrusive(tensor.Holder(), tensor.offset()); + return std::make_unique(std::move(shared_storage), + std::move(meta)); + } else { + return std::make_unique( + std::move(pten::make_intrusive( + pten::TransToFluidPlace(arg_def.backend))), + std::move(meta)); + } +} + std::unique_ptr MakePtenTensorBaseFromVar( const framework::Variable& variable, const pten::TensorArgDef& arg_def) { auto expected_place = pten::TransToFluidPlace(arg_def.backend); @@ -93,17 +136,12 @@ std::unique_ptr MakePtenTensorBaseFromVar( // KernelContext to original tensor if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); - tensor->mutable_data(pten::TransToFluidPlace(arg_def.backend), - pten::TransToProtoVarType(arg_def.dtype)); - return MakePtenDenseTensor(*tensor); + return MakePtenDenseTensor(*tensor, arg_def); } else if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); - tensor->mutable_value()->mutable_data( - pten::TransToFluidPlace(arg_def.backend), - pten::TransToProtoVarType(arg_def.dtype)); // TODO(chenweihang): adapt SelectedRows by xiaowei's design, // here the row and height will lost in output! - return MakePtenDenseTensor(tensor->value()); + return MakePtenDenseTensor(tensor->value(), arg_def); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported shared output `%s` type now when call pt kernel.", @@ -122,6 +160,7 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::Tensor* dst) { platform::errors::InvalidArgument( "The destination Tensor is nullptr when move storage.")); dst->Resize(src->dims()); + dst->set_type(pten::TransToProtoVarType(src->dtype())); auto storage = src->release(); std::shared_ptr holder( new TensorStorage(std::move(storage))); @@ -142,40 +181,53 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst) { } void ReMakePtenDenseTensor(const paddle::framework::Tensor& src, + const pten::TensorArgDef& arg_def, pten::DenseTensor* dst) { auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst); meta->dims = src.dims(); // Since the type of DenseTensorMeta is const, const_cast must be used - const_cast(meta->dtype) = pten::TransToPtenDataType(src.type()); + const_cast(meta->dtype) = arg_def.dtype; // Since the type of DenseTensorMeta is const, const_cast must be used const_cast(meta->layout) = pten::TransToPtenDataLayout(src.layout()); + auto* shared_storage = static_cast( pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst)); PADDLE_ENFORCE_NOT_NULL( shared_storage, platform::errors::NotFound( "Target DenseTensor's shared storage is nullptr.")); - shared_storage->ResetAllocation(src.Holder(), src.offset()); + + if (src.IsInitialized()) { + shared_storage->ResetAllocation(src.Holder(), src.offset()); + } } void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src, + const pten::TensorArgDef& arg_def, pten::DenseTensor* dst) { auto* meta = pten::CompatibleDenseTensorUtils::GetMutableMeta(dst); meta->dims = src.dims(); // Since the type of DenseTensorMeta is const, const_cast must be used - const_cast(meta->dtype) = pten::TransToPtenDataType(src.type()); + const_cast(meta->dtype) = arg_def.dtype; // Since the type of DenseTensorMeta is const, const_cast must be used const_cast(meta->layout) = pten::TransToPtenDataLayout(src.layout()); SetLoD(&(meta->lod), src.lod()); + auto* shared_storage = static_cast( pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(dst)); PADDLE_ENFORCE_NOT_NULL( shared_storage, platform::errors::NotFound( "Target DenseTensor's shared storage is nullptr.")); - shared_storage->ResetAllocation(src.Holder(), src.offset()); + if (src.IsInitialized() && + src.place() == pten::TransToFluidPlace(arg_def.backend)) { + shared_storage->ResetAllocation(src.Holder(), src.offset()); + } else { + shared_storage->ResetAllocationPlace( + pten::TransToFluidPlace(arg_def.backend)); + } } void ReMakePtenDenseTensorFromVar(const framework::Variable& variable, @@ -188,9 +240,9 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable, if (!platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); - ReMakePtenDenseTensor(tmp_tensor, dst); + ReMakePtenDenseTensor(tmp_tensor, arg_def, dst); } else { - ReMakePtenDenseTensor(tensor, dst); + ReMakePtenDenseTensor(tensor, arg_def, dst); } } else if (variable.IsType()) { // TODO(chenweihang): now we don't deal with row and height @@ -200,9 +252,9 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable, framework::Tensor tmp_tensor; TensorCopySync(tensor.value(), expected_place, &tmp_tensor); // TODO(chenweihang): adapt SelectedRows by xiaowei's design - ReMakePtenDenseTensor(tmp_tensor, dst); + ReMakePtenDenseTensor(tmp_tensor, arg_def, dst); } else { - ReMakePtenDenseTensor(tensor.value(), dst); + ReMakePtenDenseTensor(tensor.value(), arg_def, dst); } } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -218,18 +270,12 @@ void ReMakePtenDenseTensorFromVar(framework::Variable* variable, // KernelContext to original tensor if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); - // TODO(chenweihang): use original var type if arg_def.dtype is UNDEFINED - tensor->mutable_data(pten::TransToFluidPlace(arg_def.backend), - pten::TransToProtoVarType(arg_def.dtype)); - ReMakePtenDenseTensor(*tensor, dst); + ReMakePtenDenseTensor(*tensor, arg_def, dst); } else if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); - tensor->mutable_value()->mutable_data( - pten::TransToFluidPlace(arg_def.backend), - pten::TransToProtoVarType(arg_def.dtype)); // TODO(chenweihang): adapt SelectedRows by xiaowei's design, // here the row and height will lost in output! - ReMakePtenDenseTensor(tensor->value(), dst); + ReMakePtenDenseTensor(tensor->value(), arg_def, dst); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported shared output `%s` type now when call pt kernel.", @@ -237,5 +283,53 @@ void ReMakePtenDenseTensorFromVar(framework::Variable* variable, } } +static bool IsSameAllocation(const std::shared_ptr& a, + const std::shared_ptr& b) { + return a->ptr() == b->ptr() && a->size() == b->size() && + platform::is_same_place(a->place(), b->place()); +} + +void MakeVariableFromPtenTensor(pten::DenseTensor* src, + framework::Variable* variable) { + if (variable->IsType()) { + auto* tensor = variable->GetMutable(); + + auto dtype = pten::TransToProtoVarType(src->dtype()); + tensor->Resize(src->dims()); + SetLoD(tensor->mutable_lod(), src->lod()); + + // here dynamic_cast is slow + auto* storage = static_cast( + pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); + + if (!tensor->IsInitialized() || + (tensor->IsInitialized() && + !IsSameAllocation(tensor->Holder(), storage->GetAllocation()))) { + tensor->ResetHolderWithType(std::move(storage->GetAllocation()), dtype); + } else { + // Even the pten tensor and Variable have the same Alloctation (both have + // the same pointer address, same size and same place) + // but there is possible that they do not have the same data_type. + // so, here we set the variable's type with the pten tensor dtype. + tensor->set_type(dtype); + } + + } else if (variable->IsType()) { + auto* tensor = variable->GetMutable(); + auto dtype = pten::TransToProtoVarType(src->dtype()); + + if (!tensor->value().IsInitialized()) { + auto storage = dynamic_cast( + pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src)); + tensor->mutable_value()->ResetHolderWithType( + std::move(storage->GetAllocation()), dtype); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported shared input `%s` type now when call pt kernel.", + framework::ToTypeName(variable->Type()))); + } +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/utils/tensor_utils.h b/paddle/pten/api/lib/utils/tensor_utils.h index c1840d97fd2..62d4cab02b6 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.h +++ b/paddle/pten/api/lib/utils/tensor_utils.h @@ -55,9 +55,11 @@ void MovesStorage(pten::DenseTensor* src, paddle::framework::LoDTensor* dst); */ void ReMakePtenDenseTensor(const paddle::framework::Tensor& src, + const pten::TensorArgDef& arg_def, pten::DenseTensor* dst); void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src, + const pten::TensorArgDef& arg_def, pten::DenseTensor* dst); void ReMakePtenDenseTensorFromVar(const framework::Variable& variable, @@ -68,5 +70,8 @@ void ReMakePtenDenseTensorFromVar(framework::Variable* variable, const pten::TensorArgDef& arg_def, pten::DenseTensor* dst); +void MakeVariableFromPtenTensor(pten::DenseTensor* src, + framework::Variable* variable); + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/core/compat_utils.h b/paddle/pten/core/compat_utils.h index 6c8eeec6553..c61b96546ec 100644 --- a/paddle/pten/core/compat_utils.h +++ b/paddle/pten/core/compat_utils.h @@ -42,8 +42,10 @@ class CompatibleDenseTensorUtils { // only can deal with SharedStorage now static void ClearStorage(DenseTensor* tensor) { // use static_cast to improve performance, replace by dynamic_cast later - static_cast(tensor->storage_.get()) - ->Reset(); + if (tensor->storage_ != nullptr) { + static_cast(tensor->storage_.get()) + ->Reset(); + } } static DenseTensor Slice(DenseTensor* tensor, diff --git a/paddle/pten/core/convert_utils.cc b/paddle/pten/core/convert_utils.cc index 32f2497dd18..92709647dac 100644 --- a/paddle/pten/core/convert_utils.cc +++ b/paddle/pten/core/convert_utils.cc @@ -160,4 +160,24 @@ paddle::framework::DataLayout TransToFluidDataLayout(const DataLayout& layout) { } } +paddle::framework::LoD TransToFluidLoD(const pten::LoD& lod) { + paddle::framework::LoD out; + out.reserve(lod.size()); + + for (auto& elem : lod) { + out.emplace_back(elem); + } + return out; +} + +pten::LoD TransToPtenLoD(const paddle::framework::LoD& lod) { + pten::LoD out; + out.reserve(lod.size()); + + for (auto& elem : lod) { + out.emplace_back(elem); + } + return out; +} + } // namespace pten diff --git a/paddle/pten/core/convert_utils.h b/paddle/pten/core/convert_utils.h index aa79cb240dd..0b807c48bc1 100644 --- a/paddle/pten/core/convert_utils.h +++ b/paddle/pten/core/convert_utils.h @@ -17,10 +17,12 @@ limitations under the License. */ #include "paddle/pten/common/backend.h" #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/layout.h" +#include "paddle/pten/core/tensor_meta.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/place.h" // TODO(chenweihang): this file may need to be removed @@ -40,4 +42,7 @@ paddle::framework::proto::VarType::Type TransToProtoVarType( const DataType& dtype); paddle::framework::DataLayout TransToFluidDataLayout(const DataLayout& layout); +paddle::framework::LoD TransToFluidLoD(const pten::LoD& lod); +pten::LoD TransToPtenLoD(const paddle::framework::LoD& lod); + } // namespace pten diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index c080b865a51..3237576cb64 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -69,7 +69,9 @@ void* DenseTensor::mutable_data(size_t request_bytes) { bytes)); bytes = request_bytes; } - if (storage_->size() < bytes) { + if (storage_->size() < bytes || storage_->size() == 0) { + VLOG(10) << "mutbale data realloc, original size: " << storage_->size() + << ", new size: " << bytes; storage_->Realloc(bytes); } return storage_->data(); @@ -81,6 +83,8 @@ T* DenseTensor::mutable_data() { // execution system, we have to reset the datatype in mutable_data. // When the compatibility phase is over in the future, we can delete it if (meta_.dtype == DataType::UNDEFINED) { + VLOG(10) << "change data type in mutbale_data, target dtype - " + << paddle::experimental::CppTypeToDataType::Type(); const_cast(meta_.dtype) = paddle::experimental::CppTypeToDataType::Type(); } @@ -120,12 +124,13 @@ void DenseTensor::set_meta(DenseTensorMeta&& meta) { meta_ = std::move(meta); } -void DenseTensor::Resize(const DDim& dims, const LoD& lod) { +void DenseTensor::Resize(const DDim& dims) { meta_.dims = dims; - meta_.lod = lod; mutable_data(); } +void DenseTensor::ResetLoD(const LoD& lod) { meta_.lod = lod; } + #define DATA_MEMBER_FUNC_INSTANTIATION(dtype) \ template dtype* DenseTensor::mutable_data(); \ template const dtype* DenseTensor::data() const; diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index 42fed722d0d..92c8e3d4bdb 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -127,7 +127,11 @@ class DenseTensor : public TensorBase, /// larger than the original value, the storage area will be reallocated. /// \param dims The new dims of the dense tensor. /// \param lod The new lod of the dense tensor. - void Resize(const DDim& dims, const LoD& lod = {}); + void Resize(const DDim& dims); + + /// \brief Change the lod information in the metadata. + /// \param lod The new lod of the dense tensor. + void ResetLoD(const LoD& lod); /// \brief Returns the actual storage size occupied by tensor, may be larger /// than its shape dims. diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index 973640906e0..4f4d673dfe6 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -58,6 +58,10 @@ class KernelContext { input_range_.emplace_back(std::pair(index, index + 1)); } + void EmplaceBackInputWithoutSetRange(std::shared_ptr input) { + inputs_.emplace_back(std::move(input)); + } + void EmplaceBackInputs( paddle::SmallVector> inputs) { int index = inputs_.size(); @@ -76,6 +80,10 @@ class KernelContext { output_range_.emplace_back(std::pair(index, index + 1)); } + void EmplaceBackOutputWithoutSetRange(std::shared_ptr output) { + outputs_.emplace_back(std::move(output)); + } + void EmplaceBackOutputs( paddle::SmallVector> outputs) { int index = outputs_.size(); @@ -171,9 +179,6 @@ class KernelContext { size_t OutputsSize() const { return outputs_.size(); } size_t AttrsSize() const { return attrs_.size(); } - private: - bool IsDuplicable() const { return input_range_.size() != inputs_.size(); } - private: // DeviceContext base class DeviceContext* dev_ctx_; diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 23143c06244..794857dba73 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -207,6 +207,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index 4900c78e63a..f6a7fcd3882 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -37,6 +37,20 @@ DenseTensor Flatten(const ContextT& dev_ctx, return dense_out; } +template +DenseTensor Cast(const ContextT& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DataType in_dtype) { + auto out_meta = CastInferMeta(x.meta(), out_dtype); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Cast(dev_ctx, x, out_dtype, in_dtype, &dense_out); + return dense_out; +} + template DenseTensor Reshape(const ContextT& dev_ctx, const DenseTensor& x, diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index ea6e97db346..945a0b4e23f 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -74,6 +74,12 @@ DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, return return_meta; } +DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, + const DataType out_dtype) { + DenseTensorMeta out_meta(out_dtype, x_meta.dims, x_meta.layout); + return out_meta; +} + DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DataType dtype, DataLayout layout) { diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 4e22c9bf2d8..92c14d43ea9 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -40,6 +40,8 @@ DenseTensorMeta ReductionInferShape(const DenseTensorMeta& x_meta); DenseTensorMeta FlattenInferShape(const DenseTensorMeta& x_meta, int start_axis, int stop_axis); +DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, + const DataType out_dtype); DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DataType dtype, diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index 79b2c96dcac..c7027e487b0 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "paddle/pten/kernels/cpu/manipulation.h" +#include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/cpu/utils.h" #include "paddle/pten/kernels/functions/general/manipulation.h" +#include "paddle/pten/kernels/functions/math/cast_func.h" namespace pten { @@ -44,27 +46,17 @@ void FlattenWithXShape(const CPUContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorValImpl(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out, - bool set_lod) { - auto out_meta = InferShapeFromVecValue(x.meta(), shape); - if (&x != out) { - pten::Copy(dev_ctx, x, false, out); - } - if (set_lod) { - out->Resize(out_meta.dims, out_meta.lod); - } else { - out->Resize(out_meta.dims); - } -} - void ReshapeFromVectorVal(const CPUContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - ReshapeFromVectorValImpl(dev_ctx, x, shape, out, false); + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + if (&x == out) { + out->Resize(out_meta.dims); + return; + } + pten::Copy(dev_ctx, x, false, out); + out->Resize(out_meta.dims); } void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, @@ -72,8 +64,8 @@ void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromVectorVal(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromVectorVal(dev_ctx, x, shape, out); } void ReshapeFromDT(const CPUContext& dev_ctx, @@ -83,7 +75,8 @@ void ReshapeFromDT(const CPUContext& dev_ctx, auto* shape_data = shape.data(); auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorValImpl(dev_ctx, x, vector_shape, out, true); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->ResetLoD(x.lod()); } void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, @@ -91,8 +84,8 @@ void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, const DenseTensor& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromDT(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromDT(dev_ctx, x, shape, out); } void ReshapeFromVectorDT(const CPUContext& dev_ctx, @@ -119,8 +112,20 @@ void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromVectorDT(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromVectorDT(dev_ctx, x, shape, out); +} + +template +void Cast(const CPUContext& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DataType in_dtype, + DenseTensor* out) { + PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] { + math::CastKernelImpl( + dev_ctx, x, out); + })); } } // namespace pten @@ -151,6 +156,23 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} +PT_REGISTER_KERNEL("cast", + CPU, + ANY, + pten::Cast, + float, + double, + int, + int64_t, + int16_t, + bool, + uint8_t, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} // TODO(yuanrisheng): "reshape2" is compatible with old kernel // architecture, kernel_name should be "reshape". diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index 435139e1fdf..3dce249c545 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -29,6 +29,13 @@ void Flatten(const CPUContext& dev_ctx, int stop_axis, DenseTensor* out); +template +void Cast(const CPUContext& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DataType in_dtype, + DenseTensor* out); + void ReshapeFromDT(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& shape, diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc index 4d194bc069f..9b91aa347a4 100644 --- a/paddle/pten/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -70,6 +70,9 @@ void ElementwiseAdd(const CPUContext& dev_ctx, const DenseTensor& y, int axis, DenseTensor* out) { + // allocate memory for out + out->mutable_data(); + if (x.dims() == y.dims()) { SameDimsElementwiseCompute>()( dev_ctx, x, y, out); @@ -92,6 +95,9 @@ void ElementwiseSub(const CPUContext& dev_ctx, const DenseTensor& y, int axis, DenseTensor* out) { + // allocate memory for out + out->mutable_data(); + if (x.dims() == y.dims()) { SameDimsElementwiseCompute>()( dev_ctx, x, y, out); diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index d2315965b28..9b8f18dab4e 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/cuda/manipulation.h" #include "paddle/pten/kernels/cuda/utils.h" #include "paddle/pten/kernels/functions/general/manipulation.h" +#include "paddle/pten/kernels/functions/math/cast_func.h" namespace pten { @@ -44,27 +46,17 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorValImpl(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out, - bool set_lod) { - auto out_meta = InferShapeFromVecValue(x.meta(), shape); - if (&x != out) { - pten::Copy(dev_ctx, x, false, out); - } - if (set_lod) { - out->Resize(out_meta.dims, out_meta.lod); - } else { - out->Resize(out_meta.dims); - } -} - void ReshapeFromVectorVal(const CUDAContext& dev_ctx, const DenseTensor& x, const std::vector& shape, DenseTensor* out) { - ReshapeFromVectorValImpl(dev_ctx, x, shape, out, false); + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + if (&x == out) { + out->Resize(out_meta.dims); + return; + } + pten::Copy(dev_ctx, x, false, out); + out->Resize(out_meta.dims); } void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, @@ -72,8 +64,8 @@ void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromVectorVal(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromVectorVal(dev_ctx, x, shape, out); } void ReshapeFromDT(const CUDAContext& dev_ctx, @@ -83,7 +75,8 @@ void ReshapeFromDT(const CUDAContext& dev_ctx, auto* shape_data = shape.data(); auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorValImpl(dev_ctx, x, vector_shape, out, true); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); + out->ResetLoD(x.lod()); } void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, @@ -91,8 +84,8 @@ void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, const DenseTensor& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromDT(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromDT(dev_ctx, x, shape, out); } void ReshapeFromVectorDT(const CUDAContext& dev_ctx, @@ -119,8 +112,20 @@ void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, const std::vector& shape, DenseTensor* xshape, DenseTensor* out) { - ReshapeFromVectorDT(dev_ctx, x, shape, out); general::SetXShape(x, xshape); + ReshapeFromVectorDT(dev_ctx, x, shape, out); +} + +template +void Cast(const CUDAContext& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DataType in_dtype, + DenseTensor* out) { + PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] { + math::CastKernelImpl( + dev_ctx, x, out); + })); } } // namespace pten @@ -153,6 +158,23 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} +// todo: Hip need support bfloat16 +PT_REGISTER_KERNEL("cast", + CUDA, + ANY, + pten::Cast, + float, + double, + int, + int64_t, + int16_t, + bool, + uint8_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", CUDA, diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index 40be7670baa..bb724beb2e3 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -33,6 +33,13 @@ void Flatten(const CUDAContext& dev_ctx, int stop_axis, DenseTensor* out); +template +void Cast(const CUDAContext& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DataType in_dtype, + DenseTensor* out); + void ReshapeFromDT(const CUDAContext& dev_ctx, const DenseTensor& x, const DenseTensor& shape, diff --git a/paddle/pten/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu index 9191ad59ab2..92a1eeef923 100644 --- a/paddle/pten/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -134,6 +134,8 @@ void ElementwiseAdd(const CUDAContext& dev_ctx, std::vector outputs; inputs.emplace_back(&x); inputs.emplace_back(&y); + // allocate memory for out + out->mutable_data(); outputs.emplace_back(out); LaunchElementwiseCudaKernel( dev_ctx, inputs, &outputs, axis, general::AddFunctor()); @@ -149,6 +151,8 @@ void ElementwiseSub(const CUDAContext& dev_ctx, std::vector outputs; inputs.emplace_back(&x); inputs.emplace_back(&y); + // allocate memory for out + out->mutable_data(); outputs.emplace_back(out); LaunchElementwiseCudaKernel( dev_ctx, inputs, &outputs, axis, general::SubFunctor()); diff --git a/paddle/pten/kernels/functions/cpu/elementwise.h b/paddle/pten/kernels/functions/cpu/elementwise.h index b565b8403b9..98600f29910 100644 --- a/paddle/pten/kernels/functions/cpu/elementwise.h +++ b/paddle/pten/kernels/functions/cpu/elementwise.h @@ -147,6 +147,7 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, int axis, Functor func, DenseTensor *z) { + z->mutable_data(); auto x_dims = x.dims(); auto y_dims = y.dims(); bool is_xsize_larger = true; diff --git a/paddle/pten/kernels/functions/eigen/dot.h b/paddle/pten/kernels/functions/eigen/dot.h index 300da4ae1f1..27a0b8cf329 100644 --- a/paddle/pten/kernels/functions/eigen/dot.h +++ b/paddle/pten/kernels/functions/eigen/dot.h @@ -28,6 +28,7 @@ void Dot(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { + out->mutable_data(); if (1 == out->dims().size()) { auto eigen_out = pten::EigenScalar::From(*out); auto eigen_x = pten::EigenVector::Flatten(x); diff --git a/paddle/pten/kernels/functions/eigen/elementwise.h b/paddle/pten/kernels/functions/eigen/elementwise.h index e9854a2d5cd..dd42234118c 100644 --- a/paddle/pten/kernels/functions/eigen/elementwise.h +++ b/paddle/pten/kernels/functions/eigen/elementwise.h @@ -25,6 +25,7 @@ void ElementwiseAdd(const DevCtx& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { + out->mutable_data(); auto eigen_x = pten::EigenVector::Flatten(x); auto eigen_y = pten::EigenVector::Flatten(y); auto eigen_z = pten::EigenVector::Flatten(*out); diff --git a/paddle/pten/kernels/functions/eigen/mean.h b/paddle/pten/kernels/functions/eigen/mean.h index ee4bf1653f2..e006c76a9f5 100644 --- a/paddle/pten/kernels/functions/eigen/mean.h +++ b/paddle/pten/kernels/functions/eigen/mean.h @@ -28,6 +28,7 @@ void Mean(const DevCtx& dev_ctx, const DenseTensor& x, DenseTensor* out) { // TODO(chenweihang): if we design new tensor, we should support // the low-level calc functor use new tensor as input, // which may be a big project! + out->mutable_data(); auto eigen_x = pten::EigenVector::Flatten(x); auto eigen_out = pten::EigenScalar::From(*out); diff --git a/paddle/pten/kernels/functions/general/manipulation.h b/paddle/pten/kernels/functions/general/manipulation.h index cade585792c..85f6b613ac6 100644 --- a/paddle/pten/kernels/functions/general/manipulation.h +++ b/paddle/pten/kernels/functions/general/manipulation.h @@ -26,7 +26,8 @@ inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) { for (int i = 0; i < in_dims.size(); ++i) { xshape_dims[i + 1] = in_dims[i]; } - xshape->Resize(paddle::framework::make_ddim(xshape_dims), x.meta().lod); + xshape->Resize(paddle::framework::make_ddim(xshape_dims)); + xshape->ResetLoD(x.meta().lod); } } // namespace general diff --git a/paddle/pten/kernels/functions/math/cast_func.h b/paddle/pten/kernels/functions/math/cast_func.h new file mode 100644 index 00000000000..0a67736dbb2 --- /dev/null +++ b/paddle/pten/kernels/functions/math/cast_func.h @@ -0,0 +1,48 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/transform.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { +namespace math { + +template +struct CastOpTransformFunctor { + HOSTDEVICE OutT operator()(InT in) const { return static_cast(in); } +}; + +template +void CastKernelImpl(const DeviceContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto* in_begin = x.data(); + auto numel = x.numel(); + auto* in_end = in_begin + numel; + + auto* out_begin = out->mutable_data(); + + paddle::platform::Transform trans; + trans(dev_ctx, + in_begin, + in_end, + out_begin, + CastOpTransformFunctor()); +} + +} // namespace math + +} // namespace pten diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index 2a726e1cb25..d55def8b8a7 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -47,7 +47,8 @@ void FlattenWithXShape(const XPUContext& dev_ctx, for (int i = 0; i < in_dims.size(); ++i) { xshape_dims[i + 1] = in_dims[i]; } - xshape->Resize(paddle::framework::make_ddim(xshape_dims), x.meta().lod); + xshape->Resize(paddle::framework::make_ddim(xshape_dims)); + xshape->ResetLoD(x.lod()); } void ReshapeFromVectorVal(const XPUContext& dev_ctx, diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index fdff473ddbb..9acf39f7c2b 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -7,13 +7,13 @@ endif() cc_test(test_pten_exception SRCS test_pten_exception.cc DEPS gtest) cc_test(test_framework_storage SRCS test_storage.cc DEPS pten_api_utils) cc_test(test_framework_tensor_utils SRCS test_tensor_utils.cc DEPS pten_api_utils) - cc_test(test_mean_api SRCS test_mean_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_cast_api SRCS test_cast_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_cast_api.cc b/paddle/pten/tests/api/test_cast_api.cc new file mode 100644 index 00000000000..46265d8568c --- /dev/null +++ b/paddle/pten/tests/api/test_cast_api.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_DECLARE_MODULE(ManipulationCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(ManipulationCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, cast) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x->mutable_data(); + + for (int i = 0; i < dense_x->numel(); i++) { + dense_x_data[i] = i; + } + + paddle::experimental::Tensor x(dense_x); + pten::DataType out_dtype = pten::DataType::FLOAT64; + // 2. test API + auto out = paddle::experimental::cast(x, out_dtype); + + // 3. check result + std::vector expect_shape = {3, 4}; + ASSERT_EQ(out.shape().size(), size_t(2)); + ASSERT_EQ(out.shape()[0], expect_shape[0]); + ASSERT_EQ(out.shape()[1], expect_shape[1]); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT64); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* dense_out_data = dense_out->data(); + for (int i = 0; i < dense_x->numel(); i++) { + ASSERT_NEAR(dense_out_data[i], static_cast(dense_x_data[i]), 1e-6f); + } +} diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index 8a66fd18609..b9a47ee21c3 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -4,5 +4,6 @@ cc_test(test_fill_dev_api SRCS test_fill_dev_api.cc DEPS pten pten_api_utils) cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_utils) cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_cast_dev_api SRCS test_cast_dev_api.cc DEPS pten pten_api_utils) cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils) cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_cast_dev_api.cc b/paddle/pten/tests/kernels/test_cast_dev_api.cc new file mode 100644 index 00000000000..2471529ba9b --- /dev/null +++ b/paddle/pten/tests/kernels/test_cast_dev_api.cc @@ -0,0 +1,74 @@ + +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/include/manipulation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +PT_DECLARE_MODULE(ManipulationCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(ManipulationCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, cast) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + float sum = 0.0; + for (size_t i = 0; i < 12; ++i) { + dense_x_data[i] = i * 1.0; + sum += i * 1.0; + } + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + pten::DataType out_dtype = pten::DataType::FLOAT64; + pten::DataType in_dtype = pten::DataType::FLOAT32; + // 2. test API + auto out = pten::Cast( + *(static_cast(dev_ctx)), + dense_x, + out_dtype, + in_dtype); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.dims()[1], 4); + ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT64); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto actual_result = out.data(); + for (size_t i = 0; i < 12; ++i) { + ASSERT_NEAR(actual_result[i], static_cast(dense_x_data[i]), 1e-6f); + } +} -- GitLab