import argparse
import os
import sys
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from builder import load_pretrained_model
sys.path.append(os.path.join(os.path.dirname(__file__), "language_model"))
from omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM
class SpeechDataset(Dataset):
def __init__(self, input_dir, target_dir, transform=None):
self.input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".mp3")])
self.target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith(".mp3")])
assert len(self.input_files) == len(self.target_files), "입력과 출력 mp3 파일의 개수가 일치하지 않습니다."
self.transform = transform
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
input_path = self.input_files[idx]
target_path = self.target_files[idx]
waveform_in, sr_in = torchaudio.load(input_path)
waveform_target, sr_target = torchaudio.load(target_path)
if waveform_in.size(0) > 1:
waveform_in = waveform_in.mean(dim=0, keepdim=True)
if waveform_target.size(0) > 1:
waveform_target = waveform_target.mean(dim=0, keepdim=True)
waveform_in = waveform_in.transpose(0, 1)
waveform_target = waveform_target.transpose(0, 1)
length_in = waveform_in.size(0)
length_target = waveform_target.size(0)
if self.transform:
waveform_in = self.transform(waveform_in)
waveform_target = self.transform(waveform_target)
from omni_speech.constants import SPEECH_TOKEN_INDEX, IGNORE_INDEX
input_ids = torch.tensor([SPEECH_TOKEN_INDEX], dtype=torch.long)
labels = input_ids.clone()
unit_vocab_size = 1000
tgt_units_length = max(1, int(length_target / 160))
target_units = torch.randint(low=0, high=unit_vocab_size, size=(tgt_units_length,), dtype=torch.long)
return {
"input_ids": input_ids,
"labels": labels,
"speech": waveform_in, # (time, channels)
"speech_lengths": torch.tensor(length_in, dtype=torch.long),
"tgt_units": target_units
}
def collate_fn(batch):
input_ids = torch.stack([item["input_ids"] for item in batch])
labels = torch.stack([item["labels"] for item in batch])
max_len = max(item["speech"].size(0) for item in batch)
padded_speech = []
speech_lengths = []
for item in batch:
seq = item["speech"]
pad_size = max_len - seq.size(0)
if pad_size > 0:
pad = torch.zeros((pad_size, seq.size(1)), dtype=seq.dtype)
seq = torch.cat([seq, pad], dim=0)
padded_speech.append(seq)
speech_lengths.append(item["speech_lengths"])
speech = torch.stack(padded_speech)
speech_lengths = torch.stack(speech_lengths)
max_units = max(item["tgt_units"].size(0) for item in batch)
padded_units = []
for item in batch:
units = item["tgt_units"]
pad_size = max_units - units.size(0)
if pad_size > 0:
pad = torch.full((pad_size,), fill_value=-100, dtype=units.dtype)
units = torch.cat([units, pad], dim=0)
padded_units.append(units)
tgt_units = torch.stack(padded_units)
return {
"input_ids": input_ids,
"labels": labels,
"speech": speech,
"speech_lengths": speech_lengths,
"tgt_units": tgt_units
}
def main():
parser = argparse.ArgumentParser(description="Train OmniSpeech2S model for speech-to-speech (mp3 입출력)")
parser.add_argument("--model_path", type=str, required=True, help="Pretrained 모델 디렉토리")
parser.add_argument("--model_base", type=str, default=None, help="LoRA 사용 시 base model 경로")
parser.add_argument("--train_input_dir", type=str, required=True, help="입력 mp3 파일 폴더")
parser.add_argument("--train_target_dir", type=str, required=True, help="출력(mp3) 파일 폴더")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=10)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
tokenizer, model, context_len = load_pretrained_model(
args.model_path,
args.model_base,
is_lora=False,
s2s=True,
load_8bit=False,
device=args.device
)
if hasattr(model, "initialize_speech_generator"):
model.initialize_speech_generator(args)
model.train()
model.to(args.device)
dataset = SpeechDataset(args.train_input_dir, args.train_target_dir)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
for epoch in range(args.num_epochs):
for batch in dataloader:
input_ids = batch["input_ids"].to(args.device)
labels = batch["labels"].to(args.device)
speech = batch["speech"].to(args.device)
speech_lengths = batch["speech_lengths"].to(args.device)
tgt_units = batch["tgt_units"].to(args.device)
optimizer.zero_grad()
outputs = model(
input_ids=input_ids,
labels=labels,
speech=speech,
speech_lengths=speech_lengths,
tgt_units=tgt_units
)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch} Loss: {loss.item()}")
torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")
if __name__ == "__main__":
main()
LLaMA-Omni/omni_speech/model/speech_generator/speech_generator.py at main · ictnlp/LLaMA-Omni
LLaMA-Omni/omni_speech/model/language_model/omni_speech2s_llama.py at main · ictnlp/LLaMA-Omni
speech = item['audio']['array']
speech = whisper.pad_or_trim(speech).astype(np.float32)
speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0)
# Labels 생성
labels = input_ids.clone() # input_ids와 동일한 크기의 labels 초기화
labels[:] = -100 # 모든 위치를 -100으로 초기화
# 정답(Answer) 위치 탐색 및 토큰화
pattern = re.escape("<|start_header_id|>assistant<|end_header_id|>\\\\\\\\n\\\\\\\\n")
assistant_indices = [m.start() for m in re.finditer(pattern, prompt)]
for idx in assistant_indices:
answer_start = idx + len("<|start_header_id|>assistant<|end_header_id|>\\\\\\\\n\\\\\\\\n")
answer_end = prompt.find("<|eot_id|>", answer_start) + len("<|eot_id|>")
answer_text = prompt[answer_start:answer_end]
answer_ids = tokenizer_speech_token(answer_text, self.tokenizer, return_tensors='pt')
# Labels에 정답 위치 설정
labels_start = len(tokenizer_speech_token(prompt[:answer_start], self.tokenizer, return_tensors='pt'))
labels_end = labels_start + answer_ids.shape[0]
# print(labels_start, labels_end)
labels[labels_start:labels_end] = answer_ids
return input_ids, speech, torch.LongTensor([speech.shape[0]]), labels