未验证 提交 3da97b45 编写于 作者: C Chen Weihang 提交者: GitHub

[Eager] Polish generated code details (#42512)

* polish code details

* remove needless prefix

* revert needless change

* polish grad func generated format
上级 5acd764d
...@@ -178,7 +178,7 @@ def GetForwardFunctionName(string): ...@@ -178,7 +178,7 @@ def GetForwardFunctionName(string):
def GetIndent(num): def GetIndent(num):
tab = " " tab = " "
return "".join([tab for i in range(num)]) return "".join([tab for i in range(num)])
......
...@@ -55,58 +55,49 @@ def ParseArguments(): ...@@ -55,58 +55,49 @@ def ParseArguments():
## Code Gen Templates ## ## Code Gen Templates ##
######################## ########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \ SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
""" """ void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{ {} = egr::TensorWrapper({}, full_reserved, {});
{} = egr::TensorWrapper({}, full_reserved, {}); }}
}}
""" """
PLAIN_TENSOR_MEMBER_TEMPLATE = \ PLAIN_TENSOR_MEMBER_TEMPLATE = \
""" """ egr::TensorWrapper {};
egr::TensorWrapper {};
""" """
CLEAR_TENSOR_WRAPPER_TEMPLATE = \ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
""" """ {}.clear();
{}.clear();
""" """
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \ SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
""" """ void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{ for(const auto& eager_tensor : {}) {{
for(const auto& eager_tensor : {}) {{ {}.emplace_back(egr::TensorWrapper(eager_tensor, full_reserved, {}));
{}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) ); }};
}}; }}
}}
""" """
VECTOR_TENSOR_MEMBER_TEMPLATE = \ VECTOR_TENSOR_MEMBER_TEMPLATE = \
""" """ std::vector<egr::TensorWrapper> {};
std::vector<egr::TensorWrapper> {};
""" """
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \ CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \
""" """ for (auto& tw : {}) {{
for (auto& tw : {}) {{ tw.clear();
tw.clear(); }}
}}
""" """
SET_ATTR_METHOD_TEMPLATE = \ SET_ATTR_METHOD_TEMPLATE = \
""" """ void SetAttribute{}({} {}) {{
void SetAttribute{}({} {}) {{ {} = {};
{} = {}; }}
}}
""" """
ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = \ ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = \
""" """ {} {} = {};
{} {} = {};
""" """
ATTRIBUTE_MEMBER_TEMPLATE = \ ATTRIBUTE_MEMBER_TEMPLATE = \
""" """ {} {};
{} {};
""" """
NODE_DECLARATION_TEMPLATE = \ NODE_DECLARATION_TEMPLATE = \
...@@ -114,140 +105,114 @@ NODE_DECLARATION_TEMPLATE = \ ...@@ -114,140 +105,114 @@ NODE_DECLARATION_TEMPLATE = \
class {} : public egr::GradNodeBase {{ class {} : public egr::GradNodeBase {{
public: public:
{}() : egr::GradNodeBase() {{}} {}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~{}() override = default; ~{}() override = default;
virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> operator()( virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> operator()(
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph = false, bool is_new_grad = false) override; paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph = false, bool is_new_grad = false) override;
std::string name() override {{ return \"{}\"; }} std::string name() override {{ return \"{}\"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
{} {}
SetIsTensorWrappersCleared(true); SetIsTensorWrappersCleared(true);
}} }}
std::shared_ptr<GradNodeBase> Copy() const override {{ std::shared_ptr<GradNodeBase> Copy() const override {{
auto copied_node = std::shared_ptr<{}>(new {}(*this)); auto copied_node = std::shared_ptr<{}>(new {}(*this));
return copied_node;
return copied_node;
}} }}
// SetTensorWrapperX, SetTensorWrapperY, ... // SetTensorWrapperX, SetTensorWrapperY, ...
{} {}
// SetAttributes // SetAttributes
{} {}
private: private:
// TensorWrappers // TensorWrappers
{} {}
// Attributes // Attributes
{} {}}};
}};
""" """
GRAD_FUNCTION_TEMPLATE = \ GRAD_FUNCTION_TEMPLATE = \
""" """
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> {}::operator()(paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) {{ paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> {}::operator()(paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) {{
// Fill Zero For GradIn Tensors // Fill Zero For GradIn Tensors
{} {}
// Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads);
// Apply Gradient Hooks // Collect GradIn Tensors, Attrs and Recovered TensorWrappers
auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{} {}
// Call grad_api function // Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\"; VLOG(3) << \"Final State Running: \" << \"{}\";
{} {}
// Get Output // Get Output
{} {}
// Get GradIn autograd_meta
// Get GradIn autograd_meta
{} {}
// Get GradOut autograd_meta
// Get GradOut autograd_meta
{} {}
// Compute Require Grad
// Compute Require Grad
{} {}
// Create Grad Node
// Create Grad Node
{} {}
// Return
// Return
{} {}
}} }}
""" """
FORWARD_FUNCTION_TEMPLATE = \ FORWARD_FUNCTION_TEMPLATE = \
""" """
{} {}({}) {{ {} {}({}) {{
// Dygraph Record Event // Dygraph Record Event
{} {}
// AMP Logic // AMP Logic
{} {}
// Get Input AutoGradMeta
// Get Input AutoGradMeta
{} {}
// Set Device Id // Forward API Call
auto place = egr::Controller::Instance().GetExpectedPlace(); VLOG(3) << \"Final State Running: \" << \"{}\";
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
}}
// Forward API Call
VLOG(3) << \"Final State Running: \" << \"{}\";
{} {}
// Get Outputs // Get Outputs
{} {}
// Get Output AutoGradMeta // Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version
{}
{}
// Node Creation
{} {}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Returns // Check Inplace if needed
return {}; {}{}
// Node Creation
{}
// Returns
return {};
}} }}
""" """
FORWARD_BODY_TEMPLATE = \ FORWARD_BODY_TEMPLATE = \
""" """ if(require_any_grad) {{
if(require_any_grad) {{
{} {}
egr::EagerUtils::PassStopGradient({}); egr::EagerUtils::PassStopGradient({});
// Node Construction // Node Construction
{} {}
// SetAttributes // SetAttributes if needed
{} {}
// Set TensorWrappers for Forward Inputs // Set TensorWrappers for Forward Inputs if needed
{} {}
// SetGradOutMeta & SetEdges // SetGradOutMeta & SetEdges
{} {}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{} {}
{} {}
{} {}
{} {}
// Set TensorWrappers for Forward Outputs // Set TensorWrappers for Forward Outputs if needed
{} {}
}} }}
""" """
NAMESPACE_WRAPPER_TEMPLATE = \ NAMESPACE_WRAPPER_TEMPLATE = \
...@@ -340,30 +305,29 @@ extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_ ...@@ -340,30 +305,29 @@ extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_
CHECK_INPLACE_TEMPLATE = \ CHECK_INPLACE_TEMPLATE = \
""" """
egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n egr::EagerUtils::CheckInplace({}, {}, require_any_grad);
""" """
BUMP_INPLACE_VERSION_TEMPLATE = \ BUMP_INPLACE_VERSION_TEMPLATE = \
""" """
// Bump Inplace Version // Bump Inplace Version
{}.bump_inplace_version(); {}.bump_inplace_version();
VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";
""" """
AMP_LOGIC_TEMPLATE = \ AMP_LOGIC_TEMPLATE = \
""" """ if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{ VLOG(5) << "Check and Prepare For AMP";
VLOG(5) << "Check and Prepare For AMP"; {}
{} paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {}; {}
{} {}
{} {}
{} {{
{{ paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), paddle::imperative::AmpLevel::O0);
paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), paddle::imperative::AmpLevel::O0); {}
{}
}}
}} }}
}}
""" """
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \ CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \
...@@ -1045,7 +1009,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1045,7 +1009,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.GenerateNodeCreationCodes() self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str node_creation_str = self.node_creation_str
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name) forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
# Forward amp logic # Forward amp logic
...@@ -1055,8 +1019,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1055,8 +1019,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_tensors_vector_optional_list_str = "".join( amp_tensors_vector_optional_list_str = "".join(
amp_tensors_vector_optional_list) amp_tensors_vector_optional_list)
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n" amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
amp_autocast_list_str = " ".join( amp_autocast_list_str = " ".join(
amp_autocast_list) + " " + " ".join( amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list) amp_autocast_optional_list)
amp_inputs_call_args_str = ", ".join(amp_inputs_call_list) amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});" amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});"
......
...@@ -66,12 +66,13 @@ PARSE_PYTHON_C_TENSORS_TEMPLATE = \ ...@@ -66,12 +66,13 @@ PARSE_PYTHON_C_TENSORS_TEMPLATE = \
PARSE_PYTHON_C_ARGS_TEMPLATE = \ PARSE_PYTHON_C_ARGS_TEMPLATE = \
""" PyObject* {}_obj = PyTuple_GET_ITEM(args, {});\n """ PyObject* {}_obj = PyTuple_GET_ITEM(args, {});
{} {} = {}({}_obj, \"{}\", {});\n""" {} {} = {}({}_obj, \"{}\", {});
"""
RECORD_EVENT_TEMPLATE = \ RECORD_EVENT_TEMPLATE = \
" paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::Operator, 1);" "paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::Operator, 1);"
RETURN_INPLACE_PYOBJECT_TEMPLATE = \ RETURN_INPLACE_PYOBJECT_TEMPLATE = \
...@@ -84,33 +85,27 @@ RETURN_INPLACE_PYOBJECT_TEMPLATE = \ ...@@ -84,33 +85,27 @@ RETURN_INPLACE_PYOBJECT_TEMPLATE = \
PYTHON_C_FUNCTION_TEMPLATE = \ PYTHON_C_FUNCTION_TEMPLATE = \
""" """
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) {{
{{
{} {}
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try try {{
{{
VLOG(6) << "Running Eager Final State API: {}"; VLOG(6) << "Running Eager Final State API: {}";
// Get EagerTensors from args // Get EagerTensors from args
{} {}
// Parse Attributes if needed
// Parse Attributes
{} {}
tstate = PyEval_SaveThread(); tstate = PyEval_SaveThread();
// Set Device ID // Set Device ID
{} {}
auto out = {}({}); auto out = {}({});
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
tstate = nullptr; tstate = nullptr;
{} {}
}} }} catch(...) {{
catch(...) {{
if (tstate) {{ if (tstate) {{
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
}} }}
...@@ -118,13 +113,10 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -118,13 +113,10 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
return nullptr; return nullptr;
}} }}
}} }}
""" """
FUNCTION_SET_DEVICE_TEMPLATE = \ FUNCTION_SET_DEVICE_TEMPLATE = \
""" """{} if (paddle::platform::is_gpu_place(place)) {{
{}
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device); phi::backends::gpu::SetDeviceId(place.device);
VLOG(1) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device; VLOG(1) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device;
...@@ -309,7 +301,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -309,7 +301,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
"false") "false")
parse_attributes_str = "" parse_attributes_str = ""
expected_place_str = "auto place = egr::Controller::Instance().GetExpectedPlace();\n" expected_place_str = " auto place = egr::Controller::Instance().GetExpectedPlace();\n"
# Generate Python-C Attributes Parsing Logic # Generate Python-C Attributes Parsing Logic
for name, atype, _, pos in orig_forward_attrs_list: for name, atype, _, pos in orig_forward_attrs_list:
......
...@@ -22,13 +22,10 @@ using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta; ...@@ -22,13 +22,10 @@ using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta;
/** /**
* *
* AutogradMeta is what record the backward info for tensor. When we run * AutogradMeta is what record the backward info for tensor. When we run
* computation * computation graph eagerly, we can not build a static paddle program like
* graph eagerly, we can not build a static paddle program like static mode do, * static mode do, so we need a new method to record forward info to trace
* so we * backward when we finish all forward computation. This require our
* need a new method to record forward info to trace backward when we finish all * AutogradMeta class record following main members
* forward
* computation. This require our AutogradMeta class record following main
* members
* *
* 1. grad_op: * 1. grad_op:
* Grad_op indicate the grad operation of the forward op * Grad_op indicate the grad operation of the forward op
...@@ -38,28 +35,24 @@ using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta; ...@@ -38,28 +35,24 @@ using AbstractAutogradMeta = paddle::experimental::AbstractAutogradMeta;
* backward computation * backward computation
* *
* NOTE: grad should only be available when current tensor is a leaf tensor, and * NOTE: grad should only be available when current tensor is a leaf tensor, and
* for non-leaf * for non-leaf tensor grad is only available while user set `retain_grad`
* tensor grad is only available while user set `retain_grad` option as `true`. * option as `true`.
* *
* TODO(jiabin) : support hooks * TODO(jiabin) : support hooks
* 3. hooks: * 3. hooks:
* Hooks are some computation logic which only attached with backward operation, * Hooks are some computation logic which only attached with backward operation,
* it registered * it registered by user and run before accumulator.
* by user and run before accumulator.
* *
* 4.overrided_stop_gradient_ * 4. overrided_stop_gradient_
* This member is used to finish some auto-prune related work, which indicate * This member is used to finish some auto-prune related work, which indicate
* user set stop_gradient * user set stop_gradient should overrided the result indicated by framework.
* should overrided the result indicated by framework. All non-parameter * All non-parameter tensor's stop_gradient properties should be true. We will
* tensor's stop_gradient * pass stop_gradient when we find one who need it.
* properties should be true. We will pass stop_gradient when we find one who
* need it.
* *
* NOTE: AutogradMeta is inherited from AbstractAutogradMeta which is defined * NOTE: AutogradMeta is inherited from AbstractAutogradMeta which is defined
* in tensor's deps, * in tensor's deps, we did this to avoid additional dependency on Autograd.
* we did this to avoid additional dependency on Autograd. In eager execution, * In eager execution, we will cast AbstractAutogradMeta as AutogradMeta to use
* we will cast * it.
* AbstractAutogradMeta as AutogradMeta to use it.
* *
* **/ * **/
...@@ -119,7 +112,7 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -119,7 +112,7 @@ class AutogradMeta : public AbstractAutogradMeta {
return std::make_pair(out_slot_id_, out_rank_); return std::make_pair(out_slot_id_, out_rank_);
} }
bool IsInitialized() { return grad_node_.get(); } bool IsInitialized() const { return grad_node_.get(); }
// TODO(jiabin): This may cause error, since -1 still can indication true; // TODO(jiabin): This may cause error, since -1 still can indication true;
bool StopGradient() const { return stop_gradient_ != 0; } bool StopGradient() const { return stop_gradient_ != 0; }
...@@ -140,7 +133,7 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -140,7 +133,7 @@ class AutogradMeta : public AbstractAutogradMeta {
void SetPersistable(bool persistable) { persistable_ = persistable; } void SetPersistable(bool persistable) { persistable_ = persistable; }
bool RetainGrads() { return retain_grads_; } bool RetainGrads() const { return retain_grads_; }
void SetRetainGrads(bool value) { retain_grads_ = value; } void SetRetainGrads(bool value) { retain_grads_ = value; }
...@@ -156,7 +149,7 @@ class AutogradMeta : public AbstractAutogradMeta { ...@@ -156,7 +149,7 @@ class AutogradMeta : public AbstractAutogradMeta {
/** /**
* Why we need slot id here? * Why we need slot id here?
* Because in paddle most of our operators inputs and outputs * Because in paddle most of operators, inputs and outputs
* are assemble in form of {"slot name", vector<tensor>}. * are assemble in form of {"slot name", vector<tensor>}.
* So its better for us to set a slot id to fit this format. **/ * So its better for us to set a slot id to fit this format. **/
size_t out_slot_id_; size_t out_slot_id_;
......
...@@ -111,11 +111,10 @@ void EmptyStringTensorInitializer(TensorObject* self, const std::string& name, ...@@ -111,11 +111,10 @@ void EmptyStringTensorInitializer(TensorObject* self, const std::string& name,
// Note(zhoushunjie): Only support CPUPlace when create StringTensor // Note(zhoushunjie): Only support CPUPlace when create StringTensor
auto actual_place = platform::CPUPlace(); auto actual_place = platform::CPUPlace();
// Allocate memory // Allocate memory
const auto string_allocator = paddle::experimental::DefaultAllocator string_allocator(actual_place);
std::make_unique<paddle::experimental::DefaultAllocator>(actual_place);
const auto alloc = string_allocator.get();
std::shared_ptr<phi::StringTensor> string_tensor = std::shared_ptr<phi::StringTensor> string_tensor =
std::make_shared<phi::StringTensor>(alloc, phi::StringTensorMeta{ddims}); std::make_shared<phi::StringTensor>(&string_allocator,
phi::StringTensorMeta{ddims});
if (phi::product(ddims) > 0) { if (phi::product(ddims) > 0) {
string_tensor->mutable_data(actual_place); string_tensor->mutable_data(actual_place);
} }
...@@ -184,8 +183,7 @@ void InitTensorWithTensor(TensorObject* self, ...@@ -184,8 +183,7 @@ void InitTensorWithTensor(TensorObject* self,
const std::string& name) { const std::string& name) {
self->tensor.set_name(name); self->tensor.set_name(name);
if (place == src.place()) { if (place == src.place()) {
auto impl = std::static_pointer_cast<phi::DenseTensor>(src.impl()); self->tensor.set_impl(src.impl());
self->tensor.set_impl(impl);
VLOG(4) << "Same place, do ShareDataWith"; VLOG(4) << "Same place, do ShareDataWith";
} else { } else {
self->tensor.set_impl(src.copy_to(place, true).impl()); self->tensor.set_impl(src.copy_to(place, true).impl());
......
...@@ -27,9 +27,7 @@ typedef struct { ...@@ -27,9 +27,7 @@ typedef struct {
} TensorObject; } TensorObject;
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD PyObject* container;
PyObject* container;
PyObject* non_differentiable; PyObject* non_differentiable;
PyObject* dirty_tensors; PyObject* dirty_tensors;
bool materialize_grads; bool materialize_grads;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册