Remove the last projection layer for mel of decoder and append a waveform decoder after it
Propose two types of decoders
Employ the snake activation function → paper
AdaIN module is added after each activation function like StyleTTS decoder
Replace the mel-disctiminator in [StyleTTS]
ref: https://github.com/NVIDIA/BigVGAN/blob/main/models.py#L124
input
output
model architecture
| Submodule | Input | Layer | Norm | Output Shape | Output |
|---|---|---|---|---|---|
| (smoothing) | p, n | F.conv1d | - | (B, min_F) | |
| F0_conv | p | Conv1D | weight norm | (B, 1, min_F//2) | p_ |
| N_conv | n | Conv1D | weight norm | (B, 1, min_F//2) | n_ |
| encode | [h, p_, n_], s | AdainResBlk1d | (B, 1024, min_F//2) | x | |
| asr_res | h | Conv1D | weight norm | (B, 64, min_F//2) | h_ |
| decode | [x, h_, p_, n_], s | 4 x AdainResBlk1d | (B, 512, min_F//2) | x_ | |
| generator (BigVGan) | x_, s, p | ||||
| generator (iSTFTNet) |
BigVGan 모델은 input으로 mel만 받음
decoder의 generator는 decoder.decode의 output과 style vector, pitch curve를 추가로 인풋으로 받음
모델 구조가 복잡해서 어떤 처리를 하는지 모르겠음 → 확인 필요
class Generator(torch.nn.Module):
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
resblock = AdaINResBlock1
self.m_source = SourceModuleHnNSF(
sampling_rate=24000,
upsample_scale=np.prod(upsample_rates),
harmonic_num=8, voiced_threshod=10)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
self.noise_convs = nn.ModuleList()
self.ups = nn.ModuleList()
self.noise_res = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
c_cur = upsample_initial_channel // (2 ** (i + 1))
self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel//(2**i),
upsample_initial_channel//(2**(i+1)),
k, u, padding=(u//2 + u%2), output_padding=u%2)))
if i + 1 < len(upsample_rates): #
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(Conv1d(
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
else:
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
self.resblocks = nn.ModuleList()
self.alphas = nn.ParameterList()
self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d, style_dim))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x, s, f0):
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, noi_source, uv = self.m_source(f0)
har_source = har_source.transpose(1, 2)
for i in range(self.num_upsamples):
x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
x_source = self.noise_convs[i](har_source)
x_source = self.noise_res[i](x_source, s)
x = self.ups[i](x)
x = x + x_source
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = x + (1 / self.alphas[i+1]) * (torch.sin(self.alphas[i+1] * x) ** 2)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class BigVGAN(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, h):
super(BigVGAN, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)