pytorch中的where和gather的介绍

说明

本内容节选自《深度学习框架pytorch》,目前本专栏限时打折中,欢迎订阅


正文

本文将介绍在使用深度学习框架pytorch中的两个常用的函数,这两个函数在搭建模型的时候会经常用到,下面我们看一下它的具体用法

torch.where(condition,x,y)

首先condition和x、y他们三个的shape应该是一致的

然后这个方法的意思就是说判断condition,如果conditon为1,则取x,否则取y

pytorch中的where和gather的介绍

gather就是一个查表得过程

torch.gather(input,dim,index,out=None)

这个表示在input的dim维度查找索引index所对应的值

pytorch中的where和gather的介绍


分享到:


相關文章: