Tengine 推理引擎 | API源码解读

Tengine由OPEN AI LAB主导开发,该项目实现了深度学习神经网络模型在嵌入式设备上的快速、高效部署需求。为实现在众多AIoT应用中的跨平台部署,本项目使用C 语言进行核心模块开发,针对嵌入式设备资源有限的特点进行了深度框架裁剪。同时采用了完全分离的前后端设计,有利于 CPU、GPU、NPU 等异构计算单元的快速移植和部署,降低评估、迁移成本。

最近终于空出时间看了看tengine的源码,请教虫叔给了比较合适的源码学习路线

基本api -> 创建图 -> 预运行 -> 切图 -> 异构调度

我也不确定我能看懂多少,总之是花了几天时间先看了看基本api部分,可能是项目太庞大了,看的时候各个结构体直接的调用有点乱,于是做了下面的思维导图,主要是为了能够梳理基本api有关的结构体及其属性。内容上如果有疏漏,请大家多多指教。

这里分析了大多能用到的,包括实际使用的、调试可以用的一些api的实现。文章很长,有兴趣的可以慢慢阅读,还可以关注一波,后续还会继续分享tengine的源码学习。

Context 相关API

这里的context我理解是将graph、scheduler和device关联在一起的机制,方便在不同device上以不同的scheduler执行graph

create_context: 创建contextdestroy_context: 销毁contextget_ir_graph_context:返回指定graph的contextget_context_device:返回指定context的deviceadd_context_device:通过查找device名,将context->device设置为deviceset_context_device:给无device的context设置device和device_optionremove_context_device:按照device name删除context->device
// 创建context context_t create_context(const char* context_name, int empty_context) { // 分配上下文结构体的空间 struct context* context = (struct context*)sys_malloc(sizeof(struct context)); // 初始化context, 为context->name分配内存空间,将context_name拷贝到该内存空间中,将context的其他成员置NULL init_ir_context(context, context_name); // 给context的调度器赋予默认的scheduler context->scheduler = find_default_scheduler(); // 如果empty_context是0,让context->device设定为CPU if (0 == empty_context) { context->device = find_default_device(); } return context; } // 销毁指定的context void destroy_context(context_t context) { struct context* ctx = (struct context*)context; if (NULL == context) { return; } // 这里只释放掉了name、default_options、device_options, // 而scheduler和device没有在这里释放,理由是在切图之后,每个subgraph在不同的设备上,一个context没法释放所有的subgraph,所以scheduler和device的释放由graph去控制 if (NULL != ctx->name) { sys_free(ctx->name); } if (NULL != ctx->default_options) { sys_free(ctx->default_options); } if (NULL != ctx->device_options) { sys_free(ctx->device_options); } sys_free(ctx); } // 返回指定graph的context struct context* get_ir_graph_context(struct graph* ir_graph) { return ir_graph->attribute->context; } // 返回指定context的device struct device* get_context_device(context_t context, int index) { struct context* ctx = (struct context*)context; if (NULL == ctx) { TLOG_ERR("Tengine: Context pointer is null.\n"); return NULL; } // 这里我理解是在正常切图之后,每个subgraph对应自己的context,context只对应一个设备 if (NULL != ctx->device && 0 == index) { // 【核心代码】,返回ctx->device return ctx->device; } return NULL; } // 通过查找device名,将context->device设置为device int add_context_device(context_t context, const char* dev_name) { struct context* ctx = (struct context*)context; if (NULL == ctx) { TLOG_ERR("Tengine: Context pointer is null.\n"); return -1; } if (NULL != ctx->device) { TLOG_ERR("Tengine: Context(%s) is not multi-device collaborative.\n", ctx->name); return -1; } // 【核心代码】 struct device* selected_device = find_device_via_name(dev_name); if (NULL == selected_device) { TLOG_ERR("Tengine: Device(%s) is not found(may not registered).\n", dev_name); return -1; } // 【核心代码】 ctx->device = selected_device; return 0; } // 给无device的context设置device和device_option int set_context_device(context_t context, const char* dev_name, const void* dev_option, size_t dev_opt_size) { struct context* ctx = (struct context*)context; // 如果ctx指针为空,说明context指针也为空,context没有创建成功 if (NULL == ctx) { TLOG_ERR("Tengine: Context pointer is null.\n"); return -1; } // 如果ctx已经设置了设备,返回 if (NULL != ctx->device) { TLOG_ERR("Tengine: A device(%s) has been set for this context(%s).\n", ctx->device->name, ctx->name); return -1; } // 【核心代码】 通过设备名找到设备 struct device* selected_device = find_device_via_name(dev_name); if (NULL == selected_device) { TLOG_ERR("Tengine: Device(%s) is not found(may not registered).\n", dev_name); return -1; } // 【核心代码】 ctx的设备指向selected_device ctx->device = selected_device; // 如果设置了dev_option 在ctx中为device_options分配空间,将dev_option拷贝到ctx->device_options中 if (NULL != dev_option) { ctx->device_options = sys_malloc(dev_opt_size); memcpy(ctx->device_options, dev_option, dev_opt_size); } return 0; } // 按照device name删除context->device int remove_context_device(context_t context, const char* dev_name) { struct context* ctx = (struct context*)context; if (NULL == ctx) { TLOG_ERR("Tengine: Context pointer is null.\n"); return -1; } if (NULL == dev_name) { TLOG_ERR("Tengine: Device name is null.\n"); return 0; } if (NULL == ctx->device) { TLOG_ERR("Tengine: Context(%s) does not has any device.\n", ctx->name, dev_name); return -1; } // ???迷惑操作 已提交PR,第二个参数应该是dev_name if (0 == strcmp(ctx->device->name, ctx->device->name)) { // 【核心代码】 对ctx的device置空相当于删除device ctx->device = NULL; return 0; } TLOG_ERR("Tengine: Context(%s) does not has a device named %s.\n", ctx->name, dev_name); return -1; }

