自制深度学习框架--构建自己的计算图
PNNX
PyTorch Neural Network eXchange(PNNX)是PyTorch模型互操作性的开放标准.
PNNX为PyTorch提供了一种开源的模型格式,它定义了与PyTorch相匹配的数据流图和运算操作。
我们的框架在PNNX之上封装了一层更加易用和简单的计算图格式,
PyTorch训练好一个模型之后,然后模型需要转换到PNNX格式,然后PNNX格式我们再去读取,形成计算图。
PNNX的格式定义
Operator(操作符)
- Inputs: std::vector
,输入操作数 - Outputs: std::vector
,输出操作数 - Type: std::string,运算符的类型
- Name: std::string,运算符的名称
- Params: std::map,存放运算符的所有参数,例如卷积运算的stride, padding, kernel size
- Attrs: std::map,存放运算符所需的具体权重属性,例如卷积的权重w和偏移量b
Operand(操作数)
- Producer: operator,产生这个操作数的运算符,表示运算符的输出,只能有一个生产者
- Customer: operator,下一个操作需要该操作数作为输入的运算符,表示运算符的输入,可以有多个消费者
- Name: std::string,操作数的名称
- shape: std::vector
,操作数的维度
定义我们自己的Operator和Operand
我们给自己的神经网络推理框架定义了RuntimeOperator和RuntimeOperand
RuntimeOperand1
2
3
4
5
6struct RuntimeOperand {
std::string name; // 操作数的名称
std::vector<int32_t> shapes; // 操作数的形状
std::vector<std::shared_ptr<Tensor<float>>> datas; // 存储操作数
RuntimeDataType type = RuntimeDataType::kTypeUnknown; // 操作数的类型,一般是float
};
RuntimeOperator1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23struct RuntimeOperator {
int32_t meet_num = 0; // 计算节点被相连接节点访问到的次数
~RuntimeOperator() {
for (const auto ¶m : this->params) {
delete param.second;
}
}
std::string name; // 计算节点的名称
std::string type; // 计算节点的类型
std::shared_ptr<Layer> layer; // 计算节点对应的计算Layer
std::vector<std::string> output_names; // 节点的输出节点名称
std::shared_ptr<RuntimeOperand> output_operands; // 节点的输出操作数
std::map<std::string, std::shared_ptr<RuntimeOperand>> input_operands; // 节点的输入操作数
std::vector<std::shared_ptr<RuntimeOperand>> input_operands_seq; // 节点的输入操作数,顺序排列
std::map<std::string, std::shared_ptr<RuntimeOperator>> output_operators; // 输出节点的名字和节点对应
std::map<std::string, RuntimeParameter *> params; // 算子的参数信息
std::map<std::string, std::shared_ptr<RuntimeAttribute> > attribute; // 算子的属性信息,内含权重信息
};
同时定义了一些类型的状态1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59enum class RuntimeParameterType {
kParameterUnknown = 0,
kParameterBool = 1,
kParameterInt = 2,
kParameterFloat = 3,
kParameterString = 4,
kParameterIntArray = 5,
kParameterFloatArray = 6,
kParameterStringArray = 7,
};
enum class InferStatus {
kInferUnknown = -1,
kInferFailedInputEmpty = 1,
kInferFailedWeightParameterError = 2,
kInferFailedBiasParameterError = 3,
kInferFailedStrideParameterError = 4,
kInferFailedDimensionParameterError = 5,
kInferFailedChannelParameterError = 6,
kInferFailedInputOutSizeAdaptingError = 6,
kInferFailedOutputSizeError = 7,
kInferFailedOperationUnknown = 8,
kInferFailedYoloStageNumberError = 9,
kInferSuccess = 0,
};
enum class ParseParameterAttrStatus {
kParameterMissingUnknown = -1,
kParameterMissingStride = 1,
kParameterMissingPadding = 2,
kParameterMissingKernel = 3,
kParameterMissingUseBias = 4,
kParameterMissingInChannel = 5,
kParameterMissingOutChannel = 6,
kParameterMissingEps = 7,
kParameterMissingNumFeatures = 8,
kParameterMissingDim = 9,
kParameterMissingExpr = 10,
kParameterMissingOutHW = 11,
kParameterMissingShape = 12,
kParameterMissingGroups = 13,
kParameterMissingScale = 14,
kParameterMissingResizeMode = 15,
kAttrMissingBias = 21,
kAttrMissingWeight = 22,
kAttrMissingRunningMean = 23,
kAttrMissingRunningVar = 24,
kAttrMissingOutFeatures = 25,
kAttrMissingYoloStrides = 26,
kAttrMissingYoloAnchorGrides = 27,
kAttrMissingYoloGrides = 28,
kParameterAttrParseSuccess = 0
};
定义了一些参数1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57struct RuntimeParameter { // 计算节点中的参数信息 基类
virtual ~RuntimeParameter() = default;
explicit RuntimeParameter(RuntimeParameterType type = RuntimeParameterType::kParameterUnknown) : type(type) {
}
RuntimeParameterType type = RuntimeParameterType::kParameterUnknown;
};
struct RuntimeParameterInt : public RuntimeParameter {
RuntimeParameterInt() : RuntimeParameter(RuntimeParameterType::kParameterInt) {
}
int value = 0;
};
struct RuntimeParameterFloat : public RuntimeParameter {
RuntimeParameterFloat() : RuntimeParameter(RuntimeParameterType::kParameterFloat) {
}
float value = 0.f;
};
struct RuntimeParameterString : public RuntimeParameter {
RuntimeParameterString() : RuntimeParameter(RuntimeParameterType::kParameterString) {
}
std::string value;
};
struct RuntimeParameterIntArray : public RuntimeParameter {
RuntimeParameterIntArray() : RuntimeParameter(RuntimeParameterType::kParameterIntArray) {
}
std::vector<int> value;
};
struct RuntimeParameterFloatArray : public RuntimeParameter {
RuntimeParameterFloatArray() : RuntimeParameter(RuntimeParameterType::kParameterFloatArray) {
}
std::vector<float> value;
};
struct RuntimeParameterStringArray : public RuntimeParameter {
RuntimeParameterStringArray() : RuntimeParameter(RuntimeParameterType::kParameterStringArray) {
}
std::vector<std::string> value;
};
struct RuntimeParameterBool : public RuntimeParameter {
RuntimeParameterBool() : RuntimeParameter(RuntimeParameterType::kParameterBool) {
}
bool value = false;
};
最终,定义我们的计算图1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91class RuntimeGraph {
public:
/**
* 计算图的初始化
* @return 是否初始化成功
*/
bool Init();
/**
* 初始化计算图
* @param param_path 计算图的结构文件
* @param bin_path 计算图中的权重文件
*/
RuntimeGraph(std::string param_path, std::string bin_path);
/**
* 设置权重文件
* @param bin_path 权重文件路径
*/
void set_bin_path(const std::string &bin_path);
/**
* 设置结构文件
* @param param_path 结构文件路径
*/
void set_param_path(const std::string ¶m_path);
/**
* 返回结构文件
* @return 返回结构文件
*/
const std::string ¶m_path() const;
/**
* 返回权重文件
* @return 返回权重文件
*/
const std::string &bin_path() const;
const std::vector<std::shared_ptr<RuntimeOperator>> operators() const;
private:
/**
* 初始化kuiper infer计算图节点中的输入操作数
* @param inputs pnnx中的输入操作数
* @param runtime_operator 计算图节点
*/
static void InitInputOperators(const std::vector<pnnx::Operand *> &inputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图节点中的输出操作数
* @param outputs pnnx中的输出操作数
* @param runtime_operator 计算图节点
*/
static void InitOutputOperators(const std::vector<pnnx::Operand *> &outputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图中的节点属性
* @param attrs pnnx中的节点属性
* @param runtime_operator 计算图节点
*/
static void InitGraphAttrs(const std::map<std::string, pnnx::Attribute> &attrs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图中的节点参数
* @param params pnnx中的参数属性
* @param runtime_operator 计算图节点
*/
static void InitGraphParams(const std::map<std::string, pnnx::Parameter> ¶ms,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
private:
enum class GraphState {
NeedInit = -2,
NeedBuild = -1,
Complete = 0,
};
GraphState graph_state_ = GraphState::NeedInit;
std::string input_name_; // 计算图输入节点的名称
std::string output_name_; // 计算图输出节点的名称
std::string param_path_; // 计算图的结构文件
std::string bin_path_; // 计算图的权重文件
std::map<std::string, std::shared_ptr<RuntimeOperator>> input_operators_maps_; // 保存输入节点
std::map<std::string, std::shared_ptr<RuntimeOperator>> output_operators_maps_; // 保存输出节点
std::vector<std::shared_ptr<RuntimeOperator>> operators_; // 计算图的计算节点
std::unique_ptr<pnnx::Graph> graph_; // pnnx的graph
};
从PNNX计算图到KuiperInfer计算图的过程
加载PNNX的计算图
1
2this->graph_ = std::make_unique<pnnx::Graph>();
int load_result = this->graph_->load(param_path_, bin_path_);获取PNNX计算图中的运算符
1
std::vector<pnnx::Operator *> operators = this->graph_->ops;
遍历PNNX计算图中的运算符,构建我们的计算图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47// 根据const pnnx::Operator *op 去赋值std::shared_ptr<RuntimeOperator> runtime_operator
for (const pnnx::Operator *op : operators) {
if (!op) { // 空的计算节点
LOG(ERROR) << "Meet the empty node";
continue;
} else {
std::shared_ptr<RuntimeOperator> runtime_operator = std::make_shared<RuntimeOperator>();
// 初始化算子的名称
runtime_operator->name = op->name;
runtime_operator->type = op->type;
// 初始化算子中的input,对操作符号operator赋予runtimeoperand作为输入,输入是根据pnnx::operand来的
const std::vector<pnnx::Operand *> &inputs = op->inputs;
if (!inputs.empty()) {
InitInputOperators(inputs, runtime_operator);
}
// 记录输出operand中的名称
// 有一个pnnx::operator 来自与load_graph这个操作
// load_graph pnnx::operators数组 进行遍历 pnnx::operator
// 每一个遍历中operator,我们再初始化自己的kuiperinfer::RuntimeOperator
/// RuntimeOperator根据pnnx::operator赋予inputs和outputs
const std::vector<pnnx::Operand *> &outputs = op->outputs;
if (!outputs.empty()) {
InitOutputOperators(outputs, runtime_operator);
}
// 初始化算子中的attribute(权重)
//没一个pnnx::operator里面有一个权重,我们根据pnnx::Attr这个权重去初始化RuntimeAttr
/// 初始化RutimeAttr之后呢,存放在runtime_operator
const std::map<std::string, pnnx::Attribute> &attrs = op->attrs;
if (!attrs.empty()) {
InitGraphAttrs(attrs, runtime_operator);
}
// 初始化算子中的parameter
// 根据const pnnx::Operator *op 去赋值std::shared_ptr<RuntimeOperator> runtime_operator
// 先得到pnnx::parameter再根据这个去赋值RuntimeOperator中的RuntimeParameter
const std::map<std::string, pnnx::Parameter> ¶ms = op->params;
if (!params.empty()) {
InitGraphParams(params, runtime_operator);
}
// runtime_operator初始化玩成了,存放到一个vector中
this->operators_.push_back(runtime_operator);
}
}
初始化RuntimeOperator的输入
初始化RuntimeOperator中的RuntimeOperator.input_operands和RuntimeOperator.input_operands_seq两个属性。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32void RuntimeGraph::InitInputOperators(const std::vector<pnnx::Operand *> &inputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
// 遍历PNNX的操作数operands
for (const pnnx::Operand *input : inputs) {
if (!input) {
continue;
}
// 得到pnnx操作数对应的生产者
const pnnx::Operator *producer = input->producer;
// 初始化runtime_operand
std::shared_ptr<RuntimeOperand> runtime_operand = std::make_shared<RuntimeOperand>();
runtime_operand->name = producer->name; // 名称
runtime_operand->shapes = input->shape; // 形状
switch (input->type) { // 类型
case 1: {
runtime_operand->type = RuntimeDataType::kTypeFloat32;
break;
}
case 0: {
runtime_operand->type = RuntimeDataType::kTypeUnknown;
break;
}
default: {
LOG(FATAL) << "Unknown input operand type: " << input->type;
}
}
// runtime_operand放入到KuiperInfer的运算符中
runtime_operator->input_operands.insert({producer->name, runtime_operand});
runtime_operator->input_operands_seq.push_back(runtime_operand);
}
}
初始化RuntimeOperator中的输出
初始化RuntimeOperator.output_names属性1
2
3
4
5
6
7
8
9
10
11
12
13
14
15void RuntimeGraph::InitOutputOperators(const std::vector<pnnx::Operand *> &outputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
// 遍历pnnx操作数operands
for (const pnnx::Operand *output : outputs) {
if (!output) { // 空的操作数
continue;
}
// 得到pnnx操作数对应的消费者
const auto &consumers = output->consumers;
// 初始化RuntimeOperator.output_names属性
for (const auto &c : consumers) {
runtime_operator->output_names.push_back(c->name);
}
}
}
初始化RuntimeOperator的权重(Attr)属性1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23void RuntimeGraph::InitGraphAttrs(const std::map<std::string, pnnx::Attribute> &attrs,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
for (const auto &pair : attrs) {
const std::string &name = pair.first;
// 1.得到pnnx中的Attribute
const pnnx::Attribute &attr = pair.second;
switch (attr.type) {
case 1: {
// 2. 根据Pnnx的Attribute初始化KuiperInferOperator中的Attribute
std::shared_ptr<RuntimeAttribute> runtime_attribute = std::make_shared<RuntimeAttribute>();
runtime_attribute->type = RuntimeDataType::kTypeFloat32;
// 2.1 赋值权重weight(此处的data是std::vector<uchar>类型)
runtime_attribute->weight_data = attr.data;
runtime_attribute->shape = attr.shape;
runtime_operator->attribute.insert({name, runtime_attribute});
break;
}
default : {
LOG(FATAL) << "Unknown attribute type";
}
}
}
}
初始化RuntimeOperator的参数(Param)属性1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67void RuntimeGraph::InitGraphParams(const std::map<std::string, pnnx::Parameter> ¶ms,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
for (const auto &pair : params) {
const std::string &name = pair.first;
const pnnx::Parameter ¶meter = pair.second;
const int type = parameter.type;
// 根据PNNX的Parameter去初始化KuiperInfer::RuntimeOperator中的Parameter
switch (type) {
case int(RuntimeParameterType::kParameterUnknown): {
RuntimeParameter *runtime_parameter = new RuntimeParameter;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
// 在这应该使用派生类RuntimeParameterBool
case int(RuntimeParameterType::kParameterBool): {
RuntimeParameterBool *runtime_parameter = new RuntimeParameterBool;
runtime_parameter->value = parameter.b;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
// 在这应该使用派生类RuntimeParameterInt
case int(RuntimeParameterType::kParameterInt): {
RuntimeParameterInt *runtime_parameter = new RuntimeParameterInt;
runtime_parameter->value = parameter.i;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
case int(RuntimeParameterType::kParameterFloat): {
RuntimeParameterFloat *runtime_parameter = new RuntimeParameterFloat;
runtime_parameter->value = parameter.f;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
case int(RuntimeParameterType::kParameterString): {
RuntimeParameterString *runtime_parameter = new RuntimeParameterString;
runtime_parameter->value = parameter.s;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
case int(RuntimeParameterType::kParameterIntArray): {
RuntimeParameterIntArray *runtime_parameter = new RuntimeParameterIntArray;
runtime_parameter->value = parameter.ai;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
case int(RuntimeParameterType::kParameterFloatArray): {
RuntimeParameterFloatArray *runtime_parameter = new RuntimeParameterFloatArray;
runtime_parameter->value = parameter.af;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
case int(RuntimeParameterType::kParameterStringArray): {
RuntimeParameterStringArray *runtime_parameter = new RuntimeParameterStringArray;
runtime_parameter->value = parameter.as;
runtime_operator->params.insert({name, runtime_parameter});
break;
}
default: {
LOG(FATAL) << "Unknown parameter type";
}
}
}
}

