未验证 提交 a6072055 编写于 作者: C Chen Weihang 提交者: GitHub

[Complex] Handle complex to real after type promotion (#29855)

* try to add fwd op input dtypes

* refactor base impl

* return tmp_ins after dygraph prepare data

* fix typo found in debug

* polish comment & add complex net test

* revert detail change

* fix unittest failed

* add complex kernel condition control

* fix xpu test failed & polish comment

* polish details by review comments
上级 1a304e6c
...@@ -109,5 +109,30 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -109,5 +109,30 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
} }
} }
void TransComplexToReal(const proto::VarType::Type& dst_type,
const proto::VarType::Type& src_type, const Tensor& in,
Tensor* out) {
auto& pool = platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(in.place());
out->Resize(in.dims());
// complex -> real
switch (src_type) {
case proto::VarType::COMPLEX64:
framework::VisitDataType(dst_type,
CastDataType<platform::complex64>(in, out, ctx));
break;
case proto::VarType::COMPLEX128:
framework::VisitDataType(
dst_type, CastDataType<platform::complex128>(in, out, ctx));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when casting complex tensor to real "
"data type.",
DataTypeToString(src_type)));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -33,5 +33,20 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -33,5 +33,20 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const Tensor& in, const OpKernelType& expected_kernel_type, const Tensor& in,
Tensor* out); Tensor* out);
/**
* Transform complex gradient to real data type.
*
* If complex type promotion occurred in forward op, the grad output of
* this op is complex data type, but the input variable may be real type,
* in this case the grad input need to be cast to type same with input,
* this casting executed at the end of grad op.
*
* note: call this function need to ensure that dst_type is real and
* src_type is complex
*/
void TransComplexToReal(const proto::VarType::Type& dst_type,
const proto::VarType::Type& src_type, const Tensor& in,
Tensor* out);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -1110,6 +1111,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1110,6 +1111,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// there is inplace variable has been transferred. // there is inplace variable has been transferred.
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope); TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
} }
// See [ Why need handle complex gradient to real gradient? ]
// Only handle the case where the current kernel data type is complex
if (framework::IsComplexType(kernel_type_->data_type_)) {
HandleComplexGradToRealGrad(scope, runtime_ctx);
}
if (FLAGS_enable_unused_var_check) { if (FLAGS_enable_unused_var_check) {
// skip op that uses mkldnn because it has different memory reuse strategy. // skip op that uses mkldnn because it has different memory reuse strategy.
// use attr here because some GradMakers (like ActivationGradOpMaker) add // use attr here because some GradMakers (like ActivationGradOpMaker) add
...@@ -1255,6 +1263,73 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -1255,6 +1263,73 @@ void OperatorWithKernel::TransferInplaceVarsBack(
} }
} }
void OperatorWithKernel::HandleComplexGradToRealGrad(
const Scope& scope, RuntimeContext* ctx) const {
for (auto& var_name_item : Outputs()) {
std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first];
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
// 1. find grad_var & check whether is complex tensor
auto var_name = var_name_item.second[i];
auto orig_var_name = GradOriginalVarName(var_name);
// only focus on gradient var
if (var_name == orig_var_name) {
continue;
}
auto* grad_var = output_vars[i];
// skip nullptr var
if (grad_var == nullptr) {
continue;
}
// don't process LoDTensorArray temporarily,
// add support if necessary for complex number calculations in the future
if (!VarIsTensor(*grad_var)) {
continue;
}
auto* grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(grad_var);
// skip nullptr tensor
if (grad_tensor == nullptr || !grad_tensor->IsInitialized()) {
continue;
}
// only focus on complex dtype now
auto src_type = grad_tensor->type();
if (!IsComplexType(src_type)) {
continue;
}
// 2. find forward var & check whether need to cast
auto* var = scope.FindVar(orig_var_name);
// if forward var not exists, do nothing
if (var == nullptr) {
continue;
}
if (!VarIsTensor(*var)) {
continue;
}
const auto* tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
PADDLE_ENFORCE_NOT_NULL(
tensor,
platform::errors::Unavailable(
"Forward tensor is nullptr when handle complex data to real."));
// only need record type, the allocation may have been released
auto dst_type = tensor->saved_type();
// only focus on real dtype and need casting
if (IsComplexType(dst_type)) {
continue;
}
// 3. cast complex grad to real grad
VLOG(6) << "Transform " << framework::DataTypeToString(src_type)
<< " var `" << var_name << "` to "
<< framework::DataTypeToString(dst_type)
<< " real var in static graph.";
Tensor out;
TransComplexToReal(dst_type, src_type, *grad_tensor, &out);
SetTensorToVariable(*grad_var, out, grad_var);
}
}
}
Scope* OperatorWithKernel::PrepareData( Scope* OperatorWithKernel::PrepareData(
const Scope& scope, const OpKernelType& expected_kernel_key, const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars, std::vector<std::string>* transfered_inplace_vars,
......
...@@ -545,6 +545,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -545,6 +545,9 @@ class OperatorWithKernel : public OperatorBase {
void ChooseKernel(const RuntimeContext& ctx, const Scope& scope, void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
const platform::Place& place) const; const platform::Place& place) const;
void HandleComplexGradToRealGrad(const Scope& scope,
RuntimeContext* ctx) const;
/* Inner assist methods */ /* Inner assist methods */
// indicate kernel DataType by input data. // indicate kernel DataType by input data.
// By default all input data must be same. // By default all input data must be same.
......
...@@ -197,6 +197,24 @@ class Tensor { ...@@ -197,6 +197,24 @@ class Tensor {
return type_; return type_;
} }
/**
* [Add method get the saved type of tensor]
*
* After the introduction of complex number calculations, Ops that support
* complex number calculations generally support type promotion, such as
* x(float32) + y(complex64) = out(complex64), then the type of the grad
* tensor should be dout(complex64), dx(float32), dy (complex64), but the
* type of dx to be recognized to be float32 by the grad Op relay on the type
* of forward tensor x. But many of our ops have registered InplaceInferer,
* covering the tensor memory of x with out, so as to save storage.
*
* In this case, the dim and type information recorded by x still exist,
* but because x becomes an uninitialized tensor, The type of x record cannot
* be obtained with x.type(), but the type is still valid here, so we
* add saved_type(), This method SHOULD NOT be called by general scenarios.
*/
proto::VarType::Type saved_type() const { return type_; }
// memory size returns the holding memory size in byte. // memory size returns the holding memory size in byte.
size_t memory_size() const; size_t memory_size() const;
...@@ -232,6 +250,7 @@ class Tensor { ...@@ -232,6 +250,7 @@ class Tensor {
void ResetHolderWithType(std::shared_ptr<memory::Allocation> holder, void ResetHolderWithType(std::shared_ptr<memory::Allocation> holder,
const proto::VarType::Type type); const proto::VarType::Type type);
TensorInplaceVersion& InplaceVersionCounter() { TensorInplaceVersion& InplaceVersionCounter() {
return inplace_version_counter_; return inplace_version_counter_;
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/gradient_accumulator.h" #include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
...@@ -239,6 +240,7 @@ void BasicEngine::Execute() { ...@@ -239,6 +240,7 @@ void BasicEngine::Execute() {
if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) { if (var->OverridedStopGradient() || iter->second->RefCnt() > 1) {
auto tmp_var = std::make_shared<VariableWrapper>(var->Name()); auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
tmp_var->SetType(var->Type()); tmp_var->SetType(var->Type());
tmp_var->SetForwardDataType(var->ForwardDataType());
var = tmp_var; var = tmp_var;
need_accu_var_list_.emplace_back(iter->second.get(), var); need_accu_var_list_.emplace_back(iter->second.get(), var);
VLOG(10) << "create temporary var of " << var->Name() VLOG(10) << "create temporary var of " << var->Name()
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include <utility> #include <utility>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
...@@ -356,9 +357,31 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, ...@@ -356,9 +357,31 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
} }
VLOG(5) << LayerDebugString(op.Type(), ins, outs); VLOG(5) << LayerDebugString(op.Type(), ins, outs);
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs);
prepared_op.Run(ins, outs, attrs); /**
* [ Why need temporary inputs here? ]
*
* PrepareData should not change original input tensor inplace.
* Suppose the user defines a tensor(int), enters an op to execute,
* and then this op rewrites GetExpectedKernelForVar, and converts
* this tensor to float type during execution. After the dynamic
* graph is executed, the user-defined variable will be lost, and
* the user cannot get the originally defined int tensor, because
* it has been converted to float, this should be regarded as a bug
* in certain usage scenarios
*
* In static graph mode, when op is executed, a temporary scope
* `transfer_scope` is created before PrepareData, the data after
* transform is stored in the temporary scope, and then discarded
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto expected_kernel_key =
GetExpectedKernelKey<VarType>(ins, outs, *op_kernel, place, attrs);
auto prepared_op = PreparedOp::Prepare(*op_kernel, expected_kernel_key);
auto tmp_ins = PrepareData<VarType>(*op_kernel, ins, expected_kernel_key);
prepared_op.Run(tmp_ins, outs, attrs);
VLOG(4) << LayerDebugString(op.Type(), ins, outs); VLOG(4) << LayerDebugString(op.Type(), ins, outs);
} }
......
...@@ -201,6 +201,14 @@ class VarBase { ...@@ -201,6 +201,14 @@ class VarBase {
framework::proto::VarType::Type DataType() const { return var_->DataType(); } framework::proto::VarType::Type DataType() const { return var_->DataType(); }
void SetForwardDataType(framework::proto::VarType::Type data_type) {
var_->SetForwardDataType(data_type);
}
framework::proto::VarType::Type ForwardDataType() const {
return var_->ForwardDataType();
}
const platform::Place Place() const { return var_->Place(); } const platform::Place Place() const { return var_->Place(); }
void ClearGradient(); void ClearGradient();
......
...@@ -857,6 +857,7 @@ void PartialGradTask::RunEachOp(OpBase *op) { ...@@ -857,6 +857,7 @@ void PartialGradTask::RunEachOp(OpBase *op) {
auto new_grad_var = std::make_shared<VarBase>(true, grad_var->Name()); auto new_grad_var = std::make_shared<VarBase>(true, grad_var->Name());
new_grad_var->SetOverridedStopGradient(false); new_grad_var->SetOverridedStopGradient(false);
new_grad_var->SetForwardDataType(grad_var->ForwardDataType());
if (new_grad_var_iter->second->TotalRefCnt() > 1) { if (new_grad_var_iter->second->TotalRefCnt() > 1) {
grads_to_accumulate_.emplace_back(new_grad_var_iter->second.get(), grads_to_accumulate_.emplace_back(new_grad_var_iter->second.get(),
new_grad_var->SharedVar()); new_grad_var->SharedVar());
......
...@@ -16,12 +16,10 @@ ...@@ -16,12 +16,10 @@
#include <sstream> #include <sstream>
#include "paddle/fluid/imperative/execution_context.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h" #include "paddle/fluid/imperative/infer_var_type_context.h"
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -36,26 +34,32 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { ...@@ -36,26 +34,32 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
} }
template <typename VarType> template <typename VarType>
static void PrepareData(const platform::Place& place, static void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
const NameVarMap<VarType>& ins, for (auto& pair : outs) {
const framework::OperatorWithKernel& op, for (auto& var : pair.second) {
const framework::OpKernelType& expected_kernel_key) { if (var == nullptr) {
for (const auto& name_pair : ins) { continue;
for (const auto& var_base : name_pair.second) { }
const auto* tensor = GetTensorFromVar(var_base->Var()); if (var->ForwardDataType() ==
if (tensor && tensor->IsInitialized()) { static_cast<framework::proto::VarType::Type>(-1)) {
auto kernel_type_for_var = op.GetKernelTypeForVar( VLOG(6) << "Var (" << var->Name()
name_pair.first, *tensor, expected_kernel_key); << ")'s forward data type is not set.";
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue; continue;
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
} }
if (!framework::IsComplexType(var->DataType()) ||
framework::IsComplexType(var->ForwardDataType())) {
continue;
}
const auto* tensor = GetTensorFromVar(var->Var());
if (tensor && tensor->IsInitialized()) {
VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType())
<< " var `" << var->Name() << "` to "
<< framework::DataTypeToString(var->ForwardDataType())
<< " real var in dynamic graph.";
framework::Tensor out;
framework::TransComplexToReal(var->ForwardDataType(), var->DataType(),
*tensor, &out);
SetTensorToVariable(var->Var(), out, var->MutableVar());
} }
} }
} }
...@@ -63,18 +67,20 @@ static void PrepareData(const platform::Place& place, ...@@ -63,18 +67,20 @@ static void PrepareData(const platform::Place& place,
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx) platform::DeviceContext* dev_ctx)
: op_(op), ctx_(ctx), func_(func), dev_ctx_(dev_ctx) {} : op_(op),
ctx_(ctx),
kernel_type_(kernel_type),
func_(func),
dev_ctx_(dev_ctx) {}
template <typename VarType> PreparedOp PreparedOp::Prepare(
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
platform::Place place, const framework::OpKernelType& expected_kernel_key) {
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(expected_kernel_key.place_);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = op.AllOpKernels(); auto& all_op_kernels = op.AllOpKernels();
...@@ -89,62 +95,20 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins, ...@@ -89,62 +95,20 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
// Const qualifier of Attrs had to be discarded to overwrite it.
if (FLAGS_use_mkldnn) {
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
mutable_op_attrs = attrs;
}
#endif
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU
if (kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key.place_)) {
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
// TODO(jiabin): Add operator.cc's line 1000 part back when we need that case // TODO(jiabin): Add operator.cc's line 1000 part back when we need that case
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), PADDLE_ENFORCE_NE(kernel_iter, kernels.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Operator %s does not have kernel for %s.", op.Type(), "Operator %s does not have kernel for %s.", op.Type(),
KernelTypeToString(expected_kernel_key))); KernelTypeToString(expected_kernel_key)));
if (!(expected_kernel_key.place_ == place)) { return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx);
dev_ctx = pool.Get(expected_kernel_key.place_);
place = dev_ctx->GetPlace();
}
PrepareData<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VarBase>(ins, outs, op, place, attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VariableWrapper>(ins, outs, op, place, attrs);
} }
template <typename VarType> template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) { const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs) {
...@@ -158,19 +122,36 @@ static void PreparedOpRunImpl( ...@@ -158,19 +122,36 @@ static void PreparedOpRunImpl(
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs)); attrs));
/**
* [ Why need handle complex gradient to real gradient? ]
*
* After the introduction of complex number calculations, Ops that support
* complex number calculations generally support type promotion, such as
* x(float32) + y(complex64) = out(complex64), then the type of the grad
* tensor should be dout(complex64), dx(float32), dy (complex64).
*
* But because the dout is complex64, the dx is also complex64 after
* grad op kernel executed, we need to recognize this situation and
* convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
*/
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
} }
void PreparedOp::Run(const NameVarMap<VarBase>& ins, void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs, const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, ins, outs, attrs); PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs);
} }
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs, const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) { const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_, ins, outs, PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
attrs); ins, outs, attrs);
} }
} // namespace imperative } // namespace imperative
......
...@@ -21,9 +21,12 @@ ...@@ -21,9 +21,12 @@
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Tensor; class Tensor;
...@@ -39,24 +42,120 @@ namespace imperative { ...@@ -39,24 +42,120 @@ namespace imperative {
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
template <typename VarType>
static void SetForwardDataTypeOfGradVar(const std::shared_ptr<VarType>& var);
template <>
void SetForwardDataTypeOfGradVar<VariableWrapper>(
const std::shared_ptr<VariableWrapper>& var) {
if (var->HasGradVar()) {
auto grad_var = var->GetGradVar();
VLOG(6) << "Set grad var (" << grad_var->Name() << ") dtype to ("
<< framework::DataTypeToString(var->DataType()) << ").";
grad_var->SetForwardDataType(var->DataType());
}
}
template <>
void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
if (var->HasGradVar()) {
auto& shared_var = var->SharedVar();
SetForwardDataTypeOfGradVar<VariableWrapper>(shared_var);
}
}
#ifdef PADDLE_WITH_XPU
static void ReplaceXPUKernelIfNotExists(
const framework::OperatorWithKernel& op,
framework::OpKernelType* expected_kernel_key) {
auto& all_op_kernels = op.AllOpKernels();
auto kernels_iter = all_op_kernels.find(op.Type());
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::NotFound(
"There are no kernels which are registered in the %s operator.",
op.Type()));
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(*expected_kernel_key);
if (kernel_iter == kernels.end() &&
is_xpu_place(expected_kernel_key->place_)) {
expected_kernel_key->place_ = platform::CPUPlace();
}
}
#endif
template <typename VarType>
framework::OpKernelType GetExpectedKernelKey(
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const platform::Place& place,
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
// Const qualifier of Attrs had to be discarded to overwrite it.
if (FLAGS_use_mkldnn) {
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
mutable_op_attrs = attrs;
}
#endif
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
#ifdef PADDLE_WITH_XPU
ReplaceXPUKernelIfNotExists(op, &expected_kernel_key);
#endif
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
return expected_kernel_key;
}
template <typename VarType>
NameVarMap<VarType> PrepareData(
const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins,
const framework::OpKernelType& expected_kernel_key) {
NameVarMap<VarType> tmp_ins(ins);
for (auto& name_pair : tmp_ins) {
for (auto& var_base : name_pair.second) {
const auto* tensor = GetTensorFromVar(var_base->Var());
SetForwardDataTypeOfGradVar(var_base);
if (tensor && tensor->IsInitialized()) {
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue;
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
framework::Tensor out;
auto tmp_var = std::make_shared<VarType>(var_base->Name());
tmp_var->SetType(var_base->Type());
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar());
var_base = tmp_var;
}
}
}
}
return tmp_ins;
}
class PreparedOp { class PreparedOp {
public: public:
PreparedOp(const framework::OperatorBase& op, PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type,
const framework::OperatorWithKernel::OpKernelFunc& func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx); platform::DeviceContext* dev_ctx);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins, static PreparedOp Prepare(const framework::OperatorWithKernel& op,
const NameVarMap<VarBase>& outs, const framework::OpKernelType& expected_kernel_key);
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
...@@ -68,6 +167,7 @@ class PreparedOp { ...@@ -68,6 +167,7 @@ class PreparedOp {
private: private:
const framework::OperatorBase& op_; const framework::OperatorBase& op_;
const framework::RuntimeContext& ctx_; const framework::RuntimeContext& ctx_;
framework::OpKernelType kernel_type_;
framework::OperatorWithKernel::OpKernelFunc func_; framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_; platform::DeviceContext* dev_ctx_;
}; };
......
...@@ -32,27 +32,6 @@ namespace framework = paddle::framework; ...@@ -32,27 +32,6 @@ namespace framework = paddle::framework;
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static framework::RuntimeContext PrepareRuntimeContext(
const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
framework::VariableValueMap inputs, outputs;
for (auto& in_pair : ins) {
auto& in_ctx = inputs[in_pair.first];
in_ctx.reserve(in_pair.second.size());
for (auto& in_var : in_pair.second) {
in_ctx.emplace_back(in_var->MutableVar());
}
}
for (auto& out_pair : outs) {
auto& out_ctx = outputs[out_pair.first];
out_ctx.reserve(out_pair.second.size());
for (auto& out_var : out_pair.second) {
out_ctx.emplace_back(out_var->MutableVar());
}
}
return framework::RuntimeContext(std::move(inputs), std::move(outputs));
}
static framework::VariableNameMap CreateVarNameMap( static framework::VariableNameMap CreateVarNameMap(
const framework::OpInfo& op_info, const std::string& op_type, const framework::OpInfo& op_info, const std::string& op_type,
const NameVarBaseMap& varbase_map, bool is_input) { const NameVarBaseMap& varbase_map, bool is_input) {
...@@ -111,11 +90,12 @@ TEST(test_prepare_op, test_prepare_op) { ...@@ -111,11 +90,12 @@ TEST(test_prepare_op, test_prepare_op) {
CreateVarNameMap(info, "split", outs, false); CreateVarNameMap(info, "split", outs, false);
auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map, auto op = framework::OpRegistry::CreateOp("split", var_in_map, var_out_map,
split_attr_map); split_attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs); auto expected_kernel_key = GetExpectedKernelKey<imperative::VarBase>(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), place,
split_attr_map);
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op), dynamic_cast<framework::OperatorWithKernel&>(*op),
place, split_attr_map)); expected_kernel_key));
} }
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
...@@ -161,13 +141,15 @@ TEST(test_prepare_op, test_prepare_data) { ...@@ -161,13 +141,15 @@ TEST(test_prepare_op, test_prepare_data) {
CreateVarNameMap(info, op_type, outs, false); CreateVarNameMap(info, op_type, outs, false);
auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
attr_map); attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it can be transformed to GPU place // test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp::Prepare( auto expected_kernel_key = GetExpectedKernelKey<imperative::VarBase>(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
attr_map); attr_map);
for (const auto& name_pair : ins) { imperative::NameVarBaseMap tmp_ins = PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
expected_kernel_key);
for (const auto& name_pair : tmp_ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
vb->Var().Get<framework::LoDTensor>().place(), gpu_place)); vb->Var().Get<framework::LoDTensor>().place(), gpu_place));
...@@ -208,13 +190,15 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { ...@@ -208,13 +190,15 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map, auto op = framework::OpRegistry::CreateOp(op_type, var_in_map, var_out_map,
attr_map); attr_map);
framework::RuntimeContext ctx = PrepareRuntimeContext(ins, outs);
// test if it never transferred on GPU place // test if it never transferred on GPU place
PreparedOp prepared_op = PreparedOp::Prepare( auto expected_kernel_key = GetExpectedKernelKey<imperative::VarBase>(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
attr_map); attr_map);
for (const auto& name_pair : ins) { imperative::NameVarBaseMap tmp_ins = PrepareData<imperative::VarBase>(
dynamic_cast<framework::OperatorWithKernel&>(*op), ins,
expected_kernel_key);
for (const auto& name_pair : tmp_ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
vb->Var().Get<framework::LoDTensor>().place(), cpu_place)); vb->Var().Get<framework::LoDTensor>().place(), cpu_place));
......
...@@ -122,10 +122,6 @@ class VariableWrapper { ...@@ -122,10 +122,6 @@ class VariableWrapper {
framework::proto::VarType::Type Type() const { return type_; } framework::proto::VarType::Type Type() const { return type_; }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
}
std::shared_ptr<VariableWrapper> GetGradVar() const { std::shared_ptr<VariableWrapper> GetGradVar() const {
return grad_var_.lock(); return grad_var_.lock();
} }
...@@ -140,6 +136,10 @@ class VariableWrapper { ...@@ -140,6 +136,10 @@ class VariableWrapper {
bool HasGradVar() const { return !grad_var_.expired(); } bool HasGradVar() const { return !grad_var_.expired(); }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
}
framework::proto::VarType::Type DataType() const { framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr; const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) { if (var_.IsInitialized()) {
...@@ -160,6 +160,14 @@ class VariableWrapper { ...@@ -160,6 +160,14 @@ class VariableWrapper {
} }
} }
void SetForwardDataType(framework::proto::VarType::Type data_type) {
fwd_data_type_ = data_type;
}
framework::proto::VarType::Type ForwardDataType() const {
return fwd_data_type_;
}
const platform::Place Place() const { const platform::Place Place() const {
const framework::Tensor* tensor = nullptr; const framework::Tensor* tensor = nullptr;
auto place = auto place =
...@@ -306,6 +314,13 @@ class VariableWrapper { ...@@ -306,6 +314,13 @@ class VariableWrapper {
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR}; framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32}; framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
// See [ Why need handle complex gradient to real gradient? ]
// Used for grad var to get the data type of its corresponding forward var,
// if inconsistent, the data type of grad var needs to be casted to be
// consistent with forward var
framework::proto::VarType::Type fwd_data_type_{
static_cast<framework::proto::VarType::Type>(-1)};
std::weak_ptr<VariableWrapper> grad_var_; std::weak_ptr<VariableWrapper> grad_var_;
std::weak_ptr<GradOpNode> grad_node_; std::weak_ptr<GradOpNode> grad_node_;
......
...@@ -269,7 +269,7 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>( ...@@ -269,7 +269,7 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
if (UNLIKELY(num == 0)) return; if (UNLIKELY(num == 0)) return;
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")"; << dst_place << " by stream(" << stream << ")";
if (dst_place == src_place) { if (dst_place == src_place) {
platform::SetDeviceId(src_place.device); platform::SetDeviceId(src_place.device);
if (stream) { if (stream) {
......
...@@ -352,7 +352,8 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -352,7 +352,8 @@ class ElementwiseOpDoubleGradWithoutDXDY
"ElementwiseOpDoubleGradWithoutDXDY"); "ElementwiseOpDoubleGradWithoutDXDY");
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
} else { } else {
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -364,6 +365,19 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -364,6 +365,19 @@ class ElementwiseOpDoubleGradWithoutDXDY
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
template <typename T> template <typename T>
......
...@@ -855,7 +855,7 @@ class OpTest(unittest.TestCase): ...@@ -855,7 +855,7 @@ class OpTest(unittest.TestCase):
place, place,
no_check_set=None, no_check_set=None,
inplace_atol=None): inplace_atol=None):
"""Chech the inplace correctness of given op (self.op_type). """Check the inplace correctness of given op (self.op_type).
Run the op twice with same inputs, one enable inplace and another disable, compare their outputs. Run the op twice with same inputs, one enable inplace and another disable, compare their outputs.
Args: Args:
...@@ -935,7 +935,7 @@ class OpTest(unittest.TestCase): ...@@ -935,7 +935,7 @@ class OpTest(unittest.TestCase):
fwd_res, fwd_res,
grad_op_desc, grad_op_desc,
inplace_atol=None): inplace_atol=None):
"""Chech the inplace correctness of given grad_op_desc. """Check the inplace correctness of given grad_op_desc.
Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs. Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs.
It works like _check_forward_inplace, but the way to construct program and feed_map differs. It works like _check_forward_inplace, but the way to construct program and feed_map differs.
...@@ -1291,7 +1291,6 @@ class OpTest(unittest.TestCase): ...@@ -1291,7 +1291,6 @@ class OpTest(unittest.TestCase):
def _assert_is_close(self, numeric_grads, analytic_grads, names, def _assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix): max_relative_error, msg_prefix):
for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names):
# It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which # It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
# max_relative_error is 1e-7. According to the value of np.abs(a), we # max_relative_error is 1e-7. According to the value of np.abs(a), we
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
class Optimization_ex1(paddle.nn.Layer):
def __init__(self,
shape,
param_attr=paddle.nn.initializer.Uniform(
low=-5., high=5.),
dtype='float32'):
super(Optimization_ex1, self).__init__()
self.theta = self.create_parameter(
shape=shape, attr=param_attr, dtype=dtype, is_bias=False)
self.A = paddle.to_tensor(
np.random.randn(4, 4) + np.random.randn(4, 4) * 1j)
def forward(self):
loss = paddle.add(self.theta, self.A)
return loss.real()
class TestComplexSimpleNet(unittest.TestCase):
def setUp(self):
self.devices = ['cpu']
if core.is_compiled_with_cuda():
self.devices.append('gpu')
self.iter = 10
self.learning_rate = 0.5
self.theta_size = [4, 4]
def train(self, device):
paddle.set_device(device)
myLayer = Optimization_ex1(self.theta_size)
optimizer = paddle.optimizer.Adam(
learning_rate=self.learning_rate, parameters=myLayer.parameters())
for itr in range(self.iter):
loss = myLayer()
loss.backward()
optimizer.step()
optimizer.clear_grad()
def test_train_success(self):
for dev in self.devices:
self.train(dev)
if __name__ == '__main__':
unittest.main()
...@@ -441,7 +441,8 @@ class TestAddOp(unittest.TestCase): ...@@ -441,7 +441,8 @@ class TestAddOp(unittest.TestCase):
class TestComplexElementwiseAddOp(OpTest): class TestComplexElementwiseAddOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_add" self.op_type = "elementwise_add"
self.init_base_dtype() self.dtype = np.float64
self.shape = (2, 3, 4, 5)
self.init_input_output() self.init_input_output()
self.init_grad_input_output() self.init_grad_input_output()
...@@ -456,17 +457,15 @@ class TestComplexElementwiseAddOp(OpTest): ...@@ -456,17 +457,15 @@ class TestComplexElementwiseAddOp(OpTest):
self.dtype = np.float64 self.dtype = np.float64
def init_input_output(self): def init_input_output(self):
self.x = np.random.random( self.x = np.random.random(self.shape).astype(
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random( self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
(2, 3, 4, 5)).astype(self.dtype) self.y = np.random.random(self.shape).astype(
self.y = np.random.random( self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
(2, 3, 4, 5)).astype(self.dtype) + 1J * np.random.random(
(2, 3, 4, 5)).astype(self.dtype)
self.out = self.x + self.y self.out = self.x + self.y
def init_grad_input_output(self): def init_grad_input_output(self):
self.grad_out = np.ones((2, 3, 4, 5), self.dtype) + 1J * np.ones( self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
(2, 3, 4, 5), self.dtype) self.shape, self.dtype)
self.grad_x = self.grad_out self.grad_x = self.grad_out
self.grad_y = self.grad_out self.grad_y = self.grad_out
...@@ -497,6 +496,20 @@ class TestComplexElementwiseAddOp(OpTest): ...@@ -497,6 +496,20 @@ class TestComplexElementwiseAddOp(OpTest):
user_defined_grad_outputs=[self.grad_out]) user_defined_grad_outputs=[self.grad_out])
class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp):
def init_input_output(self):
self.x = np.random.random(self.shape).astype(self.dtype)
self.y = np.random.random(self.shape).astype(
self.dtype) + 1J * np.random.random(self.shape).astype(self.dtype)
self.out = self.x + self.y
def init_grad_input_output(self):
self.grad_out = np.ones(self.shape, self.dtype) + 1J * np.ones(
self.shape, self.dtype)
self.grad_x = np.real(self.grad_out)
self.grad_y = self.grad_out
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -519,7 +519,7 @@ class TestStridedSliceAPI(unittest.TestCase): ...@@ -519,7 +519,7 @@ class TestStridedSliceAPI(unittest.TestCase):
np.random.randn(2, 10), place=paddle.CUDAPinnedPlace()) np.random.randn(2, 10), place=paddle.CUDAPinnedPlace())
self.assertTrue(x.place.is_cuda_pinned_place()) self.assertTrue(x.place.is_cuda_pinned_place())
y = x[:, ::2] y = x[:, ::2]
self.assertFalse(x.place.is_cuda_pinned_place()) self.assertTrue(x.place.is_cuda_pinned_place())
self.assertFalse(y.place.is_cuda_pinned_place()) self.assertFalse(y.place.is_cuda_pinned_place())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册