|
Nerf
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, skips=[4]):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.skips = skips
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] +
[nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)]
)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
def forward(self, x, viewdirs):
input_pts, input_views = x[..., :3], x[..., 3:]
h = input_pts
for i, l in enumerate(self.pts_linears):
h = F.relu(l(h))
if i in self.skips:
h = torch.cat([input_pts, h], -1)
alpha = F.relu(self.alpha_linear(h))
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = F.relu(l(h))
rgb = torch.sigmoid(self.rgb_linear(h))
return rgb, alpha
# Usage example
model = NeRF()
points = torch.rand(1024, 6) # xyz + viewdir
rgb, density = model(points[:, :3], points[:, 3:])
Currently there is no media on this page
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, skips=[4]):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.skips = skips
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] +
[nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)]
)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
def forward(self, x, viewdirs):
input_pts, input_views = x[..., :3], x[..., 3:]
h = input_pts
for i, l in enumerate(self.pts_linears):
h = F.relu(l(h))
if i in self.skips:
h = torch.cat([input_pts, h], -1)
alpha = F.relu(self.alpha_linear(h))
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = F.relu(l(h))
rgb = torch.sigmoid(self.rgb_linear(h))
return rgb, alpha
# Usage example
model = NeRF()
points = torch.rand(1024, 6) # xyz + viewdir
rgb, density = model(points[:, :3], points[:, 3:])
|