Engine 相关

get_tengine_version: 返回tengine版本init_tengine:初始化tengine的解析器、各种算子和CPU设备,只能初始化一次release_tengine:释放tengine
// 返回tengine版本 const char* get_tengine_version(void) { static char buf[128]; snprintf(buf, 128, "%s-%s", tengine_lite_version, ver_postfix); buf[127] = 0x0; /* save moving */ return buf; } // 初始化tengine,只能初始化一次 int init_tengine(void) { if (0 != init_flag) // 如果已经初始化了,直接返回 { return 0; } // 注册所有的算子 int ret = register_all_op_prototype(); if (0 != ret) { TLOG_ERR("Tengine: Register operator failed: %d\n", ret); return ret; } // 注册解析器器 和 所有的解析器算子 ret = register_all_serializer(); if (0 != ret) { TLOG_ERR("Tengine: Register serializer failed: %d\n", ret); return ret; } // 注册cpu设备,虽然函数名是all devices但实际调用只注册了cpu ret = register_all_devices(); if (0 != ret) { TLOG_ERR("Tengine: Register neural network devices failed: %d\n", ret); return ret; } // 全局init_flag,调用了一次后,通过该变量判断是否调用过,只能调用一次 init_flag++; return ret; } // 释放tengine void release_tengine(void) { // 如果tengine没有初始化过,也就不用释放 if (0 == init_flag) { return; } int ret = unregister_all_op_prototype(); // 注销掉所有的算子 if (0 != ret) { TLOG_ERR("Tengine: Unregister operator failed: %d\n", ret); } // 释放掉算子的注册器:(遍历寄存器vector)当内部算子寄存器vector仍有元素时,将寄存器vector的0号元素的算子注销,。 // 当没有元素后,释放掉寄存器vector,寄存器置NULL ret = release_op_registry(); if (0 != ret) { TLOG_ERR("Tengine: Release operator prototype registry failed: %d\n", ret); } // 注销掉所有算子加载器、注销掉算子解析器 ret = unregister_all_serializer(); if (0 != ret) { TLOG_ERR("Tengine: Unregister serializer failed: %d\n", ret); } // 释放解析器寄存器,(遍历寄存器vector)当内部解析器寄存器vector仍有元素时,将寄存器vector的0号元素解析器注销, // 没有元素后,释放掉寄存器vector,寄存器值NULL ret = release_serializer_registry(); if (0 != ret) { TLOG_ERR("Tengine: Release serializer registry failed: %d\n", ret); } // 注销掉cpu设备 ret = unregister_all_devices(); if (0 != ret) { TLOG_ERR("Tengine: Unregister neural network devices failed: %d\n", ret); } // 注销后,重置为没有初始化的状态 init_flag = 0; }

