澄清在PyTorch上理解TorchScript和JIT

只想阐明我对JIT和TorchScripts工作方式的理解,并阐明一个特定的例子。

因此,如果我没记错,torch.jit.script会将我的方法或模块转换为TorchScript。我可以在python以外的环境中使用TorchScript编译模块,但也可以在python中使用经过改进和优化的模块。与torch.jit.trace类似的情况是跟踪权重和操作,但大致遵循类似的想法。

如果是这种情况,通常,TorchScripted模块应至少与python解释器的典型推理时间一样快。在进行一些实验时,我发现它通常比典型的解释器推理时间慢,并且在读了一点后发现,显然TorchScripted模块需要“预热”以达到最佳性能。这样做的时候,我没有发现推理时间有任何变化,虽然更好,但不足以对典型的做事方式(python解释器)进行改进。此外,我使用了一个名为torch_tvm的第三方库,应该启用该库,以使通过任何方式添加模块的推理时间都减少一半。

到目前为止,这一切都没有发生,我真的无法说出原因。

以下是我的示例代码,以防万一我做错了事-

class TrialC(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1024,2048)
        self.l2 = nn.Linear(2048,4096)
        self.l3 = nn.Linear(4096,4096)
        self.l4 = nn.Linear(4096,2048)
        self.l5 = nn.Linear(2048,1024)

    def forward(self,input):
        out = self.l1(input)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)
        out = self.l5(out)
        return out 

if __name__ == '__main__':
    # Trial inference input 
    TrialC_input = torch.randn(1,1024)
    warmup = 10

    # Record time for typical inference 
    model = TrialC()
    start = time.time()
    model_out = model(TrialC_input)
    elapsed = time.time() - start 

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript 
    script_model = torch.jit.script(TrialC())
    for i in range(warmup):
        start_2 = time.time()
        model_out_check_2 = script_model(TrialC_input)
        elapsed_2 = time.time() - start_2

    # Record the 10th inference time (10 warmup) for the optimized model in TorchScript + tvm optimization
    torch_tvm.enable()
    script_model_2 = torch.jit.trace(TrialC(),torch.randn(1,1024))
    for i in range(warmup):
        start_3 = time.time()
        model_out_check_3 = script_model_2(TrialC_input)
        elapsed_3 = time.time() - start_3 
    
    print("Regular model inference time: {}s\nJIT compiler inference time: {}s\nJIT Compiler with tvm: {}s".format(elapsed,elapsed_2,elapsed_3))

以下是我的CPU上上述代码的结果-

Regular model inference time: 0.10335588455200195s
JIT compiler inference time: 0.11449170112609863s
JIT Compiler with tvm: 0.10834860801696777s

对此将提供任何帮助或澄清!

akinschen 回答:澄清在PyTorch上理解TorchScript和JIT

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/1351849.html

大家都在问