提交 438bca9c 编写于 作者: M minqiyang

Implement Runtime Var Type Inference

test=develop
上级 ca392c7e
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
......@@ -80,6 +81,19 @@ class InferVarTypeContext {
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
inline std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
inline void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
......@@ -101,17 +115,11 @@ class InferVarTypeContext {
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
private:
protected:
const OpDesc* op_;
BlockDesc* block_;
};
// infer var type context for imperative mode
class RuntimeInferVarTypeContext : public InferVarTypeContext {
public:
RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {}
};
class VarTypeInference {
public:
virtual ~VarTypeInference() {}
......
......@@ -220,7 +220,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs;
std::vector<VarBasePtrMap> tmp_grad_outputs;
if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad";
tmp_grad_outputs.resize(1);
......@@ -246,23 +246,59 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.emplace_back(tmp_var);
VarBase* tmp_var_base =
new VarBase(it.second[i]->Name(), tmp_var, nullptr, true);
outputs.emplace_back(tmp_var_base);
}
}
// Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
// auto& info =
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
// if (info.infer_var_type_) {
// framework::RuntimeInferVarTypeContext infer_var_type_ctx(
// this, &grad_inputs, &outputs, &attrs_map);
// info.infer_var_type_(infer_var_type_ctx);
// }
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// Run grad op
framework::VariableValueMap grad_invars_map;
framework::VariableValueMap grad_outvars_map;
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
grad_invars.emplace_back(grad_inp->var_);
}
}
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
grad_outvars.emplace_back(grad_out->var_);
}
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
......@@ -279,8 +315,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]->var_;
AddTo(grad, orig_grad, place_);
delete grad;
}
......@@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<Variable*> PyLayer::Apply(int func_id,
const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) {
invars.push_back(in->var_);
}
std::vector<framework::Variable*> PyLayer::Apply(
int func_id, const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], invars);
return CallPythonFunc(py_funcs_[func_id], inputs);
}
std::vector<Variable*> PyLayer::ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs) {
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], inputs);
auto rets = CallPythonFunc(py_funcs_[func_id], inputs);
std::vector<VarBase*> outs;
outs.reserve(rets.size());
for (size_t i = 0U; i != rets.size(); ++i) {
outs.emplace_back(new VarBase(
string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
i),
rets[i], nullptr, true));
}
return outs;
}
std::vector<framework::Variable*> PyLayer::CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins) {
const py::object& callable, const std::vector<VarBase*>& ins) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
const framework::LoDTensor& t = ins[i]->Get<framework::LoDTensor>();
const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
}
VLOG(3) << "pyfunc in " << py::len(in_args);
......@@ -359,6 +402,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
std::vector<framework::Variable*> outs;
outs.reserve(ret_num);
VLOG(3) << "pyfunc out " << ret_num;
for (size_t i = 0; i < ret_num; ++i) {
try {
......@@ -369,7 +413,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod());
outs.push_back(var);
outs.emplace_back(var);
} catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
......
......@@ -18,14 +18,16 @@
#include "paddle/fluid/framework/python_headers.h"
// clang-format on
#include <map> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <map> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -184,6 +186,10 @@ class VarBase {
}
}
inline void SetDType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place_, dtype_);
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline void SetStopGradient(bool stop_gradient) {
......@@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
// Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_;
std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_;
std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_;
};
......@@ -359,12 +365,130 @@ class PyLayer {
static std::vector<framework::Variable*> Apply(
int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs);
static std::vector<VarBase*> ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs);
private:
static std::vector<framework::Variable*> CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins);
const py::object& callable, const std::vector<VarBase*>& ins);
};
// infer var type context for imperative mode
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(imperative::OpBase* op,
const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
op_(op),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
}
framework::Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
inline bool HasVar(const std::string& name) const {
return var_set_.count(name) > 0;
}
inline bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
return input_names_.at(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
return output_names_.at(name);
}
inline framework::proto::VarType::Type GetType(
const std::string& name) const {
return var_set_.at(name)->DType();
}
inline void SetType(const std::string& name,
framework::proto::VarType::Type type) {
var_set_[name]->SetDType(type);
}
inline framework::proto::VarType::Type GetDataType(
const std::string& name) const {
return var_set_.at(name)->DType();
}
inline void SetDataType(const std::string& name,
framework::proto::VarType::Type type) {
var_set_[name]->SetDType(type);
}
inline std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
inline void SetDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline int32_t GetLoDLevel(const std::string& name) const {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
imperative::OpBase* op_;
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
};
} // namespace imperative
......
......@@ -19,6 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
VarBasePtrMap& outputs,
framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient) {
......@@ -228,6 +229,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(op, &inputs, &outputs,
&attrs_map);
info.infer_var_type_(infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
VLOG(3) << "tracer running " << op->Type();
......@@ -278,12 +285,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second->var_);
grad_in_vars.emplace_back(fwd_var_it->second);
} else {
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
// Douts.
grad_in_vars.emplace_back(var->grads_->var_);
grad_in_vars.emplace_back(var->grads_);
}
vars_saved_for_backward.insert(it.first);
......@@ -300,7 +307,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->Type());
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_);
grad_out_vars.push_back(var->grads_);
}
}
}
......@@ -342,23 +349,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& grad_output_vars =
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_);
for (VarBase* inp : inputs) {
grad_input_vars.push_back(inp);
}
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_);
grad_input_vars.push_back(out);
}
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place;
for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_);
grad_input_vars.push_back(out->grads_);
}
for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_);
grad_output_vars.push_back(inp->grads_);
}
}
return outputs;
......
......@@ -25,6 +25,7 @@ class VarBase;
class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
} // namespace imperative
......
......@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kX);
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
void operator()(framework::InferVarTypeContext &ctx) const override {
auto p_names = ctx.Input(kX);
auto pg_ig_names = ctx.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
if (ctx.HasVar(pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType();
g_var->SetType(p_var.GetType());
g_var->SetDataType(p_var.GetDataType());
<< " type: " << ctx.GetType(p_names[i]);
ctx.SetType(pg_ig_names[i], ctx.GetType(p_names[i]));
ctx.SetDataType(pg_ig_names[i], ctx.GetDataType(p_names[i]));
}
}
}
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle {
namespace operators {
......
......@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
void operator()(framework::InferVarTypeContext &ctx) const override {
auto out_var_name = ctx.Output("Communicator").front();
ctx.SetType(out_var_name, framework::proto::VarType::RAW);
}
};
......
......@@ -14,8 +14,11 @@
#include "paddle/fluid/operators/py_func_op.h"
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class CustomReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::InferVarTypeContext& ctx) const override {
void operator()(framework::InferVarTypeContext& ctx) const override {
auto& out_var_name = ctx.Output("Out")[0];
PADDLE_ENFORCE(ctx.HasVar(out_var_name));
ctx.SetType(out_var_name, framework::proto::VarType::READER);
......
......@@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase {
class ReadInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::InferVarTypeContext& ctx) const override {
void operator()(framework::InferVarTypeContext& ctx) const override {
bool infer_out = boost::get<bool>(ctx.GetAttr("infer_out"));
if (infer_out) {
std::string reader_name = ctx.Input("Reader")[0];
......
......@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
}
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarType::READER);
void FileReaderInferVarType::operator()(
framework::InferVarTypeContext& ctx) const {
std::string reader_name = ctx.Output("Out")[0];
ctx.SetType(reader_name, framework::proto::VarType::READER);
}
void DecoratedReaderInferShape::operator()(
......@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
framework::InferVarTypeContext& ctx) const {
const std::string& in_reader_name = ctx.Input("UnderlyingReader")[0];
const std::string& out_reader_name = ctx.Output("Out")[0];
ctx.SetType(out_reader_name, framework::proto::VarType::READER);
ctx.SetDataTypes(out_reader_name, ctx.GetDataTypes(in_reader_name));
}
void DecoratedReaderMakerBase::Make() {
......
......@@ -14,7 +14,9 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/detail/safe_ref.h"
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace paddle {
namespace operators {
......
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
......
......@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
imperative::VarBasePtrMap& outputs,
framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
......@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
imperative::VarBasePtrMap& outputs,
framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册