From 91ebc4608139a44521b133a3783a20da76502876 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 29 Dec 2020 06:17:30 -0600 Subject: [PATCH] [Cherry-pick] Complex network execute support (#29905) * [Complex] Add support for complex grad accumulated (#29889) * add support for complex grad accumulated * add unittest for coverage * update test dtype * remove useless blank line * [Complex] Handle complex to real after type promotion (#29855) * try to add fwd op input dtypes * refactor base impl * return tmp_ins after dygraph prepare data * fix typo found in debug * polish comment & add complex net test * revert detail change * fix unittest failed * add complex kernel condition control * fix xpu test failed & polish comment * polish details by review comments * Complex op test (#29753) * delete no need to calculate inputs in dygraph op_test * delete no need to calculate inputs in dygraph op_test * change grad elementwise_mul for complex types (#29757) * add conj op for complex types * add conj for complex types * add more test case * add conj_op test * modify conj api and impl * add complex type for fill_constant_op xpu * add setConstant for complex type * remove complex conj test file * user define grad for test_conj_op * add test case for static mode of conj api * modify conj doc * change input args name to x * remove useless codes * conj support real types * add conj test case for real number * delete no need to calculate inputs in dygraph op_test * delete no need to calculate inputs in dygraph op_test * modify grad of mul for complex types * fix the grads of inputs args order not match bug * change the grad of div when complex types (#29804) * change the grad of div when complex types * fix the grads of inputs args order not match bug Co-authored-by: chentianyu03 --- paddle/fluid/framework/data_type_transform.cc | 25 ++++ paddle/fluid/framework/data_type_transform.h | 15 ++ paddle/fluid/framework/operator.cc | 75 ++++++++++ paddle/fluid/framework/operator.h | 3 + paddle/fluid/framework/tensor.h | 19 +++ paddle/fluid/imperative/basic_engine.cc | 2 + .../fluid/imperative/gradient_accumulator.cc | 6 + paddle/fluid/imperative/layer.cc | 27 +++- paddle/fluid/imperative/layer.h | 8 + .../fluid/imperative/partial_grad_engine.cc | 1 + paddle/fluid/imperative/prepared_operator.cc | 137 ++++++++---------- paddle/fluid/imperative/prepared_operator.h | 122 ++++++++++++++-- .../fluid/imperative/tests/test_prepare_op.cc | 44 ++---- paddle/fluid/imperative/variable_wrapper.h | 23 ++- paddle/fluid/memory/memcpy.cc | 2 +- .../elementwise/elementwise_div_op.cu | 39 +++++ .../elementwise/elementwise_div_op.h | 43 ++++++ .../elementwise/elementwise_mul_op.cu | 30 ++++ .../elementwise/elementwise_mul_op.h | 42 ++++++ .../operators/elementwise/elementwise_op.h | 16 +- paddle/fluid/operators/math/blas_impl.cu.h | 18 +++ paddle/fluid/operators/math/blas_impl.h | 24 +-- .../operators/math/selected_rows_functor.cc | 6 + .../operators/math/selected_rows_functor.cu | 2 + .../reduce_ops/reduce_sum_op.part.cu | 4 +- paddle/fluid/platform/cuda_primitives.h | 17 ++- paddle/fluid/platform/dynload/cublas.h | 2 + .../paddle/fluid/tests/unittests/op_test.py | 11 +- .../test_complex_grad_accumulated.py | 101 +++++++++++++ .../tests/unittests/test_complex_simplenet.py | 72 +++++++++ .../unittests/test_elementwise_add_op.py | 31 ++-- .../unittests/test_elementwise_div_op.py | 60 ++++++++ .../unittests/test_elementwise_mul_op.py | 70 ++++++++- .../tests/unittests/test_strided_slice_op.py | 2 +- 34 files changed, 943 insertions(+), 156 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py create mode 100644 python/paddle/fluid/tests/unittests/test_complex_simplenet.py diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 3d56152c237..30a2ac2c6f6 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -109,5 +109,30 @@ void TransDataType(const OpKernelType& kernel_type_for_var, } } +void TransComplexToReal(const proto::VarType::Type& dst_type, + const proto::VarType::Type& src_type, const Tensor& in, + Tensor* out) { + auto& pool = platform::DeviceContextPool::Instance(); + auto* ctx = pool.Get(in.place()); + out->Resize(in.dims()); + + // complex -> real + switch (src_type) { + case proto::VarType::COMPLEX64: + framework::VisitDataType(dst_type, + CastDataType(in, out, ctx)); + break; + case proto::VarType::COMPLEX128: + framework::VisitDataType( + dst_type, CastDataType(in, out, ctx)); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when casting complex tensor to real " + "data type.", + DataTypeToString(src_type))); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_type_transform.h b/paddle/fluid/framework/data_type_transform.h index b42b2f594aa..499b133dadb 100644 --- a/paddle/fluid/framework/data_type_transform.h +++ b/paddle/fluid/framework/data_type_transform.h @@ -33,5 +33,20 @@ void TransDataType(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_type, const Tensor& in, Tensor* out); +/** + * Transform complex gradient to real data type. + * + * If complex type promotion occurred in forward op, the grad output of + * this op is complex data type, but the input variable may be real type, + * in this case the grad input need to be cast to type same with input, + * this casting executed at the end of grad op. + * + * note: call this function need to ensure that dst_type is real and + * src_type is complex + */ +void TransComplexToReal(const proto::VarType::Type& dst_type, + const proto::VarType::Type& src_type, const Tensor& in, + Tensor* out); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7b40a5977a0..bb8ff55d6c9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_transform.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -1110,6 +1111,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // there is inplace variable has been transferred. TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope); } + + // See [ Why need handle complex gradient to real gradient? ] + // Only handle the case where the current kernel data type is complex + if (framework::IsComplexType(kernel_type_->data_type_)) { + HandleComplexGradToRealGrad(scope, runtime_ctx); + } + if (FLAGS_enable_unused_var_check) { // skip op that uses mkldnn because it has different memory reuse strategy. // use attr here because some GradMakers (like ActivationGradOpMaker) add @@ -1255,6 +1263,73 @@ void OperatorWithKernel::TransferInplaceVarsBack( } } +void OperatorWithKernel::HandleComplexGradToRealGrad( + const Scope& scope, RuntimeContext* ctx) const { + for (auto& var_name_item : Outputs()) { + std::vector& output_vars = ctx->outputs[var_name_item.first]; + for (size_t i = 0; i < var_name_item.second.size(); ++i) { + // 1. find grad_var & check whether is complex tensor + auto var_name = var_name_item.second[i]; + auto orig_var_name = GradOriginalVarName(var_name); + // only focus on gradient var + if (var_name == orig_var_name) { + continue; + } + auto* grad_var = output_vars[i]; + // skip nullptr var + if (grad_var == nullptr) { + continue; + } + // don't process LoDTensorArray temporarily, + // add support if necessary for complex number calculations in the future + if (!VarIsTensor(*grad_var)) { + continue; + } + auto* grad_tensor = + GetMutableLoDTensorOrSelectedRowsValueFromVar(grad_var); + // skip nullptr tensor + if (grad_tensor == nullptr || !grad_tensor->IsInitialized()) { + continue; + } + // only focus on complex dtype now + auto src_type = grad_tensor->type(); + if (!IsComplexType(src_type)) { + continue; + } + + // 2. find forward var & check whether need to cast + auto* var = scope.FindVar(orig_var_name); + // if forward var not exists, do nothing + if (var == nullptr) { + continue; + } + if (!VarIsTensor(*var)) { + continue; + } + const auto* tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); + PADDLE_ENFORCE_NOT_NULL( + tensor, + platform::errors::Unavailable( + "Forward tensor is nullptr when handle complex data to real.")); + // only need record type, the allocation may have been released + auto dst_type = tensor->saved_type(); + // only focus on real dtype and need casting + if (IsComplexType(dst_type)) { + continue; + } + + // 3. cast complex grad to real grad + VLOG(6) << "Transform " << framework::DataTypeToString(src_type) + << " var `" << var_name << "` to " + << framework::DataTypeToString(dst_type) + << " real var in static graph."; + Tensor out; + TransComplexToReal(dst_type, src_type, *grad_tensor, &out); + SetTensorToVariable(*grad_var, out, grad_var); + } + } +} + Scope* OperatorWithKernel::PrepareData( const Scope& scope, const OpKernelType& expected_kernel_key, std::vector* transfered_inplace_vars, diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 652d5330f2b..fd1cc18b951 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -545,6 +545,9 @@ class OperatorWithKernel : public OperatorBase { void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, const platform::Place& place) const; + void HandleComplexGradToRealGrad(const Scope& scope, + RuntimeContext* ctx) const; + /* Inner assist methods */ // indicate kernel DataType by input data. // By default all input data must be same. diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 0a4885ea325..76119e7c708 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -197,6 +197,24 @@ class Tensor { return type_; } + /** + * [Add method get the saved type of tensor] + * + * After the introduction of complex number calculations, Ops that support + * complex number calculations generally support type promotion, such as + * x(float32) + y(complex64) = out(complex64), then the type of the grad + * tensor should be dout(complex64), dx(float32), dy (complex64), but the + * type of dx to be recognized to be float32 by the grad Op relay on the type + * of forward tensor x. But many of our ops have registered InplaceInferer, + * covering the tensor memory of x with out, so as to save storage. + * + * In this case, the dim and type information recorded by x still exist, + * but because x becomes an uninitialized tensor, The type of x record cannot + * be obtained with x.type(), but the type is still valid here, so we + * add saved_type(), This method SHOULD NOT be called by general scenarios. + */ + proto::VarType::Type saved_type() const { return type_; } + // memory size returns the holding memory size in byte. size_t memory_size() const; @@ -232,6 +250,7 @@ class Tensor { void ResetHolderWithType(std::shared_ptr holder, const proto::VarType::Type type); + TensorInplaceVersion& InplaceVersionCounter() { return inplace_version_counter_; } diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index f97ab4f4e05..0a43a0307d2 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -23,6 +23,7 @@ #include #include #include + #include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/op_base.h" @@ -239,6 +240,7 @@ void BasicEngine::Execute() { if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) { auto tmp_var = std::make_shared(var->Name()); tmp_var->SetType(var->Type()); + tmp_var->SetForwardDataType(var->ForwardDataType()); var = tmp_var; need_accu_var_list_.emplace_back(iter->second.get(), var); VLOG(10) << "create temporary var of " << var->Name() diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 66c4d1c5f55..bc38e3b59b6 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -25,6 +25,8 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" @@ -161,6 +163,10 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(double); + // NOTE(chenweihang): only support complex grad tensor accumulated, + // support selected rows if needed in the future + PADDLE_TENSOR_ADD(platform::complex64); + PADDLE_TENSOR_ADD(platform::complex128); #undef PADDLE_TENSOR_ADD diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 6f490c3c2be..57c3a3aae11 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -16,6 +16,7 @@ #include #include #include + #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/variable_helper.h" @@ -326,9 +327,31 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, } VLOG(5) << LayerDebugString(op.Type(), ins, outs); - auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs); - prepared_op.Run(ins, outs, attrs); + /** + * [ Why need temporary inputs here? ] + * + * PrepareData should not change original input tensor inplace. + * Suppose the user defines a tensor(int), enters an op to execute, + * and then this op rewrites GetExpectedKernelForVar, and converts + * this tensor to float type during execution. After the dynamic + * graph is executed, the user-defined variable will be lost, and + * the user cannot get the originally defined int tensor, because + * it has been converted to float, this should be regarded as a bug + * in certain usage scenarios + * + * In static graph mode, when op is executed, a temporary scope + * `transfer_scope` is created before PrepareData, the data after + * transform is stored in the temporary scope, and then discarded + * after the execution of op, but the original input is directly + * overwritten in the previous dynamic graph implemention. + */ + auto expected_kernel_key = + GetExpectedKernelKey(ins, outs, *op_kernel, place, attrs); + auto prepared_op = PreparedOp::Prepare(*op_kernel, expected_kernel_key); + auto tmp_ins = PrepareData(*op_kernel, ins, expected_kernel_key); + + prepared_op.Run(tmp_ins, outs, attrs); VLOG(4) << LayerDebugString(op.Type(), ins, outs); } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 1a974ab346e..9821c359de8 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -201,6 +201,14 @@ class VarBase { framework::proto::VarType::Type DataType() const { return var_->DataType(); } + void SetForwardDataType(framework::proto::VarType::Type data_type) { + var_->SetForwardDataType(data_type); + } + + framework::proto::VarType::Type ForwardDataType() const { + return var_->ForwardDataType(); + } + const platform::Place Place() const { return var_->Place(); } void ClearGradient(); diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index d8f828ede25..149a38e2586 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -857,6 +857,7 @@ void PartialGradTask::RunEachOp(OpBase *op) { auto new_grad_var = std::make_shared(true, grad_var->Name()); new_grad_var->SetOverridedStopGradient(false); + new_grad_var->SetForwardDataType(grad_var->ForwardDataType()); if (new_grad_var_iter->second->TotalRefCnt() > 1) { grads_to_accumulate_.emplace_back(new_grad_var_iter->second.get(), new_grad_var->SharedVar()); diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c58b1e9596f..ba4b1d4c980 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -16,12 +16,10 @@ #include -#include "paddle/fluid/imperative/execution_context.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_var_type_context.h" -DECLARE_bool(use_mkldnn); - namespace paddle { namespace imperative { @@ -36,26 +34,32 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { } template -static void PrepareData(const platform::Place& place, - const NameVarMap& ins, - const framework::OperatorWithKernel& op, - const framework::OpKernelType& expected_kernel_key) { - for (const auto& name_pair : ins) { - for (const auto& var_base : name_pair.second) { - const auto* tensor = GetTensorFromVar(var_base->Var()); +static void HandleComplexGradToRealGrad(const NameVarMap& outs) { + for (auto& pair : outs) { + for (auto& var : pair.second) { + if (var == nullptr) { + continue; + } + if (var->ForwardDataType() == + static_cast(-1)) { + VLOG(6) << "Var (" << var->Name() + << ")'s forward data type is not set."; + continue; + } + if (!framework::IsComplexType(var->DataType()) || + framework::IsComplexType(var->ForwardDataType())) { + continue; + } + const auto* tensor = GetTensorFromVar(var->Var()); if (tensor && tensor->IsInitialized()) { - auto kernel_type_for_var = op.GetKernelTypeForVar( - name_pair.first, *tensor, expected_kernel_key); - if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { - continue; - } else { - VLOG(3) << "Transform Variable " << var_base->Name() << " from " - << kernel_type_for_var << " to " << expected_kernel_key; - framework::Tensor out; - TransformData(expected_kernel_key, kernel_type_for_var, *tensor, - &out); - SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); - } + VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType()) + << " var `" << var->Name() << "` to " + << framework::DataTypeToString(var->ForwardDataType()) + << " real var in dynamic graph."; + framework::Tensor out; + framework::TransComplexToReal(var->ForwardDataType(), var->DataType(), + *tensor, &out); + SetTensorToVariable(var->Var(), out, var->MutableVar()); } } } @@ -63,18 +67,20 @@ static void PrepareData(const platform::Place& place, PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, + const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx) - : op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {} - -template -PreparedOp PrepareOpImpl(const NameVarMap& ins, - const NameVarMap& outs, - const framework::OperatorWithKernel& op, - platform::Place place, - const framework::AttributeMap& attrs) { + : op_(op), + ctx_(ctx), + kernel_type_(kernel_type), + func_(func), + dev_ctx_(dev_ctx) {} + +PreparedOp PreparedOp::Prepare( + const framework::OperatorWithKernel& op, + const framework::OpKernelType& expected_kernel_key) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); + auto* dev_ctx = pool.Get(expected_kernel_key.place_); // check if op[type] has kernel registered. auto& all_op_kernels = op.AllOpKernels(); @@ -89,62 +95,20 @@ PreparedOp PrepareOpImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; framework::RuntimeContext ctx({}, {}); -#ifdef PADDLE_WITH_MKLDNN - // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and - // GetKernelType functions, so we need to copy the attributes there. - // Const qualifier of Attrs had to be discarded to overwrite it. - if (FLAGS_use_mkldnn) { - auto& mutable_op_attrs = const_cast(op.Attrs()); - mutable_op_attrs = attrs; - } -#endif - auto expected_kernel_key = - op.GetExpectedKernelType(DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); - VLOG(3) << "expected_kernel_key:" << expected_kernel_key; - auto kernel_iter = kernels.find(expected_kernel_key); -#ifdef PADDLE_WITH_XPU - if (kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_)) { - expected_kernel_key.place_ = platform::CPUPlace(); - kernel_iter = kernels.find(expected_kernel_key); - } -#endif // TODO(jiabin): Add operator.cc's line 1000 part back when we need that case PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), platform::errors::NotFound( "Operator %s does not have kernel for %s.", op.Type(), KernelTypeToString(expected_kernel_key))); - if (!(expected_kernel_key.place_ == place)) { - dev_ctx = pool.Get(expected_kernel_key.place_); - place = dev_ctx->GetPlace(); - } - - PrepareData(place, ins, op, expected_kernel_key); - return PreparedOp(op, ctx, kernel_iter->second, dev_ctx); -} - -PreparedOp PreparedOp::Prepare(const NameVarMap& ins, - const NameVarMap& outs, - const framework::OperatorWithKernel& op, - const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareOpImpl(ins, outs, op, place, attrs); -} - -PreparedOp PreparedOp::Prepare(const NameVarMap& ins, - const NameVarMap& outs, - const framework::OperatorWithKernel& op, - const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareOpImpl(ins, outs, op, place, attrs); + return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx); } template static void PreparedOpRunImpl( const framework::OperatorBase& op, const framework::RuntimeContext& ctx, + const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { @@ -158,19 +122,36 @@ static void PreparedOpRunImpl( func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs)); + + /** + * [ Why need handle complex gradient to real gradient? ] + * + * After the introduction of complex number calculations, Ops that support + * complex number calculations generally support type promotion, such as + * x(float32) + y(complex64) = out(complex64), then the type of the grad + * tensor should be dout(complex64), dx(float32), dy (complex64). + * + * But because the dout is complex64, the dx is also complex64 after + * grad op kernel executed, we need to recognize this situation and + * convert dx to float32 type. HandleComplexGradToRealGrad does this thing. + */ + if (framework::IsComplexType(kernel_type.data_type_)) { + HandleComplexGradToRealGrad(outs); + } } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { - PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, ins, outs, attrs); + PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, + outs, attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs) { - PreparedOpRunImpl(op_, ctx_, func_, dev_ctx_, ins, outs, - attrs); + PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, + ins, outs, attrs); } } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 3bf032e642b..7952c453ee8 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -21,9 +21,12 @@ #include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" +DECLARE_bool(use_mkldnn); + namespace paddle { namespace framework { class Tensor; @@ -39,24 +42,120 @@ namespace imperative { const framework::Tensor* GetTensorFromVar(const framework::Variable& var); +template +static void SetForwardDataTypeOfGradVar(const std::shared_ptr& var); + +template <> +void SetForwardDataTypeOfGradVar( + const std::shared_ptr& var) { + if (var->HasGradVar()) { + auto grad_var = var->GetGradVar(); + VLOG(6) << "Set grad var (" << grad_var->Name() << ") dtype to (" + << framework::DataTypeToString(var->DataType()) << ")."; + grad_var->SetForwardDataType(var->DataType()); + } +} + +template <> +void SetForwardDataTypeOfGradVar(const std::shared_ptr& var) { + if (var->HasGradVar()) { + auto& shared_var = var->SharedVar(); + SetForwardDataTypeOfGradVar(shared_var); + } +} + +#ifdef PADDLE_WITH_XPU +static void ReplaceXPUKernelIfNotExists( + const framework::OperatorWithKernel& op, + framework::OpKernelType* expected_kernel_key) { + auto& all_op_kernels = op.AllOpKernels(); + auto kernels_iter = all_op_kernels.find(op.Type()); + PADDLE_ENFORCE_NE( + kernels_iter, all_op_kernels.end(), + platform::errors::NotFound( + "There are no kernels which are registered in the %s operator.", + op.Type())); + + auto& kernels = kernels_iter->second; + auto kernel_iter = kernels.find(*expected_kernel_key); + if (kernel_iter == kernels.end() && + is_xpu_place(expected_kernel_key->place_)) { + expected_kernel_key->place_ = platform::CPUPlace(); + } +} +#endif + +template +framework::OpKernelType GetExpectedKernelKey( + const NameVarMap& ins, const NameVarMap& outs, + const framework::OperatorWithKernel& op, const platform::Place& place, + const framework::AttributeMap& attrs) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + framework::RuntimeContext ctx({}, {}); + +#ifdef PADDLE_WITH_MKLDNN + // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and + // GetKernelType functions, so we need to copy the attributes there. + // Const qualifier of Attrs had to be discarded to overwrite it. + if (FLAGS_use_mkldnn) { + auto& mutable_op_attrs = const_cast(op.Attrs()); + mutable_op_attrs = attrs; + } +#endif + + auto expected_kernel_key = + op.GetExpectedKernelType(DygraphExecutionContext( + op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); +#ifdef PADDLE_WITH_XPU + ReplaceXPUKernelIfNotExists(op, &expected_kernel_key); +#endif + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + + return expected_kernel_key; +} + +template +NameVarMap PrepareData( + const framework::OperatorWithKernel& op, const NameVarMap& ins, + const framework::OpKernelType& expected_kernel_key) { + NameVarMap tmp_ins(ins); + for (auto& name_pair : tmp_ins) { + for (auto& var_base : name_pair.second) { + const auto* tensor = GetTensorFromVar(var_base->Var()); + SetForwardDataTypeOfGradVar(var_base); + if (tensor && tensor->IsInitialized()) { + auto kernel_type_for_var = op.GetKernelTypeForVar( + name_pair.first, *tensor, expected_kernel_key); + if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { + continue; + } else { + VLOG(3) << "Transform Variable " << var_base->Name() << " from " + << kernel_type_for_var << " to " << expected_kernel_key; + framework::Tensor out; + auto tmp_var = std::make_shared(var_base->Name()); + tmp_var->SetType(var_base->Type()); + TransformData(expected_kernel_key, kernel_type_for_var, *tensor, + &out); + SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); + var_base = tmp_var; + } + } + } + } + return tmp_ins; +} + class PreparedOp { public: PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, + const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx); - static PreparedOp Prepare(const NameVarMap& ins, - const NameVarMap& outs, - const framework::OperatorWithKernel& op, - const platform::Place& place, - const framework::AttributeMap& attrs); - - static PreparedOp Prepare(const NameVarMap& ins, - const NameVarMap& outs, - const framework::OperatorWithKernel& op, - const platform::Place& place, - const framework::AttributeMap& attrs); + static PreparedOp Prepare(const framework::OperatorWithKernel& op, + const framework::OpKernelType& expected_kernel_key); void Run(const NameVarMap& in, const NameVarMap& out, const framework::AttributeMap& attrs); @@ -68,6 +167,7 @@ class PreparedOp { private: const framework::OperatorBase& op_; const framework::RuntimeContext& ctx_; + framework::OpKernelType kernel_type_; framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; }; diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index f226c63f0c4..b9ad5306f03 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -32,27 +32,6 @@ namespace framework = paddle::framework; namespace paddle { namespace imperative { -static framework::RuntimeContext PrepareRuntimeContext( - const NameVarBaseMap& ins, const NameVarBaseMap& outs) { - framework::VariableValueMap inputs, outputs; - for (auto& in_pair : ins) { - auto& in_ctx = inputs[in_pair.first]; - in_ctx.reserve(in_pair.second.size()); - for (auto& in_var : in_pair.second) { - in_ctx.emplace_back(in_var->MutableVar()); - } - } - - for (auto& out_pair : outs) { - auto& out_ctx = outputs[out_pair.first]; - out_ctx.reserve(out_pair.second.size()); - for (auto& out_var : out_pair.second) { - out_ctx.emplace_back(out_var->MutableVar()); - } - } - return framework::RuntimeContext(std::move(inputs), std::move(outputs)); -} - static framework::VariableNameMap CreateVarNameMap( const framework::OpInfo& op_info, const std::string& op_type, const NameVarBaseMap& varbase_map, bool is_input) { @@ -111,11 +90,12 @@ TEST(test_prepare_op, test_prepare_op) { CreateVarNameMap(info, "split", outs, false); auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map, split_attr_map); - framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); + auto expected_kernel_key = GetExpectedKernelKey( + ins, outs, dynamic_cast(*op), place, + split_attr_map); ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( - ins, outs, dynamic_cast(*op), - place, split_attr_map)); + expected_kernel_key)); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); @@ -161,13 +141,15 @@ TEST(test_prepare_op, test_prepare_data) { CreateVarNameMap(info, op_type, outs, false); auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, attr_map); - framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); // test if it can be transformed to GPU place - PreparedOp prepared_op = PreparedOp::Prepare( + auto expected_kernel_key = GetExpectedKernelKey( ins, outs, dynamic_cast(*op), gpu_place, attr_map); - for (const auto& name_pair : ins) { + imperative::NameVarBaseMap tmp_ins = PrepareData( + dynamic_cast(*op), ins, + expected_kernel_key); + for (const auto& name_pair : tmp_ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( vb->Var().Get().place(), gpu_place)); @@ -208,13 +190,15 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, attr_map); - framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); // test if it never transferred on GPU place - PreparedOp prepared_op = PreparedOp::Prepare( + auto expected_kernel_key = GetExpectedKernelKey( ins, outs, dynamic_cast(*op), cpu_place, attr_map); - for (const auto& name_pair : ins) { + imperative::NameVarBaseMap tmp_ins = PrepareData( + dynamic_cast(*op), ins, + expected_kernel_key); + for (const auto& name_pair : tmp_ins) { for (const auto& vb : name_pair.second) { ASSERT_TRUE(platform::is_same_place( vb->Var().Get().place(), cpu_place)); diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index d8373042078..ca9d5bc3ad7 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -122,10 +122,6 @@ class VariableWrapper { framework::proto::VarType::Type Type() const { return type_; } - void SetDataType(framework::proto::VarType::Type data_type) { - data_type_ = data_type; - } - std::shared_ptr GetGradVar() const { return grad_var_.lock(); } @@ -140,6 +136,10 @@ class VariableWrapper { bool HasGradVar() const { return !grad_var_.expired(); } + void SetDataType(framework::proto::VarType::Type data_type) { + data_type_ = data_type; + } + framework::proto::VarType::Type DataType() const { const framework::Tensor* tensor = nullptr; if (var_.IsInitialized()) { @@ -160,6 +160,14 @@ class VariableWrapper { } } + void SetForwardDataType(framework::proto::VarType::Type data_type) { + fwd_data_type_ = data_type; + } + + framework::proto::VarType::Type ForwardDataType() const { + return fwd_data_type_; + } + const platform::Place Place() const { const framework::Tensor* tensor = nullptr; auto place = @@ -306,6 +314,13 @@ class VariableWrapper { framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR}; framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32}; + // See [ Why need handle complex gradient to real gradient? ] + // Used for grad var to get the data type of its corresponding forward var, + // if inconsistent, the data type of grad var needs to be casted to be + // consistent with forward var + framework::proto::VarType::Type fwd_data_type_{ + static_cast(-1)}; + std::weak_ptr grad_var_; std::weak_ptr grad_node_; diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 8a04f74c6de..10e8bb1f4a7 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -269,7 +269,7 @@ void Copy( if (UNLIKELY(num == 0)) return; VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " - << dst_place << " by thream(" << stream << ")"; + << dst_place << " by stream(" << stream << ")"; if (dst_place == src_place) { platform::SetDeviceId(src_place.device); if (stream) { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index df5a2115c3b..96583d06571 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -75,6 +75,45 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, } } +template <> +__global__ void SimpleElemwiseDivGradCUDAKernel( + const paddle::platform::complex64* x, const paddle::platform::complex64* y, + const paddle::platform::complex64* out, + const paddle::platform::complex64* dout, int64_t size, + paddle::platform::complex64* dx, paddle::platform::complex64* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + paddle::platform::complex64 o = dout[col]; + paddle::platform::complex64 y_conj(y[col].real, -y[col].imag); + paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); + dx[col] = o / y_conj; + dy[col] = -o * out_div_y_conj; + col += blockDim.x * gridDim.x; + } +} + +template <> +__global__ void SimpleElemwiseDivGradCUDAKernel( + const paddle::platform::complex128* x, + const paddle::platform::complex128* y, + const paddle::platform::complex128* out, + const paddle::platform::complex128* dout, int64_t size, + paddle::platform::complex128* dx, paddle::platform::complex128* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + paddle::platform::complex128 o = dout[col]; + paddle::platform::complex128 y_conj(y[col].real, -y[col].imag); + paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); + dx[col] = o / y_conj; + dy[col] = -o * out_div_y_conj; + col += blockDim.x * gridDim.x; + } +} + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 1d016fba34b..d824014713d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -73,6 +73,27 @@ struct DivGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; +template <> +struct DivGradDX { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 y_conj(y.real, -y.imag); + return dout / y_conj; + } +}; + +template <> +struct DivGradDX { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 y_conj(y.real, -y.imag); + return dout / y_conj; + } +}; + template struct DivGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -80,6 +101,28 @@ struct DivGradDY { } }; +template <> +struct DivGradDY { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag); + return -dout * out_div_y_conj; + } +}; + +template <> +struct DivGradDY { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 out_div_y_conj((out / y).real, + -(out / y).imag); + return -dout * out_div_y_conj; + } +}; + template struct DivDoubleDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index b3b4b054490..5b598ab2d78 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -75,6 +75,36 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, } } +template <> +__global__ void SimpleElemwiseMulGradCUDAKernel( + const plat::complex64* x, const plat::complex64* y, + const plat::complex64* out, const plat::complex64* dout, int64_t size, + plat::complex64* dx, plat::complex64* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + plat::complex64 o = dout[col]; + dx[col] = plat::complex64(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex64(x[col].real, -x[col].imag) * o; + col += blockDim.x * gridDim.x; + } +} + +template <> +__global__ void SimpleElemwiseMulGradCUDAKernel( + const plat::complex128* x, const plat::complex128* y, + const plat::complex128* out, const plat::complex128* dout, int64_t size, + plat::complex128* dx, plat::complex128* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + plat::complex128 o = dout[col]; + dx[col] = plat::complex128(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex128(x[col].real, -x[col].imag) * o; + col += blockDim.x * gridDim.x; + } +} + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index a5bd7221c75..66a9e6dd0fc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -132,11 +132,53 @@ struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; +template <> +struct MulGradDX { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + +template <> +struct MulGradDX { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + template struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; +template <> +struct MulGradDY { + HOSTDEVICE paddle::platform::complex64 operator()( + paddle::platform::complex64 x, paddle::platform::complex64 y, + paddle::platform::complex64 out, paddle::platform::complex64 dout) const { + paddle::platform::complex64 x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + +template <> +struct MulGradDY { + HOSTDEVICE paddle::platform::complex128 operator()( + paddle::platform::complex128 x, paddle::platform::complex128 y, + paddle::platform::complex128 out, + paddle::platform::complex128 dout) const { + paddle::platform::complex128 x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index abafedf2057..d799abf92d9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -352,7 +352,8 @@ class ElementwiseOpDoubleGradWithoutDXDY "ElementwiseOpDoubleGradWithoutDXDY"); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } else { - input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); + input_data_type = + OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY"); } #ifdef PADDLE_WITH_MKLDNN @@ -364,6 +365,19 @@ class ElementwiseOpDoubleGradWithoutDXDY #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (framework::IsComplexType(expected_kernel_type.data_type_)) { + // only promote inputs’s types when contains complex input + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } }; template diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 53e07d2ba4e..c44c15adb13 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -275,6 +275,15 @@ struct CUBlas { reinterpret_cast(C), ldc)); } + static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha, + const complex64 *X, const int incX, complex64 *Y, + const int incY) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy( + handle, n, reinterpret_cast(alpha), + reinterpret_cast(X), incX, + reinterpret_cast(Y), incY)); + } + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, @@ -362,6 +371,15 @@ struct CUBlas { reinterpret_cast(C), ldc)); } + static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha, + const complex128 *X, const int incX, complex128 *Y, + const int incY) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy( + handle, n, reinterpret_cast(alpha), + reinterpret_cast(X), incX, + reinterpret_cast(Y), incY)); + } + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 32aced7619c..5ccdeabf96b 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -295,6 +295,13 @@ struct CBlas { template <> struct CBlas { + template + static void AXPY(int n, const paddle::platform::complex64 alpha, + const paddle::platform::complex64 *X, const int incX, + paddle::platform::complex64 *Y, const int incY) { + platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); + } + template static void VCOPY(ARGS... args) { platform::dynload::cblas_ccopy(args...); @@ -415,6 +422,13 @@ struct CBlas { template <> struct CBlas { + template + static void AXPY(int n, const paddle::platform::complex128 alpha, + const paddle::platform::complex128 *X, const int incX, + paddle::platform::complex128 *Y, const int incY) { + platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); + } + template static void VCOPY(ARGS... args) { platform::dynload::cblas_zcopy(args...); @@ -598,11 +612,6 @@ struct CBlas { cblas_ccopy(args...); } - template - static void VADD(ARGS... args) { - vcAdd(args...); - } - template static void AXPY(int n, const paddle::platform::complex64 alpha, const paddle::platform::complex64 *X, const int incX, @@ -641,11 +650,6 @@ struct CBlas { cblas_zcopy(args...); } - template - static void VADD(ARGS... args) { - vzAdd(args...); - } - template static void AXPY(int n, const paddle::platform::complex128 alpha, const paddle::platform::complex128 *X, const int incX, diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index c2595beb0cb..21b60119dca 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" namespace paddle { namespace operators { @@ -548,6 +550,10 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template struct MergeAverage; template struct MergeAverage; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 35bd02ad35b..26e9a0de606 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -448,6 +448,8 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 0d689d710a1..f2bee6dddc3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,4 +23,6 @@ using CUDAReduceSumGradKernel = REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel); + CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel); diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 4d9673e9646..72430a3f753 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include #include +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -126,9 +128,22 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { return ret; } } - #endif +CUDA_ATOMIC_WRAPPER(Add, complex64) { + float *real = reinterpret_cast(address); + float *imag = real + 1; + return complex64(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} + +CUDA_ATOMIC_WRAPPER(Add, complex128) { + double *real = reinterpret_cast(address); + double *imag = real + 1; + return complex128(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} + // For atomicMax USE_CUDA_ATOMIC(Max, int); USE_CUDA_ATOMIC(Max, unsigned int); diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 66032075f29..96e16894c78 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -55,6 +55,8 @@ extern void *cublas_dso_handle; #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSaxpy_v2); \ __macro(cublasDaxpy_v2); \ + __macro(cublasCaxpy_v2); \ + __macro(cublasZaxpy_v2); \ __macro(cublasSscal_v2); \ __macro(cublasDscal_v2); \ __macro(cublasScopy_v2); \ diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index bd38bae42e0..e3e84a73301 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -855,7 +855,7 @@ class OpTest(unittest.TestCase): place, no_check_set=None, inplace_atol=None): - """Chech the inplace correctness of given op (self.op_type). + """Check the inplace correctness of given op (self.op_type). Run the op twice with same inputs, one enable inplace and another disable, compare their outputs. Args: @@ -935,7 +935,7 @@ class OpTest(unittest.TestCase): fwd_res, grad_op_desc, inplace_atol=None): - """Chech the inplace correctness of given grad_op_desc. + """Check the inplace correctness of given grad_op_desc. Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs. It works like _check_forward_inplace, but the way to construct program and feed_map differs. @@ -1291,7 +1291,6 @@ class OpTest(unittest.TestCase): def _assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): - for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): # It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which # max_relative_error is 1e-7. According to the value of np.abs(a), we @@ -1544,6 +1543,10 @@ class OpTest(unittest.TestCase): grad_outputs = [] for grad_out_value in user_defined_grad_outputs: grad_outputs.append(paddle.to_tensor(grad_out_value)) + # delete the inputs which no need to calculate grad + for no_grad_val in no_grad_set: + del (inputs[no_grad_val]) + grad_inputs = paddle.grad( outputs=fluid.layers.utils.flatten(outputs), inputs=fluid.layers.utils.flatten(inputs), @@ -1612,7 +1615,7 @@ class OpTest(unittest.TestCase): targets = [ outputs[name] for name in outputs if name in output_names ] - inputs = [inputs[name] for name in inputs if name in input_to_check] + inputs = [inputs[name] for name in input_to_check if name in inputs] grad_inputs = paddle.static.gradients(targets, inputs, grad_outputs, no_grad_set) fetch_list = grad_inputs diff --git a/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py b/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py new file mode 100644 index 00000000000..106b9fe15a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle + +import paddle.fluid.core as core + + +class Optimization_ex1(paddle.nn.Layer): + def __init__(self, + shape, + dtype, + param_attr=paddle.nn.initializer.Uniform( + low=-5., high=5.)): + super(Optimization_ex1, self).__init__() + + self.theta0 = self.create_parameter( + shape=shape, attr=param_attr, dtype=dtype, is_bias=False) + self.theta1 = self.create_parameter( + shape=shape, attr=param_attr, dtype=dtype, is_bias=False) + self.A = paddle.to_tensor( + np.random.random((4, 4)).astype(dtype) + np.random.random((4, 4)) + .astype(dtype) * 1j) + self.B = paddle.to_tensor( + np.random.random((4, 4)).astype(dtype) + np.random.random( + (4, 4)).astype(dtype) * 1j, + stop_gradient=False) + print(self.A) + + def forward(self, mode=1): + jj = paddle.to_tensor(np.array([1j]).astype(np.complex64)) + if mode == 1: + # run all calc in one step + loss = paddle.sum(self.A + (self.theta0 + self.theta1 * jj)) * ( + paddle.sum(self.A + (self.theta0 + self.theta1 * jj)).conj()) + return loss.real() + elif mode == 2: + # run in two step + self.theta = self.theta0 + self.theta1 * jj + loss = paddle.sum(self.A + self.theta) * ( + paddle.sum(self.A + self.theta).conj()) + return loss.real() + elif mode == 3: + # run without param + loss = paddle.sum(self.A + self.B) * ( + paddle.sum(self.A + self.B).conj()) + return loss.real() + else: + raise NotImplementedError + + +class TestComplexGradAccumulated(unittest.TestCase): + def setUp(self): + self.devices = ['cpu'] + if core.is_compiled_with_cuda(): + self.devices.append('gpu') + self.dtypes = ['float32', 'float64'] + self.theta_size = [4, 4] + + def run_backward(self, device, dtype, mode): + paddle.set_device(device) + + myLayer = Optimization_ex1(self.theta_size, dtype) + + loss = myLayer(mode) + loss.backward() + + def test_case_one_step(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 1) + + def test_case_two_step(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 2) + + def test_case_non_param(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_complex_simplenet.py b/python/paddle/fluid/tests/unittests/test_complex_simplenet.py new file mode 100644 index 00000000000..4016f810624 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_simplenet.py @@ -0,0 +1,72 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle + +import paddle.fluid.core as core + + +class Optimization_ex1(paddle.nn.Layer): + def __init__(self, + shape, + param_attr=paddle.nn.initializer.Uniform( + low=-5., high=5.), + dtype='float32'): + super(Optimization_ex1, self).__init__() + + self.theta = self.create_parameter( + shape=shape, attr=param_attr, dtype=dtype, is_bias=False) + self.A = paddle.to_tensor( + np.random.randn(4, 4) + np.random.randn(4, 4) * 1j) + + def forward(self): + loss = paddle.add(self.theta, self.A) + return loss.real() + + +class TestComplexSimpleNet(unittest.TestCase): + def setUp(self): + self.devices = ['cpu'] + if core.is_compiled_with_cuda(): + self.devices.append('gpu') + self.iter = 10 + self.learning_rate = 0.5 + self.theta_size = [4, 4] + + def train(self, device): + paddle.set_device(device) + + myLayer = Optimization_ex1(self.theta_size) + optimizer = paddle.optimizer.Adam( + learning_rate=self.learning_rate, parameters=myLayer.parameters()) + + for itr in range(self.iter): + loss = myLayer() + loss.backward() + + optimizer.step() + optimizer.clear_grad() + + def test_train_success(self): + for dev in self.devices: + self.train(dev) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 717ffb76536..318ef9fd39a 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -431,7 +431,8 @@ class TestAddOp(unittest.TestCase): class TestComplexElementwiseAddOp(OpTest): def setUp(self): self.op_type = "elementwise_add" - self.init_base_dtype() + self.dtype = np.float64 + self.shape = (2, 3, 4, 5) self.init_input_output() self.init_grad_input_output() @@ -446,17 +447,15 @@ class TestComplexElementwiseAddOp(OpTest): self.dtype = np.float64 def init_input_output(self): - self.x = np.random.random( - (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( - (2, 3, 4, 5)).astype(self.dtype) - self.y = np.random.random( - (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( - (2, 3, 4, 5)).astype(self.dtype) + self.x = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) self.out = self.x + self.y def init_grad_input_output(self): - self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( - (2, 3, 4, 5), self.dtype) + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) self.grad_x = self.grad_out self.grad_y = self.grad_out @@ -487,5 +486,19 @@ class TestComplexElementwiseAddOp(OpTest): user_defined_grad_outputs=[self.grad_out]) +class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp): + def init_input_output(self): + self.x = np.random.random(self.shape).astype(self.dtype) + self.y = np.random.random(self.shape).astype( + self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype) + self.out = self.x + self.y + + def init_grad_input_output(self): + self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones( + self.shape, self.dtype) + self.grad_x = np.real(self.grad_out) + self.grad_y = self.grad_out + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 3cfbac8b613..f93802c47c9 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -261,5 +261,65 @@ class TestDivideOp(unittest.TestCase): self.assertEqual((np_z == z_expected).all(), True) +class TestComplexElementwiseDivOp(OpTest): + def setUp(self): + self.op_type = "elementwise_div" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x / self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = self.grad_out / np.conj(self.y) + self.grad_y = -self.grad_out * np.conj(self.x / self.y / self.y) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index fd2fe73ad51..f69fa7084ed 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -13,13 +13,17 @@ # limitations under the License. from __future__ import print_function + import unittest + import numpy as np -from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid import Program, compiler, program_guard from paddle.fluid.op import Operator -import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard + +from op_test import OpTest, skip_check_grad_ci class ElementwiseMulOp(OpTest): @@ -241,5 +245,65 @@ class TestElementwiseMulOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2) +class TestComplexElementwiseMulOp(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.init_base_dtype() + self.init_input_output() + self.init_grad_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': -1, 'use_mkldnn': False} + self.outputs = {'Out': self.out} + + def init_base_dtype(self): + self.dtype = np.float64 + + def init_input_output(self): + self.x = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.y = np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( + (2, 3, 4, 5)).astype(self.dtype) + self.out = self.x * self.y + + def init_grad_input_output(self): + self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( + (2, 3, 4, 5), self.dtype) + self.grad_x = self.grad_out * np.conj(self.y) + self.grad_y = self.grad_out * np.conj(self.x) + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[self.grad_x, self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + user_defined_grads=[self.grad_y], + user_defined_grad_outputs=[self.grad_out]) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out]) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index 71550c8f247..8b2cf56c886 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -519,7 +519,7 @@ class TestStridedSliceAPI(unittest.TestCase): np.random.randn(2, 10), place=paddle.CUDAPinnedPlace()) self.assertTrue(x.place.is_cuda_pinned_place()) y = x[:, ::2] - self.assertFalse(x.place.is_cuda_pinned_place()) + self.assertTrue(x.place.is_cuda_pinned_place()) self.assertFalse(y.place.is_cuda_pinned_place()) -- GitLab