Google Colab

import argparse
import os
import sys
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

from builder import load_pretrained_model

sys.path.append(os.path.join(os.path.dirname(__file__), "language_model"))
from omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM

class SpeechDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".mp3")])
        self.target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith(".mp3")])
        assert len(self.input_files) == len(self.target_files), "입력과 출력 mp3 파일의 개수가 일치하지 않습니다."
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_path = self.input_files[idx]
        target_path = self.target_files[idx]
        waveform_in, sr_in = torchaudio.load(input_path)
        waveform_target, sr_target = torchaudio.load(target_path)
        
        if waveform_in.size(0) > 1:
            waveform_in = waveform_in.mean(dim=0, keepdim=True)
        if waveform_target.size(0) > 1:
            waveform_target = waveform_target.mean(dim=0, keepdim=True)
        
        waveform_in = waveform_in.transpose(0, 1)    
        waveform_target = waveform_target.transpose(0, 1)
        
        length_in = waveform_in.size(0)
        length_target = waveform_target.size(0)
        
        if self.transform:
            waveform_in = self.transform(waveform_in)
            waveform_target = self.transform(waveform_target)
        
        from omni_speech.constants import SPEECH_TOKEN_INDEX, IGNORE_INDEX
        
        input_ids = torch.tensor([SPEECH_TOKEN_INDEX], dtype=torch.long)
        labels = input_ids.clone()
        

        unit_vocab_size = 1000  
        tgt_units_length = max(1, int(length_target / 160))
        target_units = torch.randint(low=0, high=unit_vocab_size, size=(tgt_units_length,), dtype=torch.long)
        
        return {
            "input_ids": input_ids,
            "labels": labels,
            "speech": waveform_in,               # (time, channels)
            "speech_lengths": torch.tensor(length_in, dtype=torch.long),
            "tgt_units": target_units
        }

def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    
    max_len = max(item["speech"].size(0) for item in batch)
    padded_speech = []
    speech_lengths = []
    for item in batch:
        seq = item["speech"]
        pad_size = max_len - seq.size(0)
        if pad_size > 0:
            pad = torch.zeros((pad_size, seq.size(1)), dtype=seq.dtype)
            seq = torch.cat([seq, pad], dim=0)
        padded_speech.append(seq)
        speech_lengths.append(item["speech_lengths"])
    speech = torch.stack(padded_speech)
    speech_lengths = torch.stack(speech_lengths)
    
    max_units = max(item["tgt_units"].size(0) for item in batch)
    padded_units = []
    for item in batch:
        units = item["tgt_units"]
        pad_size = max_units - units.size(0)
        if pad_size > 0:
            pad = torch.full((pad_size,), fill_value=-100, dtype=units.dtype)
            units = torch.cat([units, pad], dim=0)
        padded_units.append(units)
    tgt_units = torch.stack(padded_units)
    
    return {
        "input_ids": input_ids,
        "labels": labels,
        "speech": speech,
        "speech_lengths": speech_lengths,
        "tgt_units": tgt_units
    }

def main():
    parser = argparse.ArgumentParser(description="Train OmniSpeech2S model for speech-to-speech (mp3 입출력)")
    parser.add_argument("--model_path", type=str, required=True, help="Pretrained 모델 디렉토리")
    parser.add_argument("--model_base", type=str, default=None, help="LoRA 사용 시 base model 경로")
    parser.add_argument("--train_input_dir", type=str, required=True, help="입력 mp3 파일 폴더")
    parser.add_argument("--train_target_dir", type=str, required=True, help="출력(mp3) 파일 폴더")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    tokenizer, model, context_len = load_pretrained_model(
        args.model_path,
        args.model_base,
        is_lora=False,
        s2s=True,
        load_8bit=False,
        device=args.device
    )
    
    if hasattr(model, "initialize_speech_generator"):
        model.initialize_speech_generator(args)
    
    model.train()
    model.to(args.device)

    dataset = SpeechDataset(args.train_input_dir, args.train_target_dir)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    for epoch in range(args.num_epochs):
        for batch in dataloader:
            input_ids = batch["input_ids"].to(args.device)
            labels = batch["labels"].to(args.device)
            speech = batch["speech"].to(args.device)
            speech_lengths = batch["speech_lengths"].to(args.device)
            tgt_units = batch["tgt_units"].to(args.device)

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                labels=labels,
                speech=speech,
                speech_lengths=speech_lengths,
                tgt_units=tgt_units
            )
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch} Loss: {loss.item()}")

        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

if __name__ == "__main__":
    main()

LLaMA-Omni/omni_speech/model/speech_generator/speech_generator.py at main · ictnlp/LLaMA-Omni

LLaMA-Omni/omni_speech/model/language_model/omni_speech2s_llama.py at main · ictnlp/LLaMA-Omni

speech = item['audio']['array']
speech = whisper.pad_or_trim(speech).astype(np.float32)
speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0)

# Labels 생성
labels = input_ids.clone()  # input_ids와 동일한 크기의 labels 초기화
labels[:] = -100  # 모든 위치를 -100으로 초기화

# 정답(Answer) 위치 탐색 및 토큰화
pattern = re.escape("<|start_header_id|>assistant<|end_header_id|>\\\\\\\\n\\\\\\\\n")
assistant_indices = [m.start() for m in re.finditer(pattern, prompt)]
for idx in assistant_indices:
    answer_start = idx + len("<|start_header_id|>assistant<|end_header_id|>\\\\\\\\n\\\\\\\\n")
    answer_end = prompt.find("<|eot_id|>", answer_start) + len("<|eot_id|>")

    answer_text = prompt[answer_start:answer_end]
    answer_ids = tokenizer_speech_token(answer_text, self.tokenizer, return_tensors='pt')

    # Labels에 정답 위치 설정
    labels_start = len(tokenizer_speech_token(prompt[:answer_start], self.tokenizer, return_tensors='pt'))
    labels_end = labels_start + answer_ids.shape[0]
    # print(labels_start, labels_end)
    labels[labels_start:labels_end] = answer_ids

return input_ids, speech, torch.LongTensor([speech.shape[0]]), labels