Graph 相关

create_graph:利用context、模型格式、模型文件名创建graphprerun_graph:计算图中每个节点的shape,使用device切图、优化图,配置cpu使用掩码、亲和性、计算精度prerun_graph_multithread:计算图中每个节点的shape,使用device切图、优化图,配置cpu使用掩码、亲和性、计算精度run_graph:prerun之后,执行计算图postrun_graph:释放graph资源set_graph_layout:设置graph的layout_type属性,设置layout是NCHW还是NHWC的destroy_graph:销毁graph
// 利用context、模型格式、模型文件名创建graph graph_t create_graph(context_t context, const char* model_format, const char* file_name, ...) { int is_new_context = 0; if (context == NULL) { // 如果context为空,则创建context,context的name为NULL,device为空,使用默认scheduler context = create_context(NULL, 1); is_new_context = 1; } // 创建一个全空的graph,graph的上下文设置为context ir_graph_t* ir_graph = create_ir_graph((struct context*)context); if (ir_graph == NULL) { if (is_new_context) { destroy_context(context); } return NULL; } ir_graph->attribute->private_context = is_new_context; // 如果设置了model_format, Example都设置为"tengine" if (NULL != model_format) { int ret = 0; // 通过模型格式tengine找到解析器 struct serializer* loader = find_serializer_via_name(model_format); if (loader == NULL) { TLOG_ERR("Tengine: No matched serializer(name: %s) found.\n", model_format); return create_graph_error; } va_list ap; va_start(ap, file_name); // p是model_format字符串在:之后的地址 const char* p = strchr(model_format, :); // load from file or memory // 一般我们自己指定tmfile文件 都是执行if流程,暂时没太搞清楚什么时候会执行else if (NULL == p) { // 解析器从模型文件加载加载模型,计算图结构转换至graph中 ret = loader->load_model(loader, ir_graph, file_name, ap); } else { if (p[1] != m) { TLOG_ERR("Tengine: Invalid postfix(%s) for model format: should m only.\n", p); return create_graph_error(ir_graph); } if (NULL == loader->load_mem) { TLOG_ERR("Tengine: Serializer(%s) does not support loading from memory.\n", loader->get_name(loader)); return create_graph_error(ir_graph); } int size = va_arg(ap, int); ret = loader->load_mem(loader, ir_graph, (void*)file_name, size, ap); } va_end(ap); if (0 != ret) { return create_graph_error(ir_graph); } // graph的设备指向默认设备CPU, // 从这可以看出graph的device和context的device可能会不一样, // 原因是在切图后,每个子图对应的device都可能不一致,到执行图的时候,是从context中获取device去执行,而不是之间从graph的device来执行 ir_graph->device = find_default_device(); } return ir_graph; } // 计算图中每个节点的shape,使用device切图、优化图,配置cpu使用掩码、亲和性、计算精度 int prerun_graph(graph_t graph) { struct options option; option.num_thread = 1; option.precision = -1; option.affinity = -1; option.cluster = TENGINE_CLUSTER_BIG; return prerun_graph_multithread(graph, option); } // 计算图中每个节点的shape,使用device切图、优化图,配置cpu使用掩码、亲和性、计算精度 int prerun_graph_multithread(graph_t graph, struct options option) { struct graph* ir_graph = (struct graph*)graph; // 计算图中每个节点的shape int ret = infer_ir_graph_shape(ir_graph); if (0 != ret) { ir_graph->status = GRAPH_STAT_ERROR; fprintf(stderr, "Tengine: Infer shape of graph failed(%d).\n", ret); return -1; } // 获取graph的context struct context* ctx = get_ir_graph_context(ir_graph); // 获取ctx的device struct device* dev = ctx->device; if (NULL == dev) // 如果ctx为空,默认为cpu { dev = find_default_device(); } if (NULL != dev && NULL != dev->optimizer) { if (NULL != dev->optimizer->split_graph) { // device的optimizer切图 if (0 != dev->optimizer->split_graph(ir_graph)) { ir_graph->status = GRAPH_STAT_ERROR; fprintf(stderr, "Tengine: Split graph via device(%s) failed.\n", dev->name); return -1; } } if (NULL != dev->optimizer->optimize_graph) { // device的optimizer优化图 ret = dev->optimizer->optimize_graph(ir_graph, -1); if (0 != ret) { ir_graph->status = GRAPH_STAT_ERROR; fprintf(stderr, "Tengine: Optimize graph via device(%s) failed.\n", dev->name); return -1; } } } check_cpu(); // 设置可用的CPU核心 size_t mask = get_cpu_cluster_mask(TENGINE_CLUSTER_BIG); if (0 <= option.cluster) { mask = get_cpu_cluster_mask(option.cluster); } int count = get_cpu_mask_count(mask); if (0 < option.num_thread && count > option.num_thread) { count = option.num_thread; } // 设置精度 int precision = TENGINE_MODE_FP32; if (0 <= option.precision && (TENGINE_MODE_FP32 == option.precision || TENGINE_MODE_FP16 == option.precision || TENGINE_MODE_HYBRID_INT8 == option.precision || TENGINE_MODE_UINT8 == option.precision || TENGINE_MODE_INT8 == option.precision)) { precision = option.precision; } ctx->default_options = sys_malloc(sizeof(struct cpu_option)); // cpu_option 指向 ctx的默认option struct cpu_option* opt = (struct cpu_option*)ctx->default_options; opt->dev_name = CPU_DEVICE_NAME; opt->num_thread = count; opt->cluster = TENGINE_CLUSTER_BIG; opt->precision = precision; opt->affinity = option.affinity; // 调度器指向ctx的scheduler struct scheduler* scheduler = ctx->scheduler; // scheduer 准备计算图 ret = scheduler->prerun(scheduler, ir_graph); if (0 != ret) { ir_graph->status = GRAPH_STAT_ERROR; fprintf(stderr, "Tengine: Scheduler(%s) prerun failed.\n", scheduler->name); return ret; } // 修改graph的状态为准备就绪 ir_graph->status = GRAPH_STAT_READY; // 设置cpu亲和性 if (0 != opt->affinity && 0 != (opt->affinity & mask)) { set_cpu_affine(opt->affinity); } // cpu亲和性 默认设置为mask else { set_cpu_affine(mask); } /* dump graph */ const char* env = getenv(TENGINE_DUMP_GRAPH); if (env && env[0] == 1) { set_log_level(LOG_INFO); dump_ir_graph(ir_graph); } return 0; } // prerun之后,执行计算图 int run_graph(graph_t graph, int block) { struct graph* ir_graph = (struct graph*)graph; // 获取graph的context struct context* context = get_ir_graph_context(ir_graph); // 获取graph中context的scheduler struct scheduler* scheduler = context->scheduler; // 更新graph的状态为running ir_graph->status = GRAPH_STAT_RUNNING; // scheduler执行graph if (scheduler->run(scheduler, ir_graph, block) < 0) { ir_graph->status = GRAPH_STAT_ERROR; return -1; } else { if (block) // 如果阻塞模式,更新graph的状态为准备就绪 ir_graph->status = GRAPH_STAT_READY; } return 0; } // 释放graph资源 int postrun_graph(graph_t graph) { struct graph* ir_graph = (struct graph*)graph; struct context* context = get_ir_graph_context(ir_graph); // 获取graph中的context的scheduler struct scheduler* scheduler = context->scheduler; // scheduler释放graph资源 if (scheduler->postrun(scheduler, ir_graph) < 0) { ir_graph->status = GRAPH_STAT_ERROR; return -1; } // 释放graph的设备 if (NULL != ir_graph->attribute->device_privacy) { release_vector((vector_t*)ir_graph->attribute->device_privacy); } // 更新graph状态为Done ir_graph->status = GRAPH_STAT_DONE; return 0; } // 设置graph的layout_type属性,设置layout是NCHW还是NHWC的 // 并不对数据进行调整,只是设置graph的属性 int set_graph_layout(graph_t graph, int layout_type) { struct graph* ir_graph = (struct graph*)graph; if ((layout_type != TENGINE_LAYOUT_NCHW) && (layout_type != TENGINE_LAYOUT_NHWC)) { return -1; } ir_graph->graph_layout = layout_type; return 0; } // 销毁graph int destroy_graph(graph_t graph) { struct graph* ir_graph = (struct graph*)graph; // 销毁graph的context, context的scheduler、device也被销毁 if (ir_graph->attribute->private_context) destroy_context(ir_graph->attribute->context); // 销毁graph destroy_ir_graph(ir_graph); return 0; }

