diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f97200dc..e2b52db8 100755 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -17,7 +17,7 @@ jobs: max-parallel: 10 matrix: python-version: [3.7] - tensorflow-version: [2.3.0] + tensorflow-version: [2.3.1] steps: - uses: actions/checkout@master - uses: actions/setup-python@v1 diff --git a/.gitignore b/.gitignore index d43284a1..ac548057 100755 --- a/.gitignore +++ b/.gitignore @@ -42,5 +42,4 @@ dump_baker/ dump_ljspeech/ dump_kss/ dump_libritts/ -/examples/*/* /notebooks/test_saved/ diff --git a/README.md b/README.md index de98f57c..08acb6c7 100755 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ :zany_face: TensorFlowTTS provides real-time state-of-the-art speech synthesis architectures such as Tacotron-2, Melgan, Multiband-Melgan, FastSpeech, FastSpeech2 based-on TensorFlow 2. With Tensorflow 2, we can speed-up training/inference progress, optimizer further by using [fake-quantize aware](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide) and [pruning](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras), make TTS models can be run faster than real-time and be able to deploy on mobile devices or embedded systems. ## What's new +- 2020/11/24 **(NEW!)** Add HiFi-GAN vocoder. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/hifigan) - 2020/11/19 **(NEW!)** Add Multi-GPU gradient accumulator. See [here](https://github.com/TensorSpeech/TensorFlowTTS/pull/377) - 2020/08/23 Add Parallel WaveGAN tensorflow implementation. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/parallel_wavegan) - 2020/08/23 Add MBMelGAN G + ParallelWaveGAN G example. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/multiband_pwgan) @@ -85,6 +86,7 @@ TensorFlowTTS currently provides the following architectures: 4. **Multi-band MelGAN** released with the paper [Multi-band MelGAN: Faster Waveform Generation for High-Quality Text-to-Speech](https://arxiv.org/abs/2005.05106) by Geng Yang, Shan Yang, Kai Liu, Peng Fang, Wei Chen, Lei Xie. 5. **FastSpeech2** released with the paper [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558) by Yi Ren, Chenxu Hu, Xu Tan, Tao Qin, Sheng Zhao, Zhou Zhao, Tie-Yan Liu. 6. **Parallel WaveGAN** released with the paper [Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram](https://arxiv.org/abs/1910.11480) by Ryuichi Yamamoto, Eunwoo Song, Jae-Min Kim. +7. **HiFi-GAN** released with the paper [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) by Jungil Kong, Jaehyeon Kim, Jaekyoung Bae. We are also implementing some techniques to improve quality and convergence speed from the following papers: @@ -217,6 +219,7 @@ To know how to train model from scratch or fine-tune with other datasets/languag - For Multiband-MelGAN tutorial, pls see [examples/multiband_melgan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/multiband_melgan) - For Parallel WaveGAN tutorial, pls see [examples/parallel_wavegan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/parallel_wavegan) - For Multiband-MelGAN Generator + Parallel WaveGAN Discriminator tutorial, pls see [examples/multiband_pwgan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/multiband_pwgan) +- For HiFi-GAN tutorial, pls see [examples/hifigan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan) # Abstract Class Explaination ## Abstract DataLoader Tensorflow-based dataset diff --git a/examples/hifigan/README.md b/examples/hifigan/README.md new file mode 100755 index 00000000..c2e7ca35 --- /dev/null +++ b/examples/hifigan/README.md @@ -0,0 +1,65 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis +Based on the script [`train_hifigan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan/train_hifigan.py). + +## Training HiFi-GAN from scratch with LJSpeech dataset. +This example code show you how to train MelGAN from scratch with Tensorflow 2 based on custom training loop and tf.function. The data used for this example is LJSpeech, you can download the dataset at [link](https://keithito.com/LJ-Speech-Dataset/). + +### Step 1: Create Tensorflow based Dataloader (tf.dataset) +First, you need define data loader based on AbstractDataset class (see [`abstract_dataset.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/tensorflow_tts/datasets/abstract_dataset.py)). On this example, a dataloader read dataset from path. I use suffix to classify what file is a audio and mel-spectrogram (see [`audio_mel_dataset.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan/audio_mel_dataset.py)). If you already have preprocessed version of your target dataset, you don't need to use this example dataloader, you just need refer my dataloader and modify **generator function** to adapt with your case. Normally, a generator function should return [audio, mel]. + +### Step 2: Training from scratch +After you re-define your dataloader, pls modify an input arguments, train_dataset and valid_dataset from [`train_hifigan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan/train_hifigan.py). Here is an example command line to training HiFi-GAN from scratch: + +First, you need training generator with only stft loss: + +```bash +CUDA_VISIBLE_DEVICES=0 python examples/hifigan/train_hifigan.py \ + --train-dir ./dump/train/ \ + --dev-dir ./dump/valid/ \ + --outdir ./examples/hifigan/exp/train.hifigan.v1/ \ + --config ./examples/hifigan/conf/hifigan.v1.yaml \ + --use-norm 1 + --generator_mixed_precision 1 \ + --resume "" +``` + +Then resume and start training generator + discriminator: + +```bash +CUDA_VISIBLE_DEVICES=0 python examples/hifigan/train_hifigan.py \ + --train-dir ./dump/train/ \ + --dev-dir ./dump/valid/ \ + --outdir ./examples/hifigan/exp/train.hifigan.v1/ \ + --config ./examples/hifigan/conf/hifigan.v1.yaml \ + --use-norm 1 + --resume ./examples/hifigan/exp/train.hifigan.v1/checkpoints/ckpt-100000 +``` + +IF you want to use MultiGPU to training you can replace `CUDA_VISIBLE_DEVICES=0` by `CUDA_VISIBLE_DEVICES=0,1,2,3` for example. You also need to tune the `batch_size` for each GPU (in config file) by yourself to maximize the performance. Note that MultiGPU now support for Training but not yet support for Decode. + +In case you want to resume the training progress, please following below example command line: + +```bash +--resume ./examples/hifigan/exp/train.hifigan.v1/checkpoints/ckpt-100000 +``` + +If you want to finetune a model, use `--pretrained` like this with the filename of the generator +```bash +--pretrained ptgenerator.h5 +``` + +**IMPORTANT NOTES**: + +- When training generator only, we enable mixed precision to speed-up training progress. +- We don't apply mixed precision when training both generator and discriminator. (Discriminator include group-convolution, which cause discriminator slower when enable mixed precision). +- 100k here is a *discriminator_train_start_steps* parameters from [hifigan.v1.yaml](https://github.com/tensorspeech/TensorflowTTS/tree/master/examples/hifigan/conf/hifigan.v1.yaml) + + +## Reference + +1. https://github.com/descriptinc/melgan-neurips +2. https://github.com/kan-bayashi/ParallelWaveGAN +3. https://github.com/tensorflow/addons +4. [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) +5. [MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis](https://arxiv.org/abs/1910.06711) +6. [Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram](https://arxiv.org/abs/1910.11480) \ No newline at end of file diff --git a/examples/hifigan/conf/hifigan.v1.yaml b/examples/hifigan/conf/hifigan.v1.yaml new file mode 100755 index 00000000..d89a879b --- /dev/null +++ b/examples/hifigan/conf/hifigan.v1.yaml @@ -0,0 +1,116 @@ + +# This is the hyperparameter configuration file for Hifigan. +# Please make sure this is adjusted for the LJSpeech dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration performs 4000k iters. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sampling_rate: 22050 # Sampling rate of dataset. +hop_size: 256 # Hop size. +format: "npy" + + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +model_type: "hifigan_generator" + +hifigan_generator_params: + out_channels: 1 + kernel_size: 7 + filters: 512 + use_bias: true + upsample_scales: [8, 8, 2, 2] + stacks: 3 + stack_kernel_size: [3, 7, 11] + stack_dilation_rate: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + use_final_nolinear_activation: true + is_weight_norm: false + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +hifigan_discriminator_params: + out_channels: 1 # Number of output channels (number of subbands). + period_scales: [2, 3, 5, 7, 11] # List of period scales. + n_layers: 5 # Number of layer of each period discriminator. + kernel_size: 5 # Kernel size. + strides: 3 # Strides + filters: 8 # In Conv filters of each period discriminator + filter_scales: 4 # Filter scales. + max_filters: 1024 # maximum filters of period discriminator's conv. + is_weight_norm: false # Use weight-norm or not. + +melgan_discriminator_params: + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AveragePooling1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + pool_size: 4 + strides: 2 + kernel_sizes: [5, 3] # List of kernel size. + filters: 16 # Number of channels of the initial conv layer. + max_downsample_filters: 1024 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + alpha: 0.2 + is_weight_norm: false # Use weight-norm or not. + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_lengths: [1024, 2048, 512] # List of FFT size for STFT-based loss. + frame_steps: [120, 240, 50] # List of hop size for STFT-based loss + frame_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_feat_match: 10.0 +lambda_adv: 4.0 + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. +batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. +batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. +remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. +allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +is_shuffle: true # shuffle dataset after each epoch. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + lr_fn: "PiecewiseConstantDecay" + lr_params: + boundaries: [100000, 200000, 300000, 400000, 500000, 600000, 700000] + values: [0.0005, 0.0005, 0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] + amsgrad: false + +discriminator_optimizer_params: + lr_fn: "PiecewiseConstantDecay" + lr_params: + boundaries: [100000, 200000, 300000, 400000, 500000] + values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] + amsgrad: false + +gradient_accumulation_steps: 1 # should be even number or 1. +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 100000 # steps begin training discriminator +train_max_steps: 4000000 # Number of training steps. +save_interval_steps: 20000 # Interval steps to save checkpoint. +eval_interval_steps: 5000 # Interval steps to evaluate the network. +log_interval_steps: 200 # Interval steps to record the training log. + +########################################################### +# OTHER SETTING # +########################################################### +num_save_intermediate_results: 1 # Number of batch to be saved as intermediate results. diff --git a/examples/hifigan/conf/hifigan.v2.yaml b/examples/hifigan/conf/hifigan.v2.yaml new file mode 100755 index 00000000..0baedb73 --- /dev/null +++ b/examples/hifigan/conf/hifigan.v2.yaml @@ -0,0 +1,116 @@ + +# This is the hyperparameter configuration file for Hifigan. +# Please make sure this is adjusted for the LJSpeech dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration performs 4000k iters. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sampling_rate: 22050 # Sampling rate of dataset. +hop_size: 256 # Hop size. +format: "npy" + + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +model_type: "hifigan_generator" + +hifigan_generator_params: + out_channels: 1 + kernel_size: 7 + filters: 128 + use_bias: true + upsample_scales: [8, 8, 2, 2] + stacks: 3 + stack_kernel_size: [3, 7, 11] + stack_dilation_rate: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + use_final_nolinear_activation: true + is_weight_norm: false + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +hifigan_discriminator_params: + out_channels: 1 # Number of output channels (number of subbands). + period_scales: [2, 3, 5, 7, 11] # List of period scales. + n_layers: 5 # Number of layer of each period discriminator. + kernel_size: 5 # Kernel size. + strides: 3 # Strides + filters: 8 # In Conv filters of each period discriminator + filter_scales: 4 # Filter scales. + max_filters: 512 # maximum filters of period discriminator's conv. + is_weight_norm: false # Use weight-norm or not. + +melgan_discriminator_params: + out_channels: 1 # Number of output channels. + scales: 3 # Number of multi-scales. + downsample_pooling: "AveragePooling1D" # Pooling type for the input downsampling. + downsample_pooling_params: # Parameters of the above pooling function. + pool_size: 4 + strides: 2 + kernel_sizes: [5, 3] # List of kernel size. + filters: 16 # Number of channels of the initial conv layer. + max_downsample_filters: 512 # Maximum number of channels of downsampling layers. + downsample_scales: [4, 4, 4, 4] # List of downsampling scales. + nonlinear_activation: "LeakyReLU" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + alpha: 0.2 + is_weight_norm: false # Use weight-norm or not. + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_lengths: [1024, 2048, 512] # List of FFT size for STFT-based loss. + frame_steps: [120, 240, 50] # List of hop size for STFT-based loss + frame_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_feat_match: 10.0 +lambda_adv: 4.0 + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. +batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. +batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. +remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. +allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. +is_shuffle: true # shuffle dataset after each epoch. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + lr_fn: "PiecewiseConstantDecay" + lr_params: + boundaries: [100000, 200000, 300000, 400000, 500000, 600000, 700000] + values: [0.0005, 0.0005, 0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] + amsgrad: false + +discriminator_optimizer_params: + lr_fn: "PiecewiseConstantDecay" + lr_params: + boundaries: [100000, 200000, 300000, 400000, 500000] + values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] + amsgrad: false + +gradient_accumulation_steps: 1 # should be even number or 1. +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 100000 # steps begin training discriminator +train_max_steps: 4000000 # Number of training steps. +save_interval_steps: 20000 # Interval steps to save checkpoint. +eval_interval_steps: 5000 # Interval steps to evaluate the network. +log_interval_steps: 200 # Interval steps to record the training log. + +########################################################### +# OTHER SETTING # +########################################################### +num_save_intermediate_results: 1 # Number of batch to be saved as intermediate results. diff --git a/examples/hifigan/train_hifigan.py b/examples/hifigan/train_hifigan.py new file mode 100755 index 00000000..362c1072 --- /dev/null +++ b/examples/hifigan/train_hifigan.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 TensorFlowTTS Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Train Hifigan.""" + +import tensorflow as tf + +physical_devices = tf.config.list_physical_devices("GPU") +for i in range(len(physical_devices)): + tf.config.experimental.set_memory_growth(physical_devices[i], True) + +import sys + +sys.path.append(".") + +import argparse +import logging +import os + +import numpy as np +import soundfile as sf +import yaml +from tqdm import tqdm + +import tensorflow_tts +from examples.melgan.audio_mel_dataset import AudioMelDataset +from examples.melgan.train_melgan import collater +from examples.melgan_stft.train_melgan_stft import MultiSTFTMelganTrainer +from tensorflow_tts.configs import ( + HifiGANDiscriminatorConfig, + HifiGANGeneratorConfig, + MelGANDiscriminatorConfig, +) +from tensorflow_tts.models import ( + TFHifiGANGenerator, + TFHifiGANMultiPeriodDiscriminator, + TFMelGANMultiScaleDiscriminator, +) +from tensorflow_tts.utils import return_strategy + + +class TFHifiGANDiscriminator(tf.keras.Model): + def __init__(self, multiperiod_dis, multiscale_dis, **kwargs): + super().__init__(**kwargs) + self.multiperiod_dis = multiperiod_dis + self.multiscale_dis = multiscale_dis + + def call(self, x): + outs = [] + period_outs = self.multiperiod_dis(x) + scale_outs = self.multiscale_dis(x) + outs.extend(period_outs) + outs.extend(scale_outs) + return outs + + +def main(): + """Run training process.""" + parser = argparse.ArgumentParser( + description="Train Hifigan (See detail in examples/hifigan/train_hifigan.py)" + ) + parser.add_argument( + "--train-dir", + default=None, + type=str, + help="directory including training data. ", + ) + parser.add_argument( + "--dev-dir", + default=None, + type=str, + help="directory including development data. ", + ) + parser.add_argument( + "--use-norm", default=1, type=int, help="use norm mels for training or raw." + ) + parser.add_argument( + "--outdir", type=str, required=True, help="directory to save checkpoints." + ) + parser.add_argument( + "--config", type=str, required=True, help="yaml format configuration file." + ) + parser.add_argument( + "--resume", + default="", + type=str, + nargs="?", + help='checkpoint file path to resume training. (default="")', + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + parser.add_argument( + "--generator_mixed_precision", + default=0, + type=int, + help="using mixed precision for generator or not.", + ) + parser.add_argument( + "--discriminator_mixed_precision", + default=0, + type=int, + help="using mixed precision for discriminator or not.", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + nargs="?", + help="path of .h5 melgan generator to load weights from", + ) + args = parser.parse_args() + + # return strategy + STRATEGY = return_strategy() + + # set mixed precision config + if args.generator_mixed_precision == 1 or args.discriminator_mixed_precision == 1: + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) + + args.generator_mixed_precision = bool(args.generator_mixed_precision) + args.discriminator_mixed_precision = bool(args.discriminator_mixed_precision) + + args.use_norm = bool(args.use_norm) + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # check arguments + if args.train_dir is None: + raise ValueError("Please specify --train-dir") + if args.dev_dir is None: + raise ValueError("Please specify either --valid-dir") + + # load and save config + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + config.update(vars(args)) + config["version"] = tensorflow_tts.__version__ + with open(os.path.join(args.outdir, "config.yml"), "w") as f: + yaml.dump(config, f, Dumper=yaml.Dumper) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + if config["remove_short_samples"]: + mel_length_threshold = config["batch_max_steps"] // config[ + "hop_size" + ] + 2 * config["hifigan_generator_params"].get("aux_context_window", 0) + else: + mel_length_threshold = None + + if config["format"] == "npy": + audio_query = "*-wave.npy" + mel_query = "*-raw-feats.npy" if args.use_norm is False else "*-norm-feats.npy" + audio_load_fn = np.load + mel_load_fn = np.load + else: + raise ValueError("Only npy are supported.") + + # define train/valid dataset + train_dataset = AudioMelDataset( + root_dir=args.train_dir, + audio_query=audio_query, + mel_query=mel_query, + audio_load_fn=audio_load_fn, + mel_load_fn=mel_load_fn, + mel_length_threshold=mel_length_threshold, + ).create( + is_shuffle=config["is_shuffle"], + map_fn=lambda items: collater( + items, + batch_max_steps=tf.constant(config["batch_max_steps"], dtype=tf.int32), + hop_size=tf.constant(config["hop_size"], dtype=tf.int32), + ), + allow_cache=config["allow_cache"], + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], + ) + + valid_dataset = AudioMelDataset( + root_dir=args.dev_dir, + audio_query=audio_query, + mel_query=mel_query, + audio_load_fn=audio_load_fn, + mel_load_fn=mel_load_fn, + mel_length_threshold=mel_length_threshold, + ).create( + is_shuffle=config["is_shuffle"], + map_fn=lambda items: collater( + items, + batch_max_steps=tf.constant( + config["batch_max_steps_valid"], dtype=tf.int32 + ), + hop_size=tf.constant(config["hop_size"], dtype=tf.int32), + ), + allow_cache=config["allow_cache"], + batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + ) + + # define trainer + trainer = MultiSTFTMelganTrainer( + steps=0, + epochs=0, + config=config, + strategy=STRATEGY, + is_generator_mixed_precision=args.generator_mixed_precision, + is_discriminator_mixed_precision=args.discriminator_mixed_precision, + ) + + with STRATEGY.scope(): + # define generator and discriminator + generator = TFHifiGANGenerator( + HifiGANGeneratorConfig(**config["hifigan_generator_params"]), + name="hifigan_generator", + ) + + multiperiod_discriminator = TFHifiGANMultiPeriodDiscriminator( + HifiGANDiscriminatorConfig(**config["hifigan_discriminator_params"]), + name="hifigan_multiperiod_discriminator", + ) + multiscale_discriminator = TFMelGANMultiScaleDiscriminator( + MelGANDiscriminatorConfig( + **config["melgan_discriminator_params"], + name="melgan_multiscale_discriminator", + ) + ) + + discriminator = TFHifiGANDiscriminator( + multiperiod_discriminator, + multiscale_discriminator, + name="hifigan_discriminator", + ) + + # dummy input to build model. + fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32) + y_hat = generator(fake_mels) + discriminator(y_hat) + + if len(args.pretrained) > 1: + generator.load_weights(args.pretrained) + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) + + generator.summary() + discriminator.summary() + + # define optimizer + generator_lr_fn = getattr( + tf.keras.optimizers.schedules, config["generator_optimizer_params"]["lr_fn"] + )(**config["generator_optimizer_params"]["lr_params"]) + discriminator_lr_fn = getattr( + tf.keras.optimizers.schedules, + config["discriminator_optimizer_params"]["lr_fn"], + )(**config["discriminator_optimizer_params"]["lr_params"]) + + gen_optimizer = tf.keras.optimizers.Adam( + learning_rate=generator_lr_fn, + amsgrad=config["generator_optimizer_params"]["amsgrad"], + ) + dis_optimizer = tf.keras.optimizers.Adam( + learning_rate=discriminator_lr_fn, + amsgrad=config["discriminator_optimizer_params"]["amsgrad"], + ) + + trainer.compile( + gen_model=generator, + dis_model=discriminator, + gen_optimizer=gen_optimizer, + dis_optimizer=dis_optimizer, + ) + + # start training + try: + trainer.fit( + train_dataset, + valid_dataset, + saved_path=os.path.join(config["outdir"], "checkpoints/"), + resume=args.resume, + ) + except KeyboardInterrupt: + trainer.save_checkpoint() + logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") + + +if __name__ == "__main__": + main() diff --git a/examples/melgan.stft/README.md b/examples/melgan_stft/README.md similarity index 74% rename from examples/melgan.stft/README.md rename to examples/melgan_stft/README.md index c1eadfb7..803ad73f 100755 --- a/examples/melgan.stft/README.md +++ b/examples/melgan_stft/README.md @@ -1,23 +1,23 @@ # MelGAN STFT: MelGAN With Multi Resolution STFT Loss -Based on the script [`train_melgan_stft.py`](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan.stft/train_melgan_stft.py). +Based on the script [`train_melgan_stft.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan_stft/train_melgan_stft.py). ## Training MelGAN STFT from scratch with LJSpeech dataset. This example code show you how to train MelGAN from scratch with Tensorflow 2 based on custom training loop and tf.function. The data used for this example is LJSpeech, you can download the dataset at [link](https://keithito.com/LJ-Speech-Dataset/). ### Step 1: Create Tensorflow based Dataloader (tf.dataset) -Please see detail at [examples/melgan/](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan#step-1-create-tensorflow-based-dataloader-tfdataset) +Please see detail at [examples/melgan/](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan#step-1-create-tensorflow-based-dataloader-tfdataset) ### Step 2: Training from scratch -After you re-define your dataloader, pls modify an input arguments, train_dataset and valid_dataset from [`train_melgan_stft.py`](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan.stft/train_melgan_stft.py). Here is an example command line to training melgan-stft from scratch: +After you re-define your dataloader, pls modify an input arguments, train_dataset and valid_dataset from [`train_melgan_stft.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan_stft/train_melgan_stft.py). Here is an example command line to training melgan-stft from scratch: First, you need training generator with only stft loss: ```bash -CUDA_VISIBLE_DEVICES=0 python examples/melgan/train_melgan.py \ +CUDA_VISIBLE_DEVICES=0 python examples/melgan_stft/train_melgan.py \ --train-dir ./dump/train/ \ --dev-dir ./dump/valid/ \ - --outdir ./examples/melgan.stft/exp/train.melgan.stft.v1/ \ - --config ./examples/melgan.stft/conf/melgan.stft.v1.yaml \ + --outdir ./examples/melgan_stft/exp/train.melgan_stft.v1/ \ + --config ./examples/melgan_stft/conf/melgan_stft.v1.yaml \ --use-norm 1 --generator_mixed_precision 1 \ --resume "" @@ -26,13 +26,13 @@ CUDA_VISIBLE_DEVICES=0 python examples/melgan/train_melgan.py \ Then resume and start training generator + discriminator: ```bash -CUDA_VISIBLE_DEVICES=0 python examples/melgan/train_melgan.py \ +CUDA_VISIBLE_DEVICES=0 python examples/melgan_stft/train_melgan.py \ --train-dir ./dump/train/ \ --dev-dir ./dump/valid/ \ - --outdir ./examples/melgan/exp/train.melgan.v1/ \ - --config ./examples/melgan/conf/melgan.v1.yaml \ + --outdir ./examples/melgan_stft/exp/train.melgan_stft.v1/ \ + --config ./examples/melgan_stft/conf/melgan_stft.v1.yaml \ --use-norm 1 - --resume ./examples/melgan.stft/exp/train.melgan.stft.v1/checkpoints/ckpt-100000 + --resume ./examples/melgan_stft/exp/train.melgan_stft.v1/checkpoints/ckpt-100000 ``` IF you want to use MultiGPU to training you can replace `CUDA_VISIBLE_DEVICES=0` by `CUDA_VISIBLE_DEVICES=0,1,2,3` for example. You also need to tune the `batch_size` for each GPU (in config file) by yourself to maximize the performance. Note that MultiGPU now support for Training but not yet support for Decode. @@ -40,7 +40,7 @@ IF you want to use MultiGPU to training you can replace `CUDA_VISIBLE_DEVICES=0` In case you want to resume the training progress, please following below example command line: ```bash ---resume ./examples/melgan.stft/exp/train.melgan.stft.v1/checkpoints/ckpt-100000 +--resume ./examples/melgan_stft/exp/train.melgan_stft.v1/checkpoints/ckpt-100000 ``` If you want to finetune a model, use `--pretrained` like this with the filename of the generator @@ -52,14 +52,14 @@ If you want to finetune a model, use `--pretrained` like this with the filename - When training generator only, we enable mixed precision to speed-up training progress. - We don't apply mixed precision when training both generator and discriminator. (Discriminator include group-convolution, which cause discriminator slower when enable mixed precision). -- 100k here is a *discriminator_train_start_steps* parameters from [melgan.stft.v1.yaml](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan.stft/conf/melgan.stft.v1.yaml) +- 100k here is a *discriminator_train_start_steps* parameters from [melgan_stft.v1.yaml](https://github.com/tensorspeech/TensorflowTTS/tree/master/examples/melgan_stft/conf/melgan_stft.v1.yaml) ## Finetune MelGAN STFT with ljspeech pretrained on other languages Just load pretrained model and training from scratch with other languages. **DO NOT FORGET** re-preprocessing on your dataset if needed. A hop_size should be 256 if you want to use our pretrained. ## Learning Curves -Here is a learning curves of melgan based on this config [`melgan.stft.v1.yaml`](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan.stft/conf/melgan.stft.v1.yaml) +Here is a learning curves of melgan based on this config [`melgan_stft.v1.yaml`](https://github.com/tensorspeech/TensorflowTTS/tree/master/examples/melgan_stft/conf/melgan_stft.v1.yaml) @@ -68,12 +68,12 @@ Here is a learning curves of melgan based on this config [`melgan.stft.v1.yaml`] ## Some important notes * We apply learning rate = 1e-3 when training generator only then apply lr = 1e-4 for both G and D. -* See [examples/melgan](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan#some-important-notes) for more notes. +* See [examples/melgan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan#some-important-notes) for more notes. ## Pretrained Models and Audio samples | Model | Conf | Lang | Fs [Hz] | Mel range [Hz] | FFT / Hop / Win [pt] | # iters | | :------ | :---: | :---: | :----: | :--------: | :---------------: | :-----: | -| [melgan.stft.v1](https://drive.google.com/drive/folders/1xUkDjbciupEkM3N4obiJAYySTo6J9z6b?usp=sharing) | [link](https://github.com/dathudeptrai/TensorflowTTS/tree/master/examples/melgan.stft/conf/melgan.stft.v1.yaml) | EN | 22.05k | 80-7600 | 1024 / 256 / None | 1900k | +| [melgan_stft.v1](https://drive.google.com/drive/folders/1xUkDjbciupEkM3N4obiJAYySTo6J9z6b?usp=sharing) | [link](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan_stft/conf/melgan_stft.v1.yaml) | EN | 22.05k | 80-7600 | 1024 / 256 / None | 1900k | ## Reference diff --git a/examples/melgan.stft/conf/melgan.stft.v1.yaml b/examples/melgan_stft/conf/melgan_stft.v1.yaml similarity index 100% rename from examples/melgan.stft/conf/melgan.stft.v1.yaml rename to examples/melgan_stft/conf/melgan_stft.v1.yaml diff --git a/examples/melgan.stft/fig/melgan.stft.v1.eval.png b/examples/melgan_stft/fig/melgan.stft.v1.eval.png similarity index 100% rename from examples/melgan.stft/fig/melgan.stft.v1.eval.png rename to examples/melgan_stft/fig/melgan.stft.v1.eval.png diff --git a/examples/melgan.stft/fig/melgan.stft.v1.train.png b/examples/melgan_stft/fig/melgan.stft.v1.train.png similarity index 100% rename from examples/melgan.stft/fig/melgan.stft.v1.train.png rename to examples/melgan_stft/fig/melgan.stft.v1.train.png diff --git a/examples/melgan.stft/train_melgan_stft.py b/examples/melgan_stft/train_melgan_stft.py similarity index 98% rename from examples/melgan.stft/train_melgan_stft.py rename to examples/melgan_stft/train_melgan_stft.py index 978eb78f..a2820f20 100755 --- a/examples/melgan.stft/train_melgan_stft.py +++ b/examples/melgan_stft/train_melgan_stft.py @@ -114,6 +114,12 @@ def compute_per_example_generator_losses(self, batch, outputs): sc_loss, mag_loss = calculate_2d_loss( audios, tf.squeeze(y_hat, -1), self.stft_loss ) + + # trick to prevent loss expoded here + sc_loss = tf.where(sc_loss >= 15.0, 0.0, sc_loss) + mag_loss = tf.where(mag_loss >= 15.0, 0.0, mag_loss) + + # compute generator loss gen_loss = 0.5 * (sc_loss + mag_loss) if self.steps >= self.config["discriminator_train_start_steps"]: diff --git a/tensorflow_tts/configs/__init__.py b/tensorflow_tts/configs/__init__.py index 36a33b2d..e2d4ba65 100755 --- a/tensorflow_tts/configs/__init__.py +++ b/tensorflow_tts/configs/__init__.py @@ -8,6 +8,10 @@ MultiBandMelGANDiscriminatorConfig, MultiBandMelGANGeneratorConfig, ) +from tensorflow_tts.configs.hifigan import ( + HifiGANGeneratorConfig, + HifiGANDiscriminatorConfig, +) from tensorflow_tts.configs.tacotron2 import Tacotron2Config from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANGeneratorConfig from tensorflow_tts.configs.parallel_wavegan import ParallelWaveGANDiscriminatorConfig diff --git a/tensorflow_tts/configs/hifigan.py b/tensorflow_tts/configs/hifigan.py new file mode 100644 index 00000000..965cd993 --- /dev/null +++ b/tensorflow_tts/configs/hifigan.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 TensorflowTTS Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""HifiGAN Config object.""" + + +class HifiGANGeneratorConfig(object): + """Initialize HifiGAN Generator Config.""" + + def __init__( + self, + out_channels=1, + kernel_size=7, + filters=128, + use_bias=True, + upsample_scales=[8, 8, 2, 2], + stacks=3, + stack_kernel_size=[3, 7, 11], + stack_dilation_rate=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + padding_type="REFLECT", + use_final_nolinear_activation=True, + is_weight_norm=True, + initializer_seed=42, + **kwargs + ): + """Init parameters for HifiGAN Generator model.""" + self.out_channels = out_channels + self.kernel_size = kernel_size + self.filters = filters + self.use_bias = use_bias + self.upsample_scales = upsample_scales + self.stacks = stacks + self.stack_kernel_size = stack_kernel_size + self.stack_dilation_rate = stack_dilation_rate + self.nonlinear_activation = nonlinear_activation + self.nonlinear_activation_params = nonlinear_activation_params + self.padding_type = padding_type + self.use_final_nolinear_activation = use_final_nolinear_activation + self.is_weight_norm = is_weight_norm + self.initializer_seed = initializer_seed + + +class HifiGANDiscriminatorConfig(object): + """Initialize HifiGAN Discriminator Config.""" + + def __init__( + self, + out_channels=1, + period_scales=[2, 3, 5, 7, 11], + n_layers=5, + kernel_size=5, + strides=3, + filters=8, + filter_scales=4, + max_filters=1024, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + is_weight_norm=True, + initializer_seed=42, + **kwargs + ): + """Init parameters for MelGAN Discriminator model.""" + self.out_channels = out_channels + self.period_scales = period_scales + self.n_layers = n_layers + self.kernel_size = kernel_size + self.strides = strides + self.filters = filters + self.filter_scales = filter_scales + self.max_filters = max_filters + self.nonlinear_activation = nonlinear_activation + self.nonlinear_activation_params = nonlinear_activation_params + self.is_weight_norm = is_weight_norm + self.initializer_seed = initializer_seed diff --git a/tensorflow_tts/inference/auto_config.py b/tensorflow_tts/inference/auto_config.py index 931c0066..220a942b 100644 --- a/tensorflow_tts/inference/auto_config.py +++ b/tensorflow_tts/inference/auto_config.py @@ -23,6 +23,7 @@ FastSpeech2Config, MelGANGeneratorConfig, MultiBandMelGANGeneratorConfig, + HifiGANGeneratorConfig, Tacotron2Config, ParallelWaveGANGeneratorConfig, ) @@ -33,6 +34,7 @@ ("fastspeech2", FastSpeech2Config), ("multiband_melgan_generator", MultiBandMelGANGeneratorConfig), ("melgan_generator", MelGANGeneratorConfig), + ("hifigan_generator", HifiGANGeneratorConfig), ("tacotron2", Tacotron2Config), ("parallel_wavegan_generator", ParallelWaveGANGeneratorConfig) ] diff --git a/tensorflow_tts/inference/auto_model.py b/tensorflow_tts/inference/auto_model.py index 497c2240..9eec91fe 100644 --- a/tensorflow_tts/inference/auto_model.py +++ b/tensorflow_tts/inference/auto_model.py @@ -23,6 +23,7 @@ FastSpeech2Config, MelGANGeneratorConfig, MultiBandMelGANGeneratorConfig, + HifiGANGeneratorConfig, Tacotron2Config, ParallelWaveGANGeneratorConfig, ) @@ -32,6 +33,7 @@ TFFastSpeech2, TFMelGANGenerator, TFMBMelGANGenerator, + TFHifiGANGenerator, TFTacotron2, TFParallelWaveGANGenerator, ) @@ -39,11 +41,12 @@ TF_MODEL_MAPPING = OrderedDict( [ - (FastSpeechConfig, TFFastSpeech), (FastSpeech2Config, TFFastSpeech2), + (FastSpeechConfig, TFFastSpeech), (MultiBandMelGANGeneratorConfig, TFMBMelGANGenerator), (MelGANGeneratorConfig, TFMelGANGenerator), (Tacotron2Config, TFTacotron2), + (HifiGANGeneratorConfig, TFHifiGANGenerator), (ParallelWaveGANGeneratorConfig, TFParallelWaveGANGenerator) ] ) diff --git a/tensorflow_tts/models/__init__.py b/tensorflow_tts/models/__init__.py index 6578969f..26d8f36c 100755 --- a/tensorflow_tts/models/__init__.py +++ b/tensorflow_tts/models/__init__.py @@ -7,6 +7,11 @@ ) from tensorflow_tts.models.mb_melgan import TFPQMF from tensorflow_tts.models.mb_melgan import TFMBMelGANGenerator +from tensorflow_tts.models.hifigan import ( + TFHifiGANGenerator, + TFHifiGANMultiPeriodDiscriminator, + TFHifiGANPeriodDiscriminator +) from tensorflow_tts.models.tacotron2 import TFTacotron2 from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANGenerator from tensorflow_tts.models.parallel_wavegan import TFParallelWaveGANDiscriminator diff --git a/tensorflow_tts/models/hifigan.py b/tensorflow_tts/models/hifigan.py new file mode 100644 index 00000000..155c508f --- /dev/null +++ b/tensorflow_tts/models/hifigan.py @@ -0,0 +1,356 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Hifigan Authors and TensorflowTTS Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hifi Modules.""" + +import numpy as np +import tensorflow as tf + +from tensorflow_tts.models.melgan import TFReflectionPad1d +from tensorflow_tts.models.melgan import TFConvTranspose1d + +from tensorflow_tts.utils import GroupConv1D +from tensorflow_tts.utils import WeightNormalization + +from tensorflow_tts.models import TFMelGANGenerator + + +class TFHifiResBlock(tf.keras.layers.Layer): + """Tensorflow Hifigan resblock 1 module.""" + + def __init__( + self, + kernel_size, + filters, + dilation_rate, + use_bias, + nonlinear_activation, + nonlinear_activation_params, + is_weight_norm, + initializer_seed, + **kwargs + ): + """Initialize TFHifiResBlock module. + Args: + kernel_size (int): Kernel size. + filters (int): Number of filters. + dilation_rate (list): List dilation rate. + use_bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + is_weight_norm (bool): Whether to use weight norm or not. + """ + super().__init__(**kwargs) + self.blocks_1 = [] + self.blocks_2 = [] + + for i in range(len(dilation_rate)): + self.blocks_1.append( + [ + TFReflectionPad1d((kernel_size - 1) // 2 * dilation_rate[i]), + tf.keras.layers.Conv1D( + filters=filters, + kernel_size=kernel_size, + dilation_rate=dilation_rate[i], + use_bias=use_bias, + ), + ] + ) + self.blocks_2.append( + [ + TFReflectionPad1d((kernel_size - 1) // 2 * 1), + tf.keras.layers.Conv1D( + filters=filters, + kernel_size=kernel_size, + dilation_rate=1, + use_bias=use_bias, + ), + ] + ) + + self.activation = getattr(tf.keras.layers, nonlinear_activation)( + **nonlinear_activation_params + ) + + # apply weightnorm + if is_weight_norm: + self._apply_weightnorm(self.blocks_1) + self._apply_weightnorm(self.blocks_2) + + def call(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, T, C). + Returns: + Tensor: Output tensor (B, T, C). + """ + for c1, c2 in zip(self.blocks_1, self.blocks_2): + xt = self.activation(x) + for c in c1: + xt = c(xt) + xt = self.activation(xt) + for c in c2: + xt = c(xt) + x = xt + x + return x + + def _apply_weightnorm(self, list_layers): + """Try apply weightnorm for all layer in list_layers.""" + for i in range(len(list_layers)): + try: + layer_name = list_layers[i].name.lower() + if "conv1d" in layer_name or "dense" in layer_name: + list_layers[i] = WeightNormalization(list_layers[i]) + except Exception: + pass + + +class TFHifiGANGenerator(tf.keras.Model): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + # check hyper parameter is valid or not + assert ( + config.stacks + == len(config.stack_kernel_size) + == len(config.stack_dilation_rate) + ) + + # add initial layer + layers = [] + layers += [ + TFReflectionPad1d( + (config.kernel_size - 1) // 2, + padding_type=config.padding_type, + name="first_reflect_padding", + ), + tf.keras.layers.Conv1D( + filters=config.filters, + kernel_size=config.kernel_size, + use_bias=config.use_bias, + ), + ] + + for i, upsample_scale in enumerate(config.upsample_scales): + # add upsampling layer + layers += [ + getattr(tf.keras.layers, config.nonlinear_activation)( + **config.nonlinear_activation_params + ), + TFConvTranspose1d( + filters=config.filters // (2 ** (i + 1)), + kernel_size=upsample_scale * 2, + strides=upsample_scale, + padding="same", + is_weight_norm=config.is_weight_norm, + initializer_seed=config.initializer_seed, + name="conv_transpose_._{}".format(i), + ), + ] + + # ad residual stack layer + for j in range(config.stacks): + layers += [ + TFHifiResBlock( + kernel_size=config.stack_kernel_size[j], + filters=config.filters // (2 ** (i + 1)), + dilation_rate=config.stack_dilation_rate[j], + use_bias=config.use_bias, + nonlinear_activation=config.nonlinear_activation, + nonlinear_activation_params=config.nonlinear_activation_params, + is_weight_norm=config.is_weight_norm, + initializer_seed=config.initializer_seed, + name="hifigan_resblock_._{}._._{}".format(i, j), + ) + ] + # add final layer + layers += [ + getattr(tf.keras.layers, config.nonlinear_activation)( + **config.nonlinear_activation_params + ), + TFReflectionPad1d( + (config.kernel_size - 1) // 2, + padding_type=config.padding_type, + name="last_reflect_padding", + ), + tf.keras.layers.Conv1D( + filters=config.out_channels, + kernel_size=config.kernel_size, + use_bias=config.use_bias, + dtype=tf.float32, + ), + ] + if config.use_final_nolinear_activation: + layers += [tf.keras.layers.Activation("tanh", dtype=tf.float32)] + + if config.is_weight_norm is True: + self._apply_weightnorm(layers) + + self.hifigan = tf.keras.models.Sequential(layers) + + def call(self, mels, **kwargs): + """Calculate forward propagation. + Args: + c (Tensor): Input tensor (B, T, channels) + Returns: + Tensor: Output tensor (B, T ** prod(upsample_scales), out_channels) + """ + return self.inference(mels) + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None, None, 80], dtype=tf.float32, name="mels") + ] + ) + def inference(self, mels): + return self.hifigan(mels) + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[1, None, 80], dtype=tf.float32, name="mels") + ] + ) + def inference_tflite(self, mels): + return self.hifigan(mels) + + def _apply_weightnorm(self, list_layers): + """Try apply weightnorm for all layer in list_layers.""" + for i in range(len(list_layers)): + try: + layer_name = list_layers[i].name.lower() + if "conv1d" in layer_name or "dense" in layer_name: + list_layers[i] = WeightNormalization(list_layers[i]) + except Exception: + pass + + def _build(self): + """Build model by passing fake input.""" + fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32) + self(fake_mels) + + +class TFHifiGANPeriodDiscriminator(tf.keras.layers.Layer): + """Tensorflow Hifigan period discriminator module.""" + + def __init__( + self, + period, + out_channels=1, + n_layers=5, + kernel_size=5, + strides=3, + filters=8, + filter_scales=4, + max_filters=1024, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + initializer_seed=42, + is_weight_norm=False, + **kwargs + ): + super().__init__(**kwargs) + self.period = period + self.out_filters = out_channels + self.convs = [] + + for i in range(n_layers): + self.convs.append( + tf.keras.layers.Conv2D( + filters=min(filters * (filter_scales ** (i + 1)), max_filters), + kernel_size=(kernel_size, 1), + strides=(strides, 1), + padding="same", + ) + ) + self.conv_post = tf.keras.layers.Conv2D( + filters=out_channels, kernel_size=(3, 1), padding="same", + ) + self.activation = getattr(tf.keras.layers, nonlinear_activation)( + **nonlinear_activation_params + ) + + if is_weight_norm: + self._apply_weightnorm(self.convs) + self.conv_post = WeightNormalization(self.conv_post) + + def call(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input noise signal (B, T, 1). + Returns: + List: List of output tensors. + """ + shape = tf.shape(x) + n_pad = tf.convert_to_tensor(0, dtype=tf.int32) + if shape[1] % self.period != 0: + n_pad = self.period - (shape[1] % self.period) + x = tf.pad(x, [[0, 0], [0, n_pad], [0, 0]], "REFLECT") + x = tf.reshape( + x, [shape[0], (shape[1] + n_pad) // self.period, self.period, x.shape[2]] + ) + for layer in self.convs: + x = layer(x) + x = self.activation(x) + x = self.conv_post(x) + x = tf.reshape(x, [shape[0], -1, self.out_filters]) + return [x] + + def _apply_weightnorm(self, list_layers): + """Try apply weightnorm for all layer in list_layers.""" + for i in range(len(list_layers)): + try: + layer_name = list_layers[i].name.lower() + if "conv1d" in layer_name or "dense" in layer_name: + list_layers[i] = WeightNormalization(list_layers[i]) + except Exception: + pass + + +class TFHifiGANMultiPeriodDiscriminator(tf.keras.Model): + """Tensorflow Hifigan Multi Period discriminator module.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.discriminator = [] + + # add discriminator + for i in range(len(config.period_scales)): + self.discriminator += [ + TFHifiGANPeriodDiscriminator( + config.period_scales[i], + out_channels=config.out_channels, + n_layers=config.n_layers, + kernel_size=config.kernel_size, + strides=config.strides, + filters=config.filters, + filter_scales=config.filter_scales, + max_filters=config.max_filters, + nonlinear_activation=config.nonlinear_activation, + nonlinear_activation_params=config.nonlinear_activation_params, + initializer_seed=config.initializer_seed, + is_weight_norm=config.is_weight_norm, + name="hifigan_period_discriminator_._{}".format(i), + ) + ] + + def call(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input noise signal (B, T, 1). + Returns: + List: list of each discriminator outputs + """ + outs = [] + for f in self.discriminator: + outs += [f(x)] + return outs diff --git a/test/test_auto.py b/test/test_auto.py index d39d8d8d..e964d9af 100644 --- a/test/test_auto.py +++ b/test/test_auto.py @@ -54,11 +54,13 @@ def test_auto_processor(mapper_path): "./examples/fastspeech2/conf/fastspeech2.kss.v1.yaml", "./examples/fastspeech2/conf/fastspeech2.kss.v2.yaml", "./examples/melgan/conf/melgan.v1.yaml", - "./examples/melgan.stft/conf/melgan.stft.v1.yaml", + "./examples/melgan_stft/conf/melgan_stft.v1.yaml", "./examples/multiband_melgan/conf/multiband_melgan.v1.yaml", "./examples/tacotron2/conf/tacotron2.v1.yaml", "./examples/tacotron2/conf/tacotron2.kss.v1.yaml", "./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml", + "./examples/hifigan/conf/hifigan.v1.yaml", + "./examples/hifigan/conf/hifigan.v2.yaml", ] ) def test_auto_model(config_path): diff --git a/test/test_hifigan.py b/test/test_hifigan.py new file mode 100644 index 00000000..933d3dea --- /dev/null +++ b/test/test_hifigan.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 TensorFlowTTS Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import pytest +import tensorflow as tf + +from tensorflow_tts.configs import ( + HifiGANDiscriminatorConfig, + HifiGANGeneratorConfig, + MelGANDiscriminatorConfig, +) +from tensorflow_tts.models import ( + TFHifiGANGenerator, + TFHifiGANMultiPeriodDiscriminator, + TFMelGANMultiScaleDiscriminator, +) + +from examples.hifigan.train_hifigan import TFHifiGANDiscriminator + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", +) + + +def make_hifigan_generator_args(**kwargs): + defaults = dict( + out_channels=1, + kernel_size=7, + filters=128, + use_bias=True, + upsample_scales=[8, 8, 2, 2], + stacks=3, + stack_kernel_size=[3, 7, 11], + stack_dilation_rate=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + padding_type="REFLECT", + use_final_nolinear_activation=True, + is_weight_norm=True, + initializer_seed=42, + ) + defaults.update(kwargs) + return defaults + + +def make_hifigan_discriminator_args(**kwargs): + defaults_multisperiod = dict( + out_channels=1, + period_scales=[2, 3, 5, 7, 11], + n_layers=5, + kernel_size=5, + strides=3, + filters=8, + filter_scales=4, + max_filters=1024, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + is_weight_norm=True, + initializer_seed=42, + ) + defaults_multisperiod.update(kwargs) + defaults_multiscale = dict( + out_channels=1, + scales=3, + downsample_pooling="AveragePooling1D", + downsample_pooling_params={"pool_size": 4, "strides": 2,}, + kernel_sizes=[5, 3], + filters=16, + max_downsample_filters=1024, + use_bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"alpha": 0.2}, + padding_type="REFLECT", + ) + defaults_multiscale.update(kwargs) + return [defaults_multisperiod, defaults_multiscale] + + +@pytest.mark.parametrize( + "dict_g, dict_d, dict_loss", + [ + ({}, {}, {}), + ({"kernel_size": 3}, {}, {}), + ({"filters": 1024}, {}, {}), + ({"stack_kernel_size": [1, 2, 3]}, {}, {}), + ({"stack_kernel_size": [3, 5, 7], "stacks": 3}, {}, {}), + ({"upsample_scales": [4, 4, 4, 4]}, {}, {}), + ({"upsample_scales": [8, 8, 2, 2]}, {}, {}), + ({"filters": 1024, "upsample_scales": [8, 8, 2, 2]}, {}, {}), + ], +) +def test_hifigan_trainable(dict_g, dict_d, dict_loss): + batch_size = 4 + batch_length = 4096 + args_g = make_hifigan_generator_args(**dict_g) + args_d_p, args_d_s = make_hifigan_discriminator_args(**dict_d) + + args_g = HifiGANGeneratorConfig(**args_g) + args_d_p = HifiGANDiscriminatorConfig(**args_d_p) + args_d_s = MelGANDiscriminatorConfig(**args_d_s) + + generator = TFHifiGANGenerator(args_g) + + discriminator_p = TFHifiGANMultiPeriodDiscriminator(args_d_p) + discriminator_s = TFMelGANMultiScaleDiscriminator(args_d_s) + discriminator = TFHifiGANDiscriminator(discriminator_p, discriminator_s)