python

超轻量级php框架startmvc

pytorch 模型可视化的例子

更新时间:2020-07-24 18:30:01 作者:startmvc
如下所示:一.visualize.pyfromgraphvizimportDigraphimporttorchfromtorch.autogradimportVariabledefmake_dot(var,param

如下所示:

一. visualize.py


from graphviz import Digraph
import torch
from torch.autograd import Variable
 
 
def make_dot(var, params=None):
 """ Produces Graphviz representation of PyTorch autograd graph
 Blue nodes are the Variables that require grad, orange are Tensors
 saved for backward in torch.autograd.Function
 Args:
 var: output Variable
 params: dict of (name, Variable) to add names to node that
 require grad (TODO: make optional)
 """
 if params is not None:
 assert isinstance(params.values()[0], Variable)
 param_map = {id(v): k for k, v in params.items()}
 
 node_attr = dict(style='filled',
 shape='box',
 align='left',
 fontsize='12',
 ranksep='0.1',
 height='0.2')
 dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
 seen = set()
 
 def size_to_str(size):
 return '('+(', ').join(['%d' % v for v in size])+')'
 
 def add_nodes(var):
 if var not in seen:
 if torch.is_tensor(var):
 dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
 elif hasattr(var, 'variable'):
 u = var.variable
 name = param_map[id(u)] if params is not None else ''
 node_name = '%s\n %s' % (name, size_to_str(u.size()))
 dot.node(str(id(var)), node_name, fillcolor='lightblue')
 else:
 dot.node(str(id(var)), str(type(var).__name__))
 seen.add(var)
 if hasattr(var, 'next_functions'):
 for u in var.next_functions:
 if u[0] is not None:
 dot.edge(str(id(u[0])), str(id(var)))
 add_nodes(u[0])
 if hasattr(var, 'saved_tensors'):
 for t in var.saved_tensors:
 dot.edge(str(id(t)), str(id(var)))
 add_nodes(t)
 add_nodes(var.grad_fn)
 return dot

二. 使用步骤


import torch
from torch.autograd import Variable
from models import *
from visualize import make_dot
x = Variable(torch.rand(1, 3, 256, 256))
model = GeneratorUNet()
y = model(x)
g = make_dot(y)
g.view()

三. 效果展示

以上这篇pytorch 模型可视化的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

pytorch 模型 可视化