Node 相关

set_graph_input_node:设置graph的输入节点,告诉graph哪些节点是输入节点set_graph_output_node:设置graph的输出节点,告诉graph哪些节点是输出节点,代码、解读参考set_graph_input_nodeget_graph_input_node_number:返回graph的输入节点数量get_graph_input_node:返回graph的输入节点的第idx个节点get_graph_output_node_number:返回graph的输出节点数量,代码、解读参考get_graph_input_node_numberget_graph_output_node:返回graph的输出节点的第idx个节点,代码、解读参考get_graph_input_nodeget_graph_input_tensor:通过input_idx和tensor_idx 获得graph的输入节点的输出张量(有点迷惑)create_graph_node:通过节点名称和op名称创建节点get_graph_node:通过节点名称 返回节点get_graph_node_by_idx:通过节点索引 返回节点get_graph_node_num:获得graph节点数量get_node_input_number:返回指定节点的输入数量get_node_output_number:返回指定节点的输出数量,代码、解读参考get_node_input_numberget_node_name:返回指定节点的节点名get_node_op:返回指定节点的op名称get_node_device:返回指定节点的device
// 设置graph的输入节点 // 告诉graph哪些节点为输入节点 int set_graph_input_node(graph_t graph, const char* input_nodes[], int input_number) { struct graph* ir_graph = (struct graph*)graph; int16_t* input_node_indexes; // 为输入节点的id指针分配内存 input_node_indexes = (int16_t*)sys_malloc(sizeof(int16_t) * input_number); // 分配失败,返回 if (input_node_indexes == NULL) { return -1; } // 对每一个输入节点进行设置 for (int i = 0; i < input_number; i++) { // 通过第i个节点的名字 查找节点在graph中的id int node_idx = get_ir_node_index_from_name(ir_graph, input_nodes[i]); // 如果索引小于0,说明没找到该节点名,因此要释放掉内存,然后返回 if (node_idx < 0) { sys_free(input_node_indexes); return -1; } // 如果找到了,为第i个id赋值为node_idx input_node_indexes[i] = node_idx; } // 为graph设置输入节点 int ret = set_ir_graph_input_node(ir_graph, input_node_indexes, input_number); sys_free(input_node_indexes); return ret; } // 返回graph的输入节点数量 int get_graph_input_node_number(graph_t graph) { struct graph* ir_graph = (struct graph*)graph; return ir_graph->input_num; } // 返回graph的输入节点的第idx个节点 node_t get_graph_input_node(graph_t graph, int idx) { struct graph* ir_graph = (struct graph*)graph; // 如果idx超过范围,则返回空 if (idx < 0 || idx >= ir_graph->input_num) { return NULL; } // input_nodes是graph的输入节点的id列表,使用idx索引列表的第几个输入节点 return get_ir_graph_node(ir_graph, ir_graph->input_nodes[idx]); } // 通过input_idx和tensor_idx 获得graph的输入节点的输出张量(有点迷惑) tensor_t get_graph_input_tensor(graph_t graph, int input_idx, int tensor_idx) { struct graph* ir_graph = (struct graph*)graph; // 如果input_idx超过graph的input node的数量,返回空 if (input_idx < 0 || input_idx >= ir_graph->input_num) { return NULL; } // 先获取到输入节点列表的第input_idx个节点的id int input_node_idx = ir_graph->input_nodes[input_idx]; // 获取graph的input_idx个输入节点 struct node* ir_node = ir_graph->node_list[input_node_idx]; // 如果没有找到,返回空 if (tensor_idx < 0 || tensor_idx >= ir_node->output_num) { return NULL; } return get_ir_graph_tensor(ir_node->graph, ir_node->output_tensors[tensor_idx]); } // 通过节点名称和op名称创建节点 node_t create_graph_node(graph_t graph, const char* node_name, const char* op_name) { struct graph* ir_graph = (struct graph*)graph; // 先通过节点名 查找graph中是否已经有同名的节点,如果有 则返回空,创建失败 int node_idx = get_ir_node_index_from_name(ir_graph, node_name); if (node_idx >= 0) { return NULL; } // 通过op名获取op类型 int op_type = get_op_type_from_name(op_name); if (op_type < 0) { return NULL; } // 创建节点 return create_ir_node(ir_graph, node_name, op_type, 1); } // 通过节点名称 返回节点 node_t get_graph_node(graph_t graph, const char* node_name) { struct graph* ir_graph = (struct graph*)graph; // 返回名为node_name的节点的索引 int node_idx = get_ir_node_index_from_name(ir_graph, node_name); if (node_idx < 0) { return NULL; } // 返回node_list的第node_idx个节点 return ir_graph->node_list[node_idx]; } // 通过节点索引 返回节点 node_t get_graph_node_by_idx(graph_t graph, int idx) { struct graph* ir_graph = (struct graph*)graph; if (idx < 0 || idx >= ir_graph->node_num) return NULL; return ir_graph->node_list[idx]; } // 返回graph节点数量 int get_graph_node_num(graph_t graph) { struct graph* ir_graph = (struct graph*)graph; return ir_graph->node_num; } // 返回指定节点的输入数量 int get_node_input_number(node_t node) { struct node* ir_node = (struct node*)node; return ir_node->input_num; } // 返回指定节点的节点名 const char* get_node_name(node_t node) { struct node* ir_node = (struct node*)node; // 如果指定了node->name,则返回节点名 if (ir_node->name) { return ir_node->name; } // 如果没有指定节点名,返回NULL。猜测这里create_ir_node_name_from_index是没实现完 ir_node->name = create_ir_node_name_from_index(ir_node->index); return ir_node->name; } // 返回指定节点的op名称 const char* get_node_op(node_t node) { struct node* ir_node = (struct node*)node; // 返回节点的op类型 int op_type = ir_node->op.type; // 根据op类型返回op名称 return get_op_name_from_type(op_type); } // 返回指定节点的device const char* get_node_device(node_t node) { // 获取node对应的graph struct node* ir_node = (struct node*)node; struct graph* graph = ir_node->graph; // node节点对应的graph的子图数 int subgraph_count = get_vector_num(graph->subgraph_list); // 如果有子图 if (subgraph_count > 0) { if (0 <= ir_node->subgraph_idx) // 这里的subgraph_idx理解的不是很透彻 { // 获得node的子图列表中的第subgraph_idx个子图, 返回子图设备名 struct subgraph* subgraph = get_ir_graph_subgraph(graph, ir_node->subgraph_idx); if (subgraph->device) { return subgraph->device->name; } } } else { return graph->device->name; } return NULL; }

