提交 5917e09c 编写于 作者: Q qiaolongfei

tmp work

上级 ab9545aa
......@@ -55,6 +55,10 @@ class OpRegistry {
const std::string& grad_op_type) {
OperatorRegistrar<OpType, ProtoMakerType> reg(op_type.c_str());
reg.info.grad_op_type_ = grad_op_type;
auto proto = reg.info.Proto();
std::cout << "====== " << op_type << " =======" << std::endl;
std::cout << proto.SerializeAsString() << std::endl;
std::cout << "=============" << std::endl;
ShapeInferenceMap::Instance().CreateOpWithKernel(reg.info, op_type);
// register gradient op
if (!grad_op_type.empty()) {
......
......@@ -598,9 +598,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
......
......@@ -37,10 +37,13 @@ ShapeInferenceMap& ShapeInferenceMap::Instance() {
void ShapeInferenceMap::CreateOpWithKernel(const OpInfo& op_info,
const std::string& op_type) {
const VariableNameMap inputs =
ConvertOpProtoVarsToVarNameMap(op_info.Proto().inputs());
auto proto = op_info.Proto();
std::cout << "========= " << op_type << " in======" << std::endl;
std::cout << proto.SerializeAsString() << std::endl;
std::cout << "========= " << op_type << " out======" << std::endl;
const VariableNameMap inputs = ConvertOpProtoVarsToVarNameMap(proto.inputs());
const VariableNameMap outputs =
ConvertOpProtoVarsToVarNameMap(op_info.Proto().outputs());
ConvertOpProtoVarsToVarNameMap(proto.outputs());
auto* op = op_info.Creator()(op_type, inputs, outputs, {});
auto* op_with_kernel = dynamic_cast<OperatorWithKernel*>(op);
auto it = op_shape_inference_map_.find(op_type);
......
......@@ -27,14 +27,6 @@ class ShapeInferenceMap {
public:
static ShapeInferenceMap& Instance();
const OperatorBase* GetOperator(const std::string& op_type) {
auto it = op_shape_inference_map_.find(op_type);
if (it == op_shape_inference_map_.end()) {
PADDLE_THROW("op with kernel for Op(%s) is not registered", op_type);
}
return it->second;
}
void CreateOpWithKernel(const OpInfo& op_info, const std::string& op_type);
OperatorWithKernel* GetOpWithKernel(const std::string& op_type) {
......
......@@ -223,6 +223,15 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def("infer_shape",
[](const OpDescBind &op_desc, BlockDescBind &block) {
auto &shape_inference_map = ShapeInferenceMap::Instance();
auto *op = shape_inference_map.GetOpWithKernel(op_desc.Type());
if (op != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op->InferShape(&ctx);
}
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
# prepare input/output
x1 = block.new_var("x1")
x1.set_shape([10, 20])
x2 = block.new_var("x2")
x2.set_shape([10, 20])
out = block.new_var("out")
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("sum")
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
sum_op = Operator("sum", X=["x1", "x2"], Out="out")
sum_op.infer_shape(sum_op_desc, block)
print(out.shape())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册