提交 be8bef9b 编写于 作者: C caoying03

Merge branch 'develop' into add_config_helper_for_resize_layer

# Design Doc: Gradient Operators Registration
## The Problem Posed
In our current operator registration mechanism, for each operator, the programmer should register a *gradient operator creator* function, which takes a C++ operator instance, and returns the corresponding gradient instance.
However, as we decided to separate the *compilation* and *execution* of DL models, we need to reshape the creator to take a protobuf `OpDesc` message, and returns a corresponding message.
More than that, the new registration mechanism need to support the fact that an operators' gradient computation might be a composition of operators.
## Current Implementation
OpInfos store in a association map which key is the operator type. The `grad_op_type` indicate associated gradient operator type. Operator can create gradient operator by `OpInfo::creator_` of gradient. The pseudo code is
```cpp
struct OpInfo {
std::function<OperatorBase*(...)> creator_;
std::string grad_op_type_;
...
};
map<string, OpInfo> OpInfoMap;
OperatorBase* CreateGradientOperator(const OperatorBase& op) {
return OpInfoMap.at(op.Type()).creator_(...);
}
```
## Proposed Solution
The mapping relationship between an operator and its gradient operators is a function. The interface of that function is:
```cpp
// (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
```
The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.
The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be
```cpp
struct OpInfo {
GradOpDescMaker grad_op_maker_;
...
};
```
The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
The user interface should be
```cpp
vector<OpDesc> MinusOpGradMaker(OpDesc) {...}
REGISTER_OPERATOR(minus, MinusOp, MinusOpProtoAndCheckerMaker, SumOpGradMaker);
// Developers can still manually implement gradient operator.
REGISTER_OPERATOR(minus_grad, MinusGradOp);
```
The interface of current `REGISTER_OP` macro could not be changed. In `REGISTER_OP`, it will invoke `REGISTER_OPERATOR` two times and generate GradOpDescMaker inside.
```cpp
REGISTER_OP(minus, MinusOp, MinusOpProtoAndCheckerMaker, minus_grad, MinusGradOp);
```
......@@ -22,14 +22,14 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)
py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
namespace details {
enum OpInfoFillType {
kOperator = 0,
kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2
};
template <typename T>
struct OpInfoFillTypeID {
static constexpr OpInfoFillType ID() {
return std::is_base_of<OperatorBase, T>::value
? kOperator
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
? kOpProtoAndCheckerMaker
: (std::is_base_of<GradOpDescMakerBase, T>::value
? kGradOpDescMaker
: static_cast<OpInfoFillType>(-1)));
}
};
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
struct OpInfoFiller;
template <size_t I, bool at_end, typename... ARGS>
class OperatorRegistrarRecursive;
template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, false, ARGS...> {
public:
using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
OpInfoFiller<T> fill;
fill(op_type, info);
constexpr auto size = sizeof...(ARGS);
OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
info);
(void)(reg);
}
};
template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, true, ARGS...> {
public:
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
};
template <typename T>
struct OpInfoFiller<T, kOperator> {
void operator()(const char* op_type, OpInfo* info) const {
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
return new T(type, inputs, outputs, attrs);
};
}
};
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
info->proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, info->proto_->InitializationErrorString());
}
};
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T();
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -3,7 +3,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(add);
USE_OP(sum);
namespace paddle {
namespace framework {
......@@ -41,17 +41,24 @@ namespace f = paddle::framework;
TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
"add", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
"sum", {{"X", {"x", "y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(grad_add_op->Inputs().size(), 4UL);
EXPECT_EQ(grad_add_op->Outputs().size(), 2UL);
EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Inputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Outputs().size(), 1UL);
EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x"));
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
auto &outputs = grad_add_op->Outputs(f::GradVarName("X"));
EXPECT_EQ(2UL, outputs.size());
auto in_output = [&outputs](const std::string &name) {
for (auto &output_name : outputs) {
if (output_name == name) return true;
}
return false;
};
EXPECT_TRUE(in_output(f::GradVarName("x")));
EXPECT_TRUE(in_output(f::GradVarName("y")));
}
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
......
......@@ -17,8 +17,8 @@
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
namespace paddle {
namespace framework {
......@@ -29,11 +29,18 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo {
OpCreator creator_;
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
GradOpDescMakerBase* grad_op_maker_{nullptr};
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
......
......@@ -21,49 +21,42 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle {
namespace framework {
template <typename... ARGS>
struct OperatorRegistrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass");
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
}
~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); }
const char* op_type;
OpInfo info;
};
class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
OpInfo op_info;
op_info.creator_ = [](
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
OpInfoMap::Instance().Insert(op_type, op_info);
OperatorRegistrar<OpType, ProtoMakerType> reg(op_type.c_str());
reg.info.grad_op_type_ = grad_op_type;
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
OperatorRegistrar<GradOpType> grad_reg(grad_op_type.c_str());
}
}
......
......@@ -173,3 +173,14 @@ TEST(OpRegistry, CustomChecker) {
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
class CosineOpComplete : public paddle::framework::CosineOp {
public:
DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};
TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
}
\ No newline at end of file
......@@ -245,5 +245,12 @@ std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
return res;
}
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key) {
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
<< "]";
return os;
}
} // namespace framework
} // namespace paddle
......@@ -478,9 +478,25 @@ class OperatorWithKernel : public OperatorBase {
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW("op[%s] has no kernel", type_);
}
// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
auto kernel_iter = kernels.find(kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
kernel_key);
}
kernel_iter->second->Compute(ctx);
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......@@ -529,5 +545,8 @@ class OperatorWithKernel : public OperatorBase {
}
};
std::ostream& operator<<(std::ostream& os,
const OperatorWithKernel::OpKernelKey& kernel_key);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/add_op.h"
namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AddOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Two input of Add Op's dimension must be same.");
ctx->SetOutputDim("Out", x_dims);
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Two Element Add Operator.
The equation is: Out = X + Y
)DOC");
}
};
class AddOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(add, ops::AddOp, ops::AddOpMaker, add_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add, ops::AddKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(add, ops::AddKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto X = EigenVector<T>::Flatten(*input0);
auto Y = EigenVector<T>::Flatten(*input1);
auto Z = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Z.device(place) = X + Y;
}
};
} // namespace operators
} // namespace paddle
......@@ -43,8 +43,10 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddComment(R"DOC(
Sum the input tensors.
......
import unittest
import numpy as np
from op_test import OpTest
class TestAddOp(OpTest):
def setUp(self):
self.op_type = "add"
self.inputs = {
'X': np.random.random((102, 105)).astype("float32"),
'Y': np.random.random((102, 105)).astype("float32")
}
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,7 @@ class PySimpleCond(object):
for i in range(1, 10, 2):
array[i] = 0
self.cond = np.array(array)
self.x = np.ones(shape=(10, 1))
self.x = np.ones(shape=(10, 1)).astype("float32")
def forward(self):
self.index_t = np.where(self.cond == 1)
......
import unittest
import numpy as np
import paddle.v2.framework.core as core
from op_test import get_numeric_gradient
from op_test import create_op
class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
x = np.random.random((10, 1)).astype("float32")
y = np.random.random((10, 1)).astype("float32")
z = x + y
scope = core.Scope()
add_op = create_op(scope, "add", {'X': x, 'Y': y}, {'Out': z}, dict())
arr = get_numeric_gradient(scope, add_op, {'X': x,
'Y': y}, 'X', ['Out'])
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = np.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
X = np.random.random((2, 2)).astype("float32")
Y = np.apply_along_axis(stable_softmax, 1, X)
dY = np.ones(Y.shape)
dX = label_softmax_grad(Y, dY)
scope = core.Scope()
softmax_op = create_op(scope, "softmax", {"X": X}, {"Y": Y}, dict())
arr = get_numeric_gradient(scope, softmax_op, {"X": X}, "X", "Y")
np.testing.assert_almost_equal(arr, dX, decimal=1e-2)
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,7 @@ def fc(X, W, Y):
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
op1 = Operator("add", X="X", Y="Y", Out="Out")
op1 = Operator("sum", X=["X", "Y"], Out="Out")
net.append_op(op1)
net2 = core.Net.create()
......@@ -26,7 +26,7 @@ class TestNet(unittest.TestCase):
expected = '''
Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}.
Op(add), inputs:{X[X], Y[Y]}, outputs:{Out[Out]}.
Op(sum), inputs:{X[X, Y]}, outputs:{Out[Out]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}.
......
......@@ -193,10 +193,10 @@ class TestOpDescCreationMethod(unittest.TestCase):
class TestOpCreations(unittest.TestCase):
def test_all(self):
add_op = op.Operator("add", X="a", Y="b", Out="z")
add_op = op.Operator("sum", X=["a", "b"], Out="z")
self.assertIsNotNone(add_op)
# Invoke C++ DebugString()
self.assertEqual('Op(add), inputs:{X[a], Y[b]}, outputs:{Out[z]}.',
self.assertEqual('Op(sum), inputs:{X[a, b]}, outputs:{Out[z]}.',
str(add_op))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册