Tensor 相关

get_node_input_tensor:返回指定节点的第idx个输入张量get_node_output_tensor:返回指定节点的第idx个输出张量,代码、解读参考get_node_input_tensorset_node_input_tensor:设置指定节点的第idx输入张量为tensor,如果idx超过了节点输入数量,则为节点扩充空间,把tensor加进去set_node_output_tensor:设置指定节点的第idx输出张量为tensor,代码、解读参考set_node_input_tensorcreate_graph_tensor:创建一个tensor,添加到graph的tensor_list中get_graph_tensor:返回graph中名为tensor_name的张量get_tensor_name:返回指定tensor的nameset_tensor_shape:设置tensor的shape,这里并不改变data,只是改变tensor结构体的属性get_tensor_shape:返回tensor的shape,shape写入dims[]数组中, dim_number需要和tensor的dim_num一致才会正常返回get_tensor_buffer_size:返回tensor的缓冲区大小,就是element_size*element_numget_tensor_buffer:获取指定tensor的dataset_tensor_buffer:重置tensor,tensor的data指向bufferget_tensor_data:将tensor的data拷贝到output_data中set_tensor_data:设置tensor的data为input_data,与set_tensor_buffer不同的是,set_tensor_data是直接将input_data拷贝过来,而不是指向input_dataget_tensor_data_type:返回tensor的数据类型set_tensor_data_type:// 设置tensor的数据类型,只是修改了tensor的结构体属性,没有对数据进行改变get_tensor_layout:返回tensor的layout,是NCHW还是NHWCset_tensor_layout:设置tensor的layout,只是修改了tensor的结构体属性,没有对数据进行改变
// 返回指定节点的第idx个输入张量 tensor_t get_node_input_tensor(node_t node, int input_idx) { struct node* ir_node = (struct node*)node; // 如果idx越界 返回空 if (input_idx < 0 || input_idx >= ir_node->input_num) { return NULL; } // 通过指定节点相关的graph,返回其输入的idx的张量 return get_ir_graph_tensor(ir_node->graph, ir_node->input_tensors[input_idx]); } // 设置指定节点的第idx输入张量为tensor,如果idx超过了节点输入数量,则为节点扩充空间,把tensor加进去 int set_node_input_tensor(node_t node, int input_idx, tensor_t tensor) { struct node* ir_node = (struct node*)node; struct tensor* ir_tensor = (struct tensor*)tensor; return set_ir_node_input_tensor(ir_node, input_idx, ir_tensor); } // 创建一个tensor,添加到graph的tensor_list中 tensor_t create_graph_tensor(graph_t graph, const char* tensor_name, int data_type) { struct graph* ir_graph = (struct graph*)graph; // 创建一个张量tensor,其数据类型为data_type,其layout为graph->graph_layout,tensor的名称设置为tensor_name // graph的tensor列表扩充一个空间,将新的tensor添加进去 return create_ir_tensor(ir_graph, tensor_name, data_type); } // 返回graph中名为tensor_name的张量 tensor_t get_graph_tensor(graph_t graph, const char* tensor_name) { struct graph* ir_graph = (struct graph*)graph; // 遍历graph中所有节点 for (int i = 0; i < ir_graph->node_num; i++) { // 获取到第i个节点 struct node* ir_node = get_ir_graph_node((ir_graph_t*)graph, i); if (NULL == ir_node) { continue; } else { // 遍历node的所有输入张量 for (int j = 0; j < ir_node->input_num; j++) { // 获取到node所对应graph的第j个输入张量 struct tensor* ir_tensor = get_ir_graph_tensor(ir_node->graph, ir_node->input_tensors[j]); // 如果张量存在,而且张量的name也存在,而且张量的name为要查找的name,返回当前张量 if (ir_tensor && ir_tensor->name && !strcmp(ir_tensor->name, tensor_name)) return (tensor_t)ir_tensor; } // 遍历node的所有输出张量,与上面类似 for (int j = 0; j < ir_node->output_num; j++) { struct tensor* ir_tensor = get_ir_graph_tensor(ir_node->graph, ir_node->output_tensors[j]); if (ir_tensor && ir_tensor->name && !strcmp(ir_tensor->name, tensor_name)) return (tensor_t)ir_tensor; } } } return NULL; } // 返回指定tensor的name const char* get_tensor_name(tensor_t tensor) { struct tensor* ir_tensor = (struct tensor*)tensor; if (ir_tensor->name == NULL) ir_tensor->name = create_ir_tensor_name_from_index(ir_tensor->index); return ir_tensor->name; } // 设置tensor的shape,这里并不改变data,只是改变tensor结构体的属性 int set_tensor_shape(tensor_t tensor, const int dims[], int dim_number) { struct tensor* ir_tensor = (struct tensor*)tensor; return set_ir_tensor_shape(ir_tensor, dims, dim_number); } // 返回tensor的shape,shape写入dims[]数组中, dim_number需要和tensor的dim_num一致才会正常返回 int get_tensor_shape(tensor_t tensor, int dims[], int dim_number) { struct tensor* ir_tensor = (struct tensor*)tensor; if (dim_number < ir_tensor->dim_num) { return -1; } for (int i = 0; i < ir_tensor->dim_num; i++) dims[i] = ir_tensor->dims[i]; return ir_tensor->dim_num; } // 返回tensor的缓冲区大小 int get_tensor_buffer_size(tensor_t tensor) { struct tensor* ir_tensor = (struct tensor*)tensor; return (int)(ir_tensor->elem_size * ir_tensor->elem_num); } // 获取指定tensor的data void* get_tensor_buffer(tensor_t tensor) { struct tensor* ir_tensor = (struct tensor*)tensor; // TODO: take dev mem into consideration return ir_tensor->data; } // 重置tensor,tensor的data指向buffer int set_tensor_buffer(tensor_t tensor, void* buffer, int buffer_size) { struct tensor* ir_tensor = (struct tensor*)tensor; // 获得tensor的缓冲区大小 int tensor_size = get_tensor_buffer_size(tensor); if (tensor_size != buffer_size) { fprintf(stderr, "Tengine: Size of tensor != size of buffer(%d vs %d).\n", tensor_size, buffer_size); return -1; } if (ir_tensor->data && ir_tensor->free_host_mem) sys_free(ir_tensor->data); ir_tensor->free_host_mem = 0; ir_tensor->internal_allocated = 0; ir_tensor->data = buffer; return 0; } // 将tensor的data拷贝到output_data中 int get_tensor_data(tensor_t tensor, void* output_data, int data_size) { struct tensor* ir_tensor = (struct tensor*)tensor; int tensor_size = get_tensor_buffer_size(tensor); // 需要满足output_data的大小足够大,可以装得下tensor的data if (data_size < tensor_size) { return -1; } // 将tensor的data拷贝到output_data空间中 if (ir_tensor->data) { memcpy(output_data, ir_tensor->data, tensor_size); return 0; } if (ir_tensor->dev_mem == NULL) { return -1; } // TODO: handle dev_mem case return -1; } // 设置tensor的data为input_data int set_tensor_data(tensor_t tensor, const void* input_data, int data_size) { struct tensor* ir_tensor = (struct tensor*)tensor; int tensor_size = get_tensor_buffer_size(tensor); // 保证tensor的存储空间足够大 if (data_size > tensor_size) { return -1; } // 将input_data的数据拷贝到tensor_data的空间中 if (ir_tensor->data) { memcpy(ir_tensor->data, input_data, tensor_size); return 0; } // TODO: handle dev_mem case return -1; } // 返回tensor的数据类型 int get_tensor_data_type(tensor_t tensor) { struct tensor* ir_tensor = (struct tensor*)tensor; return ir_tensor->data_type; } // 设置tensor的数据类型,只是修改了tensor的结构体属性,没有对数据进行改变 int set_tensor_data_type(tensor_t tensor, int data_type) { struct tensor* ir_tensor = (struct tensor*)tensor; ir_tensor->data_type = data_type; return 0; } // 返回tensor的layout,是NCHW还是NHWC int get_tensor_layout(tensor_t tensor) { struct tensor* ir_tensor = (struct tensor*)tensor; return ir_tensor->layout; } // 只是修改了tensor的结构体属性,没有对数据进行改变 int set_tensor_layout(tensor_t tensor, int layout) { struct tensor* ir_tensor = (struct tensor*)tensor; ir_tensor->layout = layout; return 0; }

往期文章