Mr Fix It AMERICA
MrFiXitX.tv
. ____--MsX-- {*} --MrX--____.
TJ@MrFiXitX.tv
Nerf

import torch
import torch.nn as nn
import torch.nn.functional as F

class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, skips=[4]):
super(NeRF, self).__init__()
self.D = D
self.W = W
self.skips = skips
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] +
[nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)]
)
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)

def forward(self, x, viewdirs):
input_pts, input_views = x[..., :3], x[..., 3:]
h = input_pts
for i, l in enumerate(self.pts_linears):
h = F.relu(l(h))
if i in self.skips:
h = torch.cat([input_pts, h], -1)

alpha = F.relu(self.alpha_linear(h))
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)

for i, l in enumerate(self.views_linears):
h = F.relu(l(h))

rgb = torch.sigmoid(self.rgb_linear(h))
return rgb, alpha

# Usage example
model = NeRF()
points = torch.rand(1024, 6) # xyz + viewdir
rgb, density = model(points[:, :3], points[:, 3:])

Currently there is no media on this page

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, skips=[4]):
        super(NeRF, self).__init__()
        self.D = D
        self.W = W
        self.skips = skips
        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] +
            [nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)]
        )
        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
        self.feature_linear = nn.Linear(W, W)
        self.alpha_linear = nn.Linear(W, 1)
        self.rgb_linear = nn.Linear(W//2, 3)
 
    def forward(self, x, viewdirs):
        input_pts, input_views = x[..., :3], x[..., 3:]
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = F.relu(l(h))
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)
 
        alpha = F.relu(self.alpha_linear(h))
        feature = self.feature_linear(h)
        h = torch.cat([feature, input_views], -1)
 
        for i, l in enumerate(self.views_linears):
            h = F.relu(l(h))
 
        rgb = torch.sigmoid(self.rgb_linear(h))
        return rgb, alpha
 
# Usage example
model = NeRF()
points = torch.rand(1024, 6)  # xyz + viewdir
rgb, density = model(points[:, :3], points[:, 3:])